Files
zclaw_openfang/crates/zclaw-saas/src/relay/service.rs
iven a99a3df9dd feat(saas): Phase 3 — 模型请求中转服务
- OpenAI 兼容 API 代理 (/api/v1/relay/chat/completions)
- 中转任务管理 (创建/查询/状态跟踪)
- 可用模型列表端点 (仅 enabled providers+models)
- 任务生命周期 (queued → processing → completed/failed)
- 用量自动记录 (token 统计 + 错误追踪)
- 3 个新集成测试覆盖中转端点
2026-03-27 12:58:02 +08:00

198 lines
7.7 KiB
Rust

//! 中转服务核心逻辑
use sqlx::SqlitePool;
use crate::error::{SaasError, SaasResult};
use super::types::*;
// ============ Relay Task Management ============
pub async fn create_relay_task(
db: &SqlitePool,
account_id: &str,
provider_id: &str,
model_id: &str,
request_body: &str,
priority: i64,
) -> SaasResult<RelayTaskInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339();
let request_hash = hash_request(request_body);
sqlx::query(
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, 3, ?8, ?8)"
)
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
.bind(&request_hash).bind(request_body).bind(priority).bind(&now)
.execute(db).await?;
get_relay_task(db, &id).await
}
pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayTaskInfo> {
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
sqlx::query_as(
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE id = ?1"
)
.bind(task_id)
.fetch_optional(db)
.await?;
let (id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at) =
row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
Ok(RelayTaskInfo {
id, account_id, provider_id, model_id, status, priority,
attempt_count, max_attempts, input_tokens, output_tokens,
error_message, queued_at, started_at, completed_at, created_at,
})
}
pub async fn list_relay_tasks(
db: &SqlitePool, account_id: &str, query: &RelayTaskQuery,
) -> SaasResult<Vec<RelayTaskInfo>> {
let page = query.page.unwrap_or(1).max(1);
let page_size = query.page_size.unwrap_or(20).min(100);
let offset = (page - 1) * page_size;
let sql = if query.status.is_some() {
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = ?1 AND status = ?2 ORDER BY created_at DESC LIMIT ?3 OFFSET ?4"
} else {
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = ?1 ORDER BY created_at DESC LIMIT ?2 OFFSET ?3"
};
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(sql)
.bind(account_id);
if let Some(ref status) = query.status {
query_builder = query_builder.bind(status);
}
query_builder = query_builder.bind(page_size).bind(offset);
let rows = query_builder.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
}).collect())
}
pub async fn update_task_status(
db: &SqlitePool, task_id: &str, status: &str,
input_tokens: Option<i64>, output_tokens: Option<i64>,
error_message: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
let update_sql = match status {
"processing" => "started_at = ?1, status = 'processing', attempt_count = attempt_count + 1",
"completed" => "completed_at = ?1, status = 'completed', input_tokens = COALESCE(?2, input_tokens), output_tokens = COALESCE(?3, output_tokens)",
"failed" => "completed_at = ?1, status = 'failed', error_message = ?2",
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
};
let sql = format!("UPDATE relay_tasks SET {} WHERE id = ?4", update_sql);
let mut query = sqlx::query(&sql).bind(&now);
if status == "completed" {
query = query.bind(input_tokens).bind(output_tokens);
}
if status == "failed" {
query = query.bind(error_message);
}
query = query.bind(task_id);
query.execute(db).await?;
Ok(())
}
// ============ Relay Execution ============
pub async fn execute_relay(
db: &SqlitePool,
task_id: &str,
provider_base_url: &str,
provider_api_key: Option<&str>,
request_body: &str,
stream: bool,
) -> SaasResult<RelayResponse> {
update_task_status(db, task_id, "processing", None, None, None).await?;
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
let _start = std::time::Instant::now();
let client = reqwest::Client::new();
let mut req_builder = client.post(&url)
.header("Content-Type", "application/json")
.body(request_body.to_string());
if let Some(key) = provider_api_key {
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
}
let result = req_builder.send().await;
match result {
Ok(resp) if resp.status().is_success() => {
if stream {
let body = resp.text().await.unwrap_or_default();
update_task_status(db, task_id, "completed", None, None, None).await?;
Ok(RelayResponse::Sse(body))
} else {
let body = resp.text().await.unwrap_or_default();
let (input_tokens, output_tokens) = extract_token_usage(&body);
update_task_status(db, task_id, "completed",
Some(input_tokens), Some(output_tokens), None).await?;
Ok(RelayResponse::Json(body))
}
}
Ok(resp) => {
let status = resp.status().as_u16();
let body = resp.text().await.unwrap_or_default();
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
Err(SaasError::Relay(err_msg))
}
Err(e) => {
let err_msg = format!("请求上游失败: {}", e);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
Err(SaasError::Relay(err_msg))
}
}
}
/// 中转响应类型
#[derive(Debug)]
pub enum RelayResponse {
Json(String),
Sse(String),
}
// ============ Helpers ============
fn hash_request(body: &str) -> String {
use sha2::{Sha256, Digest};
hex::encode(Sha256::digest(body.as_bytes()))
}
fn extract_token_usage(body: &str) -> (i64, i64) {
let parsed: serde_json::Value = match serde_json::from_str(body) {
Ok(v) => v,
Err(_) => return (0, 0),
};
let usage = parsed.get("usage");
let input = usage
.and_then(|u| u.get("prompt_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
let output = usage
.and_then(|u| u.get("completion_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
(input, output)
}