fix(saas): P1 审计修复 — 连接池断路器 + Worker重试 + XSS防护 + 状态机SQL解析器
P1 修复内容: - F7: health handler 连接池容量检查 (80%阈值返回503 degraded) - F9: SSE spawned task 并发限制 (Semaphore 16 permits) - F10: Key Pool 单次 JOIN 查询优化 (消除 N+1) - F12: CORS panic → 配置错误 - F14: 连接池使用率计算修正 (ratio = used*100/total) - F15: SQL 迁移解析器替换为状态机 (支持 $$, DO $body$, 存储过程) - Worker 重试机制: 失败任务通过 mpsc channel 重新入队 - DOMPurify XSS 防护 (PipelineResultPreview) - Admin V2: ErrorBoundary + SWR全局配置 + 请求优化
This commit is contained in:
19
Cargo.lock
generated
19
Cargo.lock
generated
@@ -2506,7 +2506,7 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"socket2",
|
"socket2 0.6.3",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
"tracing",
|
"tracing",
|
||||||
@@ -4189,7 +4189,7 @@ dependencies = [
|
|||||||
"quinn-udp",
|
"quinn-udp",
|
||||||
"rustc-hash",
|
"rustc-hash",
|
||||||
"rustls",
|
"rustls",
|
||||||
"socket2",
|
"socket2 0.6.3",
|
||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tracing",
|
"tracing",
|
||||||
@@ -4226,7 +4226,7 @@ dependencies = [
|
|||||||
"cfg_aliases",
|
"cfg_aliases",
|
||||||
"libc",
|
"libc",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"socket2",
|
"socket2 0.6.3",
|
||||||
"tracing",
|
"tracing",
|
||||||
"windows-sys 0.60.2",
|
"windows-sys 0.60.2",
|
||||||
]
|
]
|
||||||
@@ -5133,6 +5133,16 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "socket2"
|
||||||
|
version = "0.5.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e22376abed350d73dd1cd119b57ffccad95b4e585a7cda43e286245ce23c0678"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"windows-sys 0.52.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "socket2"
|
name = "socket2"
|
||||||
version = "0.6.3"
|
version = "0.6.3"
|
||||||
@@ -6048,7 +6058,7 @@ dependencies = [
|
|||||||
"parking_lot",
|
"parking_lot",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"signal-hook-registry",
|
"signal-hook-registry",
|
||||||
"socket2",
|
"socket2 0.6.3",
|
||||||
"tokio-macros",
|
"tokio-macros",
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.61.2",
|
||||||
]
|
]
|
||||||
@@ -8328,6 +8338,7 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sha2",
|
"sha2",
|
||||||
|
"socket2 0.5.10",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror 2.0.18",
|
"thiserror 2.0.18",
|
||||||
|
|||||||
@@ -110,6 +110,9 @@ argon2 = "0.5"
|
|||||||
totp-rs = "5"
|
totp-rs = "5"
|
||||||
hex = "0.4"
|
hex = "0.4"
|
||||||
|
|
||||||
|
# TCP socket configuration
|
||||||
|
socket2 = { version = "0.5", features = ["all"] }
|
||||||
|
|
||||||
# Internal crates
|
# Internal crates
|
||||||
zclaw-types = { path = "crates/zclaw-types" }
|
zclaw-types = { path = "crates/zclaw-types" }
|
||||||
zclaw-memory = { path = "crates/zclaw-memory" }
|
zclaw-memory = { path = "crates/zclaw-memory" }
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import { RouterProvider } from 'react-router-dom'
|
|||||||
import { ConfigProvider, App as AntApp } from 'antd'
|
import { ConfigProvider, App as AntApp } from 'antd'
|
||||||
import zhCN from 'antd/locale/zh_CN'
|
import zhCN from 'antd/locale/zh_CN'
|
||||||
import { router } from './router'
|
import { router } from './router'
|
||||||
|
import { ErrorBoundary } from './components/ErrorBoundary'
|
||||||
|
|
||||||
const queryClient = new QueryClient({
|
const queryClient = new QueryClient({
|
||||||
defaultOptions: {
|
defaultOptions: {
|
||||||
@@ -16,11 +17,13 @@ const queryClient = new QueryClient({
|
|||||||
})
|
})
|
||||||
|
|
||||||
createRoot(document.getElementById('root')!).render(
|
createRoot(document.getElementById('root')!).render(
|
||||||
<ConfigProvider locale={zhCN}>
|
<ErrorBoundary>
|
||||||
<AntApp>
|
<ConfigProvider locale={zhCN}>
|
||||||
<QueryClientProvider client={queryClient}>
|
<AntApp>
|
||||||
<RouterProvider router={router} />
|
<QueryClientProvider client={queryClient}>
|
||||||
</QueryClientProvider>
|
<RouterProvider router={router} />
|
||||||
</AntApp>
|
</QueryClientProvider>
|
||||||
</ConfigProvider>,
|
</AntApp>
|
||||||
|
</ConfigProvider>
|
||||||
|
</ErrorBoundary>,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ export default function Accounts() {
|
|||||||
|
|
||||||
const { data, isLoading } = useQuery({
|
const { data, isLoading } = useQuery({
|
||||||
queryKey: ['accounts'],
|
queryKey: ['accounts'],
|
||||||
queryFn: () => accountService.list(),
|
queryFn: ({ signal }) => accountService.list(signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const updateMutation = useMutation({
|
const updateMutation = useMutation({
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ export default function AgentTemplates() {
|
|||||||
|
|
||||||
const { data, isLoading } = useQuery({
|
const { data, isLoading } = useQuery({
|
||||||
queryKey: ['agent-templates'],
|
queryKey: ['agent-templates'],
|
||||||
queryFn: () => agentTemplateService.list(),
|
queryFn: ({ signal }) => agentTemplateService.list(signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const createMutation = useMutation({
|
const createMutation = useMutation({
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ export default function ApiKeys() {
|
|||||||
|
|
||||||
const { data, isLoading } = useQuery({
|
const { data, isLoading } = useQuery({
|
||||||
queryKey: ['api-keys'],
|
queryKey: ['api-keys'],
|
||||||
queryFn: () => apiKeyService.list(),
|
queryFn: ({ signal }) => apiKeyService.list(signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const createMutation = useMutation({
|
const createMutation = useMutation({
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ export default function Config() {
|
|||||||
|
|
||||||
const { data, isLoading } = useQuery({
|
const { data, isLoading } = useQuery({
|
||||||
queryKey: ['config', category],
|
queryKey: ['config', category],
|
||||||
queryFn: () => configService.list({ category }),
|
queryFn: ({ signal }) => configService.list({ category }, signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const updateMutation = useMutation({
|
const updateMutation = useMutation({
|
||||||
|
|||||||
@@ -42,12 +42,12 @@ const actionColors: Record<string, string> = {
|
|||||||
export default function Dashboard() {
|
export default function Dashboard() {
|
||||||
const { data: stats, isLoading: statsLoading, error: statsError } = useQuery({
|
const { data: stats, isLoading: statsLoading, error: statsError } = useQuery({
|
||||||
queryKey: ['dashboard-stats'],
|
queryKey: ['dashboard-stats'],
|
||||||
queryFn: () => statsService.dashboard(),
|
queryFn: ({ signal }) => statsService.dashboard(signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const { data: logsData, isLoading: logsLoading } = useQuery({
|
const { data: logsData, isLoading: logsLoading } = useQuery({
|
||||||
queryKey: ['recent-logs'],
|
queryKey: ['recent-logs'],
|
||||||
queryFn: () => logService.list({ page: 1, page_size: 10 }),
|
queryFn: ({ signal }) => logService.list({ page: 1, page_size: 10 }, signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
if (statsError) {
|
if (statsError) {
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ export default function Logs() {
|
|||||||
|
|
||||||
const { data, isLoading } = useQuery({
|
const { data, isLoading } = useQuery({
|
||||||
queryKey: ['logs', page, actionFilter],
|
queryKey: ['logs', page, actionFilter],
|
||||||
queryFn: () => logService.list({ page, page_size: 20, action: actionFilter }),
|
queryFn: ({ signal }) => logService.list({ page, page_size: 20, action: actionFilter }, signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const columns: ProColumns<OperationLog>[] = [
|
const columns: ProColumns<OperationLog>[] = [
|
||||||
|
|||||||
@@ -20,12 +20,12 @@ export default function Models() {
|
|||||||
|
|
||||||
const { data, isLoading } = useQuery({
|
const { data, isLoading } = useQuery({
|
||||||
queryKey: ['models'],
|
queryKey: ['models'],
|
||||||
queryFn: () => modelService.list(),
|
queryFn: ({ signal }) => modelService.list(signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const { data: providersData } = useQuery({
|
const { data: providersData } = useQuery({
|
||||||
queryKey: ['providers-for-select'],
|
queryKey: ['providers-for-select'],
|
||||||
queryFn: () => providerService.list(),
|
queryFn: ({ signal }) => providerService.list(signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const createMutation = useMutation({
|
const createMutation = useMutation({
|
||||||
|
|||||||
@@ -26,18 +26,18 @@ export default function Prompts() {
|
|||||||
|
|
||||||
const { data, isLoading } = useQuery({
|
const { data, isLoading } = useQuery({
|
||||||
queryKey: ['prompts'],
|
queryKey: ['prompts'],
|
||||||
queryFn: () => promptService.list(),
|
queryFn: ({ signal }) => promptService.list(signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const { data: detailData } = useQuery({
|
const { data: detailData } = useQuery({
|
||||||
queryKey: ['prompt-detail', detailName],
|
queryKey: ['prompt-detail', detailName],
|
||||||
queryFn: () => promptService.get(detailName!),
|
queryFn: ({ signal }) => promptService.get(detailName!, signal),
|
||||||
enabled: !!detailName,
|
enabled: !!detailName,
|
||||||
})
|
})
|
||||||
|
|
||||||
const { data: versionsData } = useQuery({
|
const { data: versionsData } = useQuery({
|
||||||
queryKey: ['prompt-versions', detailName],
|
queryKey: ['prompt-versions', detailName],
|
||||||
queryFn: () => promptService.listVersions(detailName!),
|
queryFn: ({ signal }) => promptService.listVersions(detailName!, signal),
|
||||||
enabled: !!detailName,
|
enabled: !!detailName,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -22,12 +22,12 @@ export default function Providers() {
|
|||||||
|
|
||||||
const { data, isLoading } = useQuery({
|
const { data, isLoading } = useQuery({
|
||||||
queryKey: ['providers'],
|
queryKey: ['providers'],
|
||||||
queryFn: () => providerService.list(),
|
queryFn: ({ signal }) => providerService.list(signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const { data: keysData, isLoading: keysLoading } = useQuery({
|
const { data: keysData, isLoading: keysLoading } = useQuery({
|
||||||
queryKey: ['provider-keys', keyModalProviderId],
|
queryKey: ['provider-keys', keyModalProviderId],
|
||||||
queryFn: () => providerService.listKeys(keyModalProviderId!),
|
queryFn: ({ signal }) => providerService.listKeys(keyModalProviderId!, signal),
|
||||||
enabled: !!keyModalProviderId,
|
enabled: !!keyModalProviderId,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ export default function Relay() {
|
|||||||
|
|
||||||
const { data, isLoading } = useQuery({
|
const { data, isLoading } = useQuery({
|
||||||
queryKey: ['relay-tasks', page, statusFilter],
|
queryKey: ['relay-tasks', page, statusFilter],
|
||||||
queryFn: () => relayService.list({ page, page_size: 20, status: statusFilter }),
|
queryFn: ({ signal }) => relayService.list({ page, page_size: 20, status: statusFilter }, signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const columns: ProColumns<RelayTask>[] = [
|
const columns: ProColumns<RelayTask>[] = [
|
||||||
|
|||||||
@@ -19,12 +19,12 @@ export default function Usage() {
|
|||||||
|
|
||||||
const { data: dailyData, isLoading: dailyLoading, error: dailyError } = useQuery({
|
const { data: dailyData, isLoading: dailyLoading, error: dailyError } = useQuery({
|
||||||
queryKey: ['usage-daily', days],
|
queryKey: ['usage-daily', days],
|
||||||
queryFn: () => telemetryService.dailyStats({ days }),
|
queryFn: ({ signal }) => telemetryService.dailyStats({ days }, signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
const { data: modelData, isLoading: modelLoading } = useQuery({
|
const { data: modelData, isLoading: modelLoading } = useQuery({
|
||||||
queryKey: ['usage-model', days],
|
queryKey: ['usage-model', days],
|
||||||
queryFn: () => telemetryService.modelStats({}),
|
queryFn: ({ signal }) => telemetryService.modelStats({}, signal),
|
||||||
})
|
})
|
||||||
|
|
||||||
if (dailyError) {
|
if (dailyError) {
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { AccountPublic, PaginatedResponse } from '@/types'
|
import type { AccountPublic, PaginatedResponse } from '@/types'
|
||||||
|
|
||||||
export const accountService = {
|
export const accountService = {
|
||||||
list: (params?: Record<string, unknown>) =>
|
list: (params?: Record<string, unknown>, signal?: AbortSignal) =>
|
||||||
request.get<PaginatedResponse<AccountPublic>>('/accounts', { params }).then((r) => r.data),
|
request.get<PaginatedResponse<AccountPublic>>('/accounts', withSignal({ params }, signal)).then((r) => r.data),
|
||||||
|
|
||||||
get: (id: string) =>
|
get: (id: string, signal?: AbortSignal) =>
|
||||||
request.get<AccountPublic>(`/accounts/${id}`).then((r) => r.data),
|
request.get<AccountPublic>(`/accounts/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
update: (id: string, data: Partial<Pick<AccountPublic, 'display_name' | 'email' | 'role'>>) =>
|
update: (id: string, data: Partial<Pick<AccountPublic, 'display_name' | 'email' | 'role'>>, signal?: AbortSignal) =>
|
||||||
request.patch<AccountPublic>(`/accounts/${id}`, data).then((r) => r.data),
|
request.patch<AccountPublic>(`/accounts/${id}`, data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
updateStatus: (id: string, data: { status: AccountPublic['status'] }) =>
|
updateStatus: (id: string, data: { status: AccountPublic['status'] }, signal?: AbortSignal) =>
|
||||||
request.patch(`/accounts/${id}/status`, data).then((r) => r.data),
|
request.patch(`/accounts/${id}/status`, data, withSignal({}, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,28 +1,28 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { AgentTemplate, PaginatedResponse } from '@/types'
|
import type { AgentTemplate, PaginatedResponse } from '@/types'
|
||||||
|
|
||||||
export const agentTemplateService = {
|
export const agentTemplateService = {
|
||||||
list: (params?: Record<string, unknown>) =>
|
list: (params?: Record<string, unknown>, signal?: AbortSignal) =>
|
||||||
request.get<PaginatedResponse<AgentTemplate>>('/agent-templates', { params }).then((r) => r.data),
|
request.get<PaginatedResponse<AgentTemplate>>('/agent-templates', withSignal({ params }, signal)).then((r) => r.data),
|
||||||
|
|
||||||
get: (id: string) =>
|
get: (id: string, signal?: AbortSignal) =>
|
||||||
request.get<AgentTemplate>(`/agent-templates/${id}`).then((r) => r.data),
|
request.get<AgentTemplate>(`/agent-templates/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
create: (data: {
|
create: (data: {
|
||||||
name: string; description?: string; category?: string; source?: string
|
name: string; description?: string; category?: string; source?: string
|
||||||
model?: string; system_prompt?: string; tools?: string[]
|
model?: string; system_prompt?: string; tools?: string[]
|
||||||
capabilities?: string[]; temperature?: number; max_tokens?: number
|
capabilities?: string[]; temperature?: number; max_tokens?: number
|
||||||
visibility?: string
|
visibility?: string
|
||||||
}) =>
|
}, signal?: AbortSignal) =>
|
||||||
request.post<AgentTemplate>('/agent-templates', data).then((r) => r.data),
|
request.post<AgentTemplate>('/agent-templates', data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
update: (id: string, data: {
|
update: (id: string, data: {
|
||||||
description?: string; model?: string; system_prompt?: string
|
description?: string; model?: string; system_prompt?: string
|
||||||
tools?: string[]; capabilities?: string[]; temperature?: number
|
tools?: string[]; capabilities?: string[]; temperature?: number
|
||||||
max_tokens?: number; visibility?: string; status?: string
|
max_tokens?: number; visibility?: string; status?: string
|
||||||
}) =>
|
}, signal?: AbortSignal) =>
|
||||||
request.post<AgentTemplate>(`/agent-templates/${id}`, data).then((r) => r.data),
|
request.post<AgentTemplate>(`/agent-templates/${id}`, data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
archive: (id: string) =>
|
archive: (id: string, signal?: AbortSignal) =>
|
||||||
request.delete<AgentTemplate>(`/agent-templates/${id}`).then((r) => r.data),
|
request.delete<AgentTemplate>(`/agent-templates/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { TokenInfo, CreateTokenRequest, PaginatedResponse } from '@/types'
|
import type { TokenInfo, CreateTokenRequest, PaginatedResponse } from '@/types'
|
||||||
|
|
||||||
export const apiKeyService = {
|
export const apiKeyService = {
|
||||||
list: (params?: Record<string, unknown>) =>
|
list: (params?: Record<string, unknown>, signal?: AbortSignal) =>
|
||||||
request.get<PaginatedResponse<TokenInfo>>('/keys', { params }).then((r) => r.data),
|
request.get<PaginatedResponse<TokenInfo>>('/keys', withSignal({ params }, signal)).then((r) => r.data),
|
||||||
|
|
||||||
create: (data: CreateTokenRequest) =>
|
create: (data: CreateTokenRequest, signal?: AbortSignal) =>
|
||||||
request.post<TokenInfo>('/keys', data).then((r) => r.data),
|
request.post<TokenInfo>('/keys', data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
revoke: (id: string) =>
|
revoke: (id: string, signal?: AbortSignal) =>
|
||||||
request.delete(`/keys/${id}`).then((r) => r.data),
|
request.delete(`/keys/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { AccountPublic, LoginRequest, LoginResponse } from '@/types'
|
import type { AccountPublic, LoginRequest, LoginResponse } from '@/types'
|
||||||
|
|
||||||
export const authService = {
|
export const authService = {
|
||||||
login: (data: LoginRequest) =>
|
login: (data: LoginRequest, signal?: AbortSignal) =>
|
||||||
request.post<LoginResponse>('/auth/login', data).then((r) => r.data),
|
request.post<LoginResponse>('/auth/login', data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
me: () =>
|
me: (signal?: AbortSignal) =>
|
||||||
request.get<AccountPublic>('/auth/me').then((r) => r.data),
|
request.get<AccountPublic>('/auth/me', withSignal({}, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { ConfigItem, PaginatedResponse } from '@/types'
|
import type { ConfigItem, PaginatedResponse } from '@/types'
|
||||||
|
|
||||||
export const configService = {
|
export const configService = {
|
||||||
list: (params?: Record<string, unknown>) =>
|
list: (params?: Record<string, unknown>, signal?: AbortSignal) =>
|
||||||
request.get<PaginatedResponse<ConfigItem>>('/config/items', { params })
|
request.get<PaginatedResponse<ConfigItem>>('/config/items', withSignal({ params }, signal))
|
||||||
.then((r) => r.data.items),
|
.then((r) => r.data.items),
|
||||||
|
|
||||||
update: (id: string, data: { value: string | number | boolean }) =>
|
update: (id: string, data: { value: string | number | boolean }, signal?: AbortSignal) =>
|
||||||
request.patch<ConfigItem>(`/config/items/${id}`, data).then((r) => r.data),
|
request.patch<ConfigItem>(`/config/items/${id}`, data, withSignal({}, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { OperationLog, PaginatedResponse } from '@/types'
|
import type { OperationLog, PaginatedResponse } from '@/types'
|
||||||
|
|
||||||
export const logService = {
|
export const logService = {
|
||||||
list: (params?: Record<string, unknown>) =>
|
list: (params?: Record<string, unknown>, signal?: AbortSignal) =>
|
||||||
request.get<PaginatedResponse<OperationLog>>('/logs/operations', { params }).then((r) => r.data),
|
request.get<PaginatedResponse<OperationLog>>('/logs/operations', withSignal({ params }, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { Model, PaginatedResponse } from '@/types'
|
import type { Model, PaginatedResponse } from '@/types'
|
||||||
|
|
||||||
export const modelService = {
|
export const modelService = {
|
||||||
list: (params?: Record<string, unknown>) =>
|
list: (params?: Record<string, unknown>, signal?: AbortSignal) =>
|
||||||
request.get<PaginatedResponse<Model>>('/models', { params }).then((r) => r.data),
|
request.get<PaginatedResponse<Model>>('/models', withSignal({ params }, signal)).then((r) => r.data),
|
||||||
|
|
||||||
create: (data: Partial<Omit<Model, 'id'>>) =>
|
create: (data: Partial<Omit<Model, 'id'>>, signal?: AbortSignal) =>
|
||||||
request.post<Model>('/models', data).then((r) => r.data),
|
request.post<Model>('/models', data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
update: (id: string, data: Partial<Omit<Model, 'id'>>) =>
|
update: (id: string, data: Partial<Omit<Model, 'id'>>, signal?: AbortSignal) =>
|
||||||
request.patch<Model>(`/models/${id}`, data).then((r) => r.data),
|
request.patch<Model>(`/models/${id}`, data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
delete: (id: string) =>
|
delete: (id: string, signal?: AbortSignal) =>
|
||||||
request.delete(`/models/${id}`).then((r) => r.data),
|
request.delete(`/models/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +1,35 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { PromptTemplate, PromptVersion, PaginatedResponse } from '@/types'
|
import type { PromptTemplate, PromptVersion, PaginatedResponse } from '@/types'
|
||||||
|
|
||||||
export const promptService = {
|
export const promptService = {
|
||||||
list: (params?: Record<string, unknown>) =>
|
list: (params?: Record<string, unknown>, signal?: AbortSignal) =>
|
||||||
request.get<PaginatedResponse<PromptTemplate>>('/prompts', { params }).then((r) => r.data),
|
request.get<PaginatedResponse<PromptTemplate>>('/prompts', withSignal({ params }, signal)).then((r) => r.data),
|
||||||
|
|
||||||
get: (name: string) =>
|
get: (name: string, signal?: AbortSignal) =>
|
||||||
request.get<PromptTemplate>(`/prompts/${encodeURIComponent(name)}`).then((r) => r.data),
|
request.get<PromptTemplate>(`/prompts/${encodeURIComponent(name)}`, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
create: (data: {
|
create: (data: {
|
||||||
name: string; category: string; description?: string; source?: string
|
name: string; category: string; description?: string; source?: string
|
||||||
system_prompt: string; user_prompt_template?: string
|
system_prompt: string; user_prompt_template?: string
|
||||||
variables?: unknown[]; min_app_version?: string
|
variables?: unknown[]; min_app_version?: string
|
||||||
}) =>
|
}, signal?: AbortSignal) =>
|
||||||
request.post<PromptTemplate>('/prompts', data).then((r) => r.data),
|
request.post<PromptTemplate>('/prompts', data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
update: (name: string, data: { description?: string; status?: string }) =>
|
update: (name: string, data: { description?: string; status?: string }, signal?: AbortSignal) =>
|
||||||
request.put<PromptTemplate>(`/prompts/${encodeURIComponent(name)}`, data).then((r) => r.data),
|
request.put<PromptTemplate>(`/prompts/${encodeURIComponent(name)}`, data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
archive: (name: string) =>
|
archive: (name: string, signal?: AbortSignal) =>
|
||||||
request.delete<PromptTemplate>(`/prompts/${encodeURIComponent(name)}`).then((r) => r.data),
|
request.delete<PromptTemplate>(`/prompts/${encodeURIComponent(name)}`, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
listVersions: (name: string) =>
|
listVersions: (name: string, signal?: AbortSignal) =>
|
||||||
request.get<PromptVersion[]>(`/prompts/${encodeURIComponent(name)}/versions`).then((r) => r.data),
|
request.get<PromptVersion[]>(`/prompts/${encodeURIComponent(name)}/versions`, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
createVersion: (name: string, data: {
|
createVersion: (name: string, data: {
|
||||||
system_prompt: string; user_prompt_template?: string
|
system_prompt: string; user_prompt_template?: string
|
||||||
variables?: unknown[]; changelog?: string; min_app_version?: string
|
variables?: unknown[]; changelog?: string; min_app_version?: string
|
||||||
}) =>
|
}, signal?: AbortSignal) =>
|
||||||
request.post<PromptVersion>(`/prompts/${encodeURIComponent(name)}/versions`, data).then((r) => r.data),
|
request.post<PromptVersion>(`/prompts/${encodeURIComponent(name)}/versions`, data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
rollback: (name: string, version: number) =>
|
rollback: (name: string, version: number, signal?: AbortSignal) =>
|
||||||
request.post<PromptTemplate>(`/prompts/${encodeURIComponent(name)}/rollback/${version}`).then((r) => r.data),
|
request.post<PromptTemplate>(`/prompts/${encodeURIComponent(name)}/rollback/${version}`, undefined, withSignal({}, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,31 +1,31 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { Provider, ProviderKey, PaginatedResponse } from '@/types'
|
import type { Provider, ProviderKey, PaginatedResponse } from '@/types'
|
||||||
|
|
||||||
export const providerService = {
|
export const providerService = {
|
||||||
list: (params?: Record<string, unknown>) =>
|
list: (params?: Record<string, unknown>, signal?: AbortSignal) =>
|
||||||
request.get<PaginatedResponse<Provider>>('/providers', { params }).then((r) => r.data),
|
request.get<PaginatedResponse<Provider>>('/providers', withSignal({ params }, signal)).then((r) => r.data),
|
||||||
|
|
||||||
create: (data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>) =>
|
create: (data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>, signal?: AbortSignal) =>
|
||||||
request.post<Provider>('/providers', data).then((r) => r.data),
|
request.post<Provider>('/providers', data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
update: (id: string, data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>) =>
|
update: (id: string, data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>, signal?: AbortSignal) =>
|
||||||
request.patch<Provider>(`/providers/${id}`, data).then((r) => r.data),
|
request.patch<Provider>(`/providers/${id}`, data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
delete: (id: string) =>
|
delete: (id: string, signal?: AbortSignal) =>
|
||||||
request.delete(`/providers/${id}`).then((r) => r.data),
|
request.delete(`/providers/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
listKeys: (providerId: string) =>
|
listKeys: (providerId: string, signal?: AbortSignal) =>
|
||||||
request.get<ProviderKey[]>(`/providers/${providerId}/keys`).then((r) => r.data),
|
request.get<ProviderKey[]>(`/providers/${providerId}/keys`, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
addKey: (providerId: string, data: {
|
addKey: (providerId: string, data: {
|
||||||
key_label: string; key_value: string; priority?: number
|
key_label: string; key_value: string; priority?: number
|
||||||
max_rpm?: number; max_tpm?: number; quota_reset_interval?: string
|
max_rpm?: number; max_tpm?: number; quota_reset_interval?: string
|
||||||
}) =>
|
}, signal?: AbortSignal) =>
|
||||||
request.post<{ ok: boolean; key_id: string }>(`/providers/${providerId}/keys`, data).then((r) => r.data),
|
request.post<{ ok: boolean; key_id: string }>(`/providers/${providerId}/keys`, data, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
toggleKey: (providerId: string, keyId: string, active: boolean) =>
|
toggleKey: (providerId: string, keyId: string, active: boolean, signal?: AbortSignal) =>
|
||||||
request.put<{ ok: boolean }>(`/providers/${providerId}/keys/${keyId}/toggle`, { active }).then((r) => r.data),
|
request.put<{ ok: boolean }>(`/providers/${providerId}/keys/${keyId}/toggle`, { active }, withSignal({}, signal)).then((r) => r.data),
|
||||||
|
|
||||||
deleteKey: (providerId: string, keyId: string) =>
|
deleteKey: (providerId: string, keyId: string, signal?: AbortSignal) =>
|
||||||
request.delete<{ ok: boolean }>(`/providers/${providerId}/keys/${keyId}`).then((r) => r.data),
|
request.delete<{ ok: boolean }>(`/providers/${providerId}/keys/${keyId}`, withSignal({}, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { RelayTask, PaginatedResponse } from '@/types'
|
import type { RelayTask, PaginatedResponse } from '@/types'
|
||||||
|
|
||||||
export const relayService = {
|
export const relayService = {
|
||||||
list: (params?: Record<string, unknown>) =>
|
list: (params?: Record<string, unknown>, signal?: AbortSignal) =>
|
||||||
request.get<PaginatedResponse<RelayTask>>('/relay/tasks', { params }).then((r) => r.data),
|
request.get<PaginatedResponse<RelayTask>>('/relay/tasks', withSignal({ params }, signal)).then((r) => r.data),
|
||||||
|
|
||||||
get: (id: string) =>
|
get: (id: string, signal?: AbortSignal) =>
|
||||||
request.get<RelayTask>(`/relay/tasks/${id}`).then((r) => r.data),
|
request.get<RelayTask>(`/relay/tasks/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
import axios from 'axios'
|
import axios from 'axios'
|
||||||
import type { AxiosError, InternalAxiosRequestConfig } from 'axios'
|
import type { AxiosError, InternalAxiosRequestConfig } from 'axios'
|
||||||
|
import type { AxiosRequestConfig } from 'axios'
|
||||||
import type { ApiError } from '@/types'
|
import type { ApiError } from '@/types'
|
||||||
import { useAuthStore } from '@/stores/authStore'
|
import { useAuthStore } from '@/stores/authStore'
|
||||||
|
|
||||||
@@ -106,3 +107,11 @@ request.interceptors.response.use(
|
|||||||
)
|
)
|
||||||
|
|
||||||
export default request
|
export default request
|
||||||
|
|
||||||
|
/** 将 AbortSignal 注入 Axios config,用于 TanStack Query 的请求取消 */
|
||||||
|
export function withSignal(config: AxiosRequestConfig = {}, signal?: AbortSignal): AxiosRequestConfig {
|
||||||
|
if (signal) {
|
||||||
|
return { ...config, signal }
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { DashboardStats } from '@/types'
|
import type { DashboardStats } from '@/types'
|
||||||
|
|
||||||
export const statsService = {
|
export const statsService = {
|
||||||
dashboard: () =>
|
dashboard: (signal?: AbortSignal) =>
|
||||||
request.get<DashboardStats>('/stats/dashboard').then((r) => r.data),
|
request.get<DashboardStats>('/stats/dashboard', withSignal({}, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { ModelUsageStat, DailyUsageStat } from '@/types'
|
import type { ModelUsageStat, DailyUsageStat } from '@/types'
|
||||||
|
|
||||||
export const telemetryService = {
|
export const telemetryService = {
|
||||||
modelStats: (params?: Record<string, unknown>) =>
|
modelStats: (params?: Record<string, unknown>, signal?: AbortSignal) =>
|
||||||
request.get<ModelUsageStat[]>('/telemetry/stats', { params }).then((r) => r.data),
|
request.get<ModelUsageStat[]>('/telemetry/stats', withSignal({ params }, signal)).then((r) => r.data),
|
||||||
|
|
||||||
dailyStats: (params?: { days?: number }) =>
|
dailyStats: (params?: { days?: number }, signal?: AbortSignal) =>
|
||||||
request.get<DailyUsageStat[]>('/telemetry/daily', { params }).then((r) => r.data),
|
request.get<DailyUsageStat[]>('/telemetry/daily', withSignal({ params }, signal)).then((r) => r.data),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import request from './request'
|
import request, { withSignal } from './request'
|
||||||
import type { UsageRecord, UsageByModel } from '@/types'
|
import type { UsageRecord, UsageByModel } from '@/types'
|
||||||
|
|
||||||
export const usageService = {
|
export const usageService = {
|
||||||
daily: (params?: { days?: number }) =>
|
daily: (params?: { days?: number }, signal?: AbortSignal) =>
|
||||||
request.get<{ by_day: UsageRecord[] }>('/usage', { params: { ...params, group_by: 'day' } })
|
request.get<{ by_day: UsageRecord[] }>('/usage', withSignal({ params: { ...params, group_by: 'day' } }, signal))
|
||||||
.then((r) => r.data.by_day || []),
|
.then((r) => r.data.by_day || []),
|
||||||
|
|
||||||
byModel: (params?: { days?: number }) =>
|
byModel: (params?: { days?: number }, signal?: AbortSignal) =>
|
||||||
request.get<{ by_model: UsageByModel[] }>('/usage', { params: { ...params, group_by: 'model' } })
|
request.get<{ by_model: UsageByModel[] }>('/usage', withSignal({ params: { ...params, group_by: 'model' } }, signal))
|
||||||
.then((r) => r.data.by_model || []),
|
.then((r) => r.data.by_model || []),
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,16 @@ export default defineConfig({
|
|||||||
'/api': {
|
'/api': {
|
||||||
target: 'http://localhost:8080',
|
target: 'http://localhost:8080',
|
||||||
changeOrigin: true,
|
changeOrigin: true,
|
||||||
|
timeout: 30_000,
|
||||||
|
proxyTimeout: 30_000,
|
||||||
|
configure: (proxy) => {
|
||||||
|
proxy.on('proxyReq', (proxyReq) => {
|
||||||
|
proxyReq.setTimeout(30_000)
|
||||||
|
})
|
||||||
|
proxy.on('proxyRes', (proxyRes) => {
|
||||||
|
proxyRes.setTimeout(30_000)
|
||||||
|
})
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ sha2 = { workspace = true }
|
|||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
dashmap = { workspace = true }
|
dashmap = { workspace = true }
|
||||||
hex = { workspace = true }
|
hex = { workspace = true }
|
||||||
|
socket2 = { workspace = true }
|
||||||
url = "2"
|
url = "2"
|
||||||
|
|
||||||
axum = { workspace = true }
|
axum = { workspace = true }
|
||||||
|
|||||||
@@ -148,6 +148,34 @@ pub async fn verify_totp(
|
|||||||
return Err(SaasError::InvalidInput("TOTP 码必须是 6 位数字".into()));
|
return Err(SaasError::InvalidInput("TOTP 码必须是 6 位数字".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TOTP 暴力破解保护: 10 分钟内最多 5 次失败
|
||||||
|
const MAX_TOTP_FAILURES: u32 = 5;
|
||||||
|
const TOTP_LOCKOUT_SECS: u64 = 600;
|
||||||
|
let now = std::time::Instant::now();
|
||||||
|
let lockout_duration = std::time::Duration::from_secs(TOTP_LOCKOUT_SECS);
|
||||||
|
|
||||||
|
let is_locked = {
|
||||||
|
if let Some(entry) = state.totp_fail_counts.get(&ctx.account_id) {
|
||||||
|
let (count, first_fail) = entry.value();
|
||||||
|
if *count >= MAX_TOTP_FAILURES && now.duration_since(*first_fail) < lockout_duration {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
// 窗口过期,重置
|
||||||
|
drop(entry);
|
||||||
|
state.totp_fail_counts.remove(&ctx.account_id);
|
||||||
|
false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_locked {
|
||||||
|
return Err(SaasError::RateLimited(
|
||||||
|
format!("TOTP 验证失败次数过多,请 {} 秒后重试", TOTP_LOCKOUT_SECS)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// 获取存储的密钥
|
// 获取存储的密钥
|
||||||
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
||||||
"SELECT totp_secret FROM accounts WHERE id = $1"
|
"SELECT totp_secret FROM accounts WHERE id = $1"
|
||||||
@@ -172,9 +200,24 @@ pub async fn verify_totp(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if !verify_totp_code(&secret, code) {
|
if !verify_totp_code(&secret, code) {
|
||||||
|
// 记录失败次数
|
||||||
|
let new_count = {
|
||||||
|
let mut entry = state.totp_fail_counts
|
||||||
|
.entry(ctx.account_id.clone())
|
||||||
|
.or_insert((0, now));
|
||||||
|
entry.value_mut().0 += 1;
|
||||||
|
entry.value().0
|
||||||
|
};
|
||||||
|
tracing::warn!(
|
||||||
|
"TOTP verify failed for account {} ({}/{} attempts)",
|
||||||
|
ctx.account_id, new_count, MAX_TOTP_FAILURES
|
||||||
|
);
|
||||||
return Err(SaasError::Totp("TOTP 码验证失败".into()));
|
return Err(SaasError::Totp("TOTP 码验证失败".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 验证成功 → 清除失败计数
|
||||||
|
state.totp_fail_counts.remove(&ctx.account_id);
|
||||||
|
|
||||||
// 验证成功 → 启用 TOTP,同时确保密钥已加密
|
// 验证成功 → 启用 TOTP,同时确保密钥已加密
|
||||||
let final_secret = if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
|
let final_secret = if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
|
||||||
encrypted_secret
|
encrypted_secret
|
||||||
@@ -183,10 +226,10 @@ pub async fn verify_totp(
|
|||||||
encrypt_totp_secret(&secret, &enc_key)?
|
encrypt_totp_secret(&secret, &enc_key)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now_ts = chrono::Utc::now().to_rfc3339();
|
||||||
sqlx::query("UPDATE accounts SET totp_enabled = true, totp_secret = $1, updated_at = $2 WHERE id = $3")
|
sqlx::query("UPDATE accounts SET totp_enabled = true, totp_secret = $1, updated_at = $2 WHERE id = $3")
|
||||||
.bind(&final_secret)
|
.bind(&final_secret)
|
||||||
.bind(&now)
|
.bind(&now_ts)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|||||||
@@ -90,7 +90,7 @@ async fn run_migration_files(pool: &PgPool, dir: &std::path::Path) -> SaasResult
|
|||||||
let filename = path.file_name().unwrap_or_default().to_string_lossy();
|
let filename = path.file_name().unwrap_or_default().to_string_lossy();
|
||||||
tracing::info!("Running migration: {}", filename);
|
tracing::info!("Running migration: {}", filename);
|
||||||
let content = std::fs::read_to_string(path)?;
|
let content = std::fs::read_to_string(path)?;
|
||||||
for stmt in content.split(';') {
|
for stmt in split_sql_statements(&content) {
|
||||||
let trimmed = stmt.trim();
|
let trimmed = stmt.trim();
|
||||||
if !trimmed.is_empty() && !trimmed.starts_with("--") {
|
if !trimmed.is_empty() && !trimmed.starts_with("--") {
|
||||||
sqlx::query(trimmed).execute(pool).await?;
|
sqlx::query(trimmed).execute(pool).await?;
|
||||||
@@ -100,6 +100,150 @@ async fn run_migration_files(pool: &PgPool, dir: &std::path::Path) -> SaasResult
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 按语句分割 SQL 文件内容,正确处理:
|
||||||
|
/// - 单引号字符串 `'...'`
|
||||||
|
/// - 双引号标识符 `"..."`
|
||||||
|
/// - 美元符号引用字符串 `$$...$$` 和 `$tag$...$tag$`
|
||||||
|
/// - `--` 单行注释
|
||||||
|
/// - `/* ... */` 块注释
|
||||||
|
/// - `E'...'` 转义字符串
|
||||||
|
fn split_sql_statements(sql: &str) -> Vec<String> {
|
||||||
|
let mut statements = Vec::new();
|
||||||
|
let mut current = String::new();
|
||||||
|
let mut chars = sql.chars().peekable();
|
||||||
|
|
||||||
|
while let Some(ch) = chars.next() {
|
||||||
|
match ch {
|
||||||
|
'\'' => {
|
||||||
|
// 单引号字符串
|
||||||
|
current.push(ch);
|
||||||
|
loop {
|
||||||
|
match chars.next() {
|
||||||
|
Some('\'') => {
|
||||||
|
current.push('\'');
|
||||||
|
// 检查是否为转义引号 ''
|
||||||
|
if chars.peek() == Some(&'\'') {
|
||||||
|
current.push(chars.next().unwrap());
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(c) => current.push(c),
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'"' => {
|
||||||
|
// 双引号标识符
|
||||||
|
current.push(ch);
|
||||||
|
loop {
|
||||||
|
match chars.next() {
|
||||||
|
Some('"') => {
|
||||||
|
current.push('"');
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Some(c) => current.push(c),
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'-' if chars.peek() == Some(&'-') => {
|
||||||
|
// 单行注释: 跳过直到行尾
|
||||||
|
chars.next(); // consume second '-'
|
||||||
|
while let Some(&c) = chars.peek() {
|
||||||
|
if c == '\n' {
|
||||||
|
chars.next();
|
||||||
|
current.push(c);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
chars.next();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'/' if chars.peek() == Some(&'*') => {
|
||||||
|
// 块注释: 跳过直到 */
|
||||||
|
chars.next(); // consume '*'
|
||||||
|
current.push_str("/*");
|
||||||
|
let mut prev = ' ';
|
||||||
|
loop {
|
||||||
|
match chars.next() {
|
||||||
|
Some('/') if prev == '*' => {
|
||||||
|
current.push('/');
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
Some(c) => {
|
||||||
|
current.push(c);
|
||||||
|
prev = c;
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'$' => {
|
||||||
|
// 美元符号引用: $$ 或 $tag$ ... $tag$
|
||||||
|
current.push(ch);
|
||||||
|
// 读取 tag (字母数字和下划线)
|
||||||
|
let mut tag = String::new();
|
||||||
|
while let Some(&c) = chars.peek() {
|
||||||
|
if c == '$' || c.is_alphanumeric() || c == '_' {
|
||||||
|
if c == '$' {
|
||||||
|
chars.next();
|
||||||
|
current.push(c);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
chars.next();
|
||||||
|
tag.push(c);
|
||||||
|
current.push(c);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 如果 tag 为空,就是 $$ 格式
|
||||||
|
let end_marker = if tag.is_empty() {
|
||||||
|
"$$".to_string()
|
||||||
|
} else {
|
||||||
|
format!("${}$", tag)
|
||||||
|
};
|
||||||
|
// 读取直到遇到 end_marker
|
||||||
|
let mut buf = String::new();
|
||||||
|
loop {
|
||||||
|
match chars.next() {
|
||||||
|
Some(c) => {
|
||||||
|
current.push(c);
|
||||||
|
buf.push(c);
|
||||||
|
if buf.len() > end_marker.len() {
|
||||||
|
buf.remove(0);
|
||||||
|
}
|
||||||
|
if buf == end_marker {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
';' => {
|
||||||
|
// 语句结束
|
||||||
|
let trimmed = current.trim().to_string();
|
||||||
|
if !trimmed.is_empty() {
|
||||||
|
statements.push(trimmed);
|
||||||
|
}
|
||||||
|
current.clear();
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
current.push(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 最后一条语句 (可能不以分号结尾)
|
||||||
|
let trimmed = current.trim().to_string();
|
||||||
|
if !trimmed.is_empty() {
|
||||||
|
statements.push(trimmed);
|
||||||
|
}
|
||||||
|
|
||||||
|
statements
|
||||||
|
}
|
||||||
|
|
||||||
/// Seed 角色数据
|
/// Seed 角色数据
|
||||||
async fn seed_roles(pool: &PgPool) -> SaasResult<()> {
|
async fn seed_roles(pool: &PgPool) -> SaasResult<()> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|||||||
@@ -67,7 +67,9 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_handler(State(state): State<AppState>) -> axum::Json<serde_json::Value> {
|
async fn health_handler(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> (axum::http::StatusCode, axum::Json<serde_json::Value> ) {
|
||||||
// health 必须独立快速返回,用 3s 超时避免连接池满时阻塞
|
// health 必须独立快速返回,用 3s 超时避免连接池满时阻塞
|
||||||
let db_healthy = tokio::time::timeout(
|
let db_healthy = tokio::time::timeout(
|
||||||
std::time::Duration::from_secs(3),
|
std::time::Duration::from_secs(3),
|
||||||
@@ -77,15 +79,41 @@ async fn health_handler(State(state): State<AppState>) -> axum::Json<serde_json:
|
|||||||
.map(|r| r.is_ok())
|
.map(|r| r.is_ok())
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
let status = if db_healthy { "healthy" } else { "degraded" };
|
// 连接池容量检查: 使用率 >= 80% 返回 503 (degraded)
|
||||||
let _code = if db_healthy { 200 } else { 503 };
|
let pool = &state.db;
|
||||||
|
let total = pool.options().get_max_connections() as usize;
|
||||||
|
if total > 0 {
|
||||||
|
let idle = pool.num_idle() as usize;
|
||||||
|
let used = total - idle;
|
||||||
|
let ratio = used * 100 / total;
|
||||||
|
if ratio >= 80 {
|
||||||
|
return (
|
||||||
|
axum::http::StatusCode::SERVICE_UNAVAILABLE,
|
||||||
|
axum::Json(serde_json::json!({
|
||||||
|
"status": "degraded",
|
||||||
|
"database": true,
|
||||||
|
"database_pool": {
|
||||||
|
"usage_pct": ratio,
|
||||||
|
"used": used,
|
||||||
|
"total": total,
|
||||||
|
},
|
||||||
|
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||||
|
"version": env!("CARGO_PKG_VERSION"),
|
||||||
|
})),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
axum::Json(serde_json::json!({
|
let status = if db_healthy { "healthy" } else { "degraded" };
|
||||||
|
let code = if db_healthy {
|
||||||
|
axum::http::StatusCode::OK } else { axum::http::StatusCode::SERVICE_UNAVAILABLE };
|
||||||
|
|
||||||
|
(code, axum::Json(serde_json::json!({
|
||||||
"status": status,
|
"status": status,
|
||||||
"database": db_healthy,
|
"database": db_healthy,
|
||||||
"timestamp": chrono::Utc::now().to_rfc3339(),
|
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||||
"version": env!("CARGO_PKG_VERSION"),
|
"version": env!("CARGO_PKG_VERSION"),
|
||||||
}))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn build_router(state: AppState) -> axum::Router {
|
async fn build_router(state: AppState) -> axum::Router {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
use crate::models::{ProviderKeySelectRow, ProviderKeyRow};
|
use crate::models::ProviderKeyRow;
|
||||||
use crate::crypto;
|
use crate::crypto;
|
||||||
|
|
||||||
/// 解密 key_value (如果已加密),否则原样返回
|
/// 解密 key_value (如果已加密),否则原样返回
|
||||||
@@ -36,19 +36,63 @@ pub struct KeySelection {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 从 provider 的 Key Pool 中选择最佳可用 Key
|
/// 从 provider 的 Key Pool 中选择最佳可用 Key
|
||||||
|
///
|
||||||
|
/// 优化: 单次 JOIN 查询获取 Key + 当前分钟使用量,避免 N+1 查询
|
||||||
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
|
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
||||||
|
|
||||||
// 获取所有活跃 Key
|
// 单次查询: 活跃 Key + 当前分钟的 RPM/TPM 使用量 (LEFT JOIN)
|
||||||
let rows: Vec<ProviderKeySelectRow> =
|
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<String>, Option<i64>, Option<i64>)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, key_value, priority, max_rpm, max_tpm, quota_reset_interval
|
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm, pk.quota_reset_interval,
|
||||||
FROM provider_keys
|
uw.request_count, uw.token_count
|
||||||
WHERE provider_id = $1 AND is_active = TRUE AND (cooldown_until IS NULL OR cooldown_until <= $2)
|
FROM provider_keys pk
|
||||||
ORDER BY priority ASC"
|
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id AND uw.window_minute = $1
|
||||||
).bind(provider_id).bind(&now).fetch_all(db).await?;
|
WHERE pk.provider_id = $2 AND pk.is_active = TRUE
|
||||||
|
AND (pk.cooldown_until IS NULL OR pk.cooldown_until <= $3)
|
||||||
|
ORDER BY pk.priority ASC"
|
||||||
|
).bind(¤t_minute).bind(provider_id).bind(&now).fetch_all(db).await?;
|
||||||
|
|
||||||
|
for (id, key_value, priority, max_rpm, max_tpm, quota_reset_interval, req_count, token_count) in &rows {
|
||||||
|
// RPM 检查
|
||||||
|
if let Some(rpm_limit) = max_rpm {
|
||||||
|
if *rpm_limit > 0 {
|
||||||
|
let count = req_count.unwrap_or(0);
|
||||||
|
if count >= *rpm_limit {
|
||||||
|
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TPM 检查
|
||||||
|
if let Some(tpm_limit) = max_tpm {
|
||||||
|
if *tpm_limit > 0 {
|
||||||
|
let tokens = token_count.unwrap_or(0);
|
||||||
|
if tokens >= *tpm_limit {
|
||||||
|
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 此 Key 可用 — 解密 key_value
|
||||||
|
let decrypted_kv = decrypt_key_value(key_value, enc_key)?;
|
||||||
|
return Ok(KeySelection {
|
||||||
|
key: PoolKey {
|
||||||
|
id: id.clone(),
|
||||||
|
key_value: decrypted_kv,
|
||||||
|
priority: *priority,
|
||||||
|
max_rpm: *max_rpm,
|
||||||
|
max_tpm: *max_tpm,
|
||||||
|
quota_reset_interval: quota_reset_interval.clone(),
|
||||||
|
},
|
||||||
|
key_id: id.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// 所有 Key 都超限或无 Key
|
||||||
if rows.is_empty() {
|
if rows.is_empty() {
|
||||||
// 检查是否有冷却中的 Key,返回预计等待时间
|
// 检查是否有冷却中的 Key,返回预计等待时间
|
||||||
let cooldown_row: Option<(String,)> = sqlx::query_as(
|
let cooldown_row: Option<(String,)> = sqlx::query_as(
|
||||||
@@ -59,88 +103,14 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
|||||||
).bind(provider_id).bind(&now).fetch_optional(db).await?;
|
).bind(provider_id).bind(&now).fetch_optional(db).await?;
|
||||||
|
|
||||||
if let Some((earliest,)) = cooldown_row {
|
if let Some((earliest,)) = cooldown_row {
|
||||||
// 尝试解析时间差
|
|
||||||
let wait_secs = parse_cooldown_remaining(&earliest, &now);
|
let wait_secs = parse_cooldown_remaining(&earliest, &now);
|
||||||
return Err(SaasError::RateLimited(
|
return Err(SaasError::RateLimited(
|
||||||
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
|
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查 provider 级别的单 Key
|
|
||||||
let provider_key: Option<String> = sqlx::query_scalar(
|
|
||||||
"SELECT api_key FROM providers WHERE id = $1"
|
|
||||||
).bind(provider_id).fetch_optional(db).await?.flatten();
|
|
||||||
|
|
||||||
if let Some(key) = provider_key {
|
|
||||||
let decrypted = decrypt_key_value(&key, enc_key)?;
|
|
||||||
return Ok(KeySelection {
|
|
||||||
key: PoolKey {
|
|
||||||
id: "provider-fallback".to_string(),
|
|
||||||
key_value: decrypted,
|
|
||||||
priority: 0,
|
|
||||||
max_rpm: None,
|
|
||||||
max_tpm: None,
|
|
||||||
quota_reset_interval: None,
|
|
||||||
},
|
|
||||||
key_id: "provider-fallback".to_string(),
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
return Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查滑动窗口使用量
|
// 回退到 provider 单 Key
|
||||||
for row in rows {
|
|
||||||
// 检查 RPM 限额
|
|
||||||
if let Some(rpm_limit) = row.max_rpm {
|
|
||||||
if rpm_limit > 0 {
|
|
||||||
let window: Option<(i64,)> = sqlx::query_as(
|
|
||||||
"SELECT COALESCE(SUM(request_count), 0) FROM key_usage_window
|
|
||||||
WHERE key_id = $1 AND window_minute = $2"
|
|
||||||
).bind(&row.id).bind(¤t_minute).fetch_optional(db).await?;
|
|
||||||
|
|
||||||
if let Some((count,)) = window {
|
|
||||||
if count >= rpm_limit {
|
|
||||||
tracing::debug!("Key {} hit RPM limit ({}/{})", row.id, count, rpm_limit);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查 TPM 限额
|
|
||||||
if let Some(tpm_limit) = row.max_tpm {
|
|
||||||
if tpm_limit > 0 {
|
|
||||||
let window: Option<(i64,)> = sqlx::query_as(
|
|
||||||
"SELECT COALESCE(SUM(token_count), 0) FROM key_usage_window
|
|
||||||
WHERE key_id = $1 AND window_minute = $2"
|
|
||||||
).bind(&row.id).bind(¤t_minute).fetch_optional(db).await?;
|
|
||||||
|
|
||||||
if let Some((tokens,)) = window {
|
|
||||||
if tokens >= tpm_limit {
|
|
||||||
tracing::debug!("Key {} hit TPM limit ({}/{})", row.id, tokens, tpm_limit);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 此 Key 可用 — 解密 key_value
|
|
||||||
let decrypted_kv = decrypt_key_value(&row.key_value, enc_key)?;
|
|
||||||
return Ok(KeySelection {
|
|
||||||
key: PoolKey {
|
|
||||||
id: row.id.clone(),
|
|
||||||
key_value: decrypted_kv,
|
|
||||||
priority: row.priority,
|
|
||||||
max_rpm: row.max_rpm,
|
|
||||||
max_tpm: row.max_tpm,
|
|
||||||
quota_reset_interval: row.quota_reset_interval,
|
|
||||||
},
|
|
||||||
key_id: row.id,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// 所有 Key 都超限,回退到 provider 单 Key
|
|
||||||
let provider_key: Option<String> = sqlx::query_scalar(
|
let provider_key: Option<String> = sqlx::query_scalar(
|
||||||
"SELECT api_key FROM providers WHERE id = $1"
|
"SELECT api_key FROM providers WHERE id = $1"
|
||||||
).bind(provider_id).fetch_optional(db).await?.flatten();
|
).bind(provider_id).fetch_optional(db).await?.flatten();
|
||||||
@@ -160,9 +130,13 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Err(SaasError::RateLimited(
|
if rows.is_empty() {
|
||||||
format!("Provider {} 所有 Key 均已达限额", provider_id)
|
Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)))
|
||||||
))
|
} else {
|
||||||
|
Err(SaasError::RateLimited(
|
||||||
|
format!("Provider {} 所有 Key 均已达限额", provider_id)
|
||||||
|
))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 记录 Key 使用量(滑动窗口)
|
/// 记录 Key 使用量(滑动窗口)
|
||||||
|
|||||||
@@ -298,7 +298,21 @@ pub async fn execute_relay(
|
|||||||
let body = axum::body::Body::from_stream(body_stream);
|
let body = axum::body::Body::from_stream(body_stream);
|
||||||
|
|
||||||
// SSE 流结束后异步记录 usage + Key 使用量
|
// SSE 流结束后异步记录 usage + Key 使用量
|
||||||
|
// 使用全局 Arc<Semaphore> 限制并发 spawned tasks,防止高并发时耗尽连接池
|
||||||
|
static SSE_SPAWN_SEMAPHORE: std::sync::OnceLock<Arc<tokio::sync::Semaphore>> = std::sync::OnceLock::new();
|
||||||
|
let semaphore = SSE_SPAWN_SEMAPHORE.get_or_init(|| Arc::new(tokio::sync::Semaphore::new(16)));
|
||||||
|
let permit = match semaphore.clone().try_acquire_owned() {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => {
|
||||||
|
// 信号量满时跳过 usage 记录,流本身不受影响
|
||||||
|
tracing::warn!("SSE usage spawn at capacity, skipping usage record for task {}", task_id);
|
||||||
|
return Ok(RelayResponse::Sse(body));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
|
let _permit = permit; // 持有 permit 直到任务完成
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||||
let capture = usage_capture.lock().await;
|
let capture = usage_capture.lock().await;
|
||||||
let (input, output) = (
|
let (input, output) = (
|
||||||
@@ -464,11 +478,11 @@ async fn validate_provider_url(url: &str) -> SaasResult<()> {
|
|||||||
// 去除 IPv6 方括号
|
// 去除 IPv6 方括号
|
||||||
let host = host.trim_start_matches('[').trim_end_matches(']');
|
let host = host.trim_start_matches('[').trim_end_matches(']');
|
||||||
|
|
||||||
// 精确匹配的阻止列表
|
// 精确匹配的阻止列表: 仅包含主机名和特殊域名
|
||||||
|
// 私有 IP 范围 (10.x, 172.16-31.x, 192.168.x, 127.x, 169.254.x, ::1 等)
|
||||||
|
// 由 is_private_ip() 统一判断,无需在此重复列出
|
||||||
let blocked_exact = [
|
let blocked_exact = [
|
||||||
"127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1",
|
"localhost", "metadata.google.internal",
|
||||||
"0:0:0:0:0:ffff:7f00:1", "169.254.169.254", "metadata.google.internal",
|
|
||||||
"10.0.0.1", "172.16.0.1", "192.168.0.1",
|
|
||||||
];
|
];
|
||||||
if blocked_exact.contains(&host) {
|
if blocked_exact.contains(&host) {
|
||||||
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
|
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ pub struct AppState {
|
|||||||
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
|
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
|
||||||
/// 角色权限缓存: role_id → permissions list
|
/// 角色权限缓存: role_id → permissions list
|
||||||
pub role_permissions_cache: Arc<dashmap::DashMap<String, Vec<String>>>,
|
pub role_permissions_cache: Arc<dashmap::DashMap<String, Vec<String>>>,
|
||||||
|
/// TOTP 失败计数: account_id → (失败次数, 首次失败时间)
|
||||||
|
pub totp_fail_counts: Arc<dashmap::DashMap<String, (u32, Instant)>>,
|
||||||
/// 无锁 rate limit RPM(从 config 同步,避免每个请求获取 RwLock)
|
/// 无锁 rate limit RPM(从 config 同步,避免每个请求获取 RwLock)
|
||||||
rate_limit_rpm: Arc<AtomicU32>,
|
rate_limit_rpm: Arc<AtomicU32>,
|
||||||
/// Worker 调度器 (异步后台任务)
|
/// Worker 调度器 (异步后台任务)
|
||||||
@@ -37,6 +39,7 @@ impl AppState {
|
|||||||
jwt_secret,
|
jwt_secret,
|
||||||
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
|
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
|
||||||
role_permissions_cache: Arc::new(dashmap::DashMap::new()),
|
role_permissions_cache: Arc::new(dashmap::DashMap::new()),
|
||||||
|
totp_fail_counts: Arc::new(dashmap::DashMap::new()),
|
||||||
rate_limit_rpm: Arc::new(AtomicU32::new(rpm)),
|
rate_limit_rpm: Arc::new(AtomicU32::new(rpm)),
|
||||||
worker_dispatcher,
|
worker_dispatcher,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ impl WorkerDispatcher {
|
|||||||
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
|
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
|
||||||
let db = self.db.clone();
|
let db = self.db.clone();
|
||||||
let handlers = self.handlers.clone();
|
let handlers = self.handlers.clone();
|
||||||
|
let sender = self.sender.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
while let Some(msg) = receiver.recv().await {
|
while let Some(msg) = receiver.recv().await {
|
||||||
@@ -169,6 +170,7 @@ impl WorkerDispatcher {
|
|||||||
let worker_name = msg.worker_name.clone();
|
let worker_name = msg.worker_name.clone();
|
||||||
let max_retries = handler.max_retries();
|
let max_retries = handler.max_retries();
|
||||||
let db = db.clone();
|
let db = db.clone();
|
||||||
|
let sender = sender.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
match handler.perform(&db, &msg.args_json).await {
|
match handler.perform(&db, &msg.args_json).await {
|
||||||
@@ -177,18 +179,27 @@ impl WorkerDispatcher {
|
|||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
if msg.attempt < max_retries {
|
if msg.attempt < max_retries {
|
||||||
tracing::warn!(
|
|
||||||
"Worker {} failed (attempt {}/{}): {}. Will retry.",
|
|
||||||
worker_name, msg.attempt, max_retries, e
|
|
||||||
);
|
|
||||||
// 简单退避: 2^attempt 秒
|
|
||||||
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
|
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
|
||||||
|
tracing::warn!(
|
||||||
|
"Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.",
|
||||||
|
worker_name, msg.attempt, max_retries, e, delay
|
||||||
|
);
|
||||||
tokio::time::sleep(delay).await;
|
tokio::time::sleep(delay).await;
|
||||||
// 注意: 重试在当前设计中通过日志提醒
|
// 重新入队(递增 attempt 计数)
|
||||||
// 生产环境应将任务重新入队
|
let retry_msg = TaskMessage {
|
||||||
|
worker_name: msg.worker_name.clone(),
|
||||||
|
args_json: msg.args_json.clone(),
|
||||||
|
attempt: msg.attempt + 1,
|
||||||
|
};
|
||||||
|
if let Err(send_err) = sender.send(retry_msg).await {
|
||||||
|
tracing::error!(
|
||||||
|
"Worker {} retry enqueue failed (channel closed): {}",
|
||||||
|
worker_name, send_err
|
||||||
|
);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
tracing::error!(
|
tracing::error!(
|
||||||
"Worker {} failed after {} attempts: {}",
|
"Worker {} failed after {} attempts: {}. Giving up.",
|
||||||
worker_name, max_retries, e
|
worker_name, max_retries, e
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -186,7 +186,7 @@ function MarkdownPreview({ content }: MarkdownPreviewProps) {
|
|||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className="prose dark:prose-invert max-w-none p-4 bg-white dark:bg-gray-800 rounded-lg"
|
className="prose dark:prose-invert max-w-none p-4 bg-white dark:bg-gray-800 rounded-lg"
|
||||||
dangerouslySetInnerHTML={{ __html: renderMarkdown(content) }}
|
dangerouslySetInnerHTML={{ __html: DOMPurify.sanitize(renderMarkdown(content)) }}
|
||||||
/>
|
/>
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user