chore: 提交所有工作进度 — SaaS 后端增强、Admin UI、桌面端集成
包含大量 SaaS 平台改进、Admin 管理后台更新、桌面端集成完善、 文档同步、测试文件重构等内容。为 QA 测试准备干净工作树。
This commit is contained in:
1
.claude/worktrees/saas-backend
Submodule
1
.claude/worktrees/saas-backend
Submodule
Submodule .claude/worktrees/saas-backend added at 4d8d560d1f
4
.gitignore
vendored
4
.gitignore
vendored
@@ -12,6 +12,10 @@ build/
|
|||||||
.env.local
|
.env.local
|
||||||
.env.*.local
|
.env.*.local
|
||||||
|
|
||||||
|
# SaaS config (contains database credentials)
|
||||||
|
saas-config.toml
|
||||||
|
!saas-config.toml.example
|
||||||
|
|
||||||
# Logs
|
# Logs
|
||||||
logs/
|
logs/
|
||||||
*.log
|
*.log
|
||||||
|
|||||||
30
CLAUDE.md
30
CLAUDE.md
@@ -37,16 +37,20 @@ ZCLAW/
|
|||||||
│ ├── zclaw-skills/ # 技能系统 (SKILL.md解析, 执行器)
|
│ ├── zclaw-skills/ # 技能系统 (SKILL.md解析, 执行器)
|
||||||
│ ├── zclaw-hands/ # 自主能力 (Hand/Trigger 注册管理)
|
│ ├── zclaw-hands/ # 自主能力 (Hand/Trigger 注册管理)
|
||||||
│ ├── zclaw-channels/ # 通道适配器 (仅 ConsoleChannel 测试适配器)
|
│ ├── zclaw-channels/ # 通道适配器 (仅 ConsoleChannel 测试适配器)
|
||||||
│ └── zclaw-protocols/ # 协议支持 (MCP, A2A)
|
│ ├── zclaw-protocols/ # 协议支持 (MCP, A2A)
|
||||||
|
│ └── zclaw-saas/ # SaaS 后端 (账号, 模型配置, 中转, 配置同步)
|
||||||
|
├── admin/ # Next.js 管理后台
|
||||||
├── desktop/ # Tauri 桌面应用
|
├── desktop/ # Tauri 桌面应用
|
||||||
│ ├── src/
|
│ ├── src/
|
||||||
│ │ ├── components/ # React UI 组件
|
│ │ ├── components/ # React UI 组件 (含 SaaS 集成)
|
||||||
│ │ ├── store/ # Zustand 状态管理
|
│ │ ├── store/ # Zustand 状态管理 (含 saasStore)
|
||||||
│ │ └── lib/ # 客户端通信 / 工具函数
|
│ │ └── lib/ # 客户端通信 / 工具函数 (含 saas-client)
|
||||||
│ └── src-tauri/ # Tauri Rust 后端 (集成 Kernel)
|
│ └── src-tauri/ # Tauri Rust 后端 (集成 Kernel)
|
||||||
├── skills/ # SKILL.md 技能定义
|
├── skills/ # SKILL.md 技能定义
|
||||||
├── hands/ # HAND.toml 自主能力配置
|
├── hands/ # HAND.toml 自主能力配置
|
||||||
├── config/ # TOML 配置文件
|
├── config/ # TOML 配置文件
|
||||||
|
├── saas-config.toml # SaaS 后端配置 (PostgreSQL 连接等)
|
||||||
|
├── docker-compose.yml # PostgreSQL 容器配置
|
||||||
├── docs/ # 架构文档和知识库
|
├── docs/ # 架构文档和知识库
|
||||||
└── tests/ # Vitest 回归测试
|
└── tests/ # Vitest 回归测试
|
||||||
```
|
```
|
||||||
@@ -66,7 +70,9 @@ ZCLAW/
|
|||||||
| 桌面框架 | Tauri 2.x |
|
| 桌面框架 | Tauri 2.x |
|
||||||
| 样式方案 | Tailwind CSS |
|
| 样式方案 | Tailwind CSS |
|
||||||
| 配置格式 | TOML |
|
| 配置格式 | TOML |
|
||||||
| 后端核心 | Rust Workspace (8 crates) |
|
| 后端核心 | Rust Workspace (9 crates) |
|
||||||
|
| SaaS 后端 | Axum + PostgreSQL (zclaw-saas) |
|
||||||
|
| 管理后台 | Next.js (admin/) |
|
||||||
|
|
||||||
### 2.3 Crate 依赖关系
|
### 2.3 Crate 依赖关系
|
||||||
|
|
||||||
@@ -79,6 +85,8 @@ zclaw-runtime (→ types, memory)
|
|||||||
↑
|
↑
|
||||||
zclaw-kernel (→ types, memory, runtime)
|
zclaw-kernel (→ types, memory, runtime)
|
||||||
↑
|
↑
|
||||||
|
zclaw-saas (→ types, 独立运行于 8080 端口)
|
||||||
|
↑
|
||||||
desktop/src-tauri (→ kernel, skills, hands, channels, protocols)
|
desktop/src-tauri (→ kernel, skills, hands, channels, protocols)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -260,6 +268,18 @@ docs/
|
|||||||
- **面向未来** - 文档要帮助未来的开发者快速理解
|
- **面向未来** - 文档要帮助未来的开发者快速理解
|
||||||
- **中文优先** - 所有面向用户的文档使用中文
|
- **中文优先** - 所有面向用户的文档使用中文
|
||||||
|
|
||||||
|
### 8.3 完成工作后的文档同步(强制)
|
||||||
|
|
||||||
|
每次完成功能实现、架构变更、问题修复后,**必须**同步更新以下文档:
|
||||||
|
|
||||||
|
1. **CLAUDE.md** — 如果涉及项目结构、技术栈、工作流程、命令的变化
|
||||||
|
2. **docs/features/** — 如果涉及新功能、功能变更、功能状态更新
|
||||||
|
3. **docs/knowledge-base/** — 如果涉及新知识、故障排查经验、配置说明
|
||||||
|
4. **saas-config.toml 注释** — 如果涉及 SaaS 配置项变更
|
||||||
|
5. **CHANGELOG** — 如果涉及对外可见的行为变化
|
||||||
|
|
||||||
|
**执行时机:** 代码编译通过且验证成功后,在标记任务完成之前,立即执行文档更新。文档更新是任务完成的必要条件,不是可选步骤。
|
||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
## 9. 常见问题排查
|
## 9. 常见问题排查
|
||||||
|
|||||||
925
Cargo.lock
generated
925
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -57,7 +57,7 @@ chrono = { version = "0.4", features = ["serde"] }
|
|||||||
uuid = { version = "1", features = ["v4", "v5", "serde"] }
|
uuid = { version = "1", features = ["v4", "v5", "serde"] }
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }
|
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite", "postgres"] }
|
||||||
libsqlite3-sys = { version = "0.27", features = ["bundled"] }
|
libsqlite3-sys = { version = "0.27", features = ["bundled"] }
|
||||||
|
|
||||||
# HTTP client (for LLM drivers)
|
# HTTP client (for LLM drivers)
|
||||||
@@ -94,6 +94,10 @@ regex = "1"
|
|||||||
# Shell parsing
|
# Shell parsing
|
||||||
shlex = "1"
|
shlex = "1"
|
||||||
|
|
||||||
|
# WASM runtime
|
||||||
|
wasmtime = { version = "43", default-features = false, features = ["cranelift"] }
|
||||||
|
wasmtime-wasi = { version = "43" }
|
||||||
|
|
||||||
# Testing
|
# Testing
|
||||||
tempfile = "3"
|
tempfile = "3"
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,13 @@
|
|||||||
/** @type {import('next').NextConfig} */
|
/** @type {import('next').NextConfig} */
|
||||||
const nextConfig = {}
|
const nextConfig = {
|
||||||
|
async rewrites() {
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
source: '/api/:path*',
|
||||||
|
destination: 'http://localhost:8080/api/:path*',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
module.exports = nextConfig
|
module.exports = nextConfig
|
||||||
|
|||||||
@@ -11,10 +11,10 @@
|
|||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@radix-ui/react-dialog": "^1.1.14",
|
"@radix-ui/react-dialog": "^1.1.14",
|
||||||
"@radix-ui/react-select": "^2.2.5",
|
"@radix-ui/react-select": "^2.2.5",
|
||||||
|
"@radix-ui/react-separator": "^1.1.7",
|
||||||
"@radix-ui/react-switch": "^1.2.5",
|
"@radix-ui/react-switch": "^1.2.5",
|
||||||
"@radix-ui/react-tabs": "^1.1.12",
|
"@radix-ui/react-tabs": "^1.1.12",
|
||||||
"@radix-ui/react-tooltip": "^1.2.7",
|
"@radix-ui/react-tooltip": "^1.2.7",
|
||||||
"@radix-ui/react-separator": "^1.1.7",
|
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"lucide-react": "^0.484.0",
|
"lucide-react": "^0.484.0",
|
||||||
@@ -22,6 +22,7 @@
|
|||||||
"react": "^18.3.1",
|
"react": "^18.3.1",
|
||||||
"react-dom": "^18.3.1",
|
"react-dom": "^18.3.1",
|
||||||
"recharts": "^2.15.3",
|
"recharts": "^2.15.3",
|
||||||
|
"swr": "^2.4.1",
|
||||||
"tailwind-merge": "^3.0.2"
|
"tailwind-merge": "^3.0.2"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|||||||
29
admin/pnpm-lock.yaml
generated
29
admin/pnpm-lock.yaml
generated
@@ -47,6 +47,9 @@ importers:
|
|||||||
recharts:
|
recharts:
|
||||||
specifier: ^2.15.3
|
specifier: ^2.15.3
|
||||||
version: 2.15.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
version: 2.15.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
|
swr:
|
||||||
|
specifier: ^2.4.1
|
||||||
|
version: 2.4.1(react@18.3.1)
|
||||||
tailwind-merge:
|
tailwind-merge:
|
||||||
specifier: ^3.0.2
|
specifier: ^3.0.2
|
||||||
version: 3.5.0
|
version: 3.5.0
|
||||||
@@ -719,6 +722,10 @@ packages:
|
|||||||
decimal.js-light@2.5.1:
|
decimal.js-light@2.5.1:
|
||||||
resolution: {integrity: sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==}
|
resolution: {integrity: sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==}
|
||||||
|
|
||||||
|
dequal@2.0.3:
|
||||||
|
resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==}
|
||||||
|
engines: {node: '>=6'}
|
||||||
|
|
||||||
detect-node-es@1.1.0:
|
detect-node-es@1.1.0:
|
||||||
resolution: {integrity: sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==}
|
resolution: {integrity: sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==}
|
||||||
|
|
||||||
@@ -1093,6 +1100,11 @@ packages:
|
|||||||
resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==}
|
resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==}
|
||||||
engines: {node: '>= 0.4'}
|
engines: {node: '>= 0.4'}
|
||||||
|
|
||||||
|
swr@2.4.1:
|
||||||
|
resolution: {integrity: sha512-2CC6CiKQtEwaEeNiqWTAw9PGykW8SR5zZX8MZk6TeAvEAnVS7Visz8WzphqgtQ8v2xz/4Q5K+j+SeMaKXeeQIA==}
|
||||||
|
peerDependencies:
|
||||||
|
react: ^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||||
|
|
||||||
tailwind-merge@3.5.0:
|
tailwind-merge@3.5.0:
|
||||||
resolution: {integrity: sha512-I8K9wewnVDkL1NTGoqWmVEIlUcB9gFriAEkXkfCjX5ib8ezGxtR3xD7iZIxrfArjEsH7F1CHD4RFUtxefdqV/A==}
|
resolution: {integrity: sha512-I8K9wewnVDkL1NTGoqWmVEIlUcB9gFriAEkXkfCjX5ib8ezGxtR3xD7iZIxrfArjEsH7F1CHD4RFUtxefdqV/A==}
|
||||||
|
|
||||||
@@ -1159,6 +1171,11 @@ packages:
|
|||||||
'@types/react':
|
'@types/react':
|
||||||
optional: true
|
optional: true
|
||||||
|
|
||||||
|
use-sync-external-store@1.6.0:
|
||||||
|
resolution: {integrity: sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==}
|
||||||
|
peerDependencies:
|
||||||
|
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
||||||
|
|
||||||
util-deprecate@1.0.2:
|
util-deprecate@1.0.2:
|
||||||
resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==}
|
resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==}
|
||||||
|
|
||||||
@@ -1744,6 +1761,8 @@ snapshots:
|
|||||||
|
|
||||||
decimal.js-light@2.5.1: {}
|
decimal.js-light@2.5.1: {}
|
||||||
|
|
||||||
|
dequal@2.0.3: {}
|
||||||
|
|
||||||
detect-node-es@1.1.0: {}
|
detect-node-es@1.1.0: {}
|
||||||
|
|
||||||
didyoumean@1.2.2: {}
|
didyoumean@1.2.2: {}
|
||||||
@@ -2073,6 +2092,12 @@ snapshots:
|
|||||||
|
|
||||||
supports-preserve-symlinks-flag@1.0.0: {}
|
supports-preserve-symlinks-flag@1.0.0: {}
|
||||||
|
|
||||||
|
swr@2.4.1(react@18.3.1):
|
||||||
|
dependencies:
|
||||||
|
dequal: 2.0.3
|
||||||
|
react: 18.3.1
|
||||||
|
use-sync-external-store: 1.6.0(react@18.3.1)
|
||||||
|
|
||||||
tailwind-merge@3.5.0: {}
|
tailwind-merge@3.5.0: {}
|
||||||
|
|
||||||
tailwindcss@3.4.19:
|
tailwindcss@3.4.19:
|
||||||
@@ -2151,6 +2176,10 @@ snapshots:
|
|||||||
optionalDependencies:
|
optionalDependencies:
|
||||||
'@types/react': 18.3.28
|
'@types/react': 18.3.28
|
||||||
|
|
||||||
|
use-sync-external-store@1.6.0(react@18.3.1):
|
||||||
|
dependencies:
|
||||||
|
react: 18.3.1
|
||||||
|
|
||||||
util-deprecate@1.0.2: {}
|
util-deprecate@1.0.2: {}
|
||||||
|
|
||||||
victory-vendor@36.9.2:
|
victory-vendor@36.9.2:
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useEffect, useState, useCallback } from 'react'
|
import { useState } from 'react'
|
||||||
|
import useSWR from 'swr'
|
||||||
import {
|
import {
|
||||||
Search,
|
Search,
|
||||||
Plus,
|
Plus,
|
||||||
@@ -41,6 +42,9 @@ import {
|
|||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatDate } from '@/lib/utils'
|
import { formatDate } from '@/lib/utils'
|
||||||
|
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
||||||
|
import { TableSkeleton } from '@/components/ui/skeleton'
|
||||||
|
import { useDebounce } from '@/hooks/use-debounce'
|
||||||
import type { AccountPublic } from '@/lib/types'
|
import type { AccountPublic } from '@/lib/types'
|
||||||
|
|
||||||
const PAGE_SIZE = 20
|
const PAGE_SIZE = 20
|
||||||
@@ -64,14 +68,28 @@ const statusLabels: Record<string, string> = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function AccountsPage() {
|
export default function AccountsPage() {
|
||||||
const [accounts, setAccounts] = useState<AccountPublic[]>([])
|
|
||||||
const [total, setTotal] = useState(0)
|
|
||||||
const [page, setPage] = useState(1)
|
const [page, setPage] = useState(1)
|
||||||
const [search, setSearch] = useState('')
|
const [search, setSearch] = useState('')
|
||||||
const [roleFilter, setRoleFilter] = useState<string>('all')
|
const [roleFilter, setRoleFilter] = useState<string>('all')
|
||||||
const [statusFilter, setStatusFilter] = useState<string>('all')
|
const [statusFilter, setStatusFilter] = useState<string>('all')
|
||||||
const [loading, setLoading] = useState(true)
|
const [mutationError, setMutationError] = useState('')
|
||||||
const [error, setError] = useState('')
|
|
||||||
|
const debouncedSearch = useDebounce(search, 300)
|
||||||
|
|
||||||
|
const { data, error: swrError, isLoading, mutate } = useSWR(
|
||||||
|
['accounts', page, debouncedSearch, roleFilter, statusFilter],
|
||||||
|
() => {
|
||||||
|
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
||||||
|
if (debouncedSearch.trim()) params.search = debouncedSearch.trim()
|
||||||
|
if (roleFilter !== 'all') params.role = roleFilter
|
||||||
|
if (statusFilter !== 'all') params.status = statusFilter
|
||||||
|
return api.accounts.list(params)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
const accounts = data?.items ?? []
|
||||||
|
const total = data?.total ?? 0
|
||||||
|
const error = swrError?.message || mutationError
|
||||||
|
|
||||||
// 编辑 Dialog
|
// 编辑 Dialog
|
||||||
const [editTarget, setEditTarget] = useState<AccountPublic | null>(null)
|
const [editTarget, setEditTarget] = useState<AccountPublic | null>(null)
|
||||||
@@ -82,33 +100,6 @@ export default function AccountsPage() {
|
|||||||
const [confirmTarget, setConfirmTarget] = useState<{ id: string; action: string; status: string } | null>(null)
|
const [confirmTarget, setConfirmTarget] = useState<{ id: string; action: string; status: string } | null>(null)
|
||||||
const [confirmSaving, setConfirmSaving] = useState(false)
|
const [confirmSaving, setConfirmSaving] = useState(false)
|
||||||
|
|
||||||
const fetchAccounts = useCallback(async () => {
|
|
||||||
setLoading(true)
|
|
||||||
setError('')
|
|
||||||
try {
|
|
||||||
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
|
||||||
if (search.trim()) params.search = search.trim()
|
|
||||||
if (roleFilter !== 'all') params.role = roleFilter
|
|
||||||
if (statusFilter !== 'all') params.status = statusFilter
|
|
||||||
|
|
||||||
const res = await api.accounts.list(params)
|
|
||||||
setAccounts(res.items)
|
|
||||||
setTotal(res.total)
|
|
||||||
} catch (err) {
|
|
||||||
if (err instanceof ApiRequestError) {
|
|
||||||
setError(err.body.message)
|
|
||||||
} else {
|
|
||||||
setError('加载失败')
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
setLoading(false)
|
|
||||||
}
|
|
||||||
}, [page, search, roleFilter, statusFilter])
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
fetchAccounts()
|
|
||||||
}, [fetchAccounts])
|
|
||||||
|
|
||||||
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
||||||
|
|
||||||
function openEditDialog(account: AccountPublic) {
|
function openEditDialog(account: AccountPublic) {
|
||||||
@@ -130,10 +121,10 @@ export default function AccountsPage() {
|
|||||||
role: editForm.role as AccountPublic['role'],
|
role: editForm.role as AccountPublic['role'],
|
||||||
})
|
})
|
||||||
setEditTarget(null)
|
setEditTarget(null)
|
||||||
fetchAccounts()
|
mutate()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) {
|
if (err instanceof ApiRequestError) {
|
||||||
setError(err.body.message)
|
setMutationError(err.body.message)
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
setEditSaving(false)
|
setEditSaving(false)
|
||||||
@@ -157,10 +148,10 @@ export default function AccountsPage() {
|
|||||||
status: confirmTarget.status as AccountPublic['status'],
|
status: confirmTarget.status as AccountPublic['status'],
|
||||||
})
|
})
|
||||||
setConfirmTarget(null)
|
setConfirmTarget(null)
|
||||||
fetchAccounts()
|
mutate()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) {
|
if (err instanceof ApiRequestError) {
|
||||||
setError(err.body.message)
|
setMutationError(err.body.message)
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
setConfirmSaving(false)
|
setConfirmSaving(false)
|
||||||
@@ -205,24 +196,13 @@ export default function AccountsPage() {
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* 错误提示 */}
|
{/* 错误提示 */}
|
||||||
{error && (
|
{error && <ErrorBanner message={error} onDismiss={() => { setMutationError('') }} />}
|
||||||
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
|
||||||
{error}
|
|
||||||
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">
|
|
||||||
关闭
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* 表格 */}
|
{/* 表格 */}
|
||||||
{loading ? (
|
{isLoading ? (
|
||||||
<div className="flex h-64 items-center justify-center">
|
<TableSkeleton rows={6} cols={7} />
|
||||||
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
) : error ? null : accounts.length === 0 ? (
|
||||||
</div>
|
<EmptyState />
|
||||||
) : accounts.length === 0 ? (
|
|
||||||
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
|
||||||
暂无数据
|
|
||||||
</div>
|
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Table>
|
<Table>
|
||||||
|
|||||||
290
admin/src/app/(dashboard)/agent-templates/page.tsx
Normal file
290
admin/src/app/(dashboard)/agent-templates/page.tsx
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
'use client'
|
||||||
|
|
||||||
|
import { useState } from 'react'
|
||||||
|
import useSWR from 'swr'
|
||||||
|
import { api } from '@/lib/api-client'
|
||||||
|
import type { AgentTemplate } from '@/lib/types'
|
||||||
|
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
||||||
|
import { TableSkeleton } from '@/components/ui/skeleton'
|
||||||
|
|
||||||
|
export default function AgentTemplatesPage() {
|
||||||
|
const [page, setPage] = useState(1)
|
||||||
|
const [error, setError] = useState('')
|
||||||
|
const [showCreate, setShowCreate] = useState(false)
|
||||||
|
const [editingId, setEditingId] = useState<string | null>(null)
|
||||||
|
|
||||||
|
const { data, isLoading, mutate } = useSWR(
|
||||||
|
['agentTemplates.list', page],
|
||||||
|
() => api.agentTemplates.list({ page, page_size: 50 }),
|
||||||
|
)
|
||||||
|
|
||||||
|
const templates = data?.items ?? []
|
||||||
|
const total = data?.total ?? 0
|
||||||
|
|
||||||
|
const handleCreate = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||||
|
e.preventDefault()
|
||||||
|
const fd = new FormData(e.currentTarget)
|
||||||
|
try {
|
||||||
|
const tools = (fd.get('tools') as string || '').split(',').map(s => s.trim()).filter(Boolean)
|
||||||
|
const capabilities = (fd.get('capabilities') as string || '').split(',').map(s => s.trim()).filter(Boolean)
|
||||||
|
await api.agentTemplates.create({
|
||||||
|
name: fd.get('name') as string,
|
||||||
|
description: (fd.get('description') as string) || undefined,
|
||||||
|
category: (fd.get('category') as string) || 'general',
|
||||||
|
model: (fd.get('model') as string) || undefined,
|
||||||
|
system_prompt: (fd.get('system_prompt') as string) || undefined,
|
||||||
|
tools: tools.length > 0 ? tools : undefined,
|
||||||
|
capabilities: capabilities.length > 0 ? capabilities : undefined,
|
||||||
|
temperature: (fd.get('temperature') as string) ? parseFloat(fd.get('temperature') as string) : undefined,
|
||||||
|
max_tokens: (fd.get('max_tokens') as string) ? parseInt(fd.get('max_tokens') as string, 10) : undefined,
|
||||||
|
visibility: (fd.get('visibility') as string) || 'public',
|
||||||
|
})
|
||||||
|
setShowCreate(false)
|
||||||
|
mutate()
|
||||||
|
} catch {
|
||||||
|
setError('创建失败')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleArchive = async (id: string, name: string) => {
|
||||||
|
if (!confirm(`确认归档模板 "${name}"?`)) return
|
||||||
|
try {
|
||||||
|
await api.agentTemplates.archive(id)
|
||||||
|
mutate()
|
||||||
|
} catch {
|
||||||
|
setError('归档失败')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const statusBadge = (status: string) => {
|
||||||
|
const colors: Record<string, string> = {
|
||||||
|
active: 'bg-emerald-500/20 text-emerald-400',
|
||||||
|
archived: 'bg-zinc-500/20 text-zinc-400',
|
||||||
|
}
|
||||||
|
return <span className={`px-2 py-0.5 text-xs rounded-full ${colors[status] || colors.archived}`}>{status}</span>
|
||||||
|
}
|
||||||
|
|
||||||
|
const sourceBadge = (source: string) => {
|
||||||
|
const colors: Record<string, string> = {
|
||||||
|
builtin: 'bg-blue-500/20 text-blue-400',
|
||||||
|
custom: 'bg-purple-500/20 text-purple-400',
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<span className={`px-2 py-0.5 text-xs rounded-full ${colors[source] || ''}`}>
|
||||||
|
{source === 'builtin' ? '内置' : '自定义'}
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-6">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<h1 className="text-2xl font-bold text-white">Agent 配置模板</h1>
|
||||||
|
<p className="text-sm text-zinc-400 mt-1">管理 Agent 配置模板,支持团队共享和一键复用</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={() => setShowCreate(true)}
|
||||||
|
className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors text-sm"
|
||||||
|
>
|
||||||
|
+ 新建模板
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{error && <ErrorBanner message={error} onDismiss={() => setError('')} />}
|
||||||
|
|
||||||
|
<div className="bg-zinc-900 rounded-xl border border-zinc-800 overflow-hidden">
|
||||||
|
<table className="w-full text-sm">
|
||||||
|
<thead>
|
||||||
|
<tr className="border-b border-zinc-800">
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">名称</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">分类</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">来源</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">模型</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">工具数</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">可见性</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">状态</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">更新时间</th>
|
||||||
|
<th className="text-right px-4 py-3 text-zinc-400 font-medium">操作</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{isLoading ? (
|
||||||
|
<tr>
|
||||||
|
<td colSpan={9}>
|
||||||
|
<TableSkeleton rows={5} cols={9} hasToolbar={false} />
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
) : templates.length === 0 ? (
|
||||||
|
<tr><td colSpan={9}><EmptyState message="暂无 Agent 模板" /></td></tr>
|
||||||
|
) : (
|
||||||
|
templates.map(t => (
|
||||||
|
<tr key={t.id} className="border-b border-zinc-800/50 hover:bg-zinc-800/30">
|
||||||
|
<td className="px-4 py-3">
|
||||||
|
<div>
|
||||||
|
<span className="text-white font-medium">{t.name}</span>
|
||||||
|
{t.description && (
|
||||||
|
<p className="text-xs text-zinc-500 mt-0.5 truncate max-w-[200px]">{t.description}</p>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</td>
|
||||||
|
<td className="px-4 py-3 text-zinc-400">{t.category}</td>
|
||||||
|
<td className="px-4 py-3">{sourceBadge(t.source)}</td>
|
||||||
|
<td className="px-4 py-3 text-zinc-300 font-mono text-xs">{t.model || '-'}</td>
|
||||||
|
<td className="px-4 py-3 text-zinc-400">{t.tools.length}</td>
|
||||||
|
<td className="px-4 py-3 text-zinc-400">{t.visibility}</td>
|
||||||
|
<td className="px-4 py-3">{statusBadge(t.status)}</td>
|
||||||
|
<td className="px-4 py-3 text-zinc-500 text-xs">
|
||||||
|
{new Date(t.updated_at).toLocaleString('zh-CN')}
|
||||||
|
</td>
|
||||||
|
<td className="px-4 py-3 text-right">
|
||||||
|
<button
|
||||||
|
onClick={() => setEditingId(editingId === t.id ? null : t.id)}
|
||||||
|
className="text-zinc-400 hover:text-white mr-2"
|
||||||
|
>
|
||||||
|
详情
|
||||||
|
</button>
|
||||||
|
{t.source === 'custom' && (
|
||||||
|
<button
|
||||||
|
onClick={() => handleArchive(t.id, t.name)}
|
||||||
|
className="text-red-400 hover:text-red-300"
|
||||||
|
>
|
||||||
|
归档
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
<div className="px-4 py-2 text-xs text-zinc-500 border-t border-zinc-800">
|
||||||
|
共 {total} 个模板
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 展开详情 */}
|
||||||
|
{editingId && (() => {
|
||||||
|
const t = templates.find(t => t.id === editingId)
|
||||||
|
if (!t) return null
|
||||||
|
return (
|
||||||
|
<div className="bg-zinc-900 rounded-xl border border-zinc-800 p-4">
|
||||||
|
<div className="flex items-center justify-between mb-3">
|
||||||
|
<h2 className="text-lg font-semibold text-white">{t.name} — 详情</h2>
|
||||||
|
<button onClick={() => setEditingId(null)} className="text-zinc-400 hover:text-white text-sm">关闭</button>
|
||||||
|
</div>
|
||||||
|
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||||
|
<div>
|
||||||
|
<span className="text-zinc-500">分类:</span>
|
||||||
|
<span className="text-zinc-300">{t.category}</span>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="text-zinc-500">模型:</span>
|
||||||
|
<span className="text-zinc-300 font-mono">{t.model || '未指定'}</span>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="text-zinc-500">温度:</span>
|
||||||
|
<span className="text-zinc-300">{t.temperature?.toFixed(2) || '默认'}</span>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<span className="text-zinc-500">最大 Token:</span>
|
||||||
|
<span className="text-zinc-300">{t.max_tokens || '未限制'}</span>
|
||||||
|
</div>
|
||||||
|
<div className="col-span-2">
|
||||||
|
<span className="text-zinc-500">工具:</span>
|
||||||
|
<div className="flex flex-wrap gap-1 mt-1">
|
||||||
|
{t.tools.length > 0 ? t.tools.map(tool => (
|
||||||
|
<span key={tool} className="px-2 py-0.5 bg-zinc-800 rounded text-xs text-zinc-300">{tool}</span>
|
||||||
|
)) : <span className="text-zinc-600">无</span>}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="col-span-2">
|
||||||
|
<span className="text-zinc-500">能力:</span>
|
||||||
|
<div className="flex flex-wrap gap-1 mt-1">
|
||||||
|
{t.capabilities.length > 0 ? t.capabilities.map(cap => (
|
||||||
|
<span key={cap} className="px-2 py-0.5 bg-blue-500/10 rounded text-xs text-blue-400">{cap}</span>
|
||||||
|
)) : <span className="text-zinc-600">无</span>}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{t.system_prompt && (
|
||||||
|
<div className="col-span-2">
|
||||||
|
<span className="text-zinc-500">系统提示词:</span>
|
||||||
|
<pre className="text-xs text-zinc-400 bg-zinc-800/50 rounded p-2 mt-1 overflow-x-auto max-h-32">
|
||||||
|
{t.system_prompt.substring(0, 500)}{t.system_prompt.length > 500 ? '...' : ''}
|
||||||
|
</pre>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})()}
|
||||||
|
|
||||||
|
{/* Create Modal */}
|
||||||
|
{showCreate && (
|
||||||
|
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||||
|
<form onSubmit={handleCreate} className="bg-zinc-900 rounded-xl border border-zinc-700 p-6 w-full max-w-lg space-y-4 max-h-[80vh] overflow-y-auto">
|
||||||
|
<h2 className="text-lg font-semibold text-white">新建 Agent 模板</h2>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">名称 *</label>
|
||||||
|
<input name="name" required className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="my_agent" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">描述</label>
|
||||||
|
<input name="description" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="可选" />
|
||||||
|
</div>
|
||||||
|
<div className="grid grid-cols-2 gap-4">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">分类</label>
|
||||||
|
<select name="category" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm">
|
||||||
|
<option value="general">通用</option>
|
||||||
|
<option value="coding">编程</option>
|
||||||
|
<option value="research">研究</option>
|
||||||
|
<option value="creative">创意</option>
|
||||||
|
<option value="assistant">助手</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">模型</label>
|
||||||
|
<input name="model" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="如 glm-4-plus" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">系统提示词</label>
|
||||||
|
<textarea name="system_prompt" rows={4} className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm font-mono" placeholder="Agent 系统提示词" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">工具(逗号分隔)</label>
|
||||||
|
<input name="tools" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="browser, file_system, code_execute" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">能力(逗号分隔)</label>
|
||||||
|
<input name="capabilities" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="streaming, vision, function_calling" />
|
||||||
|
</div>
|
||||||
|
<div className="grid grid-cols-3 gap-4">
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">温度</label>
|
||||||
|
<input name="temperature" type="number" step="0.1" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="默认" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">最大 Token</label>
|
||||||
|
<input name="max_tokens" type="number" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="不限" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">可见性</label>
|
||||||
|
<select name="visibility" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm">
|
||||||
|
<option value="public">公开</option>
|
||||||
|
<option value="team">团队</option>
|
||||||
|
<option value="private">私有</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-2 justify-end">
|
||||||
|
<button type="button" onClick={() => setShowCreate(false)} className="px-4 py-2 bg-zinc-700 text-white rounded-lg hover:bg-zinc-600 text-sm">取消</button>
|
||||||
|
<button type="submit" className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 text-sm">创建</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useEffect, useState, useCallback } from 'react'
|
import { useState } from 'react'
|
||||||
|
import useSWR from 'swr'
|
||||||
import {
|
import {
|
||||||
Plus,
|
Plus,
|
||||||
Loader2,
|
Loader2,
|
||||||
@@ -32,8 +33,10 @@ import {
|
|||||||
DialogDescription,
|
DialogDescription,
|
||||||
} from '@/components/ui/dialog'
|
} from '@/components/ui/dialog'
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
|
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatDate } from '@/lib/utils'
|
import { formatDate } from '@/lib/utils'
|
||||||
|
import { TableSkeleton } from '@/components/ui/skeleton'
|
||||||
import type { TokenInfo } from '@/lib/types'
|
import type { TokenInfo } from '@/lib/types'
|
||||||
|
|
||||||
const PAGE_SIZE = 20
|
const PAGE_SIZE = 20
|
||||||
@@ -45,11 +48,17 @@ const allPermissions = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
export default function ApiKeysPage() {
|
export default function ApiKeysPage() {
|
||||||
const [tokens, setTokens] = useState<TokenInfo[]>([])
|
|
||||||
const [total, setTotal] = useState(0)
|
|
||||||
const [page, setPage] = useState(1)
|
const [page, setPage] = useState(1)
|
||||||
const [loading, setLoading] = useState(true)
|
const [mutationError, setMutationError] = useState('')
|
||||||
const [error, setError] = useState('')
|
|
||||||
|
const { data, error: swrError, isLoading, mutate } = useSWR(
|
||||||
|
['tokens', page],
|
||||||
|
() => api.tokens.list({ page, page_size: PAGE_SIZE }),
|
||||||
|
)
|
||||||
|
|
||||||
|
const tokens = data?.items ?? []
|
||||||
|
const total = data?.total ?? 0
|
||||||
|
const error = swrError?.message || mutationError
|
||||||
|
|
||||||
// 创建 Dialog
|
// 创建 Dialog
|
||||||
const [createOpen, setCreateOpen] = useState(false)
|
const [createOpen, setCreateOpen] = useState(false)
|
||||||
@@ -64,25 +73,6 @@ export default function ApiKeysPage() {
|
|||||||
const [revokeTarget, setRevokeTarget] = useState<TokenInfo | null>(null)
|
const [revokeTarget, setRevokeTarget] = useState<TokenInfo | null>(null)
|
||||||
const [revoking, setRevoking] = useState(false)
|
const [revoking, setRevoking] = useState(false)
|
||||||
|
|
||||||
const fetchTokens = useCallback(async () => {
|
|
||||||
setLoading(true)
|
|
||||||
setError('')
|
|
||||||
try {
|
|
||||||
const res = await api.tokens.list({ page, page_size: PAGE_SIZE })
|
|
||||||
setTokens(res.items)
|
|
||||||
setTotal(res.total)
|
|
||||||
} catch (err) {
|
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
|
||||||
else setError('加载失败')
|
|
||||||
} finally {
|
|
||||||
setLoading(false)
|
|
||||||
}
|
|
||||||
}, [page])
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
fetchTokens()
|
|
||||||
}, [fetchTokens])
|
|
||||||
|
|
||||||
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
||||||
|
|
||||||
function togglePermission(perm: string) {
|
function togglePermission(perm: string) {
|
||||||
@@ -107,9 +97,9 @@ export default function ApiKeysPage() {
|
|||||||
setCreateOpen(false)
|
setCreateOpen(false)
|
||||||
setCreatedToken(res)
|
setCreatedToken(res)
|
||||||
setCreateForm({ name: '', expires_days: '', permissions: ['chat'] })
|
setCreateForm({ name: '', expires_days: '', permissions: ['chat'] })
|
||||||
fetchTokens()
|
mutate()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setMutationError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
setCreating(false)
|
setCreating(false)
|
||||||
}
|
}
|
||||||
@@ -121,9 +111,9 @@ export default function ApiKeysPage() {
|
|||||||
try {
|
try {
|
||||||
await api.tokens.revoke(revokeTarget.id)
|
await api.tokens.revoke(revokeTarget.id)
|
||||||
setRevokeTarget(null)
|
setRevokeTarget(null)
|
||||||
fetchTokens()
|
mutate()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setMutationError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
setRevoking(false)
|
setRevoking(false)
|
||||||
}
|
}
|
||||||
@@ -158,21 +148,12 @@ export default function ApiKeysPage() {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{error && (
|
{error && <ErrorBanner message={error} onDismiss={() => setMutationError('')} />}
|
||||||
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
|
||||||
{error}
|
|
||||||
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{loading ? (
|
{isLoading ? (
|
||||||
<div className="flex h-64 items-center justify-center">
|
<TableSkeleton rows={6} cols={7} />
|
||||||
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
) : error ? null : tokens.length === 0 ? (
|
||||||
</div>
|
<EmptyState />
|
||||||
) : tokens.length === 0 ? (
|
|
||||||
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
|
||||||
暂无数据
|
|
||||||
</div>
|
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Table>
|
<Table>
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useEffect, useState, useCallback } from 'react'
|
import { useState } from 'react'
|
||||||
|
import useSWR from 'swr'
|
||||||
import {
|
import {
|
||||||
Loader2,
|
Loader2,
|
||||||
Pencil,
|
Pencil,
|
||||||
@@ -35,6 +36,8 @@ import {
|
|||||||
} from '@/components/ui/dialog'
|
} from '@/components/ui/dialog'
|
||||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
|
import { TableSkeleton } from '@/components/ui/skeleton'
|
||||||
|
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import type { ConfigItem } from '@/lib/types'
|
import type { ConfigItem } from '@/lib/types'
|
||||||
|
|
||||||
@@ -51,36 +54,24 @@ const sourceVariants: Record<string, 'secondary' | 'info' | 'default'> = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function ConfigPage() {
|
export default function ConfigPage() {
|
||||||
const [configs, setConfigs] = useState<ConfigItem[]>([])
|
|
||||||
const [loading, setLoading] = useState(true)
|
|
||||||
const [error, setError] = useState('')
|
const [error, setError] = useState('')
|
||||||
const [activeTab, setActiveTab] = useState('all')
|
const [activeTab, setActiveTab] = useState('all')
|
||||||
|
|
||||||
|
// SWR for config list
|
||||||
|
const { data: configs = [], isLoading, mutate } = useSWR(
|
||||||
|
['config', activeTab],
|
||||||
|
() => {
|
||||||
|
const params: Record<string, unknown> = {}
|
||||||
|
if (activeTab !== 'all') params.category = activeTab
|
||||||
|
return api.config.list(params)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
// 编辑 Dialog
|
// 编辑 Dialog
|
||||||
const [editTarget, setEditTarget] = useState<ConfigItem | null>(null)
|
const [editTarget, setEditTarget] = useState<ConfigItem | null>(null)
|
||||||
const [editValue, setEditValue] = useState('')
|
const [editValue, setEditValue] = useState('')
|
||||||
const [saving, setSaving] = useState(false)
|
const [saving, setSaving] = useState(false)
|
||||||
|
|
||||||
const fetchConfigs = useCallback(async (category?: string) => {
|
|
||||||
setLoading(true)
|
|
||||||
setError('')
|
|
||||||
try {
|
|
||||||
const params: Record<string, unknown> = {}
|
|
||||||
if (category && category !== 'all') params.category = category
|
|
||||||
const res = await api.config.list(params)
|
|
||||||
setConfigs(res)
|
|
||||||
} catch (err) {
|
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
|
||||||
else setError('加载失败')
|
|
||||||
} finally {
|
|
||||||
setLoading(false)
|
|
||||||
}
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
fetchConfigs(activeTab)
|
|
||||||
}, [fetchConfigs, activeTab])
|
|
||||||
|
|
||||||
function openEditDialog(config: ConfigItem) {
|
function openEditDialog(config: ConfigItem) {
|
||||||
setEditTarget(config)
|
setEditTarget(config)
|
||||||
setEditValue(config.current_value !== undefined ? String(config.current_value) : '')
|
setEditValue(config.current_value !== undefined ? String(config.current_value) : '')
|
||||||
@@ -98,7 +89,7 @@ export default function ConfigPage() {
|
|||||||
}
|
}
|
||||||
await api.config.update(editTarget.id, { value: parsedValue })
|
await api.config.update(editTarget.id, { value: parsedValue })
|
||||||
setEditTarget(null)
|
setEditTarget(null)
|
||||||
fetchConfigs(activeTab)
|
mutate()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
@@ -112,7 +103,15 @@ export default function ConfigPage() {
|
|||||||
return String(value)
|
return String(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
const categories = ['all', 'auth', 'relay', 'model', 'system']
|
const categoryLabels: Record<string, string> = {
|
||||||
|
all: '全部',
|
||||||
|
server: '服务器',
|
||||||
|
agent: 'Agent',
|
||||||
|
memory: '记忆',
|
||||||
|
llm: 'LLM',
|
||||||
|
security: '安全策略',
|
||||||
|
}
|
||||||
|
const categories = Object.keys(categoryLabels)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-4">
|
<div className="space-y-4">
|
||||||
@@ -121,27 +120,18 @@ export default function ConfigPage() {
|
|||||||
<TabsList>
|
<TabsList>
|
||||||
{categories.map((cat) => (
|
{categories.map((cat) => (
|
||||||
<TabsTrigger key={cat} value={cat}>
|
<TabsTrigger key={cat} value={cat}>
|
||||||
{cat === 'all' ? '全部' : cat}
|
{categoryLabels[cat] || cat}
|
||||||
</TabsTrigger>
|
</TabsTrigger>
|
||||||
))}
|
))}
|
||||||
</TabsList>
|
</TabsList>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
{error && (
|
{error && <ErrorBanner message={error} onDismiss={() => setError('')} />}
|
||||||
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
|
||||||
{error}
|
|
||||||
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{loading ? (
|
{isLoading ? (
|
||||||
<div className="flex h-64 items-center justify-center">
|
<TableSkeleton rows={8} cols={8} hasToolbar={false} />
|
||||||
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
) : error ? null : configs.length === 0 ? (
|
||||||
</div>
|
<EmptyState message="暂无配置项" />
|
||||||
) : configs.length === 0 ? (
|
|
||||||
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
|
||||||
暂无配置项
|
|
||||||
</div>
|
|
||||||
) : (
|
) : (
|
||||||
<Table>
|
<Table>
|
||||||
<TableHeader>
|
<TableHeader>
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import {
|
|||||||
ArrowLeftRight,
|
ArrowLeftRight,
|
||||||
Settings,
|
Settings,
|
||||||
FileText,
|
FileText,
|
||||||
|
MessageSquare,
|
||||||
|
Bot,
|
||||||
LogOut,
|
LogOut,
|
||||||
ChevronLeft,
|
ChevronLeft,
|
||||||
Menu,
|
Menu,
|
||||||
@@ -22,16 +24,44 @@ import { AuthGuard, useAuth } from '@/components/auth-guard'
|
|||||||
import { logout } from '@/lib/auth'
|
import { logout } from '@/lib/auth'
|
||||||
import { cn } from '@/lib/utils'
|
import { cn } from '@/lib/utils'
|
||||||
|
|
||||||
|
/** 权限常量 — 与后端 db.rs SEED_ROLES 保持同步 */
|
||||||
|
const ROLE_PERMISSIONS: Record<string, string[]> = {
|
||||||
|
super_admin: ['admin:full', 'account:admin', 'provider:manage', 'model:manage', 'relay:admin', 'config:write', 'prompt:read', 'prompt:write', 'prompt:publish', 'prompt:admin'],
|
||||||
|
admin: ['account:read', 'account:admin', 'provider:manage', 'model:read', 'model:manage', 'relay:use', 'relay:admin', 'config:read', 'config:write', 'prompt:read', 'prompt:write', 'prompt:publish'],
|
||||||
|
user: ['model:read', 'relay:use', 'config:read', 'prompt:read'],
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 从后端获取权限列表(运行时同步) */
|
||||||
|
async function fetchRolePermissions(role: string): Promise<string[]> {
|
||||||
|
try {
|
||||||
|
const res = await fetch('/api/v1/roles/' + role)
|
||||||
|
if (res.ok) {
|
||||||
|
const data = await res.json()
|
||||||
|
return data.permissions || []
|
||||||
|
}
|
||||||
|
return ROLE_PERMISSIONS[role] ?? []
|
||||||
|
} catch {
|
||||||
|
return ROLE_PERMISSIONS[role] ?? []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 根据 role 获取权限列表 */
|
||||||
|
function getPermissionsForRole(role: string): string[] {
|
||||||
|
return ROLE_PERMISSIONS[role] ?? []
|
||||||
|
}
|
||||||
|
|
||||||
const navItems = [
|
const navItems = [
|
||||||
{ href: '/', label: '仪表盘', icon: LayoutDashboard },
|
{ href: '/', label: '仪表盘', icon: LayoutDashboard },
|
||||||
{ href: '/accounts', label: '账号管理', icon: Users },
|
{ href: '/accounts', label: '账号管理', icon: Users, permission: 'account:admin' },
|
||||||
{ href: '/providers', label: '服务商', icon: Server },
|
{ href: '/providers', label: '服务商', icon: Server, permission: 'provider:manage' },
|
||||||
{ href: '/models', label: '模型管理', icon: Cpu },
|
{ href: '/models', label: '模型管理', icon: Cpu, permission: 'model:read' },
|
||||||
{ href: '/api-keys', label: 'API 密钥', icon: Key },
|
{ href: '/agent-templates', label: 'Agent 模板', icon: Bot, permission: 'model:read' },
|
||||||
{ href: '/usage', label: '用量统计', icon: BarChart3 },
|
{ href: '/api-keys', label: 'API 密钥', icon: Key, permission: 'admin:full' },
|
||||||
{ href: '/relay', label: '中转任务', icon: ArrowLeftRight },
|
{ href: '/usage', label: '用量统计', icon: BarChart3, permission: 'admin:full' },
|
||||||
{ href: '/config', label: '系统配置', icon: Settings },
|
{ href: '/relay', label: '中转任务', icon: ArrowLeftRight, permission: 'relay:use' },
|
||||||
{ href: '/logs', label: '操作日志', icon: FileText },
|
{ href: '/config', label: '系统配置', icon: Settings, permission: 'config:read' },
|
||||||
|
{ href: '/prompts', label: '提示词管理', icon: MessageSquare, permission: 'prompt:read' },
|
||||||
|
{ href: '/logs', label: '操作日志', icon: FileText, permission: 'admin:full' },
|
||||||
]
|
]
|
||||||
|
|
||||||
function Sidebar({
|
function Sidebar({
|
||||||
@@ -45,11 +75,18 @@ function Sidebar({
|
|||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
const { account } = useAuth()
|
const { account } = useAuth()
|
||||||
|
|
||||||
|
const permissions = account ? getPermissionsForRole(account.role) : []
|
||||||
|
|
||||||
function handleLogout() {
|
function handleLogout() {
|
||||||
logout()
|
logout()
|
||||||
router.replace('/login')
|
router.replace('/login')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const filteredNavItems = navItems.filter((item) => {
|
||||||
|
if (!item.permission) return true
|
||||||
|
return permissions.includes(item.permission) || permissions.includes('admin:full')
|
||||||
|
})
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<aside
|
<aside
|
||||||
className={cn(
|
className={cn(
|
||||||
@@ -75,7 +112,7 @@ function Sidebar({
|
|||||||
{/* 导航 */}
|
{/* 导航 */}
|
||||||
<nav className="flex-1 overflow-y-auto scrollbar-thin py-2 px-2">
|
<nav className="flex-1 overflow-y-auto scrollbar-thin py-2 px-2">
|
||||||
<ul className="space-y-1">
|
<ul className="space-y-1">
|
||||||
{navItems.map((item) => {
|
{filteredNavItems.map((item) => {
|
||||||
const isActive =
|
const isActive =
|
||||||
item.href === '/'
|
item.href === '/'
|
||||||
? pathname === '/'
|
? pathname === '/'
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useEffect, useState, useCallback } from 'react'
|
import { useState } from 'react'
|
||||||
|
import useSWR from 'swr'
|
||||||
import {
|
import {
|
||||||
Plus,
|
Plus,
|
||||||
Loader2,
|
Loader2,
|
||||||
@@ -37,6 +38,8 @@ import {
|
|||||||
SelectTrigger,
|
SelectTrigger,
|
||||||
SelectValue,
|
SelectValue,
|
||||||
} from '@/components/ui/select'
|
} from '@/components/ui/select'
|
||||||
|
import { TableSkeleton } from '@/components/ui/skeleton'
|
||||||
|
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatNumber } from '@/lib/utils'
|
import { formatNumber } from '@/lib/utils'
|
||||||
@@ -71,14 +74,29 @@ const emptyForm: ModelForm = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function ModelsPage() {
|
export default function ModelsPage() {
|
||||||
const [models, setModels] = useState<Model[]>([])
|
|
||||||
const [providers, setProviders] = useState<Provider[]>([])
|
|
||||||
const [total, setTotal] = useState(0)
|
|
||||||
const [page, setPage] = useState(1)
|
const [page, setPage] = useState(1)
|
||||||
const [providerFilter, setProviderFilter] = useState<string>('all')
|
const [providerFilter, setProviderFilter] = useState<string>('all')
|
||||||
const [loading, setLoading] = useState(true)
|
|
||||||
const [error, setError] = useState('')
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
|
// SWR for models list
|
||||||
|
const { data, isLoading, mutate } = useSWR(
|
||||||
|
['models', page, providerFilter],
|
||||||
|
() => {
|
||||||
|
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
||||||
|
if (providerFilter !== 'all') params.provider_id = providerFilter
|
||||||
|
return api.models.list(params)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
const models = data?.items ?? []
|
||||||
|
const total = data?.total ?? 0
|
||||||
|
|
||||||
|
// SWR for providers list (dropdown)
|
||||||
|
const { data: providersData } = useSWR(
|
||||||
|
['providers.all'],
|
||||||
|
() => api.providers.list({ page: 1, page_size: 100 })
|
||||||
|
)
|
||||||
|
const providers = providersData?.items ?? []
|
||||||
|
|
||||||
// Dialog
|
// Dialog
|
||||||
const [dialogOpen, setDialogOpen] = useState(false)
|
const [dialogOpen, setDialogOpen] = useState(false)
|
||||||
const [editTarget, setEditTarget] = useState<Model | null>(null)
|
const [editTarget, setEditTarget] = useState<Model | null>(null)
|
||||||
@@ -89,37 +107,6 @@ export default function ModelsPage() {
|
|||||||
const [deleteTarget, setDeleteTarget] = useState<Model | null>(null)
|
const [deleteTarget, setDeleteTarget] = useState<Model | null>(null)
|
||||||
const [deleting, setDeleting] = useState(false)
|
const [deleting, setDeleting] = useState(false)
|
||||||
|
|
||||||
const fetchModels = useCallback(async () => {
|
|
||||||
setLoading(true)
|
|
||||||
setError('')
|
|
||||||
try {
|
|
||||||
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
|
||||||
if (providerFilter !== 'all') params.provider_id = providerFilter
|
|
||||||
const res = await api.models.list(params)
|
|
||||||
setModels(res.items)
|
|
||||||
setTotal(res.total)
|
|
||||||
} catch (err) {
|
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
|
||||||
else setError('加载失败')
|
|
||||||
} finally {
|
|
||||||
setLoading(false)
|
|
||||||
}
|
|
||||||
}, [page, providerFilter])
|
|
||||||
|
|
||||||
const fetchProviders = useCallback(async () => {
|
|
||||||
try {
|
|
||||||
const res = await api.providers.list({ page: 1, page_size: 100 })
|
|
||||||
setProviders(res.items)
|
|
||||||
} catch {
|
|
||||||
// ignore
|
|
||||||
}
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
fetchModels()
|
|
||||||
fetchProviders()
|
|
||||||
}, [fetchModels, fetchProviders])
|
|
||||||
|
|
||||||
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
||||||
|
|
||||||
const providerMap = new Map(providers.map((p) => [p.id, p.display_name || p.name]))
|
const providerMap = new Map(providers.map((p) => [p.id, p.display_name || p.name]))
|
||||||
@@ -169,7 +156,7 @@ export default function ModelsPage() {
|
|||||||
await api.models.create(payload)
|
await api.models.create(payload)
|
||||||
}
|
}
|
||||||
setDialogOpen(false)
|
setDialogOpen(false)
|
||||||
fetchModels()
|
mutate()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
@@ -183,7 +170,7 @@ export default function ModelsPage() {
|
|||||||
try {
|
try {
|
||||||
await api.models.delete(deleteTarget.id)
|
await api.models.delete(deleteTarget.id)
|
||||||
setDeleteTarget(null)
|
setDeleteTarget(null)
|
||||||
fetchModels()
|
mutate()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
@@ -213,21 +200,12 @@ export default function ModelsPage() {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{error && (
|
{error && <ErrorBanner message={error} onDismiss={() => setError('')} />}
|
||||||
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
|
||||||
{error}
|
|
||||||
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{loading ? (
|
{isLoading ? (
|
||||||
<div className="flex h-64 items-center justify-center">
|
<TableSkeleton rows={8} cols={9} hasToolbar={false} />
|
||||||
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
) : error ? null : models.length === 0 ? (
|
||||||
</div>
|
<EmptyState />
|
||||||
) : models.length === 0 ? (
|
|
||||||
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
|
||||||
暂无数据
|
|
||||||
</div>
|
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Table>
|
<Table>
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useEffect, useState } from 'react'
|
|
||||||
import {
|
import {
|
||||||
Users,
|
Users,
|
||||||
Server,
|
Server,
|
||||||
ArrowLeftRight,
|
ArrowLeftRight,
|
||||||
Zap,
|
Zap,
|
||||||
Loader2,
|
|
||||||
TrendingUp,
|
TrendingUp,
|
||||||
} from 'lucide-react'
|
} from 'lucide-react'
|
||||||
import {
|
import {
|
||||||
@@ -21,8 +19,12 @@ import {
|
|||||||
Bar,
|
Bar,
|
||||||
Legend,
|
Legend,
|
||||||
} from 'recharts'
|
} from 'recharts'
|
||||||
|
import useSWR from 'swr'
|
||||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
|
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
|
||||||
import { Badge } from '@/components/ui/badge'
|
import { Badge } from '@/components/ui/badge'
|
||||||
|
import { StatsSkeleton } from '@/components/ui/skeleton'
|
||||||
|
import { ChartSkeleton } from '@/components/ui/skeleton'
|
||||||
|
import { TableSkeleton } from '@/components/ui/skeleton'
|
||||||
import {
|
import {
|
||||||
Table,
|
Table,
|
||||||
TableBody,
|
TableBody,
|
||||||
@@ -86,61 +88,24 @@ function StatusBadge({ status }: { status: string }) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function DashboardPage() {
|
export default function DashboardPage() {
|
||||||
const [stats, setStats] = useState<DashboardStats | null>(null)
|
const { data: stats, isLoading: statsLoading } = useSWR(
|
||||||
const [usageData, setUsageData] = useState<UsageRecord[]>([])
|
['stats.dashboard'],
|
||||||
const [recentLogs, setRecentLogs] = useState<OperationLog[]>([])
|
() => api.stats.dashboard(),
|
||||||
const [loading, setLoading] = useState(true)
|
|
||||||
const [error, setError] = useState('')
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
async function fetchData() {
|
|
||||||
try {
|
|
||||||
const [statsRes, usageRes, logsRes] = await Promise.allSettled([
|
|
||||||
api.stats.dashboard(),
|
|
||||||
api.usage.daily({ days: 30 }),
|
|
||||||
api.logs.list({ page: 1, page_size: 5 }),
|
|
||||||
])
|
|
||||||
|
|
||||||
if (statsRes.status === 'fulfilled') setStats(statsRes.value)
|
|
||||||
if (usageRes.status === 'fulfilled') setUsageData(usageRes.value)
|
|
||||||
if (logsRes.status === 'fulfilled') setRecentLogs(logsRes.value.items)
|
|
||||||
} catch (err) {
|
|
||||||
setError('加载数据失败,请检查后端服务是否启动')
|
|
||||||
} finally {
|
|
||||||
setLoading(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fetchData()
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
if (loading) {
|
|
||||||
return (
|
|
||||||
<div className="flex h-[60vh] items-center justify-center">
|
|
||||||
<div className="flex flex-col items-center gap-3">
|
|
||||||
<Loader2 className="h-8 w-8 animate-spin text-primary" />
|
|
||||||
<p className="text-sm text-muted-foreground">加载中...</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
if (error) {
|
const { data: usageData = [], isLoading: usageLoading } = useSWR(
|
||||||
return (
|
['usage.daily.30'],
|
||||||
<div className="flex h-[60vh] items-center justify-center">
|
() => api.usage.daily({ days: 30 }),
|
||||||
<div className="text-center">
|
|
||||||
<p className="text-destructive">{error}</p>
|
|
||||||
<button
|
|
||||||
onClick={() => window.location.reload()}
|
|
||||||
className="mt-4 text-sm text-primary hover:underline cursor-pointer"
|
|
||||||
>
|
|
||||||
重新加载
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)
|
)
|
||||||
}
|
|
||||||
|
|
||||||
const chartData = usageData.map((r) => ({
|
const { data: logsData, isLoading: logsLoading } = useSWR(
|
||||||
|
['logs.recent'],
|
||||||
|
() => api.logs.list({ page: 1, page_size: 5 }),
|
||||||
|
)
|
||||||
|
|
||||||
|
const recentLogs: OperationLog[] = logsData?.items ?? []
|
||||||
|
|
||||||
|
const chartData = usageData.map((r: UsageRecord) => ({
|
||||||
day: r.day.slice(5), // MM-DD
|
day: r.day.slice(5), // MM-DD
|
||||||
请求量: r.count,
|
请求量: r.count,
|
||||||
Input: r.input_tokens,
|
Input: r.input_tokens,
|
||||||
@@ -150,6 +115,9 @@ export default function DashboardPage() {
|
|||||||
return (
|
return (
|
||||||
<div className="space-y-6">
|
<div className="space-y-6">
|
||||||
{/* 统计卡片 */}
|
{/* 统计卡片 */}
|
||||||
|
{statsLoading ? (
|
||||||
|
<StatsSkeleton count={4} />
|
||||||
|
) : (
|
||||||
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-4">
|
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-4">
|
||||||
<StatCard
|
<StatCard
|
||||||
title="总账号数"
|
title="总账号数"
|
||||||
@@ -180,10 +148,14 @@ export default function DashboardPage() {
|
|||||||
subtitle={`In: ${formatNumber(stats?.tokens_today_input ?? 0)} / Out: ${formatNumber(stats?.tokens_today_output ?? 0)}`}
|
subtitle={`In: ${formatNumber(stats?.tokens_today_input ?? 0)} / Out: ${formatNumber(stats?.tokens_today_output ?? 0)}`}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* 图表 */}
|
{/* 图表 */}
|
||||||
<div className="grid grid-cols-1 gap-4 lg:grid-cols-2">
|
<div className="grid grid-cols-1 gap-4 lg:grid-cols-2">
|
||||||
{/* 请求趋势 */}
|
{/* 请求趋势 */}
|
||||||
|
{usageLoading ? (
|
||||||
|
<ChartSkeleton height={280} />
|
||||||
|
) : (
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle className="flex items-center gap-2 text-base">
|
<CardTitle className="flex items-center gap-2 text-base">
|
||||||
@@ -237,8 +209,12 @@ export default function DashboardPage() {
|
|||||||
)}
|
)}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* Token 用量 */}
|
{/* Token 用量 */}
|
||||||
|
{usageLoading ? (
|
||||||
|
<ChartSkeleton height={280} />
|
||||||
|
) : (
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle className="flex items-center gap-2 text-base">
|
<CardTitle className="flex items-center gap-2 text-base">
|
||||||
@@ -283,6 +259,7 @@ export default function DashboardPage() {
|
|||||||
)}
|
)}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* 最近操作日志 */}
|
{/* 最近操作日志 */}
|
||||||
@@ -291,7 +268,9 @@ export default function DashboardPage() {
|
|||||||
<CardTitle className="text-base">最近操作</CardTitle>
|
<CardTitle className="text-base">最近操作</CardTitle>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
{recentLogs.length > 0 ? (
|
{logsLoading ? (
|
||||||
|
<TableSkeleton rows={5} cols={5} hasToolbar={false} />
|
||||||
|
) : recentLogs.length > 0 ? (
|
||||||
<Table>
|
<Table>
|
||||||
<TableHeader>
|
<TableHeader>
|
||||||
<TableRow>
|
<TableRow>
|
||||||
|
|||||||
341
admin/src/app/(dashboard)/prompts/page.tsx
Normal file
341
admin/src/app/(dashboard)/prompts/page.tsx
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
'use client'
|
||||||
|
|
||||||
|
import { useState } from 'react'
|
||||||
|
import useSWR from 'swr'
|
||||||
|
import { api } from '@/lib/api-client'
|
||||||
|
import type { PromptTemplate, PromptVersion } from '@/lib/types'
|
||||||
|
import { EmptyState } from '@/components/ui/state'
|
||||||
|
import { TableSkeleton } from '@/components/ui/skeleton'
|
||||||
|
|
||||||
|
export default function PromptsPage() {
|
||||||
|
const [page, setPage] = useState(1)
|
||||||
|
const [selectedName, setSelectedName] = useState<string | null>(null)
|
||||||
|
const [versions, setVersions] = useState<PromptVersion[]>([])
|
||||||
|
const [showCreate, setShowCreate] = useState(false)
|
||||||
|
const [showNewVersion, setShowNewVersion] = useState(false)
|
||||||
|
const [filter, setFilter] = useState<{ source?: string; status?: string }>({})
|
||||||
|
|
||||||
|
const { data, error, isLoading, mutate } = useSWR(
|
||||||
|
['prompts.list', page, filter.source, filter.status],
|
||||||
|
() => api.prompts.list({ page, page_size: 50, ...filter }),
|
||||||
|
)
|
||||||
|
|
||||||
|
const templates = data?.items ?? []
|
||||||
|
const total = data?.total ?? 0
|
||||||
|
|
||||||
|
const fetchVersions = async (name: string) => {
|
||||||
|
try {
|
||||||
|
const res = await api.prompts.listVersions(name)
|
||||||
|
setVersions(res)
|
||||||
|
setSelectedName(name)
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Failed to fetch versions:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleCreate = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||||
|
e.preventDefault()
|
||||||
|
const fd = new FormData(e.currentTarget)
|
||||||
|
try {
|
||||||
|
await api.prompts.create({
|
||||||
|
name: fd.get('name') as string,
|
||||||
|
category: fd.get('category') as string,
|
||||||
|
description: (fd.get('description') as string) || undefined,
|
||||||
|
source: 'custom',
|
||||||
|
system_prompt: fd.get('system_prompt') as string,
|
||||||
|
})
|
||||||
|
setShowCreate(false)
|
||||||
|
mutate()
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Failed to create prompt:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleNewVersion = async (e: React.FormEvent<HTMLFormElement>) => {
|
||||||
|
e.preventDefault()
|
||||||
|
if (!selectedName) return
|
||||||
|
const fd = new FormData(e.currentTarget)
|
||||||
|
try {
|
||||||
|
await api.prompts.createVersion(selectedName, {
|
||||||
|
system_prompt: fd.get('system_prompt') as string,
|
||||||
|
changelog: (fd.get('changelog') as string) || undefined,
|
||||||
|
})
|
||||||
|
setShowNewVersion(false)
|
||||||
|
fetchVersions(selectedName)
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Failed to create version:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleRollback = async (name: string, version: number) => {
|
||||||
|
if (!confirm(`确认回退到版本 ${version}?`)) return
|
||||||
|
try {
|
||||||
|
await api.prompts.rollback(name, version)
|
||||||
|
fetchVersions(name)
|
||||||
|
mutate()
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Failed to rollback:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const handleArchive = async (name: string) => {
|
||||||
|
if (!confirm(`确认归档 ${name}?`)) return
|
||||||
|
try {
|
||||||
|
await api.prompts.archive(name)
|
||||||
|
mutate()
|
||||||
|
} catch (err) {
|
||||||
|
console.error('Failed to archive:', err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const statusBadge = (status: string) => {
|
||||||
|
const colors: Record<string, string> = {
|
||||||
|
active: 'bg-emerald-500/20 text-emerald-400',
|
||||||
|
deprecated: 'bg-amber-500/20 text-amber-400',
|
||||||
|
archived: 'bg-zinc-500/20 text-zinc-400',
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<span className={`px-2 py-0.5 text-xs rounded-full ${colors[status] || colors.archived}`}>
|
||||||
|
{status}
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const sourceBadge = (source: string) => {
|
||||||
|
const colors: Record<string, string> = {
|
||||||
|
builtin: 'bg-blue-500/20 text-blue-400',
|
||||||
|
custom: 'bg-purple-500/20 text-purple-400',
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<span className={`px-2 py-0.5 text-xs rounded-full ${colors[source] || ''}`}>
|
||||||
|
{source === 'builtin' ? '内置' : '自定义'}
|
||||||
|
</span>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-6">
|
||||||
|
{/* Header */}
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<h1 className="text-2xl font-bold text-white">提示词管理</h1>
|
||||||
|
<p className="text-sm text-zinc-400 mt-1">管理内置和自定义提示词模板,支持版本控制和 OTA 分发</p>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={() => setShowCreate(true)}
|
||||||
|
className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors text-sm"
|
||||||
|
>
|
||||||
|
+ 新建模板
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Filters */}
|
||||||
|
<div className="flex gap-2">
|
||||||
|
{(['all', 'builtin', 'custom'] as const).map(s => (
|
||||||
|
<button
|
||||||
|
key={s}
|
||||||
|
onClick={() => setFilter(s === 'all' ? {} : { source: s })}
|
||||||
|
className={`px-3 py-1 text-sm rounded-lg transition-colors ${
|
||||||
|
(filter.source || 'all') === s
|
||||||
|
? 'bg-zinc-700 text-white'
|
||||||
|
: 'bg-zinc-800 text-zinc-400 hover:text-white'
|
||||||
|
}`}
|
||||||
|
>
|
||||||
|
{s === 'all' ? '全部' : s === 'builtin' ? '内置' : '自定义'}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Template List */}
|
||||||
|
<div className="bg-zinc-900 rounded-xl border border-zinc-800 overflow-hidden">
|
||||||
|
<table className="w-full text-sm">
|
||||||
|
<thead>
|
||||||
|
<tr className="border-b border-zinc-800">
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">名称</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">分类</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">来源</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">版本</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">状态</th>
|
||||||
|
<th className="text-left px-4 py-3 text-zinc-400 font-medium">更新时间</th>
|
||||||
|
<th className="text-right px-4 py-3 text-zinc-400 font-medium">操作</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody>
|
||||||
|
{isLoading ? (
|
||||||
|
<tr>
|
||||||
|
<td colSpan={7}>
|
||||||
|
<TableSkeleton rows={5} cols={7} hasToolbar={false} />
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
) : error ? (
|
||||||
|
<tr><td colSpan={7} className="px-4 py-8 text-center text-red-400">加载失败</td></tr>
|
||||||
|
) : templates.length === 0 ? (
|
||||||
|
<tr><td colSpan={7}><EmptyState message="暂无提示词模板" /></td></tr>
|
||||||
|
) : (
|
||||||
|
templates.map(t => (
|
||||||
|
<tr key={t.id} className="border-b border-zinc-800/50 hover:bg-zinc-800/30">
|
||||||
|
<td className="px-4 py-3">
|
||||||
|
<button
|
||||||
|
onClick={() => fetchVersions(t.name)}
|
||||||
|
className="text-blue-400 hover:text-blue-300 font-mono"
|
||||||
|
>
|
||||||
|
{t.name}
|
||||||
|
</button>
|
||||||
|
</td>
|
||||||
|
<td className="px-4 py-3 text-zinc-400">{t.category}</td>
|
||||||
|
<td className="px-4 py-3">{sourceBadge(t.source)}</td>
|
||||||
|
<td className="px-4 py-3 text-zinc-300">v{t.current_version}</td>
|
||||||
|
<td className="px-4 py-3">{statusBadge(t.status)}</td>
|
||||||
|
<td className="px-4 py-3 text-zinc-500 text-xs">
|
||||||
|
{new Date(t.updated_at).toLocaleString('zh-CN')}
|
||||||
|
</td>
|
||||||
|
<td className="px-4 py-3 text-right">
|
||||||
|
<button
|
||||||
|
onClick={() => fetchVersions(t.name)}
|
||||||
|
className="text-zinc-400 hover:text-white mr-2"
|
||||||
|
>
|
||||||
|
历史
|
||||||
|
</button>
|
||||||
|
{t.source === 'custom' && (
|
||||||
|
<button
|
||||||
|
onClick={() => handleArchive(t.name)}
|
||||||
|
className="text-red-400 hover:text-red-300"
|
||||||
|
>
|
||||||
|
归档
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
))
|
||||||
|
)}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
<div className="px-4 py-2 text-xs text-zinc-500 border-t border-zinc-800">
|
||||||
|
共 {total} 个模板
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* Version History Panel */}
|
||||||
|
{selectedName && (
|
||||||
|
<div className="bg-zinc-900 rounded-xl border border-zinc-800 p-4">
|
||||||
|
<div className="flex items-center justify-between mb-4">
|
||||||
|
<h2 className="text-lg font-semibold text-white">
|
||||||
|
{selectedName} — 版本历史
|
||||||
|
</h2>
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<button
|
||||||
|
onClick={() => setShowNewVersion(true)}
|
||||||
|
className="px-3 py-1.5 bg-blue-600 text-white rounded-lg hover:bg-blue-700 text-xs"
|
||||||
|
>
|
||||||
|
发布新版本
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => { setSelectedName(null); setVersions([]) }}
|
||||||
|
className="px-3 py-1.5 bg-zinc-700 text-white rounded-lg hover:bg-zinc-600 text-xs"
|
||||||
|
>
|
||||||
|
关闭
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-3">
|
||||||
|
{versions.map(v => (
|
||||||
|
<div key={v.id} className="bg-zinc-800/50 rounded-lg p-3">
|
||||||
|
<div className="flex items-center justify-between mb-2">
|
||||||
|
<span className="text-sm font-mono text-zinc-300">v{v.version}</span>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<span className="text-xs text-zinc-500">
|
||||||
|
{new Date(v.created_at).toLocaleString('zh-CN')}
|
||||||
|
</span>
|
||||||
|
{v.changelog && (
|
||||||
|
<span className="text-xs text-zinc-400">— {v.changelog}</span>
|
||||||
|
)}
|
||||||
|
{v.min_app_version && (
|
||||||
|
<span className="text-xs text-amber-400">最低版本: {v.min_app_version}</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<pre className="text-xs text-zinc-400 bg-zinc-900 rounded p-2 overflow-x-auto max-h-32">
|
||||||
|
{v.system_prompt.substring(0, 300)}{v.system_prompt.length > 300 ? '...' : ''}
|
||||||
|
</pre>
|
||||||
|
<div className="mt-2 flex gap-2">
|
||||||
|
<button
|
||||||
|
onClick={() => {
|
||||||
|
navigator.clipboard.writeText(v.system_prompt)
|
||||||
|
}}
|
||||||
|
className="text-xs text-zinc-500 hover:text-white"
|
||||||
|
>
|
||||||
|
复制
|
||||||
|
</button>
|
||||||
|
<button
|
||||||
|
onClick={() => handleRollback(selectedName, v.version)}
|
||||||
|
className="text-xs text-amber-500 hover:text-amber-400"
|
||||||
|
>
|
||||||
|
回退到此版本
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
{versions.length === 0 && (
|
||||||
|
<EmptyState message="暂无版本历史" />
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* Create Modal */}
|
||||||
|
{showCreate && (
|
||||||
|
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||||
|
<form onSubmit={handleCreate} className="bg-zinc-900 rounded-xl border border-zinc-700 p-6 w-full max-w-lg space-y-4">
|
||||||
|
<h2 className="text-lg font-semibold text-white">新建提示词模板</h2>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">名称</label>
|
||||||
|
<input name="name" required className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="my_prompt" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">分类</label>
|
||||||
|
<select name="category" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm">
|
||||||
|
<option value="custom_system">系统提示词</option>
|
||||||
|
<option value="custom_extraction">提取提示词</option>
|
||||||
|
<option value="custom_compaction">压缩提示词</option>
|
||||||
|
<option value="custom_other">其他</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">描述</label>
|
||||||
|
<input name="description" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="可选" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">系统提示词</label>
|
||||||
|
<textarea name="system_prompt" required rows={6} className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm font-mono" />
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-2 justify-end">
|
||||||
|
<button type="button" onClick={() => setShowCreate(false)} className="px-4 py-2 bg-zinc-700 text-white rounded-lg hover:bg-zinc-600 text-sm">取消</button>
|
||||||
|
<button type="submit" className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 text-sm">创建</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* New Version Modal */}
|
||||||
|
{showNewVersion && selectedName && (
|
||||||
|
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||||
|
<form onSubmit={handleNewVersion} className="bg-zinc-900 rounded-xl border border-zinc-700 p-6 w-full max-w-lg space-y-4">
|
||||||
|
<h2 className="text-lg font-semibold text-white">发布 {selectedName} 新版本</h2>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">系统提示词</label>
|
||||||
|
<textarea name="system_prompt" required rows={6} className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm font-mono" />
|
||||||
|
</div>
|
||||||
|
<div>
|
||||||
|
<label className="block text-sm text-zinc-400 mb-1">变更说明</label>
|
||||||
|
<input name="changelog" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="描述本次变更" />
|
||||||
|
</div>
|
||||||
|
<div className="flex gap-2 justify-end">
|
||||||
|
<button type="button" onClick={() => setShowNewVersion(false)} className="px-4 py-2 bg-zinc-700 text-white rounded-lg hover:bg-zinc-600 text-sm">取消</button>
|
||||||
|
<button type="submit" className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 text-sm">发布</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useEffect, useState, useCallback } from 'react'
|
import { useState } from 'react'
|
||||||
|
import useSWR from 'swr'
|
||||||
import {
|
import {
|
||||||
Plus,
|
Plus,
|
||||||
Loader2,
|
Loader2,
|
||||||
@@ -8,6 +9,9 @@ import {
|
|||||||
ChevronRight,
|
ChevronRight,
|
||||||
Pencil,
|
Pencil,
|
||||||
Trash2,
|
Trash2,
|
||||||
|
KeyRound,
|
||||||
|
Power,
|
||||||
|
PowerOff,
|
||||||
} from 'lucide-react'
|
} from 'lucide-react'
|
||||||
import { Button } from '@/components/ui/button'
|
import { Button } from '@/components/ui/button'
|
||||||
import { Input } from '@/components/ui/input'
|
import { Input } from '@/components/ui/input'
|
||||||
@@ -37,10 +41,18 @@ import {
|
|||||||
SelectTrigger,
|
SelectTrigger,
|
||||||
SelectValue,
|
SelectValue,
|
||||||
} from '@/components/ui/select'
|
} from '@/components/ui/select'
|
||||||
|
import { TableSkeleton } from '@/components/ui/skeleton'
|
||||||
|
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatDate, maskApiKey } from '@/lib/utils'
|
import { formatDate, maskApiKey } from '@/lib/utils'
|
||||||
import type { Provider } from '@/lib/types'
|
|
||||||
|
function formatTokens(tokens: number): string {
|
||||||
|
if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`
|
||||||
|
if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(1)}K`
|
||||||
|
return String(tokens)
|
||||||
|
}
|
||||||
|
import type { Provider, ProviderKey } from '@/lib/types'
|
||||||
|
|
||||||
const PAGE_SIZE = 20
|
const PAGE_SIZE = 20
|
||||||
|
|
||||||
@@ -67,12 +79,17 @@ const emptyForm: ProviderForm = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function ProvidersPage() {
|
export default function ProvidersPage() {
|
||||||
const [providers, setProviders] = useState<Provider[]>([])
|
|
||||||
const [total, setTotal] = useState(0)
|
|
||||||
const [page, setPage] = useState(1)
|
const [page, setPage] = useState(1)
|
||||||
const [loading, setLoading] = useState(true)
|
|
||||||
const [error, setError] = useState('')
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
|
// SWR for providers list
|
||||||
|
const { data, isLoading, mutate } = useSWR(
|
||||||
|
['providers', page],
|
||||||
|
() => api.providers.list({ page, page_size: PAGE_SIZE })
|
||||||
|
)
|
||||||
|
const providers = data?.items ?? []
|
||||||
|
const total = data?.total ?? 0
|
||||||
|
|
||||||
// 创建/编辑 Dialog
|
// 创建/编辑 Dialog
|
||||||
const [dialogOpen, setDialogOpen] = useState(false)
|
const [dialogOpen, setDialogOpen] = useState(false)
|
||||||
const [editTarget, setEditTarget] = useState<Provider | null>(null)
|
const [editTarget, setEditTarget] = useState<Provider | null>(null)
|
||||||
@@ -83,24 +100,24 @@ export default function ProvidersPage() {
|
|||||||
const [deleteTarget, setDeleteTarget] = useState<Provider | null>(null)
|
const [deleteTarget, setDeleteTarget] = useState<Provider | null>(null)
|
||||||
const [deleting, setDeleting] = useState(false)
|
const [deleting, setDeleting] = useState(false)
|
||||||
|
|
||||||
const fetchProviders = useCallback(async () => {
|
// Key Pool 管理
|
||||||
setLoading(true)
|
const [keyPoolProvider, setKeyPoolProvider] = useState<Provider | null>(null)
|
||||||
setError('')
|
const [showAddKey, setShowAddKey] = useState(false)
|
||||||
try {
|
const [addKeyForm, setAddKeyForm] = useState({
|
||||||
const res = await api.providers.list({ page, page_size: PAGE_SIZE })
|
key_label: '',
|
||||||
setProviders(res.items)
|
key_value: '',
|
||||||
setTotal(res.total)
|
priority: 0,
|
||||||
} catch (err) {
|
max_rpm: '',
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
max_tpm: '',
|
||||||
else setError('加载失败')
|
quota_reset_interval: '',
|
||||||
} finally {
|
})
|
||||||
setLoading(false)
|
const [addingKey, setAddingKey] = useState(false)
|
||||||
}
|
|
||||||
}, [page])
|
|
||||||
|
|
||||||
useEffect(() => {
|
// SWR for key pool — only fetches when dialog is open
|
||||||
fetchProviders()
|
const { data: providerKeys = [], isLoading: keysLoading, mutate: mutateKeys } = useSWR(
|
||||||
}, [fetchProviders])
|
keyPoolProvider ? ['provider.keys', keyPoolProvider.id] : null,
|
||||||
|
() => api.providers.listKeys(keyPoolProvider!.id)
|
||||||
|
)
|
||||||
|
|
||||||
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
||||||
|
|
||||||
@@ -145,7 +162,7 @@ export default function ProvidersPage() {
|
|||||||
await api.providers.create(payload)
|
await api.providers.create(payload)
|
||||||
}
|
}
|
||||||
setDialogOpen(false)
|
setDialogOpen(false)
|
||||||
fetchProviders()
|
mutate()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
@@ -159,7 +176,7 @@ export default function ProvidersPage() {
|
|||||||
try {
|
try {
|
||||||
await api.providers.delete(deleteTarget.id)
|
await api.providers.delete(deleteTarget.id)
|
||||||
setDeleteTarget(null)
|
setDeleteTarget(null)
|
||||||
fetchProviders()
|
mutate()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
@@ -167,6 +184,55 @@ export default function ProvidersPage() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Key Pool 管理 ─────────────────────────────────────
|
||||||
|
|
||||||
|
function openKeyPool(provider: Provider) {
|
||||||
|
setKeyPoolProvider(provider)
|
||||||
|
setShowAddKey(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleAddKey() {
|
||||||
|
if (!keyPoolProvider || !addKeyForm.key_label.trim() || !addKeyForm.key_value.trim()) return
|
||||||
|
setAddingKey(true)
|
||||||
|
try {
|
||||||
|
await api.providers.addKey(keyPoolProvider.id, {
|
||||||
|
key_label: addKeyForm.key_label.trim(),
|
||||||
|
key_value: addKeyForm.key_value.trim(),
|
||||||
|
priority: addKeyForm.priority,
|
||||||
|
max_rpm: addKeyForm.max_rpm ? parseInt(addKeyForm.max_rpm, 10) : undefined,
|
||||||
|
max_tpm: addKeyForm.max_tpm ? parseInt(addKeyForm.max_tpm, 10) : undefined,
|
||||||
|
quota_reset_interval: addKeyForm.quota_reset_interval.trim() || undefined,
|
||||||
|
})
|
||||||
|
setAddKeyForm({ key_label: '', key_value: '', priority: 0, max_rpm: '', max_tpm: '', quota_reset_interval: '' })
|
||||||
|
setShowAddKey(false)
|
||||||
|
mutateKeys()
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
|
} finally {
|
||||||
|
setAddingKey(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleToggleKey(keyId: string, active: boolean) {
|
||||||
|
if (!keyPoolProvider) return
|
||||||
|
try {
|
||||||
|
await api.providers.toggleKey(keyPoolProvider.id, keyId, active)
|
||||||
|
mutateKeys()
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleDeleteKey(keyId: string) {
|
||||||
|
if (!keyPoolProvider || !confirm('确认删除此 Key?')) return
|
||||||
|
try {
|
||||||
|
await api.providers.deleteKey(keyPoolProvider.id, keyId)
|
||||||
|
mutateKeys()
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-4">
|
<div className="space-y-4">
|
||||||
{/* 工具栏 */}
|
{/* 工具栏 */}
|
||||||
@@ -178,21 +244,12 @@ export default function ProvidersPage() {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{error && (
|
{error && <ErrorBanner message={error} onDismiss={() => setError('')} />}
|
||||||
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
|
||||||
{error}
|
|
||||||
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{loading ? (
|
{isLoading ? (
|
||||||
<div className="flex h-64 items-center justify-center">
|
<TableSkeleton rows={6} cols={9} hasToolbar={false} />
|
||||||
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
) : error ? null : providers.length === 0 ? (
|
||||||
</div>
|
<EmptyState />
|
||||||
) : providers.length === 0 ? (
|
|
||||||
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
|
||||||
暂无数据
|
|
||||||
</div>
|
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Table>
|
<Table>
|
||||||
@@ -238,6 +295,9 @@ export default function ProvidersPage() {
|
|||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell className="text-right">
|
<TableCell className="text-right">
|
||||||
<div className="flex items-center justify-end gap-1">
|
<div className="flex items-center justify-end gap-1">
|
||||||
|
<Button variant="ghost" size="icon" onClick={() => openKeyPool(p)} title="Key Pool">
|
||||||
|
<KeyRound className="h-4 w-4" />
|
||||||
|
</Button>
|
||||||
<Button variant="ghost" size="icon" onClick={() => openEditDialog(p)} title="编辑">
|
<Button variant="ghost" size="icon" onClick={() => openEditDialog(p)} title="编辑">
|
||||||
<Pencil className="h-4 w-4" />
|
<Pencil className="h-4 w-4" />
|
||||||
</Button>
|
</Button>
|
||||||
@@ -381,6 +441,165 @@ export default function ProvidersPage() {
|
|||||||
</DialogFooter>
|
</DialogFooter>
|
||||||
</DialogContent>
|
</DialogContent>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
|
{/* Key Pool 管理 Dialog */}
|
||||||
|
<Dialog open={!!keyPoolProvider} onOpenChange={() => setKeyPoolProvider(null)}>
|
||||||
|
<DialogContent className="max-w-2xl">
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>Key Pool 管理 — {keyPoolProvider?.display_name || keyPoolProvider?.name}</DialogTitle>
|
||||||
|
<DialogDescription>
|
||||||
|
管理此服务商的多个 API Key,实现智能轮转绕过限额。优先级数字越小越优先。
|
||||||
|
</DialogDescription>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
<div className="max-h-[50vh] overflow-y-auto scrollbar-thin">
|
||||||
|
{keysLoading ? (
|
||||||
|
<TableSkeleton rows={4} cols={8} hasToolbar={false} />
|
||||||
|
) : providerKeys.length === 0 && !showAddKey ? (
|
||||||
|
<div className="text-center py-8 text-muted-foreground text-sm">
|
||||||
|
<p>尚未配置 Key Pool</p>
|
||||||
|
<p className="mt-1 text-xs">将使用服务商主 API Key 作为回退</p>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<Table>
|
||||||
|
<TableHeader>
|
||||||
|
<TableRow>
|
||||||
|
<TableHead>标签</TableHead>
|
||||||
|
<TableHead>优先级</TableHead>
|
||||||
|
<TableHead>RPM</TableHead>
|
||||||
|
<TableHead>TPM</TableHead>
|
||||||
|
<TableHead>状态</TableHead>
|
||||||
|
<TableHead>请求/Token</TableHead>
|
||||||
|
<TableHead>最后 429</TableHead>
|
||||||
|
<TableHead className="text-right">操作</TableHead>
|
||||||
|
</TableRow>
|
||||||
|
</TableHeader>
|
||||||
|
<TableBody>
|
||||||
|
{providerKeys.map((k) => {
|
||||||
|
const isCooling = k.cooldown_until && new Date(k.cooldown_until) > new Date()
|
||||||
|
return (
|
||||||
|
<TableRow key={k.id} className={isCooling ? 'opacity-60' : ''}>
|
||||||
|
<TableCell className="font-medium">{k.key_label}</TableCell>
|
||||||
|
<TableCell>{k.priority}</TableCell>
|
||||||
|
<TableCell className="text-muted-foreground">{k.max_rpm ?? '-'}</TableCell>
|
||||||
|
<TableCell className="text-muted-foreground">{k.max_tpm ?? '-'}</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
<Badge variant={k.is_active ? 'success' : 'secondary'}>
|
||||||
|
{isCooling ? '冷却中' : k.is_active ? '活跃' : '禁用'}
|
||||||
|
</Badge>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="text-xs text-muted-foreground">
|
||||||
|
{k.total_requests} / {formatTokens(k.total_tokens)}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="text-xs text-muted-foreground">
|
||||||
|
{k.last_429_at ? formatDate(k.last_429_at) : '-'}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="text-right">
|
||||||
|
<div className="flex items-center justify-end gap-1">
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
onClick={() => handleToggleKey(k.id, !k.is_active)}
|
||||||
|
title={k.is_active ? '禁用' : '启用'}
|
||||||
|
>
|
||||||
|
{k.is_active ? <PowerOff className="h-3.5 w-3.5 text-amber-500" /> : <Power className="h-3.5 w-3.5 text-green-500" />}
|
||||||
|
</Button>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
onClick={() => handleDeleteKey(k.id)}
|
||||||
|
title="删除"
|
||||||
|
>
|
||||||
|
<Trash2 className="h-3.5 w-3.5 text-destructive" />
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{!showAddKey ? (
|
||||||
|
<DialogFooter>
|
||||||
|
<Button variant="outline" onClick={() => setKeyPoolProvider(null)}>关闭</Button>
|
||||||
|
<Button onClick={() => setShowAddKey(true)}>
|
||||||
|
<Plus className="h-4 w-4 mr-2" />
|
||||||
|
添加 Key
|
||||||
|
</Button>
|
||||||
|
</DialogFooter>
|
||||||
|
) : (
|
||||||
|
<div className="space-y-3 border-t pt-4">
|
||||||
|
<p className="text-sm font-medium">添加新 Key</p>
|
||||||
|
<div className="grid grid-cols-2 gap-3">
|
||||||
|
<div className="space-y-1">
|
||||||
|
<Label className="text-xs">标签 *</Label>
|
||||||
|
<Input
|
||||||
|
value={addKeyForm.key_label}
|
||||||
|
onChange={(e) => setAddKeyForm({ ...addKeyForm, key_label: e.target.value })}
|
||||||
|
placeholder="如 zhipu-coding-1"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-1">
|
||||||
|
<Label className="text-xs">优先级</Label>
|
||||||
|
<Input
|
||||||
|
type="number"
|
||||||
|
value={addKeyForm.priority}
|
||||||
|
onChange={(e) => setAddKeyForm({ ...addKeyForm, priority: parseInt(e.target.value, 10) || 0 })}
|
||||||
|
placeholder="0"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="col-span-2 space-y-1">
|
||||||
|
<Label className="text-xs">API Key *</Label>
|
||||||
|
<Input
|
||||||
|
type="password"
|
||||||
|
value={addKeyForm.key_value}
|
||||||
|
onChange={(e) => setAddKeyForm({ ...addKeyForm, key_value: e.target.value })}
|
||||||
|
placeholder="输入 API Key"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-1">
|
||||||
|
<Label className="text-xs">RPM 限额</Label>
|
||||||
|
<Input
|
||||||
|
type="number"
|
||||||
|
value={addKeyForm.max_rpm}
|
||||||
|
onChange={(e) => setAddKeyForm({ ...addKeyForm, max_rpm: e.target.value })}
|
||||||
|
placeholder="不限"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-1">
|
||||||
|
<Label className="text-xs">TPM 限额</Label>
|
||||||
|
<Input
|
||||||
|
type="number"
|
||||||
|
value={addKeyForm.max_tpm}
|
||||||
|
onChange={(e) => setAddKeyForm({ ...addKeyForm, max_tpm: e.target.value })}
|
||||||
|
placeholder="不限"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="col-span-2 space-y-1">
|
||||||
|
<Label className="text-xs">限额重置周期</Label>
|
||||||
|
<Input
|
||||||
|
value={addKeyForm.quota_reset_interval}
|
||||||
|
onChange={(e) => setAddKeyForm({ ...addKeyForm, quota_reset_interval: e.target.value })}
|
||||||
|
placeholder="如 5h, 1d(可选)"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<DialogFooter>
|
||||||
|
<Button variant="outline" onClick={() => { setShowAddKey(false); setAddKeyForm({ key_label: '', key_value: '', priority: 0, max_rpm: '', max_tpm: '', quota_reset_interval: '' }) }}>
|
||||||
|
取消
|
||||||
|
</Button>
|
||||||
|
<Button onClick={handleAddKey} disabled={addingKey || !addKeyForm.key_label.trim() || !addKeyForm.key_value.trim()}>
|
||||||
|
{addingKey && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||||
|
添加
|
||||||
|
</Button>
|
||||||
|
</DialogFooter>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useEffect, useState, useCallback } from 'react'
|
import { useState } from 'react'
|
||||||
|
import useSWR from 'swr'
|
||||||
import {
|
import {
|
||||||
Search,
|
Search,
|
||||||
Loader2,
|
Loader2,
|
||||||
@@ -29,6 +30,8 @@ import {
|
|||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatDate, formatNumber } from '@/lib/utils'
|
import { formatDate, formatNumber } from '@/lib/utils'
|
||||||
|
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
||||||
|
import { TableSkeleton } from '@/components/ui/skeleton'
|
||||||
import type { RelayTask } from '@/lib/types'
|
import type { RelayTask } from '@/lib/types'
|
||||||
|
|
||||||
const PAGE_SIZE = 20
|
const PAGE_SIZE = 20
|
||||||
@@ -48,34 +51,22 @@ const statusLabels: Record<string, string> = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function RelayPage() {
|
export default function RelayPage() {
|
||||||
const [tasks, setTasks] = useState<RelayTask[]>([])
|
|
||||||
const [total, setTotal] = useState(0)
|
|
||||||
const [page, setPage] = useState(1)
|
const [page, setPage] = useState(1)
|
||||||
const [statusFilter, setStatusFilter] = useState<string>('all')
|
const [statusFilter, setStatusFilter] = useState<string>('all')
|
||||||
const [loading, setLoading] = useState(true)
|
|
||||||
const [error, setError] = useState('')
|
|
||||||
const [expandedId, setExpandedId] = useState<string | null>(null)
|
const [expandedId, setExpandedId] = useState<string | null>(null)
|
||||||
|
|
||||||
const fetchTasks = useCallback(async () => {
|
const { data, error: swrError, isLoading } = useSWR(
|
||||||
setLoading(true)
|
['relay', page, statusFilter],
|
||||||
setError('')
|
() => {
|
||||||
try {
|
|
||||||
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
||||||
if (statusFilter !== 'all') params.status = statusFilter
|
if (statusFilter !== 'all') params.status = statusFilter
|
||||||
const res = await api.relay.list(params)
|
return api.relay.list(params)
|
||||||
setTasks(res.items)
|
},
|
||||||
setTotal(res.total)
|
)
|
||||||
} catch (err) {
|
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
|
||||||
else setError('加载失败')
|
|
||||||
} finally {
|
|
||||||
setLoading(false)
|
|
||||||
}
|
|
||||||
}, [page, statusFilter])
|
|
||||||
|
|
||||||
useEffect(() => {
|
const tasks = data?.items ?? []
|
||||||
fetchTasks()
|
const total = data?.total ?? 0
|
||||||
}, [fetchTasks])
|
const error = swrError?.message
|
||||||
|
|
||||||
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
||||||
|
|
||||||
@@ -101,21 +92,12 @@ export default function RelayPage() {
|
|||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{error && (
|
{error && <ErrorBanner message={error} onDismiss={() => {}} />}
|
||||||
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
|
||||||
{error}
|
|
||||||
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{loading ? (
|
{isLoading ? (
|
||||||
<div className="flex h-64 items-center justify-center">
|
<TableSkeleton rows={6} cols={10} />
|
||||||
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
) : error ? null : tasks.length === 0 ? (
|
||||||
</div>
|
<EmptyState />
|
||||||
) : tasks.length === 0 ? (
|
|
||||||
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
|
||||||
暂无数据
|
|
||||||
</div>
|
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Table>
|
<Table>
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useEffect, useState, useCallback } from 'react'
|
import { useState } from 'react'
|
||||||
import { Loader2, Zap } from 'lucide-react'
|
import useSWR from 'swr'
|
||||||
|
import { Zap, Monitor, Smartphone } from 'lucide-react'
|
||||||
import {
|
import {
|
||||||
LineChart,
|
LineChart,
|
||||||
Line,
|
Line,
|
||||||
@@ -15,6 +16,8 @@ import {
|
|||||||
Legend,
|
Legend,
|
||||||
} from 'recharts'
|
} from 'recharts'
|
||||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
|
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
|
||||||
|
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
||||||
|
import { TableSkeleton, ChartSkeleton } from '@/components/ui/skeleton'
|
||||||
import {
|
import {
|
||||||
Select,
|
Select,
|
||||||
SelectContent,
|
SelectContent,
|
||||||
@@ -22,84 +25,87 @@ import {
|
|||||||
SelectTrigger,
|
SelectTrigger,
|
||||||
SelectValue,
|
SelectValue,
|
||||||
} from '@/components/ui/select'
|
} from '@/components/ui/select'
|
||||||
|
import {
|
||||||
|
Table,
|
||||||
|
TableBody,
|
||||||
|
TableCell,
|
||||||
|
TableHead,
|
||||||
|
TableHeader,
|
||||||
|
TableRow,
|
||||||
|
} from '@/components/ui/table'
|
||||||
|
import { Badge } from '@/components/ui/badge'
|
||||||
|
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
|
||||||
import { formatNumber } from '@/lib/utils'
|
import { formatNumber } from '@/lib/utils'
|
||||||
import type { UsageRecord, UsageByModel } from '@/lib/types'
|
import type { UsageRecord, UsageByModel, ModelUsageStat, DailyUsageStat } from '@/lib/types'
|
||||||
|
|
||||||
export default function UsagePage() {
|
export default function UsagePage() {
|
||||||
const [days, setDays] = useState(7)
|
const [days, setDays] = useState(7)
|
||||||
const [dailyData, setDailyData] = useState<UsageRecord[]>([])
|
const [activeTab, setActiveTab] = useState('relay')
|
||||||
const [modelData, setModelData] = useState<UsageByModel[]>([])
|
|
||||||
const [loading, setLoading] = useState(true)
|
|
||||||
const [error, setError] = useState('')
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
const fetchData = useCallback(async () => {
|
// 4 parallel SWR calls — each loads independently
|
||||||
setLoading(true)
|
const { data: dailyData = [], isLoading: dailyLoading } = useSWR(
|
||||||
setError('')
|
['usage.daily', days],
|
||||||
try {
|
() => api.usage.daily({ days })
|
||||||
const [dailyRes, modelRes] = await Promise.allSettled([
|
)
|
||||||
api.usage.daily({ days }),
|
const { data: modelData = [], isLoading: modelLoading } = useSWR(
|
||||||
api.usage.byModel({ days }),
|
['usage.byModel', days],
|
||||||
])
|
() => api.usage.byModel({ days })
|
||||||
if (dailyRes.status === 'fulfilled') setDailyData(dailyRes.value)
|
)
|
||||||
else throw new Error('Failed to fetch daily usage')
|
const { data: telemetryModels = [] } = useSWR(
|
||||||
if (modelRes.status === 'fulfilled') setModelData(modelRes.value)
|
['telemetry.modelStats'],
|
||||||
} catch (err) {
|
() => api.telemetry.modelStats()
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
)
|
||||||
else setError('加载数据失败')
|
const { data: telemetryDaily = [] } = useSWR(
|
||||||
} finally {
|
['telemetry.dailyStats', days],
|
||||||
setLoading(false)
|
() => api.telemetry.dailyStats({ days })
|
||||||
}
|
)
|
||||||
}, [days])
|
|
||||||
|
|
||||||
useEffect(() => {
|
const relayLoading = dailyLoading || modelLoading
|
||||||
fetchData()
|
const telemetryLoading = !telemetryModels.length && !telemetryDaily.length && (dailyLoading || modelLoading)
|
||||||
}, [fetchData])
|
|
||||||
|
|
||||||
const lineChartData = dailyData.map((r) => ({
|
// === Relay 用量图表数据 ===
|
||||||
|
|
||||||
|
const relayLineData = dailyData.map((r) => ({
|
||||||
day: r.day.slice(5),
|
day: r.day.slice(5),
|
||||||
Input: r.input_tokens,
|
Input: r.input_tokens,
|
||||||
Output: r.output_tokens,
|
Output: r.output_tokens,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
const barChartData = modelData.map((r) => ({
|
const relayBarData = modelData.map((r) => ({
|
||||||
model: r.model_id,
|
model: r.model_id,
|
||||||
请求量: r.count,
|
请求量: r.count,
|
||||||
Input: r.input_tokens,
|
Input: r.input_tokens,
|
||||||
Output: r.output_tokens,
|
Output: r.output_tokens,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
const totalInput = dailyData.reduce((s, r) => s + r.input_tokens, 0)
|
const relayTotalInput = dailyData.reduce((s, r) => s + r.input_tokens, 0)
|
||||||
const totalOutput = dailyData.reduce((s, r) => s + r.output_tokens, 0)
|
const relayTotalOutput = dailyData.reduce((s, r) => s + r.output_tokens, 0)
|
||||||
const totalRequests = dailyData.reduce((s, r) => s + r.count, 0)
|
const relayTotalRequests = dailyData.reduce((s, r) => s + r.count, 0)
|
||||||
|
|
||||||
if (loading) {
|
// === 遥测图表数据 ===
|
||||||
return (
|
|
||||||
<div className="flex h-[60vh] items-center justify-center">
|
|
||||||
<div className="flex flex-col items-center gap-3">
|
|
||||||
<Loader2 className="h-8 w-8 animate-spin text-primary" />
|
|
||||||
<p className="text-sm text-muted-foreground">加载中...</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (error) {
|
const telemetryLineData = telemetryDaily.map((r) => ({
|
||||||
return (
|
day: r.day.slice(5),
|
||||||
<div className="flex h-[60vh] items-center justify-center">
|
Input: r.input_tokens,
|
||||||
<div className="text-center">
|
Output: r.output_tokens,
|
||||||
<p className="text-destructive">{error}</p>
|
设备数: r.unique_devices,
|
||||||
<button onClick={() => fetchData()} className="mt-4 text-sm text-primary hover:underline cursor-pointer">
|
}))
|
||||||
重新加载
|
|
||||||
</button>
|
const telemetryTotalInput = telemetryDaily.reduce((s, r) => s + r.input_tokens, 0)
|
||||||
</div>
|
const telemetryTotalOutput = telemetryDaily.reduce((s, r) => s + r.output_tokens, 0)
|
||||||
</div>
|
const telemetryTotalRequests = telemetryDaily.reduce((s, r) => s + r.request_count, 0)
|
||||||
)
|
|
||||||
}
|
// === 合计 ===
|
||||||
|
|
||||||
|
const totalInput = relayTotalInput + telemetryTotalInput
|
||||||
|
const totalOutput = relayTotalOutput + telemetryTotalOutput
|
||||||
|
const totalRequests = relayTotalRequests + telemetryTotalRequests
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-6">
|
<div className="space-y-6">
|
||||||
|
{error && <ErrorBanner message={error} onDismiss={() => setError('')} />}
|
||||||
{/* 时间范围 */}
|
{/* 时间范围 */}
|
||||||
<div className="flex items-center gap-3">
|
<div className="flex items-center gap-3">
|
||||||
<span className="text-sm text-muted-foreground">时间范围:</span>
|
<span className="text-sm text-muted-foreground">时间范围:</span>
|
||||||
@@ -115,8 +121,8 @@ export default function UsagePage() {
|
|||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* 汇总统计 */}
|
{/* 汇总统计 — render immediately, use 0 while loading */}
|
||||||
<div className="grid grid-cols-1 gap-4 sm:grid-cols-3">
|
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-5">
|
||||||
<Card>
|
<Card>
|
||||||
<CardContent className="p-6">
|
<CardContent className="p-6">
|
||||||
<p className="text-sm text-muted-foreground">总请求数</p>
|
<p className="text-sm text-muted-foreground">总请求数</p>
|
||||||
@@ -127,7 +133,7 @@ export default function UsagePage() {
|
|||||||
</Card>
|
</Card>
|
||||||
<Card>
|
<Card>
|
||||||
<CardContent className="p-6">
|
<CardContent className="p-6">
|
||||||
<p className="text-sm text-muted-foreground">Input Tokens</p>
|
<p className="text-sm text-muted-foreground">总 Input Tokens</p>
|
||||||
<p className="mt-1 text-2xl font-bold text-blue-400">
|
<p className="mt-1 text-2xl font-bold text-blue-400">
|
||||||
{formatNumber(totalInput)}
|
{formatNumber(totalInput)}
|
||||||
</p>
|
</p>
|
||||||
@@ -135,101 +141,190 @@ export default function UsagePage() {
|
|||||||
</Card>
|
</Card>
|
||||||
<Card>
|
<Card>
|
||||||
<CardContent className="p-6">
|
<CardContent className="p-6">
|
||||||
<p className="text-sm text-muted-foreground">Output Tokens</p>
|
<p className="text-sm text-muted-foreground">总 Output Tokens</p>
|
||||||
<p className="mt-1 text-2xl font-bold text-orange-400">
|
<p className="mt-1 text-2xl font-bold text-orange-400">
|
||||||
{formatNumber(totalOutput)}
|
{formatNumber(totalOutput)}
|
||||||
</p>
|
</p>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
<Card>
|
||||||
|
<CardContent className="p-6">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<Monitor className="h-4 w-4 text-green-400" />
|
||||||
|
<p className="text-sm text-muted-foreground">中转请求</p>
|
||||||
|
</div>
|
||||||
|
<p className="mt-1 text-2xl font-bold text-green-400">
|
||||||
|
{formatNumber(relayTotalRequests)}
|
||||||
|
</p>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
<Card>
|
||||||
|
<CardContent className="p-6">
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<Smartphone className="h-4 w-4 text-purple-400" />
|
||||||
|
<p className="text-sm text-muted-foreground">桌面端调用</p>
|
||||||
|
</div>
|
||||||
|
<p className="mt-1 text-2xl font-bold text-purple-400">
|
||||||
|
{formatNumber(telemetryTotalRequests)}
|
||||||
|
</p>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Token 用量趋势 */}
|
{/* Tab 切换 */}
|
||||||
|
<Tabs value={activeTab} onValueChange={setActiveTab}>
|
||||||
|
<TabsList>
|
||||||
|
<TabsTrigger value="relay">
|
||||||
|
<Monitor className="h-4 w-4 mr-1" />
|
||||||
|
中转用量
|
||||||
|
</TabsTrigger>
|
||||||
|
<TabsTrigger value="telemetry">
|
||||||
|
<Smartphone className="h-4 w-4 mr-1" />
|
||||||
|
桌面端遥测
|
||||||
|
</TabsTrigger>
|
||||||
|
</TabsList>
|
||||||
|
|
||||||
|
{/* Relay 用量 Tab */}
|
||||||
|
<TabsContent value="relay" className="space-y-6">
|
||||||
|
{relayLoading ? (
|
||||||
|
<>
|
||||||
|
<ChartSkeleton height={320} />
|
||||||
|
<ChartSkeleton height={280} />
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle className="flex items-center gap-2 text-base">
|
<CardTitle className="flex items-center gap-2 text-base">
|
||||||
<Zap className="h-4 w-4 text-primary" />
|
<Zap className="h-4 w-4 text-primary" />
|
||||||
Token 用量趋势
|
中转 Token 用量趋势
|
||||||
</CardTitle>
|
</CardTitle>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
{lineChartData.length > 0 ? (
|
{relayLineData.length > 0 ? (
|
||||||
<ResponsiveContainer width="100%" height={320}>
|
<ResponsiveContainer width="100%" height={320}>
|
||||||
<LineChart data={lineChartData}>
|
<LineChart data={relayLineData}>
|
||||||
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
||||||
<XAxis
|
<XAxis dataKey="day" tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} />
|
||||||
dataKey="day"
|
<YAxis tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} />
|
||||||
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
<Tooltip contentStyle={{ backgroundColor: '#0F172A', border: '1px solid #1E293B', borderRadius: '8px', color: '#F8FAFC', fontSize: '12px' }} />
|
||||||
axisLine={{ stroke: '#1E293B' }}
|
|
||||||
/>
|
|
||||||
<YAxis
|
|
||||||
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
|
||||||
axisLine={{ stroke: '#1E293B' }}
|
|
||||||
/>
|
|
||||||
<Tooltip
|
|
||||||
contentStyle={{
|
|
||||||
backgroundColor: '#0F172A',
|
|
||||||
border: '1px solid #1E293B',
|
|
||||||
borderRadius: '8px',
|
|
||||||
color: '#F8FAFC',
|
|
||||||
fontSize: '12px',
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
|
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
|
||||||
<Line type="monotone" dataKey="Input" stroke="#3B82F6" strokeWidth={2} dot={false} />
|
<Line type="monotone" dataKey="Input" stroke="#3B82F6" strokeWidth={2} dot={false} />
|
||||||
<Line type="monotone" dataKey="Output" stroke="#F97316" strokeWidth={2} dot={false} />
|
<Line type="monotone" dataKey="Output" stroke="#F97316" strokeWidth={2} dot={false} />
|
||||||
</LineChart>
|
</LineChart>
|
||||||
</ResponsiveContainer>
|
</ResponsiveContainer>
|
||||||
) : (
|
) : (
|
||||||
<div className="flex h-[320px] items-center justify-center text-muted-foreground text-sm">
|
<EmptyState message="暂无中转数据" />
|
||||||
暂无数据
|
|
||||||
</div>
|
|
||||||
)}
|
)}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
{/* 按模型分布 */}
|
|
||||||
<Card>
|
<Card>
|
||||||
<CardHeader>
|
<CardHeader>
|
||||||
<CardTitle className="text-base">按模型分布</CardTitle>
|
<CardTitle className="text-base">中转按模型分布</CardTitle>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
{barChartData.length > 0 ? (
|
{relayBarData.length > 0 ? (
|
||||||
<ResponsiveContainer width="100%" height={320}>
|
<ResponsiveContainer width="100%" height={Math.max(200, relayBarData.length * 40)}>
|
||||||
<BarChart data={barChartData} layout="vertical">
|
<BarChart data={relayBarData} layout="vertical">
|
||||||
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
||||||
<XAxis
|
<XAxis type="number" tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} />
|
||||||
type="number"
|
<YAxis type="category" dataKey="model" tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} width={120} />
|
||||||
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
<Tooltip contentStyle={{ backgroundColor: '#0F172A', border: '1px solid #1E293B', borderRadius: '8px', color: '#F8FAFC', fontSize: '12px' }} />
|
||||||
axisLine={{ stroke: '#1E293B' }}
|
|
||||||
/>
|
|
||||||
<YAxis
|
|
||||||
type="category"
|
|
||||||
dataKey="model"
|
|
||||||
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
|
||||||
axisLine={{ stroke: '#1E293B' }}
|
|
||||||
width={120}
|
|
||||||
/>
|
|
||||||
<Tooltip
|
|
||||||
contentStyle={{
|
|
||||||
backgroundColor: '#0F172A',
|
|
||||||
border: '1px solid #1E293B',
|
|
||||||
borderRadius: '8px',
|
|
||||||
color: '#F8FAFC',
|
|
||||||
fontSize: '12px',
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
|
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
|
||||||
<Bar dataKey="Input" fill="#3B82F6" radius={[0, 2, 2, 0]} />
|
<Bar dataKey="Input" fill="#3B82F6" radius={[0, 2, 2, 0]} />
|
||||||
<Bar dataKey="Output" fill="#F97316" radius={[0, 2, 2, 0]} />
|
<Bar dataKey="Output" fill="#F97316" radius={[0, 2, 2, 0]} />
|
||||||
</BarChart>
|
</BarChart>
|
||||||
</ResponsiveContainer>
|
</ResponsiveContainer>
|
||||||
) : (
|
) : (
|
||||||
<div className="flex h-[320px] items-center justify-center text-muted-foreground text-sm">
|
<EmptyState />
|
||||||
暂无数据
|
|
||||||
</div>
|
|
||||||
)}
|
)}
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</TabsContent>
|
||||||
|
|
||||||
|
{/* 遥测 Tab */}
|
||||||
|
<TabsContent value="telemetry" className="space-y-6">
|
||||||
|
{telemetryLoading ? (
|
||||||
|
<>
|
||||||
|
<ChartSkeleton height={320} />
|
||||||
|
<TableSkeleton rows={5} cols={6} hasToolbar={false} />
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<>
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle className="flex items-center gap-2 text-base">
|
||||||
|
<Smartphone className="h-4 w-4 text-purple-400" />
|
||||||
|
桌面端 Token 用量趋势
|
||||||
|
</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
{telemetryLineData.length > 0 ? (
|
||||||
|
<ResponsiveContainer width="100%" height={320}>
|
||||||
|
<LineChart data={telemetryLineData}>
|
||||||
|
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
||||||
|
<XAxis dataKey="day" tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} />
|
||||||
|
<YAxis tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} />
|
||||||
|
<Tooltip contentStyle={{ backgroundColor: '#0F172A', border: '1px solid #1E293B', borderRadius: '8px', color: '#F8FAFC', fontSize: '12px' }} />
|
||||||
|
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
|
||||||
|
<Line type="monotone" dataKey="Input" stroke="#3B82F6" strokeWidth={2} dot={false} />
|
||||||
|
<Line type="monotone" dataKey="Output" stroke="#F97316" strokeWidth={2} dot={false} />
|
||||||
|
</LineChart>
|
||||||
|
</ResponsiveContainer>
|
||||||
|
) : (
|
||||||
|
<EmptyState message="暂无桌面端遥测数据(需要桌面端上报)" />
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle className="text-base">桌面端按模型统计</CardTitle>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
{telemetryModels.length > 0 ? (
|
||||||
|
<Table>
|
||||||
|
<TableHeader>
|
||||||
|
<TableRow>
|
||||||
|
<TableHead>模型</TableHead>
|
||||||
|
<TableHead className="text-right">请求数</TableHead>
|
||||||
|
<TableHead className="text-right">Input Tokens</TableHead>
|
||||||
|
<TableHead className="text-right">Output Tokens</TableHead>
|
||||||
|
<TableHead className="text-right">平均延迟</TableHead>
|
||||||
|
<TableHead className="text-right">成功率</TableHead>
|
||||||
|
</TableRow>
|
||||||
|
</TableHeader>
|
||||||
|
<TableBody>
|
||||||
|
{telemetryModels.map((stat) => (
|
||||||
|
<TableRow key={stat.model_id}>
|
||||||
|
<TableCell className="font-mono text-sm">{stat.model_id}</TableCell>
|
||||||
|
<TableCell className="text-right">{formatNumber(stat.request_count)}</TableCell>
|
||||||
|
<TableCell className="text-right text-blue-400">{formatNumber(stat.input_tokens)}</TableCell>
|
||||||
|
<TableCell className="text-right text-orange-400">{formatNumber(stat.output_tokens)}</TableCell>
|
||||||
|
<TableCell className="text-right">
|
||||||
|
{stat.avg_latency_ms !== null ? `${Math.round(stat.avg_latency_ms)}ms` : '-'}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="text-right">
|
||||||
|
<Badge variant={stat.success_rate >= 0.95 ? 'default' : 'destructive'}>
|
||||||
|
{(stat.success_rate * 100).toFixed(1)}%
|
||||||
|
</Badge>
|
||||||
|
</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
))}
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
) : (
|
||||||
|
<EmptyState />
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</TabsContent>
|
||||||
|
</Tabs>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
4
admin/src/app/icon.svg
Normal file
4
admin/src/app/icon.svg
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" viewBox="0 0 32 32">
|
||||||
|
<rect width="32" height="32" rx="6" fill="#0f172a"/>
|
||||||
|
<text x="16" y="22" font-family="system-ui, sans-serif" font-size="16" font-weight="700" fill="#60a5fa" text-anchor="middle">Z</text>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 282 B |
@@ -1,4 +1,5 @@
|
|||||||
import type { Metadata } from 'next'
|
import type { Metadata } from 'next'
|
||||||
|
import { SWRProvider } from '@/lib/swr-provider'
|
||||||
import './globals.css'
|
import './globals.css'
|
||||||
|
|
||||||
export const metadata: Metadata = {
|
export const metadata: Metadata = {
|
||||||
@@ -20,7 +21,9 @@ export default function RootLayout({
|
|||||||
/>
|
/>
|
||||||
</head>
|
</head>
|
||||||
<body className="min-h-screen bg-background font-sans antialiased">
|
<body className="min-h-screen bg-background font-sans antialiased">
|
||||||
|
<SWRProvider>
|
||||||
{children}
|
{children}
|
||||||
|
</SWRProvider>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import { useState, type FormEvent } from 'react'
|
import { useState, type FormEvent } from 'react'
|
||||||
import { useRouter } from 'next/navigation'
|
import { useRouter } from 'next/navigation'
|
||||||
import { Lock, User, Loader2, Eye, EyeOff } from 'lucide-react'
|
import { Lock, User, Loader2, Eye, EyeOff, ShieldCheck } from 'lucide-react'
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { login } from '@/lib/auth'
|
import { login } from '@/lib/auth'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
@@ -11,7 +11,9 @@ export default function LoginPage() {
|
|||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
const [username, setUsername] = useState('')
|
const [username, setUsername] = useState('')
|
||||||
const [password, setPassword] = useState('')
|
const [password, setPassword] = useState('')
|
||||||
|
const [totpCode, setTotpCode] = useState('')
|
||||||
const [showPassword, setShowPassword] = useState(false)
|
const [showPassword, setShowPassword] = useState(false)
|
||||||
|
const [needTotp, setNeedTotp] = useState(false)
|
||||||
const [remember, setRemember] = useState(false)
|
const [remember, setRemember] = useState(false)
|
||||||
const [loading, setLoading] = useState(false)
|
const [loading, setLoading] = useState(false)
|
||||||
const [error, setError] = useState('')
|
const [error, setError] = useState('')
|
||||||
@@ -31,12 +33,23 @@ export default function LoginPage() {
|
|||||||
|
|
||||||
setLoading(true)
|
setLoading(true)
|
||||||
try {
|
try {
|
||||||
const res = await api.auth.login({ username: username.trim(), password })
|
const res = await api.auth.login({
|
||||||
|
username: username.trim(),
|
||||||
|
password,
|
||||||
|
totp_code: totpCode.trim() || undefined,
|
||||||
|
})
|
||||||
login(res.token, res.account)
|
login(res.token, res.account)
|
||||||
router.replace('/')
|
router.replace('/')
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) {
|
if (err instanceof ApiRequestError) {
|
||||||
setError(err.body.message || '登录失败,请检查用户名和密码')
|
const msg = err.body.message || ''
|
||||||
|
// 后端返回 "需要 TOTP" 时显示 TOTP 输入框
|
||||||
|
if (msg.includes('TOTP') || msg.includes('totp') || msg.includes('2FA') || msg.includes('验证码') || err.status === 403) {
|
||||||
|
setNeedTotp(true)
|
||||||
|
setError(msg || '请输入两步验证码')
|
||||||
|
} else {
|
||||||
|
setError(msg || '登录失败,请检查用户名和密码')
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
setError('网络错误,请稍后重试')
|
setError('网络错误,请稍后重试')
|
||||||
}
|
}
|
||||||
@@ -152,6 +165,35 @@ export default function LoginPage() {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* TOTP 验证码 */}
|
||||||
|
{needTotp && (
|
||||||
|
<div className="space-y-2">
|
||||||
|
<label
|
||||||
|
htmlFor="totp"
|
||||||
|
className="text-sm font-medium text-foreground"
|
||||||
|
>
|
||||||
|
两步验证码
|
||||||
|
</label>
|
||||||
|
<div className="relative">
|
||||||
|
<ShieldCheck className="absolute left-3 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
|
||||||
|
<input
|
||||||
|
id="totp"
|
||||||
|
type="text"
|
||||||
|
placeholder="请输入 6 位验证码"
|
||||||
|
value={totpCode}
|
||||||
|
onChange={(e) => setTotpCode(e.target.value.replace(/\D/g, '').slice(0, 6))}
|
||||||
|
maxLength={6}
|
||||||
|
className="flex h-10 w-full rounded-md border border-input bg-transparent pl-10 pr-3 py-2 text-sm shadow-sm transition-colors duration-200 placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring tracking-widest"
|
||||||
|
autoComplete="one-time-code"
|
||||||
|
inputMode="numeric"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<p className="text-xs text-muted-foreground">
|
||||||
|
请使用身份验证器 App(如 Google Authenticator)扫描二维码后生成的验证码
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* 记住我 */}
|
{/* 记住我 */}
|
||||||
<div className="flex items-center gap-2">
|
<div className="flex items-center gap-2">
|
||||||
<input
|
<input
|
||||||
|
|||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
import { useEffect, useState, type ReactNode } from 'react'
|
import { useEffect, useState, type ReactNode } from 'react'
|
||||||
import { useRouter } from 'next/navigation'
|
import { useRouter } from 'next/navigation'
|
||||||
import { isAuthenticated, getAccount } from '@/lib/auth'
|
import { isAuthenticated, getAccount, clearAuth } from '@/lib/auth'
|
||||||
|
import { api } from '@/lib/api-client'
|
||||||
import type { AccountPublic } from '@/lib/types'
|
import type { AccountPublic } from '@/lib/types'
|
||||||
|
|
||||||
interface AuthGuardProps {
|
interface AuthGuardProps {
|
||||||
@@ -13,17 +14,31 @@ export function AuthGuard({ children }: AuthGuardProps) {
|
|||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
const [authorized, setAuthorized] = useState(false)
|
const [authorized, setAuthorized] = useState(false)
|
||||||
const [account, setAccount] = useState<AccountPublic | null>(null)
|
const [account, setAccount] = useState<AccountPublic | null>(null)
|
||||||
|
const [verifying, setVerifying] = useState(true)
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
async function verifyAuth() {
|
||||||
if (!isAuthenticated()) {
|
if (!isAuthenticated()) {
|
||||||
router.replace('/login')
|
router.replace('/login')
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
setAccount(getAccount())
|
|
||||||
|
try {
|
||||||
|
const serverAccount = await api.auth.me()
|
||||||
|
setAccount(serverAccount)
|
||||||
setAuthorized(true)
|
setAuthorized(true)
|
||||||
|
} catch {
|
||||||
|
clearAuth()
|
||||||
|
router.replace('/login')
|
||||||
|
} finally {
|
||||||
|
setVerifying(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyAuth()
|
||||||
}, [router])
|
}, [router])
|
||||||
|
|
||||||
if (!authorized) {
|
if (verifying) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-screen w-screen items-center justify-center bg-background">
|
<div className="flex h-screen w-screen items-center justify-center bg-background">
|
||||||
<div className="h-8 w-8 animate-spin rounded-full border-2 border-primary border-t-transparent" />
|
<div className="h-8 w-8 animate-spin rounded-full border-2 border-primary border-t-transparent" />
|
||||||
@@ -31,6 +46,10 @@ export function AuthGuard({ children }: AuthGuardProps) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!authorized) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
return <>{children}</>
|
return <>{children}</>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
115
admin/src/components/ui/skeleton.tsx
Normal file
115
admin/src/components/ui/skeleton.tsx
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
// ============================================================
|
||||||
|
// Skeleton 组件 — 替代全屏 spinner 的骨架屏
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
import { cn } from '@/lib/utils'
|
||||||
|
|
||||||
|
function SkeletonBase({ className }: { className?: string }) {
|
||||||
|
return (
|
||||||
|
<div
|
||||||
|
className={cn(
|
||||||
|
'animate-pulse rounded-md bg-muted',
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 表格骨架屏 */
|
||||||
|
export function TableSkeleton({
|
||||||
|
rows = 5,
|
||||||
|
cols = 5,
|
||||||
|
hasToolbar = true,
|
||||||
|
}: {
|
||||||
|
rows?: number
|
||||||
|
cols?: number
|
||||||
|
hasToolbar?: boolean
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
{hasToolbar && (
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<SkeletonBase className="h-9 w-[200px]" />
|
||||||
|
<SkeletonBase className="h-9 w-[120px]" />
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<div className="rounded-md border border-border overflow-hidden">
|
||||||
|
{/* Header */}
|
||||||
|
<div className="border-b border-border bg-muted/30 px-4 py-3">
|
||||||
|
<div className="flex gap-4">
|
||||||
|
{Array.from({ length: cols }).map((_, i) => (
|
||||||
|
<SkeletonBase
|
||||||
|
key={i}
|
||||||
|
className={cn(
|
||||||
|
'h-4',
|
||||||
|
i === 0 ? 'w-[120px]' : i === cols - 1 ? 'w-[80px]' : 'w-[100px]',
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{/* Rows */}
|
||||||
|
{Array.from({ length: rows }).map((_, rowIdx) => (
|
||||||
|
<div
|
||||||
|
key={rowIdx}
|
||||||
|
className={cn(
|
||||||
|
'px-4 py-3',
|
||||||
|
rowIdx < rows - 1 && 'border-b border-border',
|
||||||
|
)}
|
||||||
|
>
|
||||||
|
<div className="flex gap-4">
|
||||||
|
{Array.from({ length: cols }).map((_, colIdx) => (
|
||||||
|
<SkeletonBase
|
||||||
|
key={colIdx}
|
||||||
|
className={cn(
|
||||||
|
'h-4',
|
||||||
|
colIdx === 0 ? 'w-[120px]' : colIdx === cols - 1 ? 'w-[80px]' : 'w-[100px]',
|
||||||
|
)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
{/* Pagination */}
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<SkeletonBase className="h-4 w-[140px]" />
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<SkeletonBase className="h-8 w-[80px]" />
|
||||||
|
<SkeletonBase className="h-8 w-[80px]" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 统计卡片骨架屏 */
|
||||||
|
export function StatsSkeleton({ count = 4 }: { count?: number }) {
|
||||||
|
return (
|
||||||
|
<div className={`grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-${count}`}>
|
||||||
|
{Array.from({ length: count }).map((_, i) => (
|
||||||
|
<div key={i} className="rounded-lg border border-border p-6">
|
||||||
|
<SkeletonBase className="h-4 w-[80px]" />
|
||||||
|
<SkeletonBase className="mt-2 h-8 w-[100px]" />
|
||||||
|
<SkeletonBase className="mt-1 h-3 w-[120px]" />
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 图表骨架屏 */
|
||||||
|
export function ChartSkeleton({ height }: { height?: number }) {
|
||||||
|
return (
|
||||||
|
<div className="rounded-lg border border-border">
|
||||||
|
<div className="border-b border-border px-6 py-4">
|
||||||
|
<SkeletonBase className="h-5 w-[140px]" />
|
||||||
|
</div>
|
||||||
|
<div className="p-6">
|
||||||
|
<SkeletonBase className="w-full" />
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export { SkeletonBase as Skeleton }
|
||||||
63
admin/src/components/ui/state.tsx
Normal file
63
admin/src/components/ui/state.tsx
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
'use client'
|
||||||
|
|
||||||
|
import { AlertCircle, Inbox } from 'lucide-react'
|
||||||
|
|
||||||
|
/** 统一的错误提示横幅 */
|
||||||
|
export function ErrorBanner({
|
||||||
|
message,
|
||||||
|
onDismiss,
|
||||||
|
}: {
|
||||||
|
message: string
|
||||||
|
onDismiss?: () => void
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive flex items-center gap-2">
|
||||||
|
<AlertCircle className="h-4 w-4 shrink-0" />
|
||||||
|
<span className="flex-1">{message}</span>
|
||||||
|
{onDismiss && (
|
||||||
|
<button
|
||||||
|
onClick={onDismiss}
|
||||||
|
className="underline cursor-pointer shrink-0"
|
||||||
|
>
|
||||||
|
关闭
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 统一的空状态占位 */
|
||||||
|
export function EmptyState({
|
||||||
|
message = '暂无数据',
|
||||||
|
}: {
|
||||||
|
message?: string
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="flex h-64 flex-col items-center justify-center gap-2 text-muted-foreground">
|
||||||
|
<Inbox className="h-8 w-8" />
|
||||||
|
<span className="text-sm">{message}</span>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 统一的加载失败提示 + 重试 */
|
||||||
|
export function ErrorRetry({
|
||||||
|
message = '请求失败,请重试',
|
||||||
|
onRetry,
|
||||||
|
}: {
|
||||||
|
message?: string
|
||||||
|
onRetry: () => void
|
||||||
|
}) {
|
||||||
|
return (
|
||||||
|
<div className="flex h-64 flex-col items-center justify-center gap-3 text-muted-foreground">
|
||||||
|
<AlertCircle className="h-8 w-8 text-destructive" />
|
||||||
|
<span className="text-sm">{message}</span>
|
||||||
|
<button
|
||||||
|
onClick={onRetry}
|
||||||
|
className="rounded-md bg-primary px-4 py-2 text-sm text-primary-foreground hover:bg-primary/90 transition-colors cursor-pointer"
|
||||||
|
>
|
||||||
|
重新加载
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
16
admin/src/hooks/use-debounce.ts
Normal file
16
admin/src/hooks/use-debounce.ts
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
// ============================================================
|
||||||
|
// useDebounce — 防抖 hook
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
import { useState, useEffect } from 'react'
|
||||||
|
|
||||||
|
export function useDebounce<T>(value: T, delay = 300): T {
|
||||||
|
const [debouncedValue, setDebouncedValue] = useState<T>(value)
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handler = setTimeout(() => setDebouncedValue(value), delay)
|
||||||
|
return () => clearTimeout(handler)
|
||||||
|
}, [value, delay])
|
||||||
|
|
||||||
|
return debouncedValue
|
||||||
|
}
|
||||||
@@ -2,19 +2,25 @@
|
|||||||
// ZCLAW SaaS Admin — 类型化 HTTP 客户端
|
// ZCLAW SaaS Admin — 类型化 HTTP 客户端
|
||||||
// ============================================================
|
// ============================================================
|
||||||
|
|
||||||
import { getToken, logout } from './auth'
|
import { getToken, login as saveToken, logout, getAccount } from './auth'
|
||||||
import type {
|
import type {
|
||||||
AccountPublic,
|
AccountPublic,
|
||||||
|
AgentTemplate,
|
||||||
ApiError,
|
ApiError,
|
||||||
ConfigItem,
|
ConfigItem,
|
||||||
CreateTokenRequest,
|
CreateTokenRequest,
|
||||||
DashboardStats,
|
DashboardStats,
|
||||||
|
DailyUsageStat,
|
||||||
LoginRequest,
|
LoginRequest,
|
||||||
LoginResponse,
|
LoginResponse,
|
||||||
Model,
|
Model,
|
||||||
|
ModelUsageStat,
|
||||||
OperationLog,
|
OperationLog,
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
|
PromptTemplate,
|
||||||
|
PromptVersion,
|
||||||
Provider,
|
Provider,
|
||||||
|
ProviderKey,
|
||||||
RelayTask,
|
RelayTask,
|
||||||
TokenInfo,
|
TokenInfo,
|
||||||
UsageByModel,
|
UsageByModel,
|
||||||
@@ -35,13 +41,67 @@ export class ApiRequestError extends Error {
|
|||||||
|
|
||||||
// ── 基础请求 ──────────────────────────────────────────────
|
// ── 基础请求 ──────────────────────────────────────────────
|
||||||
|
|
||||||
const BASE_URL = process.env.NEXT_PUBLIC_SAAS_API_URL || 'http://localhost:8080'
|
const BASE_URL = process.env.NEXT_PUBLIC_SAAS_API_URL || '/api/v1'
|
||||||
|
|
||||||
|
const DEFAULT_TIMEOUT_MS = 10_000
|
||||||
|
const MAX_RETRIES = 2
|
||||||
|
|
||||||
|
function sleep(ms: number): Promise<void> {
|
||||||
|
return new Promise(resolve => setTimeout(resolve, ms))
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 判断是否为可重试的网络错误(不含 AbortError) */
|
||||||
|
function isRetryableNetworkError(err: unknown): boolean {
|
||||||
|
// AbortError 不重试:可能是组件卸载或路由切换导致的外部取消
|
||||||
|
if (err instanceof DOMException && err.name === 'AbortError') return false
|
||||||
|
if (err instanceof TypeError) {
|
||||||
|
const msg = (err as TypeError).message
|
||||||
|
return msg.includes('Failed to fetch') || msg.includes('NetworkError') || msg.includes('ECONNREFUSED')
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 尝试刷新 Token,成功返回新 token,失败返回 null */
|
||||||
|
async function tryRefreshToken(): Promise<string | null> {
|
||||||
|
try {
|
||||||
|
const token = getToken()
|
||||||
|
if (!token) return null
|
||||||
|
|
||||||
|
const res = await fetch(`${BASE_URL}/auth/refresh`, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${token}`,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!res.ok) return null
|
||||||
|
|
||||||
|
const data = await res.json()
|
||||||
|
const newToken = data.token as string
|
||||||
|
const account = getAccount()
|
||||||
|
if (account && newToken) {
|
||||||
|
saveToken(newToken, account)
|
||||||
|
}
|
||||||
|
return newToken
|
||||||
|
} catch {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async function request<T>(
|
async function request<T>(
|
||||||
method: string,
|
method: string,
|
||||||
path: string,
|
path: string,
|
||||||
body?: unknown,
|
body?: unknown,
|
||||||
|
_isRetry = false,
|
||||||
): Promise<T> {
|
): Promise<T> {
|
||||||
|
let lastError: unknown
|
||||||
|
|
||||||
|
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
|
||||||
|
const controller = new AbortController()
|
||||||
|
const timeoutId = setTimeout(() => controller.abort(), DEFAULT_TIMEOUT_MS)
|
||||||
|
|
||||||
|
try {
|
||||||
const token = getToken()
|
const token = getToken()
|
||||||
const headers: Record<string, string> = {
|
const headers: Record<string, string> = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
@@ -54,9 +114,16 @@ async function request<T>(
|
|||||||
method,
|
method,
|
||||||
headers,
|
headers,
|
||||||
body: body ? JSON.stringify(body) : undefined,
|
body: body ? JSON.stringify(body) : undefined,
|
||||||
|
signal: controller.signal,
|
||||||
})
|
})
|
||||||
|
clearTimeout(timeoutId)
|
||||||
|
|
||||||
if (res.status === 401) {
|
// 401: 尝试刷新 Token 后重试
|
||||||
|
if (res.status === 401 && !_isRetry) {
|
||||||
|
const newToken = await tryRefreshToken()
|
||||||
|
if (newToken) {
|
||||||
|
return request<T>(method, path, body, true)
|
||||||
|
}
|
||||||
logout()
|
logout()
|
||||||
if (typeof window !== 'undefined') {
|
if (typeof window !== 'undefined') {
|
||||||
window.location.href = '/login'
|
window.location.href = '/login'
|
||||||
@@ -80,6 +147,26 @@ async function request<T>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
return res.json() as Promise<T>
|
return res.json() as Promise<T>
|
||||||
|
} catch (err) {
|
||||||
|
clearTimeout(timeoutId)
|
||||||
|
|
||||||
|
// API 错误和外部取消的 AbortError 直接抛出,不重试
|
||||||
|
if (err instanceof ApiRequestError) throw err
|
||||||
|
if (err instanceof DOMException && err.name === 'AbortError') throw err
|
||||||
|
|
||||||
|
lastError = err
|
||||||
|
|
||||||
|
// 仅对可重试的网络错误重试
|
||||||
|
if (attempt < MAX_RETRIES && isRetryableNetworkError(err)) {
|
||||||
|
await sleep(1000 * Math.pow(2, attempt))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
throw err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw lastError
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── API 客户端 ────────────────────────────────────────────
|
// ── API 客户端 ────────────────────────────────────────────
|
||||||
@@ -88,7 +175,7 @@ export const api = {
|
|||||||
// ── 认证 ──────────────────────────────────────────────
|
// ── 认证 ──────────────────────────────────────────────
|
||||||
auth: {
|
auth: {
|
||||||
async login(data: LoginRequest): Promise<LoginResponse> {
|
async login(data: LoginRequest): Promise<LoginResponse> {
|
||||||
return request<LoginResponse>('POST', '/api/auth/login', data)
|
return request<LoginResponse>('POST', '/auth/login', data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async register(data: {
|
async register(data: {
|
||||||
@@ -97,11 +184,11 @@ export const api = {
|
|||||||
email: string
|
email: string
|
||||||
display_name?: string
|
display_name?: string
|
||||||
}): Promise<LoginResponse> {
|
}): Promise<LoginResponse> {
|
||||||
return request<LoginResponse>('POST', '/api/auth/register', data)
|
return request<LoginResponse>('POST', '/auth/register', data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async me(): Promise<AccountPublic> {
|
async me(): Promise<AccountPublic> {
|
||||||
return request<AccountPublic>('GET', '/api/auth/me')
|
return request<AccountPublic>('GET', '/auth/me')
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -115,25 +202,25 @@ export const api = {
|
|||||||
status?: string
|
status?: string
|
||||||
}): Promise<PaginatedResponse<AccountPublic>> {
|
}): Promise<PaginatedResponse<AccountPublic>> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString(params)
|
||||||
return request<PaginatedResponse<AccountPublic>>('GET', `/api/accounts${qs}`)
|
return request<PaginatedResponse<AccountPublic>>('GET', `/accounts${qs}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
async get(id: string): Promise<AccountPublic> {
|
async get(id: string): Promise<AccountPublic> {
|
||||||
return request<AccountPublic>('GET', `/api/accounts/${id}`)
|
return request<AccountPublic>('GET', `/accounts/${id}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
async update(
|
async update(
|
||||||
id: string,
|
id: string,
|
||||||
data: Partial<Pick<AccountPublic, 'display_name' | 'email' | 'role'>>,
|
data: Partial<Pick<AccountPublic, 'display_name' | 'email' | 'role'>>,
|
||||||
): Promise<AccountPublic> {
|
): Promise<AccountPublic> {
|
||||||
return request<AccountPublic>('PATCH', `/api/accounts/${id}`, data)
|
return request<AccountPublic>('PATCH', `/accounts/${id}`, data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async updateStatus(
|
async updateStatus(
|
||||||
id: string,
|
id: string,
|
||||||
data: { status: AccountPublic['status'] },
|
data: { status: AccountPublic['status'] },
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
return request<void>('PATCH', `/api/accounts/${id}/status`, data)
|
return request<void>('PATCH', `/accounts/${id}/status`, data)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -144,22 +231,46 @@ export const api = {
|
|||||||
page_size?: number
|
page_size?: number
|
||||||
}): Promise<PaginatedResponse<Provider>> {
|
}): Promise<PaginatedResponse<Provider>> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString(params)
|
||||||
return request<PaginatedResponse<Provider>>('GET', `/api/providers${qs}`)
|
return request<PaginatedResponse<Provider>>('GET', `/providers${qs}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
async create(data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>): Promise<Provider> {
|
async create(data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>): Promise<Provider> {
|
||||||
return request<Provider>('POST', '/api/providers', data)
|
return request<Provider>('POST', '/providers', data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async update(
|
async update(
|
||||||
id: string,
|
id: string,
|
||||||
data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>,
|
data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>,
|
||||||
): Promise<Provider> {
|
): Promise<Provider> {
|
||||||
return request<Provider>('PATCH', `/api/providers/${id}`, data)
|
return request<Provider>('PATCH', `/providers/${id}`, data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async delete(id: string): Promise<void> {
|
async delete(id: string): Promise<void> {
|
||||||
return request<void>('DELETE', `/api/providers/${id}`)
|
return request<void>('DELETE', `/providers/${id}`)
|
||||||
|
},
|
||||||
|
|
||||||
|
// Key Pool 管理
|
||||||
|
async listKeys(providerId: string): Promise<ProviderKey[]> {
|
||||||
|
return request<ProviderKey[]>('GET', `/providers/${providerId}/keys`)
|
||||||
|
},
|
||||||
|
|
||||||
|
async addKey(providerId: string, data: {
|
||||||
|
key_label: string
|
||||||
|
key_value: string
|
||||||
|
priority?: number
|
||||||
|
max_rpm?: number
|
||||||
|
max_tpm?: number
|
||||||
|
quota_reset_interval?: string
|
||||||
|
}): Promise<{ ok: boolean; key_id: string }> {
|
||||||
|
return request<{ ok: boolean; key_id: string }>('POST', `/providers/${providerId}/keys`, data)
|
||||||
|
},
|
||||||
|
|
||||||
|
async toggleKey(providerId: string, keyId: string, active: boolean): Promise<{ ok: boolean }> {
|
||||||
|
return request<{ ok: boolean }>('PUT', `/providers/${providerId}/keys/${keyId}/toggle`, { active })
|
||||||
|
},
|
||||||
|
|
||||||
|
async deleteKey(providerId: string, keyId: string): Promise<{ ok: boolean }> {
|
||||||
|
return request<{ ok: boolean }>('DELETE', `/providers/${providerId}/keys/${keyId}`)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -171,19 +282,19 @@ export const api = {
|
|||||||
provider_id?: string
|
provider_id?: string
|
||||||
}): Promise<PaginatedResponse<Model>> {
|
}): Promise<PaginatedResponse<Model>> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString(params)
|
||||||
return request<PaginatedResponse<Model>>('GET', `/api/models${qs}`)
|
return request<PaginatedResponse<Model>>('GET', `/models${qs}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
async create(data: Partial<Omit<Model, 'id'>>): Promise<Model> {
|
async create(data: Partial<Omit<Model, 'id'>>): Promise<Model> {
|
||||||
return request<Model>('POST', '/api/models', data)
|
return request<Model>('POST', '/models', data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async update(id: string, data: Partial<Omit<Model, 'id'>>): Promise<Model> {
|
async update(id: string, data: Partial<Omit<Model, 'id'>>): Promise<Model> {
|
||||||
return request<Model>('PATCH', `/api/models/${id}`, data)
|
return request<Model>('PATCH', `/models/${id}`, data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async delete(id: string): Promise<void> {
|
async delete(id: string): Promise<void> {
|
||||||
return request<void>('DELETE', `/api/models/${id}`)
|
return request<void>('DELETE', `/models/${id}`)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -194,28 +305,30 @@ export const api = {
|
|||||||
page_size?: number
|
page_size?: number
|
||||||
}): Promise<PaginatedResponse<TokenInfo>> {
|
}): Promise<PaginatedResponse<TokenInfo>> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString(params)
|
||||||
return request<PaginatedResponse<TokenInfo>>('GET', `/api/tokens${qs}`)
|
return request<PaginatedResponse<TokenInfo>>('GET', `/keys${qs}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
async create(data: CreateTokenRequest): Promise<TokenInfo> {
|
async create(data: CreateTokenRequest): Promise<TokenInfo> {
|
||||||
return request<TokenInfo>('POST', '/api/tokens', data)
|
return request<TokenInfo>('POST', '/keys', data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async revoke(id: string): Promise<void> {
|
async revoke(id: string): Promise<void> {
|
||||||
return request<void>('DELETE', `/api/tokens/${id}`)
|
return request<void>('DELETE', `/keys/${id}`)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// ── 用量统计 ──────────────────────────────────────────
|
// ── 用量统计 ──────────────────────────────────────────
|
||||||
usage: {
|
usage: {
|
||||||
async daily(params?: { days?: number }): Promise<UsageRecord[]> {
|
async daily(params?: { days?: number }): Promise<UsageRecord[]> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString({ ...params, group_by: 'day' })
|
||||||
return request<UsageRecord[]>('GET', `/api/usage/daily${qs}`)
|
const result = await request<{ by_day: UsageRecord[] }>('GET', `/usage${qs}`)
|
||||||
|
return result.by_day || []
|
||||||
},
|
},
|
||||||
|
|
||||||
async byModel(params?: { days?: number }): Promise<UsageByModel[]> {
|
async byModel(params?: { days?: number }): Promise<UsageByModel[]> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString({ ...params, group_by: 'model' })
|
||||||
return request<UsageByModel[]>('GET', `/api/usage/by-model${qs}`)
|
const result = await request<{ by_model: UsageByModel[] }>('GET', `/usage${qs}`)
|
||||||
|
return result.by_model || []
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -227,11 +340,11 @@ export const api = {
|
|||||||
status?: string
|
status?: string
|
||||||
}): Promise<PaginatedResponse<RelayTask>> {
|
}): Promise<PaginatedResponse<RelayTask>> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString(params)
|
||||||
return request<PaginatedResponse<RelayTask>>('GET', `/api/relay/tasks${qs}`)
|
return request<PaginatedResponse<RelayTask>>('GET', `/relay/tasks${qs}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
async get(id: string): Promise<RelayTask> {
|
async get(id: string): Promise<RelayTask> {
|
||||||
return request<RelayTask>('GET', `/api/relay/tasks/${id}`)
|
return request<RelayTask>('GET', `/relay/tasks/${id}`)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -239,13 +352,16 @@ export const api = {
|
|||||||
config: {
|
config: {
|
||||||
async list(params?: {
|
async list(params?: {
|
||||||
category?: string
|
category?: string
|
||||||
|
page?: number
|
||||||
|
page_size?: number
|
||||||
}): Promise<ConfigItem[]> {
|
}): Promise<ConfigItem[]> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString(params)
|
||||||
return request<ConfigItem[]>('GET', `/api/config${qs}`)
|
const result = await request<PaginatedResponse<ConfigItem>>('GET', `/config/items${qs}`)
|
||||||
|
return result.items
|
||||||
},
|
},
|
||||||
|
|
||||||
async update(id: string, data: { value: string | number | boolean }): Promise<ConfigItem> {
|
async update(id: string, data: { value: string | number | boolean }): Promise<ConfigItem> {
|
||||||
return request<ConfigItem>('PATCH', `/api/config/${id}`, data)
|
return request<ConfigItem>('PATCH', `/config/items/${id}`, data)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -257,14 +373,149 @@ export const api = {
|
|||||||
action?: string
|
action?: string
|
||||||
}): Promise<PaginatedResponse<OperationLog>> {
|
}): Promise<PaginatedResponse<OperationLog>> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString(params)
|
||||||
return request<PaginatedResponse<OperationLog>>('GET', `/api/logs${qs}`)
|
return request<PaginatedResponse<OperationLog>>('GET', `/logs/operations${qs}`)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// ── 仪表盘 ────────────────────────────────────────────
|
// ── 仪表盘 ────────────────────────────────────────────
|
||||||
stats: {
|
stats: {
|
||||||
async dashboard(): Promise<DashboardStats> {
|
async dashboard(): Promise<DashboardStats> {
|
||||||
return request<DashboardStats>('GET', '/api/stats/dashboard')
|
return request<DashboardStats>('GET', '/stats/dashboard')
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// ── 提示词管理 ────────────────────────────────────────
|
||||||
|
prompts: {
|
||||||
|
async list(params?: {
|
||||||
|
category?: string
|
||||||
|
source?: string
|
||||||
|
status?: string
|
||||||
|
page?: number
|
||||||
|
page_size?: number
|
||||||
|
}): Promise<PaginatedResponse<PromptTemplate>> {
|
||||||
|
const qs = buildQueryString(params)
|
||||||
|
return request<PaginatedResponse<PromptTemplate>>('GET', `/prompts${qs}`)
|
||||||
|
},
|
||||||
|
|
||||||
|
async get(name: string): Promise<PromptTemplate> {
|
||||||
|
return request<PromptTemplate>('GET', `/prompts/${encodeURIComponent(name)}`)
|
||||||
|
},
|
||||||
|
|
||||||
|
async create(data: {
|
||||||
|
name: string
|
||||||
|
category: string
|
||||||
|
description?: string
|
||||||
|
source?: string
|
||||||
|
system_prompt: string
|
||||||
|
user_prompt_template?: string
|
||||||
|
variables?: unknown[]
|
||||||
|
min_app_version?: string
|
||||||
|
}): Promise<PromptTemplate> {
|
||||||
|
return request<PromptTemplate>('POST', '/prompts', data)
|
||||||
|
},
|
||||||
|
|
||||||
|
async update(name: string, data: {
|
||||||
|
description?: string
|
||||||
|
status?: string
|
||||||
|
}): Promise<PromptTemplate> {
|
||||||
|
return request<PromptTemplate>('PUT', `/prompts/${encodeURIComponent(name)}`, data)
|
||||||
|
},
|
||||||
|
|
||||||
|
async archive(name: string): Promise<PromptTemplate> {
|
||||||
|
return request<PromptTemplate>('DELETE', `/prompts/${encodeURIComponent(name)}`)
|
||||||
|
},
|
||||||
|
|
||||||
|
async listVersions(name: string): Promise<PromptVersion[]> {
|
||||||
|
return request<PromptVersion[]>('GET', `/prompts/${encodeURIComponent(name)}/versions`)
|
||||||
|
},
|
||||||
|
|
||||||
|
async createVersion(name: string, data: {
|
||||||
|
system_prompt: string
|
||||||
|
user_prompt_template?: string
|
||||||
|
variables?: unknown[]
|
||||||
|
changelog?: string
|
||||||
|
min_app_version?: string
|
||||||
|
}): Promise<PromptVersion> {
|
||||||
|
return request<PromptVersion>('POST', `/prompts/${encodeURIComponent(name)}/versions`, data)
|
||||||
|
},
|
||||||
|
|
||||||
|
async rollback(name: string, version: number): Promise<PromptTemplate> {
|
||||||
|
return request<PromptTemplate>('POST', `/prompts/${encodeURIComponent(name)}/rollback/${version}`)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// ── Agent 配置模板 ──────────────────────────────────
|
||||||
|
agentTemplates: {
|
||||||
|
async list(params?: {
|
||||||
|
category?: string
|
||||||
|
source?: string
|
||||||
|
visibility?: string
|
||||||
|
status?: string
|
||||||
|
page?: number
|
||||||
|
page_size?: number
|
||||||
|
}): Promise<PaginatedResponse<AgentTemplate>> {
|
||||||
|
const qs = buildQueryString(params)
|
||||||
|
return request<PaginatedResponse<AgentTemplate>>('GET', `/agent-templates${qs}`)
|
||||||
|
},
|
||||||
|
|
||||||
|
async get(id: string): Promise<AgentTemplate> {
|
||||||
|
return request<AgentTemplate>('GET', `/agent-templates/${id}`)
|
||||||
|
},
|
||||||
|
|
||||||
|
async create(data: {
|
||||||
|
name: string
|
||||||
|
description?: string
|
||||||
|
category?: string
|
||||||
|
source?: string
|
||||||
|
model?: string
|
||||||
|
system_prompt?: string
|
||||||
|
tools?: string[]
|
||||||
|
capabilities?: string[]
|
||||||
|
temperature?: number
|
||||||
|
max_tokens?: number
|
||||||
|
visibility?: string
|
||||||
|
}): Promise<AgentTemplate> {
|
||||||
|
return request<AgentTemplate>('POST', '/agent-templates', data)
|
||||||
|
},
|
||||||
|
|
||||||
|
async update(id: string, data: {
|
||||||
|
description?: string
|
||||||
|
model?: string
|
||||||
|
system_prompt?: string
|
||||||
|
tools?: string[]
|
||||||
|
capabilities?: string[]
|
||||||
|
temperature?: number
|
||||||
|
max_tokens?: number
|
||||||
|
visibility?: string
|
||||||
|
status?: string
|
||||||
|
}): Promise<AgentTemplate> {
|
||||||
|
return request<AgentTemplate>('POST', `/agent-templates/${id}`, data)
|
||||||
|
},
|
||||||
|
|
||||||
|
async archive(id: string): Promise<AgentTemplate> {
|
||||||
|
return request<AgentTemplate>('DELETE', `/agent-templates/${id}`)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// ── 遥测统计 ──────────────────────────────────────────
|
||||||
|
telemetry: {
|
||||||
|
/** 按模型聚合用量统计 */
|
||||||
|
async modelStats(params?: {
|
||||||
|
from?: string
|
||||||
|
to?: string
|
||||||
|
model_id?: string
|
||||||
|
connection_mode?: string
|
||||||
|
}): Promise<ModelUsageStat[]> {
|
||||||
|
const qs = buildQueryString(params)
|
||||||
|
return request<ModelUsageStat[]>('GET', `/telemetry/stats${qs}`)
|
||||||
|
},
|
||||||
|
|
||||||
|
/** 按天聚合用量统计 */
|
||||||
|
async dailyStats(params?: {
|
||||||
|
days?: number
|
||||||
|
}): Promise<DailyUsageStat[]> {
|
||||||
|
const qs = buildQueryString(params)
|
||||||
|
return request<DailyUsageStat[]>('GET', `/telemetry/daily${qs}`)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
13
admin/src/lib/api-error.ts
Normal file
13
admin/src/lib/api-error.ts
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
// ============================================================
|
||||||
|
// API Error 类 — 与 swr-fetcher 共享
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
export class ApiRequestError extends Error {
|
||||||
|
constructor(
|
||||||
|
public status: number,
|
||||||
|
public body: { error?: string; message?: string },
|
||||||
|
) {
|
||||||
|
super(body.message || `Request failed with status ${status}`)
|
||||||
|
this.name = 'ApiRequestError'
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -21,6 +21,13 @@ export function logout(): void {
|
|||||||
localStorage.removeItem(ACCOUNT_KEY)
|
localStorage.removeItem(ACCOUNT_KEY)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** 清除认证状态(用于 Token 验证失败时) */
|
||||||
|
export function clearAuth(): void {
|
||||||
|
if (typeof window === 'undefined') return
|
||||||
|
localStorage.removeItem(TOKEN_KEY)
|
||||||
|
localStorage.removeItem(ACCOUNT_KEY)
|
||||||
|
}
|
||||||
|
|
||||||
/** 获取 JWT token */
|
/** 获取 JWT token */
|
||||||
export function getToken(): string | null {
|
export function getToken(): string | null {
|
||||||
if (typeof window === 'undefined') return null
|
if (typeof window === 'undefined') return null
|
||||||
|
|||||||
60
admin/src/lib/swr-fetcher.ts
Normal file
60
admin/src/lib/swr-fetcher.ts
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// ============================================================
|
||||||
|
// SWR fetcher — 将 SWR key 映射到 api-client 调用
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
import { api } from './api-client'
|
||||||
|
import { ApiRequestError } from './api-client'
|
||||||
|
|
||||||
|
type ApiMethod = typeof api
|
||||||
|
|
||||||
|
/** SWR fetcher: key 可以是字符串或 [method-path, params] 元组 */
|
||||||
|
type SwrKey =
|
||||||
|
| string
|
||||||
|
| [string, ...unknown[]]
|
||||||
|
|
||||||
|
async function resolveApiCall(key: SwrKey): Promise<unknown> {
|
||||||
|
if (typeof key === 'string') {
|
||||||
|
// 简单字符串 key,直接 fetch
|
||||||
|
return fetchGeneric(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
const [path, ...args] = key
|
||||||
|
return callByPath(path, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
async function fetchGeneric(path: string): Promise<unknown> {
|
||||||
|
const res = await fetch(path, {
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if (!res.ok) {
|
||||||
|
const body = await res.json().catch(() => ({ error: 'unknown', message: `请求失败 (${res.status})` }))
|
||||||
|
throw new ApiRequestError(res.status, body)
|
||||||
|
}
|
||||||
|
if (res.status === 204) return null
|
||||||
|
return res.json()
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 根据 path 调用对应的 api 方法 */
|
||||||
|
async function callByPath(path: string, args: unknown[]): Promise<unknown> {
|
||||||
|
const parts = path.split('.')
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||||
|
let target: any = api
|
||||||
|
for (const part of parts) {
|
||||||
|
target = target[part]
|
||||||
|
if (!target) throw new Error(`API method not found: ${path}`)
|
||||||
|
}
|
||||||
|
return target(...args)
|
||||||
|
}
|
||||||
|
|
||||||
|
export const swrFetcher = <T = unknown>(key: SwrKey): Promise<T> =>
|
||||||
|
resolveApiCall(key) as Promise<T>
|
||||||
|
|
||||||
|
/** 创建 SWR key helper — 类型安全 */
|
||||||
|
export function createKey<TMethod extends string>(
|
||||||
|
method: TMethod,
|
||||||
|
...args: unknown[]
|
||||||
|
): [TMethod, ...unknown[]] {
|
||||||
|
return [method, ...args]
|
||||||
|
}
|
||||||
26
admin/src/lib/swr-provider.tsx
Normal file
26
admin/src/lib/swr-provider.tsx
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
'use client'
|
||||||
|
|
||||||
|
import { SWRConfig } from 'swr'
|
||||||
|
import type { ReactNode } from 'react'
|
||||||
|
|
||||||
|
export function SWRProvider({ children }: { children: ReactNode }) {
|
||||||
|
return (
|
||||||
|
<SWRConfig
|
||||||
|
value={{
|
||||||
|
revalidateOnFocus: false,
|
||||||
|
dedupingInterval: 5000,
|
||||||
|
errorRetryCount: 2,
|
||||||
|
errorRetryInterval: 3000,
|
||||||
|
shouldRetryOnError: (err: unknown) => {
|
||||||
|
if (err && typeof err === 'object' && 'status' in err) {
|
||||||
|
const status = (err as { status: number }).status
|
||||||
|
return status !== 401 && status !== 403
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{children}
|
||||||
|
</SWRConfig>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -18,6 +18,7 @@ export interface AccountPublic {
|
|||||||
export interface LoginRequest {
|
export interface LoginRequest {
|
||||||
username: string
|
username: string
|
||||||
password: string
|
password: string
|
||||||
|
totp_code?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 登录响应 */
|
/** 登录响应 */
|
||||||
@@ -167,3 +168,127 @@ export interface ApiError {
|
|||||||
message: string
|
message: string
|
||||||
status?: number
|
status?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── 提示词模板 ────────────────────────────────────────────
|
||||||
|
|
||||||
|
/** 提示词模板 */
|
||||||
|
export interface PromptTemplate {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
category: string
|
||||||
|
description?: string
|
||||||
|
source: 'builtin' | 'custom'
|
||||||
|
current_version: number
|
||||||
|
status: 'active' | 'deprecated' | 'archived'
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 提示词版本 */
|
||||||
|
export interface PromptVersion {
|
||||||
|
id: string
|
||||||
|
template_id: string
|
||||||
|
version: number
|
||||||
|
system_prompt: string
|
||||||
|
user_prompt_template?: string
|
||||||
|
variables: PromptVariable[]
|
||||||
|
changelog?: string
|
||||||
|
min_app_version?: string
|
||||||
|
created_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 提示词变量定义 */
|
||||||
|
export interface PromptVariable {
|
||||||
|
name: string
|
||||||
|
type: 'string' | 'number' | 'select' | 'boolean'
|
||||||
|
default_value?: string
|
||||||
|
description?: string
|
||||||
|
required?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
/** OTA 更新检查请求 */
|
||||||
|
export interface PromptCheckRequest {
|
||||||
|
device_id: string
|
||||||
|
versions: Record<string, number>
|
||||||
|
}
|
||||||
|
|
||||||
|
/** OTA 更新响应 */
|
||||||
|
export interface PromptCheckResponse {
|
||||||
|
updates: PromptUpdatePayload[]
|
||||||
|
server_time: string
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 单个更新载荷 */
|
||||||
|
export interface PromptUpdatePayload {
|
||||||
|
name: string
|
||||||
|
version: number
|
||||||
|
system_prompt: string
|
||||||
|
user_prompt_template?: string
|
||||||
|
variables: PromptVariable[]
|
||||||
|
source: string
|
||||||
|
min_app_version?: string
|
||||||
|
changelog?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Agent 配置模板 ────────────────────────────────────────
|
||||||
|
|
||||||
|
/** Agent 模板 */
|
||||||
|
export interface AgentTemplate {
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
description?: string
|
||||||
|
category: string
|
||||||
|
source: 'builtin' | 'custom'
|
||||||
|
model?: string
|
||||||
|
system_prompt?: string
|
||||||
|
tools: string[]
|
||||||
|
capabilities: string[]
|
||||||
|
temperature?: number
|
||||||
|
max_tokens?: number
|
||||||
|
visibility: 'public' | 'team' | 'private'
|
||||||
|
status: 'active' | 'archived'
|
||||||
|
current_version: number
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Provider Key Pool ─────────────────────────────────────
|
||||||
|
|
||||||
|
/** Provider Key */
|
||||||
|
export interface ProviderKey {
|
||||||
|
id: string
|
||||||
|
provider_id: string
|
||||||
|
key_label: string
|
||||||
|
priority: number
|
||||||
|
max_rpm?: number
|
||||||
|
max_tpm?: number
|
||||||
|
quota_reset_interval?: string
|
||||||
|
is_active: boolean
|
||||||
|
last_429_at?: string
|
||||||
|
cooldown_until?: string
|
||||||
|
total_requests: number
|
||||||
|
total_tokens: number
|
||||||
|
created_at: string
|
||||||
|
updated_at: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 遥测统计 ────────────────────────────────────────────
|
||||||
|
|
||||||
|
/** 按模型聚合的用量统计 */
|
||||||
|
export interface ModelUsageStat {
|
||||||
|
model_id: string
|
||||||
|
request_count: number
|
||||||
|
input_tokens: number
|
||||||
|
output_tokens: number
|
||||||
|
avg_latency_ms: number | null
|
||||||
|
success_rate: number
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 按天的用量统计 */
|
||||||
|
export interface DailyUsageStat {
|
||||||
|
day: string
|
||||||
|
request_count: number
|
||||||
|
input_tokens: number
|
||||||
|
output_tokens: number
|
||||||
|
unique_devices: number
|
||||||
|
}
|
||||||
|
|||||||
@@ -289,6 +289,44 @@ impl sqlx::FromRow<'_, SqliteRow> for MemoryRow {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Private helper methods on SqliteStorage (NOT in impl VikingStorage block)
|
||||||
|
impl SqliteStorage {
|
||||||
|
/// Fetch memories by scope with importance-based ordering.
|
||||||
|
/// Used internally by find() for scope-based queries.
|
||||||
|
pub(crate) async fn fetch_by_scope_priv(&self, scope: Option<&str>, limit: usize) -> Result<Vec<MemoryRow>> {
|
||||||
|
let rows = if let Some(scope) = scope {
|
||||||
|
sqlx::query_as::<_, MemoryRow>(
|
||||||
|
r#"
|
||||||
|
SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed, overview, abstract_summary
|
||||||
|
FROM memories
|
||||||
|
WHERE uri LIKE ?
|
||||||
|
ORDER BY importance DESC, access_count DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#
|
||||||
|
)
|
||||||
|
.bind(format!("{}%", scope))
|
||||||
|
.bind(limit as i64)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ZclawError::StorageError(format!("Failed to fetch by scope: {}", e)))?
|
||||||
|
} else {
|
||||||
|
sqlx::query_as::<_, MemoryRow>(
|
||||||
|
r#"
|
||||||
|
SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed, overview, abstract_summary
|
||||||
|
FROM memories
|
||||||
|
ORDER BY importance DESC
|
||||||
|
LIMIT ?
|
||||||
|
"#
|
||||||
|
)
|
||||||
|
.bind(limit as i64)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ZclawError::StorageError(format!("Failed to fetch by scope: {}", e)))?
|
||||||
|
};
|
||||||
|
Ok(rows)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl VikingStorage for SqliteStorage {
|
impl VikingStorage for SqliteStorage {
|
||||||
async fn store(&self, entry: &MemoryEntry) -> Result<()> {
|
async fn store(&self, entry: &MemoryEntry) -> Result<()> {
|
||||||
@@ -374,22 +412,61 @@ impl VikingStorage for SqliteStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
|
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
|
||||||
// Get all matching entries
|
let limit = options.limit.unwrap_or(50).max(20); // Fetch more candidates for reranking
|
||||||
let rows = if let Some(ref scope) = options.scope {
|
|
||||||
|
// Strategy: use FTS5 for initial filtering when query is non-empty,
|
||||||
|
// then score candidates with TF-IDF / embedding for precise ranking.
|
||||||
|
// Fallback to scope-only scan when query is empty (e.g., "list all").
|
||||||
|
let rows = if !query.is_empty() {
|
||||||
|
// FTS5-powered candidate retrieval (fast, index-based)
|
||||||
|
let fts_candidates = if let Some(ref scope) = options.scope {
|
||||||
sqlx::query_as::<_, MemoryRow>(
|
sqlx::query_as::<_, MemoryRow>(
|
||||||
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed, overview, abstract_summary FROM memories WHERE uri LIKE ?"
|
r#"
|
||||||
|
SELECT m.uri, m.memory_type, m.content, m.keywords, m.importance,
|
||||||
|
m.access_count, m.created_at, m.last_accessed, m.overview, m.abstract_summary
|
||||||
|
FROM memories m
|
||||||
|
INNER JOIN memories_fts f ON m.uri = f.uri
|
||||||
|
WHERE f.memories_fts MATCH ?
|
||||||
|
AND m.uri LIKE ?
|
||||||
|
ORDER BY f.rank
|
||||||
|
LIMIT ?
|
||||||
|
"#
|
||||||
)
|
)
|
||||||
|
.bind(query)
|
||||||
.bind(format!("{}%", scope))
|
.bind(format!("{}%", scope))
|
||||||
|
.bind(limit as i64)
|
||||||
.fetch_all(&self.pool)
|
.fetch_all(&self.pool)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| ZclawError::StorageError(format!("Failed to find memories: {}", e)))?
|
|
||||||
} else {
|
} else {
|
||||||
sqlx::query_as::<_, MemoryRow>(
|
sqlx::query_as::<_, MemoryRow>(
|
||||||
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed, overview, abstract_summary FROM memories"
|
r#"
|
||||||
|
SELECT m.uri, m.memory_type, m.content, m.keywords, m.importance,
|
||||||
|
m.access_count, m.created_at, m.last_accessed, m.overview, m.abstract_summary
|
||||||
|
FROM memories m
|
||||||
|
INNER JOIN memories_fts f ON m.uri = f.uri
|
||||||
|
WHERE f.memories_fts MATCH ?
|
||||||
|
ORDER BY f.rank
|
||||||
|
LIMIT ?
|
||||||
|
"#
|
||||||
)
|
)
|
||||||
|
.bind(query)
|
||||||
|
.bind(limit as i64)
|
||||||
.fetch_all(&self.pool)
|
.fetch_all(&self.pool)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| ZclawError::StorageError(format!("Failed to find memories: {}", e)))?
|
};
|
||||||
|
|
||||||
|
match fts_candidates {
|
||||||
|
Ok(rows) if !rows.is_empty() => rows,
|
||||||
|
Ok(_) | Err(_) => {
|
||||||
|
// FTS5 returned nothing or query syntax was invalid —
|
||||||
|
// fallback to scope-based scan (no full table scan unless no scope)
|
||||||
|
tracing::debug!("[SqliteStorage] FTS5 returned no results, falling back to scope scan");
|
||||||
|
self.fetch_by_scope_priv(options.scope.as_deref(), limit).await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Empty query: scope-based scan only (no FTS5 needed)
|
||||||
|
self.fetch_by_scope_priv(options.scope.as_deref(), limit).await?
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert to entries and compute semantic scores
|
// Convert to entries and compute semantic scores
|
||||||
@@ -464,16 +541,8 @@ impl VikingStorage for SqliteStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>> {
|
async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>> {
|
||||||
let rows = sqlx::query_as::<_, MemoryRow>(
|
let rows = self.fetch_by_scope_priv(Some(prefix), 100).await?;
|
||||||
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed, overview, abstract_summary FROM memories WHERE uri LIKE ?"
|
|
||||||
)
|
|
||||||
.bind(format!("{}%", prefix))
|
|
||||||
.fetch_all(&self.pool)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ZclawError::StorageError(format!("Failed to find by prefix: {}", e)))?;
|
|
||||||
|
|
||||||
let entries = rows.iter().map(|row| self.row_to_entry(row)).collect();
|
let entries = rows.iter().map(|row| self.row_to_entry(row)).collect();
|
||||||
|
|
||||||
Ok(entries)
|
Ok(entries)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -484,13 +553,13 @@ impl VikingStorage for SqliteStorage {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| ZclawError::StorageError(format!("Failed to delete memory: {}", e)))?;
|
.map_err(|e| ZclawError::StorageError(format!("Failed to delete memory: {}", e)))?;
|
||||||
|
|
||||||
// Remove from FTS
|
// Remove from FTS index
|
||||||
let _ = sqlx::query("DELETE FROM memories_fts WHERE uri = ?")
|
let _ = sqlx::query("DELETE FROM memories_fts WHERE uri = ?")
|
||||||
.bind(uri)
|
.bind(uri)
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Remove from scorer
|
// Remove from in-memory scorer
|
||||||
let mut scorer = self.scorer.write().await;
|
let mut scorer = self.scorer.write().await;
|
||||||
scorer.remove_entry(uri);
|
scorer.remove_entry(uri);
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,11 @@ pub struct LlmConfig {
|
|||||||
/// Temperature
|
/// Temperature
|
||||||
#[serde(default = "default_temperature")]
|
#[serde(default = "default_temperature")]
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
|
|
||||||
|
/// Context window size in tokens (default: 128000)
|
||||||
|
/// Used to calculate dynamic compaction threshold.
|
||||||
|
#[serde(default = "default_context_window")]
|
||||||
|
pub context_window: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlmConfig {
|
impl LlmConfig {
|
||||||
@@ -66,6 +71,7 @@ impl LlmConfig {
|
|||||||
api_protocol: ApiProtocol::OpenAI,
|
api_protocol: ApiProtocol::OpenAI,
|
||||||
max_tokens: default_max_tokens(),
|
max_tokens: default_max_tokens(),
|
||||||
temperature: default_temperature(),
|
temperature: default_temperature(),
|
||||||
|
context_window: default_context_window(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,6 +146,10 @@ fn default_temperature() -> f32 {
|
|||||||
0.7
|
0.7
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_context_window() -> u32 {
|
||||||
|
128000
|
||||||
|
}
|
||||||
|
|
||||||
impl Default for KernelConfig {
|
impl Default for KernelConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -151,6 +161,7 @@ impl Default for KernelConfig {
|
|||||||
api_protocol: ApiProtocol::OpenAI,
|
api_protocol: ApiProtocol::OpenAI,
|
||||||
max_tokens: default_max_tokens(),
|
max_tokens: default_max_tokens(),
|
||||||
temperature: default_temperature(),
|
temperature: default_temperature(),
|
||||||
|
context_window: default_context_window(),
|
||||||
},
|
},
|
||||||
skills_dir: default_skills_dir(),
|
skills_dir: default_skills_dir(),
|
||||||
}
|
}
|
||||||
@@ -345,6 +356,17 @@ impl KernelConfig {
|
|||||||
pub fn temperature(&self) -> f32 {
|
pub fn temperature(&self) -> f32 {
|
||||||
self.llm.temperature
|
self.llm.temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get context window size in tokens
|
||||||
|
pub fn context_window(&self) -> u32 {
|
||||||
|
self.llm.context_window
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Dynamic compaction threshold = context_window * 0.6
|
||||||
|
/// Leaves 40% headroom for system prompt + response tokens
|
||||||
|
pub fn compaction_threshold(&self) -> usize {
|
||||||
|
(self.llm.context_window as f64 * 0.6) as usize
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// === Preset configurations for common providers ===
|
// === Preset configurations for common providers ===
|
||||||
|
|||||||
@@ -3,7 +3,9 @@
|
|||||||
use std::pin::Pin;
|
use std::pin::Pin;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::{broadcast, mpsc, Mutex};
|
use tokio::sync::{broadcast, mpsc, Mutex};
|
||||||
use zclaw_types::{AgentConfig, AgentId, AgentInfo, Event, Result};
|
use zclaw_types::{AgentConfig, AgentId, AgentInfo, Capability, Event, Result, HandRun, HandRunId, HandRunStatus, HandRunFilter, TriggerSource};
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
use zclaw_protocols::{A2aRouter, A2aAgentProfile, A2aCapability, A2aEnvelope, A2aMessageType, A2aRecipient};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|
||||||
@@ -12,7 +14,7 @@ use crate::capabilities::CapabilityManager;
|
|||||||
use crate::events::EventBus;
|
use crate::events::EventBus;
|
||||||
use crate::config::KernelConfig;
|
use crate::config::KernelConfig;
|
||||||
use zclaw_memory::MemoryStore;
|
use zclaw_memory::MemoryStore;
|
||||||
use zclaw_runtime::{AgentLoop, LlmDriver, ToolRegistry, tool::SkillExecutor};
|
use zclaw_runtime::{AgentLoop, LlmDriver, ToolRegistry, tool::SkillExecutor, tool::builtin::PathValidator};
|
||||||
use zclaw_skills::SkillRegistry;
|
use zclaw_skills::SkillRegistry;
|
||||||
use zclaw_skills::LlmCompleter;
|
use zclaw_skills::LlmCompleter;
|
||||||
use zclaw_hands::{HandRegistry, HandContext, HandResult, hands::{BrowserHand, SlideshowHand, SpeechHand, QuizHand, WhiteboardHand, ResearcherHand, CollectorHand, ClipHand, TwitterHand, quiz::LlmQuizGenerator}};
|
use zclaw_hands::{HandRegistry, HandContext, HandResult, hands::{BrowserHand, SlideshowHand, SpeechHand, QuizHand, WhiteboardHand, ResearcherHand, CollectorHand, ClipHand, TwitterHand, quiz::LlmQuizGenerator}};
|
||||||
@@ -20,6 +22,8 @@ use zclaw_hands::{HandRegistry, HandContext, HandResult, hands::{BrowserHand, Sl
|
|||||||
/// Adapter that bridges `zclaw_runtime::LlmDriver` → `zclaw_skills::LlmCompleter`
|
/// Adapter that bridges `zclaw_runtime::LlmDriver` → `zclaw_skills::LlmCompleter`
|
||||||
struct LlmDriverAdapter {
|
struct LlmDriverAdapter {
|
||||||
driver: Arc<dyn LlmDriver>,
|
driver: Arc<dyn LlmDriver>,
|
||||||
|
max_tokens: u32,
|
||||||
|
temperature: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl zclaw_skills::LlmCompleter for LlmDriverAdapter {
|
impl zclaw_skills::LlmCompleter for LlmDriverAdapter {
|
||||||
@@ -32,8 +36,8 @@ impl zclaw_skills::LlmCompleter for LlmDriverAdapter {
|
|||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let request = zclaw_runtime::CompletionRequest {
|
let request = zclaw_runtime::CompletionRequest {
|
||||||
messages: vec![zclaw_types::Message::user(prompt)],
|
messages: vec![zclaw_types::Message::user(prompt)],
|
||||||
max_tokens: Some(4096),
|
max_tokens: Some(self.max_tokens),
|
||||||
temperature: Some(0.7),
|
temperature: Some(self.temperature),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let response = driver.complete(request).await
|
let response = driver.complete(request).await
|
||||||
@@ -59,7 +63,7 @@ pub struct KernelSkillExecutor {
|
|||||||
|
|
||||||
impl KernelSkillExecutor {
|
impl KernelSkillExecutor {
|
||||||
pub fn new(skills: Arc<SkillRegistry>, driver: Arc<dyn LlmDriver>) -> Self {
|
pub fn new(skills: Arc<SkillRegistry>, driver: Arc<dyn LlmDriver>) -> Self {
|
||||||
let llm: Arc<dyn zclaw_skills::LlmCompleter> = Arc::new(LlmDriverAdapter { driver });
|
let llm: Arc<dyn zclaw_skills::LlmCompleter> = Arc::new(LlmDriverAdapter { driver, max_tokens: 4096, temperature: 0.7 });
|
||||||
Self { skills, llm }
|
Self { skills, llm }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -98,6 +102,14 @@ pub struct Kernel {
|
|||||||
hands: Arc<HandRegistry>,
|
hands: Arc<HandRegistry>,
|
||||||
trigger_manager: crate::trigger_manager::TriggerManager,
|
trigger_manager: crate::trigger_manager::TriggerManager,
|
||||||
pending_approvals: Arc<Mutex<Vec<ApprovalEntry>>>,
|
pending_approvals: Arc<Mutex<Vec<ApprovalEntry>>>,
|
||||||
|
/// Running hand runs that can be cancelled (run_id -> cancelled flag)
|
||||||
|
running_hand_runs: Arc<dashmap::DashMap<HandRunId, Arc<std::sync::atomic::AtomicBool>>>,
|
||||||
|
/// A2A router for inter-agent messaging (gated by multi-agent feature)
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
a2a_router: Arc<A2aRouter>,
|
||||||
|
/// Per-agent A2A inbox receivers
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
a2a_inboxes: Arc<dashmap::DashMap<AgentId, Arc<Mutex<mpsc::Receiver<A2aEnvelope>>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Kernel {
|
impl Kernel {
|
||||||
@@ -143,7 +155,11 @@ impl Kernel {
|
|||||||
|
|
||||||
// Create LLM completer for skill system (shared with skill_executor)
|
// Create LLM completer for skill system (shared with skill_executor)
|
||||||
let llm_completer: Arc<dyn zclaw_skills::LlmCompleter> =
|
let llm_completer: Arc<dyn zclaw_skills::LlmCompleter> =
|
||||||
Arc::new(LlmDriverAdapter { driver: driver.clone() });
|
Arc::new(LlmDriverAdapter {
|
||||||
|
driver: driver.clone(),
|
||||||
|
max_tokens: config.max_tokens(),
|
||||||
|
temperature: config.temperature(),
|
||||||
|
});
|
||||||
|
|
||||||
// Initialize trigger manager
|
// Initialize trigger manager
|
||||||
let trigger_manager = crate::trigger_manager::TriggerManager::new(hands.clone());
|
let trigger_manager = crate::trigger_manager::TriggerManager::new(hands.clone());
|
||||||
@@ -154,6 +170,13 @@ impl Kernel {
|
|||||||
registry.register(agent);
|
registry.register(agent);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize A2A router for multi-agent support
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
let a2a_router = {
|
||||||
|
let kernel_agent_id = AgentId::new();
|
||||||
|
Arc::new(A2aRouter::new(kernel_agent_id))
|
||||||
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
config,
|
config,
|
||||||
registry,
|
registry,
|
||||||
@@ -167,6 +190,11 @@ impl Kernel {
|
|||||||
hands,
|
hands,
|
||||||
trigger_manager,
|
trigger_manager,
|
||||||
pending_approvals: Arc::new(Mutex::new(Vec::new())),
|
pending_approvals: Arc::new(Mutex::new(Vec::new())),
|
||||||
|
running_hand_runs: Arc::new(dashmap::DashMap::new()),
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
a2a_router,
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
a2a_inboxes: Arc::new(dashmap::DashMap::new()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,8 +322,17 @@ impl Kernel {
|
|||||||
self.memory.save_agent(&config).await?;
|
self.memory.save_agent(&config).await?;
|
||||||
|
|
||||||
// Register in registry
|
// Register in registry
|
||||||
|
let config_clone = config.clone();
|
||||||
self.registry.register(config);
|
self.registry.register(config);
|
||||||
|
|
||||||
|
// Register with A2A router for multi-agent messaging
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
{
|
||||||
|
let profile = Self::agent_config_to_a2a_profile(&config_clone);
|
||||||
|
let rx = self.a2a_router.register_agent(profile).await;
|
||||||
|
self.a2a_inboxes.insert(id, Arc::new(Mutex::new(rx)));
|
||||||
|
}
|
||||||
|
|
||||||
// Emit event
|
// Emit event
|
||||||
self.events.publish(Event::AgentSpawned {
|
self.events.publish(Event::AgentSpawned {
|
||||||
agent_id: id,
|
agent_id: id,
|
||||||
@@ -313,6 +350,13 @@ impl Kernel {
|
|||||||
// Remove from memory
|
// Remove from memory
|
||||||
self.memory.delete_agent(id).await?;
|
self.memory.delete_agent(id).await?;
|
||||||
|
|
||||||
|
// Unregister from A2A router
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
{
|
||||||
|
self.a2a_router.unregister_agent(id).await;
|
||||||
|
self.a2a_inboxes.remove(id);
|
||||||
|
}
|
||||||
|
|
||||||
// Emit event
|
// Emit event
|
||||||
self.events.publish(Event::AgentTerminated {
|
self.events.publish(Event::AgentTerminated {
|
||||||
agent_id: *id,
|
agent_id: *id,
|
||||||
@@ -346,7 +390,7 @@ impl Kernel {
|
|||||||
|
|
||||||
// Create agent loop with model configuration
|
// Create agent loop with model configuration
|
||||||
let tools = self.create_tool_registry();
|
let tools = self.create_tool_registry();
|
||||||
let loop_runner = AgentLoop::new(
|
let mut loop_runner = AgentLoop::new(
|
||||||
*agent_id,
|
*agent_id,
|
||||||
self.driver.clone(),
|
self.driver.clone(),
|
||||||
tools,
|
tools,
|
||||||
@@ -356,7 +400,22 @@ impl Kernel {
|
|||||||
.with_skill_executor(self.skill_executor.clone())
|
.with_skill_executor(self.skill_executor.clone())
|
||||||
.with_max_tokens(agent_config.max_tokens.unwrap_or_else(|| self.config.max_tokens()))
|
.with_max_tokens(agent_config.max_tokens.unwrap_or_else(|| self.config.max_tokens()))
|
||||||
.with_temperature(agent_config.temperature.unwrap_or_else(|| self.config.temperature()))
|
.with_temperature(agent_config.temperature.unwrap_or_else(|| self.config.temperature()))
|
||||||
.with_compaction_threshold(15_000); // Compact when context exceeds ~15k tokens
|
.with_compaction_threshold(
|
||||||
|
agent_config.compaction_threshold
|
||||||
|
.map(|t| t as usize)
|
||||||
|
.unwrap_or_else(|| self.config.compaction_threshold()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Set path validator from agent's workspace directory (if configured)
|
||||||
|
if let Some(ref workspace) = agent_config.workspace {
|
||||||
|
let path_validator = PathValidator::new().with_workspace(workspace.clone());
|
||||||
|
tracing::info!(
|
||||||
|
"[Kernel] Setting path_validator with workspace: {} for agent {}",
|
||||||
|
workspace.display(),
|
||||||
|
agent_id
|
||||||
|
);
|
||||||
|
loop_runner = loop_runner.with_path_validator(path_validator);
|
||||||
|
}
|
||||||
|
|
||||||
// Build system prompt with skill information injected
|
// Build system prompt with skill information injected
|
||||||
let system_prompt = self.build_system_prompt_with_skills(agent_config.system_prompt.as_ref()).await;
|
let system_prompt = self.build_system_prompt_with_skills(agent_config.system_prompt.as_ref()).await;
|
||||||
@@ -378,21 +437,35 @@ impl Kernel {
|
|||||||
agent_id: &AgentId,
|
agent_id: &AgentId,
|
||||||
message: String,
|
message: String,
|
||||||
) -> Result<mpsc::Receiver<zclaw_runtime::LoopEvent>> {
|
) -> Result<mpsc::Receiver<zclaw_runtime::LoopEvent>> {
|
||||||
self.send_message_stream_with_prompt(agent_id, message, None).await
|
self.send_message_stream_with_prompt(agent_id, message, None, None).await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send a message with streaming and optional external system prompt
|
/// Send a message with streaming, optional system prompt, and optional session reuse
|
||||||
pub async fn send_message_stream_with_prompt(
|
pub async fn send_message_stream_with_prompt(
|
||||||
&self,
|
&self,
|
||||||
agent_id: &AgentId,
|
agent_id: &AgentId,
|
||||||
message: String,
|
message: String,
|
||||||
system_prompt_override: Option<String>,
|
system_prompt_override: Option<String>,
|
||||||
|
session_id_override: Option<zclaw_types::SessionId>,
|
||||||
) -> Result<mpsc::Receiver<zclaw_runtime::LoopEvent>> {
|
) -> Result<mpsc::Receiver<zclaw_runtime::LoopEvent>> {
|
||||||
let agent_config = self.registry.get(agent_id)
|
let agent_config = self.registry.get(agent_id)
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Agent not found: {}", agent_id)))?;
|
.ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Agent not found: {}", agent_id)))?;
|
||||||
|
|
||||||
// Create session
|
// Reuse existing session or create new one
|
||||||
let session_id = self.memory.create_session(agent_id).await?;
|
let session_id = match session_id_override {
|
||||||
|
Some(id) => {
|
||||||
|
// Verify the session exists; if not, create a new one
|
||||||
|
let existing = self.memory.get_messages(&id).await;
|
||||||
|
match existing {
|
||||||
|
Ok(msgs) if !msgs.is_empty() => id,
|
||||||
|
_ => {
|
||||||
|
tracing::debug!("Session {} not found or empty, creating new session", id);
|
||||||
|
self.memory.create_session(agent_id).await?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None => self.memory.create_session(agent_id).await?,
|
||||||
|
};
|
||||||
|
|
||||||
// Always use Kernel's current model configuration
|
// Always use Kernel's current model configuration
|
||||||
// This ensures user's "模型与 API" settings are respected
|
// This ensures user's "模型与 API" settings are respected
|
||||||
@@ -400,7 +473,7 @@ impl Kernel {
|
|||||||
|
|
||||||
// Create agent loop with model configuration
|
// Create agent loop with model configuration
|
||||||
let tools = self.create_tool_registry();
|
let tools = self.create_tool_registry();
|
||||||
let loop_runner = AgentLoop::new(
|
let mut loop_runner = AgentLoop::new(
|
||||||
*agent_id,
|
*agent_id,
|
||||||
self.driver.clone(),
|
self.driver.clone(),
|
||||||
tools,
|
tools,
|
||||||
@@ -410,7 +483,23 @@ impl Kernel {
|
|||||||
.with_skill_executor(self.skill_executor.clone())
|
.with_skill_executor(self.skill_executor.clone())
|
||||||
.with_max_tokens(agent_config.max_tokens.unwrap_or_else(|| self.config.max_tokens()))
|
.with_max_tokens(agent_config.max_tokens.unwrap_or_else(|| self.config.max_tokens()))
|
||||||
.with_temperature(agent_config.temperature.unwrap_or_else(|| self.config.temperature()))
|
.with_temperature(agent_config.temperature.unwrap_or_else(|| self.config.temperature()))
|
||||||
.with_compaction_threshold(15_000); // Compact when context exceeds ~15k tokens
|
.with_compaction_threshold(
|
||||||
|
agent_config.compaction_threshold
|
||||||
|
.map(|t| t as usize)
|
||||||
|
.unwrap_or_else(|| self.config.compaction_threshold()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Set path validator from agent's workspace directory (if configured)
|
||||||
|
// This enables file_read / file_write tools to access the workspace
|
||||||
|
if let Some(ref workspace) = agent_config.workspace {
|
||||||
|
let path_validator = PathValidator::new().with_workspace(workspace.clone());
|
||||||
|
tracing::info!(
|
||||||
|
"[Kernel] Setting path_validator with workspace: {} for agent {}",
|
||||||
|
workspace.display(),
|
||||||
|
agent_id
|
||||||
|
);
|
||||||
|
loop_runner = loop_runner.with_path_validator(path_validator);
|
||||||
|
}
|
||||||
|
|
||||||
// Use external prompt if provided, otherwise build default
|
// Use external prompt if provided, otherwise build default
|
||||||
let system_prompt = match system_prompt_override {
|
let system_prompt = match system_prompt_override {
|
||||||
@@ -489,15 +578,194 @@ impl Kernel {
|
|||||||
self.hands.list().await
|
self.hands.list().await
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute a hand with the given input
|
/// Execute a hand with the given input, tracking the run
|
||||||
pub async fn execute_hand(
|
pub async fn execute_hand(
|
||||||
&self,
|
&self,
|
||||||
hand_id: &str,
|
hand_id: &str,
|
||||||
input: serde_json::Value,
|
input: serde_json::Value,
|
||||||
) -> Result<HandResult> {
|
) -> Result<(HandResult, HandRunId)> {
|
||||||
// Use default context (agent_id will be generated)
|
let run_id = HandRunId::new();
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
|
// Create the initial HandRun record
|
||||||
|
let mut run = HandRun {
|
||||||
|
id: run_id,
|
||||||
|
hand_name: hand_id.to_string(),
|
||||||
|
trigger_source: TriggerSource::Manual,
|
||||||
|
params: input.clone(),
|
||||||
|
status: HandRunStatus::Pending,
|
||||||
|
result: None,
|
||||||
|
error: None,
|
||||||
|
duration_ms: None,
|
||||||
|
created_at: now.clone(),
|
||||||
|
started_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
};
|
||||||
|
self.memory.save_hand_run(&run).await?;
|
||||||
|
|
||||||
|
// Transition to Running
|
||||||
|
run.status = HandRunStatus::Running;
|
||||||
|
run.started_at = Some(chrono::Utc::now().to_rfc3339());
|
||||||
|
self.memory.update_hand_run(&run).await?;
|
||||||
|
|
||||||
|
// Register cancellation flag
|
||||||
|
let cancel_flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||||
|
self.running_hand_runs.insert(run_id, cancel_flag.clone());
|
||||||
|
|
||||||
|
// Execute the hand
|
||||||
let context = HandContext::default();
|
let context = HandContext::default();
|
||||||
self.hands.execute(hand_id, &context, input).await
|
let start = std::time::Instant::now();
|
||||||
|
let hand_result = self.hands.execute(hand_id, &context, input).await;
|
||||||
|
let duration = start.elapsed();
|
||||||
|
|
||||||
|
// Check if cancelled during execution
|
||||||
|
if cancel_flag.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
let mut run_update = run.clone();
|
||||||
|
run_update.status = HandRunStatus::Cancelled;
|
||||||
|
run_update.completed_at = Some(chrono::Utc::now().to_rfc3339());
|
||||||
|
run_update.duration_ms = Some(duration.as_millis() as u64);
|
||||||
|
self.memory.update_hand_run(&run_update).await?;
|
||||||
|
self.running_hand_runs.remove(&run_id);
|
||||||
|
return Err(zclaw_types::ZclawError::Internal("Hand execution cancelled".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove from running map
|
||||||
|
self.running_hand_runs.remove(&run_id);
|
||||||
|
|
||||||
|
// Update HandRun with result
|
||||||
|
let completed_at = chrono::Utc::now().to_rfc3339();
|
||||||
|
match &hand_result {
|
||||||
|
Ok(res) => {
|
||||||
|
run.status = HandRunStatus::Completed;
|
||||||
|
run.result = Some(res.output.clone());
|
||||||
|
run.error = res.error.clone();
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
run.status = HandRunStatus::Failed;
|
||||||
|
run.error = Some(e.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run.duration_ms = Some(duration.as_millis() as u64);
|
||||||
|
run.completed_at = Some(completed_at);
|
||||||
|
self.memory.update_hand_run(&run).await?;
|
||||||
|
|
||||||
|
hand_result.map(|res| (res, run_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a hand with a specific trigger source (for scheduled/event triggers)
|
||||||
|
pub async fn execute_hand_with_source(
|
||||||
|
&self,
|
||||||
|
hand_id: &str,
|
||||||
|
input: serde_json::Value,
|
||||||
|
trigger_source: TriggerSource,
|
||||||
|
) -> Result<(HandResult, HandRunId)> {
|
||||||
|
let run_id = HandRunId::new();
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
|
let mut run = HandRun {
|
||||||
|
id: run_id,
|
||||||
|
hand_name: hand_id.to_string(),
|
||||||
|
trigger_source,
|
||||||
|
params: input.clone(),
|
||||||
|
status: HandRunStatus::Pending,
|
||||||
|
result: None,
|
||||||
|
error: None,
|
||||||
|
duration_ms: None,
|
||||||
|
created_at: now,
|
||||||
|
started_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
};
|
||||||
|
self.memory.save_hand_run(&run).await?;
|
||||||
|
|
||||||
|
run.status = HandRunStatus::Running;
|
||||||
|
run.started_at = Some(chrono::Utc::now().to_rfc3339());
|
||||||
|
self.memory.update_hand_run(&run).await?;
|
||||||
|
|
||||||
|
let cancel_flag = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||||
|
self.running_hand_runs.insert(run_id, cancel_flag.clone());
|
||||||
|
|
||||||
|
let context = HandContext::default();
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let hand_result = self.hands.execute(hand_id, &context, input).await;
|
||||||
|
let duration = start.elapsed();
|
||||||
|
|
||||||
|
// Check if cancelled during execution
|
||||||
|
if cancel_flag.load(std::sync::atomic::Ordering::Relaxed) {
|
||||||
|
run.status = HandRunStatus::Cancelled;
|
||||||
|
run.completed_at = Some(chrono::Utc::now().to_rfc3339());
|
||||||
|
run.duration_ms = Some(duration.as_millis() as u64);
|
||||||
|
self.memory.update_hand_run(&run).await?;
|
||||||
|
self.running_hand_runs.remove(&run_id);
|
||||||
|
return Err(zclaw_types::ZclawError::Internal("Hand execution cancelled".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
self.running_hand_runs.remove(&run_id);
|
||||||
|
|
||||||
|
let completed_at = chrono::Utc::now().to_rfc3339();
|
||||||
|
match &hand_result {
|
||||||
|
Ok(res) => {
|
||||||
|
run.status = HandRunStatus::Completed;
|
||||||
|
run.result = Some(res.output.clone());
|
||||||
|
run.error = res.error.clone();
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
run.status = HandRunStatus::Failed;
|
||||||
|
run.error = Some(e.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run.duration_ms = Some(duration.as_millis() as u64);
|
||||||
|
run.completed_at = Some(completed_at);
|
||||||
|
self.memory.update_hand_run(&run).await?;
|
||||||
|
|
||||||
|
hand_result.map(|res| (res, run_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Hand Run Tracking
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
/// Get a hand run by ID
|
||||||
|
pub async fn get_hand_run(&self, id: &HandRunId) -> Result<Option<HandRun>> {
|
||||||
|
self.memory.get_hand_run(id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List hand runs with filter
|
||||||
|
pub async fn list_hand_runs(&self, filter: &HandRunFilter) -> Result<Vec<HandRun>> {
|
||||||
|
self.memory.list_hand_runs(filter).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Count hand runs matching filter
|
||||||
|
pub async fn count_hand_runs(&self, filter: &HandRunFilter) -> Result<u32> {
|
||||||
|
self.memory.count_hand_runs(filter).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Cancel a running hand execution
|
||||||
|
pub async fn cancel_hand_run(&self, id: &HandRunId) -> Result<()> {
|
||||||
|
if let Some((_, flag)) = self.running_hand_runs.remove(id) {
|
||||||
|
flag.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
|
||||||
|
// Note: the actual status update happens in execute_hand_with_source
|
||||||
|
// when it detects the cancel flag
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
// Not currently running — check if exists at all
|
||||||
|
let run = self.memory.get_hand_run(id).await?;
|
||||||
|
match run {
|
||||||
|
Some(r) if r.status == HandRunStatus::Pending => {
|
||||||
|
let mut updated = r;
|
||||||
|
updated.status = HandRunStatus::Cancelled;
|
||||||
|
updated.completed_at = Some(chrono::Utc::now().to_rfc3339());
|
||||||
|
self.memory.update_hand_run(&updated).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Some(r) => Err(zclaw_types::ZclawError::InvalidInput(
|
||||||
|
format!("Cannot cancel hand run {} with status {}", id, r.status)
|
||||||
|
)),
|
||||||
|
None => Err(zclaw_types::ZclawError::NotFound(
|
||||||
|
format!("Hand run {} not found", id)
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================
|
// ============================================================
|
||||||
@@ -563,6 +831,7 @@ impl Kernel {
|
|||||||
status: "pending".to_string(),
|
status: "pending".to_string(),
|
||||||
created_at: chrono::Utc::now(),
|
created_at: chrono::Utc::now(),
|
||||||
input,
|
input,
|
||||||
|
reject_reason: None,
|
||||||
};
|
};
|
||||||
let mut approvals = self.pending_approvals.lock().await;
|
let mut approvals = self.pending_approvals.lock().await;
|
||||||
approvals.push(entry.clone());
|
approvals.push(entry.clone());
|
||||||
@@ -574,13 +843,16 @@ impl Kernel {
|
|||||||
&self,
|
&self,
|
||||||
id: &str,
|
id: &str,
|
||||||
approved: bool,
|
approved: bool,
|
||||||
_reason: Option<String>,
|
reason: Option<String>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut approvals = self.pending_approvals.lock().await;
|
let mut approvals = self.pending_approvals.lock().await;
|
||||||
let entry = approvals.iter_mut().find(|a| a.id == id && a.status == "pending")
|
let entry = approvals.iter_mut().find(|a| a.id == id && a.status == "pending")
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Approval not found: {}", id)))?;
|
.ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Approval not found: {}", id)))?;
|
||||||
|
|
||||||
entry.status = if approved { "approved".to_string() } else { "rejected".to_string() };
|
entry.status = if approved { "approved".to_string() } else { "rejected".to_string() };
|
||||||
|
if let Some(r) = reason {
|
||||||
|
entry.reject_reason = Some(r);
|
||||||
|
}
|
||||||
|
|
||||||
if approved {
|
if approved {
|
||||||
let hand_id = entry.hand_id.clone();
|
let hand_id = entry.hand_id.clone();
|
||||||
@@ -623,9 +895,268 @@ impl Kernel {
|
|||||||
entry.status = "cancelled".to_string();
|
entry.status = "cancelled".to_string();
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// A2A (Agent-to-Agent) Messaging
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
/// Derive an A2A agent profile from an AgentConfig
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
fn agent_config_to_a2a_profile(config: &AgentConfig) -> A2aAgentProfile {
|
||||||
|
let caps: Vec<A2aCapability> = config.tools.iter().map(|tool_name| {
|
||||||
|
A2aCapability {
|
||||||
|
name: tool_name.clone(),
|
||||||
|
description: format!("Tool: {}", tool_name),
|
||||||
|
input_schema: None,
|
||||||
|
output_schema: None,
|
||||||
|
requires_approval: false,
|
||||||
|
version: "1.0.0".to_string(),
|
||||||
|
tags: vec![],
|
||||||
|
}
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
A2aAgentProfile {
|
||||||
|
id: config.id,
|
||||||
|
name: config.name.clone(),
|
||||||
|
description: config.description.clone().unwrap_or_default(),
|
||||||
|
capabilities: caps,
|
||||||
|
protocols: vec!["a2a".to_string()],
|
||||||
|
role: "worker".to_string(),
|
||||||
|
priority: 5,
|
||||||
|
metadata: std::collections::HashMap::new(),
|
||||||
|
groups: vec![],
|
||||||
|
last_seen: 0,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Approval entry for pending approvals
|
/// Check if an agent is authorized to send messages to a target
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
fn check_a2a_permission(&self, from: &AgentId, to: &AgentId) -> Result<()> {
|
||||||
|
let caps = self.capabilities.get(from);
|
||||||
|
match caps {
|
||||||
|
Some(cap_set) => {
|
||||||
|
let has_permission = cap_set.capabilities.iter().any(|cap| {
|
||||||
|
match cap {
|
||||||
|
Capability::AgentMessage { pattern } => {
|
||||||
|
pattern == "*" || to.to_string().starts_with(pattern)
|
||||||
|
}
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if !has_permission {
|
||||||
|
return Err(zclaw_types::ZclawError::PermissionDenied(
|
||||||
|
format!("Agent {} does not have AgentMessage capability for {}", from, to)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
// No capabilities registered — deny by default
|
||||||
|
Err(zclaw_types::ZclawError::PermissionDenied(
|
||||||
|
format!("Agent {} has no capabilities registered", from)
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a direct A2A message from one agent to another
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
pub async fn a2a_send(
|
||||||
|
&self,
|
||||||
|
from: &AgentId,
|
||||||
|
to: &AgentId,
|
||||||
|
payload: serde_json::Value,
|
||||||
|
message_type: Option<A2aMessageType>,
|
||||||
|
) -> Result<()> {
|
||||||
|
// Validate sender exists
|
||||||
|
self.registry.get(from)
|
||||||
|
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
|
||||||
|
format!("Sender agent not found: {}", from)
|
||||||
|
))?;
|
||||||
|
|
||||||
|
// Validate receiver exists and is running
|
||||||
|
self.registry.get(to)
|
||||||
|
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
|
||||||
|
format!("Target agent not found: {}", to)
|
||||||
|
))?;
|
||||||
|
|
||||||
|
// Check capability permission
|
||||||
|
self.check_a2a_permission(from, to)?;
|
||||||
|
|
||||||
|
// Build and route envelope
|
||||||
|
let envelope = A2aEnvelope::new(
|
||||||
|
*from,
|
||||||
|
A2aRecipient::Direct { agent_id: *to },
|
||||||
|
message_type.unwrap_or(A2aMessageType::Notification),
|
||||||
|
payload,
|
||||||
|
);
|
||||||
|
|
||||||
|
self.a2a_router.route(envelope).await?;
|
||||||
|
|
||||||
|
// Emit event
|
||||||
|
self.events.publish(Event::A2aMessageSent {
|
||||||
|
from: *from,
|
||||||
|
to: format!("{}", to),
|
||||||
|
message_type: "direct".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Broadcast a message from one agent to all other agents
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
pub async fn a2a_broadcast(
|
||||||
|
&self,
|
||||||
|
from: &AgentId,
|
||||||
|
payload: serde_json::Value,
|
||||||
|
) -> Result<()> {
|
||||||
|
// Validate sender exists
|
||||||
|
self.registry.get(from)
|
||||||
|
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
|
||||||
|
format!("Sender agent not found: {}", from)
|
||||||
|
))?;
|
||||||
|
|
||||||
|
let envelope = A2aEnvelope::new(
|
||||||
|
*from,
|
||||||
|
A2aRecipient::Broadcast,
|
||||||
|
A2aMessageType::Notification,
|
||||||
|
payload,
|
||||||
|
);
|
||||||
|
|
||||||
|
self.a2a_router.route(envelope).await?;
|
||||||
|
|
||||||
|
self.events.publish(Event::A2aMessageSent {
|
||||||
|
from: *from,
|
||||||
|
to: "broadcast".to_string(),
|
||||||
|
message_type: "broadcast".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Discover agents that have a specific capability
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
pub async fn a2a_discover(&self, capability: &str) -> Result<Vec<A2aAgentProfile>> {
|
||||||
|
let result = self.a2a_router.discover(capability).await?;
|
||||||
|
|
||||||
|
self.events.publish(Event::A2aAgentDiscovered {
|
||||||
|
agent_id: AgentId::new(),
|
||||||
|
capabilities: vec![capability.to_string()],
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to receive a pending A2A message for an agent (non-blocking)
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
pub async fn a2a_receive(&self, agent_id: &AgentId) -> Result<Option<A2aEnvelope>> {
|
||||||
|
let inbox = self.a2a_inboxes.get(agent_id)
|
||||||
|
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
|
||||||
|
format!("No A2A inbox for agent: {}", agent_id)
|
||||||
|
))?;
|
||||||
|
|
||||||
|
let mut rx = inbox.lock().await;
|
||||||
|
match rx.try_recv() {
|
||||||
|
Ok(envelope) => {
|
||||||
|
self.events.publish(Event::A2aMessageReceived {
|
||||||
|
from: envelope.from,
|
||||||
|
to: format!("{}", agent_id),
|
||||||
|
message_type: "direct".to_string(),
|
||||||
|
});
|
||||||
|
Ok(Some(envelope))
|
||||||
|
}
|
||||||
|
Err(_) => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delegate a task to another agent and wait for response with timeout
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
pub async fn a2a_delegate_task(
|
||||||
|
&self,
|
||||||
|
from: &AgentId,
|
||||||
|
to: &AgentId,
|
||||||
|
task_description: String,
|
||||||
|
timeout_ms: u64,
|
||||||
|
) -> Result<serde_json::Value> {
|
||||||
|
// Validate both agents exist
|
||||||
|
self.registry.get(from)
|
||||||
|
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
|
||||||
|
format!("Sender agent not found: {}", from)
|
||||||
|
))?;
|
||||||
|
self.registry.get(to)
|
||||||
|
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
|
||||||
|
format!("Target agent not found: {}", to)
|
||||||
|
))?;
|
||||||
|
|
||||||
|
// Check capability permission
|
||||||
|
self.check_a2a_permission(from, to)?;
|
||||||
|
|
||||||
|
// Send task request
|
||||||
|
let task_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let envelope = A2aEnvelope::new(
|
||||||
|
*from,
|
||||||
|
A2aRecipient::Direct { agent_id: *to },
|
||||||
|
A2aMessageType::Task,
|
||||||
|
serde_json::json!({
|
||||||
|
"task_id": task_id,
|
||||||
|
"description": task_description,
|
||||||
|
}),
|
||||||
|
).with_conversation(task_id.clone());
|
||||||
|
|
||||||
|
let envelope_id = envelope.id.clone();
|
||||||
|
self.a2a_router.route(envelope).await?;
|
||||||
|
|
||||||
|
self.events.publish(Event::A2aMessageSent {
|
||||||
|
from: *from,
|
||||||
|
to: format!("{}", to),
|
||||||
|
message_type: "task".to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait for response with timeout
|
||||||
|
let timeout = tokio::time::Duration::from_millis(timeout_ms);
|
||||||
|
let result = tokio::time::timeout(timeout, async {
|
||||||
|
let inbox = self.a2a_inboxes.get(from)
|
||||||
|
.ok_or_else(|| zclaw_types::ZclawError::NotFound(
|
||||||
|
format!("No A2A inbox for agent: {}", from)
|
||||||
|
))?;
|
||||||
|
let mut rx = inbox.lock().await;
|
||||||
|
|
||||||
|
// Poll for matching response
|
||||||
|
loop {
|
||||||
|
match rx.recv().await {
|
||||||
|
Some(msg) => {
|
||||||
|
// Check if this is a response to our task
|
||||||
|
if msg.message_type == A2aMessageType::Response
|
||||||
|
&& msg.reply_to.as_deref() == Some(&envelope_id) {
|
||||||
|
return Ok::<_, zclaw_types::ZclawError>(msg.payload);
|
||||||
|
}
|
||||||
|
// Not our response — put it back by logging it (would need a re-queue mechanism for production)
|
||||||
|
tracing::warn!("Received non-matching A2A response, discarding: {}", msg.id);
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
return Err(zclaw_types::ZclawError::Internal(
|
||||||
|
"A2A inbox channel closed".to_string()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}).await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(Ok(payload)) => Ok(payload),
|
||||||
|
Ok(Err(e)) => Err(e),
|
||||||
|
Err(_) => Err(zclaw_types::ZclawError::Timeout(
|
||||||
|
format!("A2A task delegation timed out after {}ms", timeout_ms)
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all online agents via A2A profiles
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
pub async fn a2a_get_online_agents(&self) -> Result<Vec<A2aAgentProfile>> {
|
||||||
|
Ok(self.a2a_router.list_profiles().await)
|
||||||
|
}
|
||||||
|
}
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ApprovalEntry {
|
pub struct ApprovalEntry {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
@@ -633,6 +1164,7 @@ pub struct ApprovalEntry {
|
|||||||
pub status: String,
|
pub status: String,
|
||||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||||
pub input: serde_json::Value,
|
pub input: serde_json::Value,
|
||||||
|
pub reject_reason: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Response from sending a message
|
/// Response from sending a message
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ mod capabilities;
|
|||||||
mod events;
|
mod events;
|
||||||
pub mod trigger_manager;
|
pub mod trigger_manager;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod scheduler;
|
||||||
#[cfg(feature = "multi-agent")]
|
#[cfg(feature = "multi-agent")]
|
||||||
pub mod director;
|
pub mod director;
|
||||||
pub mod generation;
|
pub mod generation;
|
||||||
@@ -21,8 +22,16 @@ pub use config::*;
|
|||||||
pub use trigger_manager::{TriggerManager, TriggerEntry, TriggerUpdateRequest, TriggerManagerConfig};
|
pub use trigger_manager::{TriggerManager, TriggerEntry, TriggerUpdateRequest, TriggerManagerConfig};
|
||||||
#[cfg(feature = "multi-agent")]
|
#[cfg(feature = "multi-agent")]
|
||||||
pub use director::*;
|
pub use director::*;
|
||||||
|
#[cfg(feature = "multi-agent")]
|
||||||
|
pub use zclaw_protocols::{
|
||||||
|
A2aRouter, A2aAgentProfile, A2aCapability, A2aEnvelope, A2aMessageType, A2aRecipient,
|
||||||
|
A2aReceiver,
|
||||||
|
BasicA2aClient,
|
||||||
|
A2aClient,
|
||||||
|
};
|
||||||
pub use generation::*;
|
pub use generation::*;
|
||||||
pub use export::{ExportFormat, ExportOptions, ExportResult, Exporter, export_classroom};
|
pub use export::{ExportFormat, ExportOptions, ExportResult, Exporter, export_classroom};
|
||||||
|
|
||||||
// Re-export hands types for convenience
|
// Re-export hands types for convenience
|
||||||
pub use zclaw_hands::{HandRegistry, HandContext, HandResult, HandConfig, Hand, HandStatus};
|
pub use zclaw_hands::{HandRegistry, HandContext, HandResult, HandConfig, Hand, HandStatus};
|
||||||
|
pub use scheduler::SchedulerService;
|
||||||
|
|||||||
341
crates/zclaw-kernel/src/scheduler.rs
Normal file
341
crates/zclaw-kernel/src/scheduler.rs
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
//! Scheduler service for automatic trigger execution
|
||||||
|
//!
|
||||||
|
//! Periodically scans scheduled triggers and fires them at the appropriate time.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use chrono::{Datelike, Timelike};
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tokio::time::{self, Duration};
|
||||||
|
use zclaw_types::Result;
|
||||||
|
use crate::Kernel;
|
||||||
|
|
||||||
|
/// Scheduler service that runs in the background and executes scheduled triggers
|
||||||
|
pub struct SchedulerService {
|
||||||
|
kernel: Arc<RwLock<Option<Kernel>>>,
|
||||||
|
running: Arc<AtomicBool>,
|
||||||
|
check_interval: Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SchedulerService {
|
||||||
|
/// Create a new scheduler service
|
||||||
|
pub fn new(kernel: Arc<RwLock<Option<Kernel>>>, check_interval_secs: u64) -> Self {
|
||||||
|
Self {
|
||||||
|
kernel,
|
||||||
|
running: Arc::new(AtomicBool::new(false)),
|
||||||
|
check_interval: Duration::from_secs(check_interval_secs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Start the scheduler loop in the background
|
||||||
|
pub fn start(&self) {
|
||||||
|
if self.running.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() {
|
||||||
|
tracing::warn!("[Scheduler] Already running, ignoring start request");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let kernel = self.kernel.clone();
|
||||||
|
let running = self.running.clone();
|
||||||
|
let interval = self.check_interval;
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tracing::info!("[Scheduler] Starting scheduler loop with {}s interval", interval.as_secs());
|
||||||
|
|
||||||
|
let mut ticker = time::interval(interval);
|
||||||
|
// First tick fires immediately — skip it
|
||||||
|
ticker.tick().await;
|
||||||
|
|
||||||
|
while running.load(Ordering::Relaxed) {
|
||||||
|
ticker.tick().await;
|
||||||
|
|
||||||
|
if !running.load(Ordering::Relaxed) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = Self::check_and_fire_scheduled_triggers(&kernel).await {
|
||||||
|
tracing::error!("[Scheduler] Error checking triggers: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("[Scheduler] Scheduler loop stopped");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Stop the scheduler loop
|
||||||
|
pub fn stop(&self) {
|
||||||
|
self.running.store(false, Ordering::Relaxed);
|
||||||
|
tracing::info!("[Scheduler] Stop requested");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the scheduler is running
|
||||||
|
pub fn is_running(&self) -> bool {
|
||||||
|
self.running.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check all scheduled triggers and fire those that are due
|
||||||
|
async fn check_and_fire_scheduled_triggers(
|
||||||
|
kernel_lock: &Arc<RwLock<Option<Kernel>>>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let kernel_read = kernel_lock.read().await;
|
||||||
|
let kernel = match kernel_read.as_ref() {
|
||||||
|
Some(k) => k,
|
||||||
|
None => return Ok(()),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get all triggers
|
||||||
|
let triggers = kernel.list_triggers().await;
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
|
||||||
|
// Filter to enabled Schedule triggers
|
||||||
|
let scheduled: Vec<_> = triggers.iter()
|
||||||
|
.filter(|t| {
|
||||||
|
t.config.enabled && matches!(t.config.trigger_type, zclaw_hands::TriggerType::Schedule { .. })
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if scheduled.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::debug!("[Scheduler] Checking {} scheduled triggers", scheduled.len());
|
||||||
|
|
||||||
|
// Drop the read lock before executing
|
||||||
|
let to_execute: Vec<(String, String, String)> = scheduled.iter()
|
||||||
|
.filter_map(|t| {
|
||||||
|
if let zclaw_hands::TriggerType::Schedule { ref cron } = t.config.trigger_type {
|
||||||
|
// Simple cron matching: check if we should fire now
|
||||||
|
if Self::should_fire_cron(cron, &now) {
|
||||||
|
Some((t.config.id.clone(), t.config.hand_id.clone(), cron.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
drop(kernel_read);
|
||||||
|
|
||||||
|
// Execute due triggers (with write lock since execute_hand may need it)
|
||||||
|
for (trigger_id, hand_id, cron_expr) in to_execute {
|
||||||
|
tracing::info!(
|
||||||
|
"[Scheduler] Firing scheduled trigger '{}' → hand '{}' (cron: {})",
|
||||||
|
trigger_id, hand_id, cron_expr
|
||||||
|
);
|
||||||
|
|
||||||
|
let kernel_read = kernel_lock.read().await;
|
||||||
|
if let Some(kernel) = kernel_read.as_ref() {
|
||||||
|
let trigger_source = zclaw_types::TriggerSource::Scheduled {
|
||||||
|
trigger_id: trigger_id.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let input = serde_json::json!({
|
||||||
|
"trigger_id": trigger_id,
|
||||||
|
"trigger_type": "schedule",
|
||||||
|
"cron": cron_expr,
|
||||||
|
"fired_at": now.to_rfc3339(),
|
||||||
|
});
|
||||||
|
|
||||||
|
match kernel.execute_hand_with_source(&hand_id, input, trigger_source).await {
|
||||||
|
Ok((_result, run_id)) => {
|
||||||
|
tracing::info!(
|
||||||
|
"[Scheduler] Successfully fired trigger '{}' → run {}",
|
||||||
|
trigger_id, run_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(
|
||||||
|
"[Scheduler] Failed to execute trigger '{}': {}",
|
||||||
|
trigger_id, e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Simple cron expression matcher
|
||||||
|
///
|
||||||
|
/// Supports basic cron format: `minute hour day month weekday`
|
||||||
|
/// Also supports interval shorthand: `every:Ns`, `every:Nm`, `every:Nh`
|
||||||
|
fn should_fire_cron(cron: &str, now: &chrono::DateTime<chrono::Utc>) -> bool {
|
||||||
|
let cron = cron.trim();
|
||||||
|
|
||||||
|
// Handle interval shorthand: "every:30s", "every:5m", "every:1h"
|
||||||
|
if let Some(interval_str) = cron.strip_prefix("every:") {
|
||||||
|
return Self::check_interval_shorthand(interval_str, now);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle ISO timestamp for one-shot: "2026-03-29T10:00:00Z"
|
||||||
|
if cron.contains('T') && cron.contains('-') {
|
||||||
|
if let Ok(target) = chrono::DateTime::parse_from_rfc3339(cron) {
|
||||||
|
let target_utc = target.with_timezone(&chrono::Utc);
|
||||||
|
// Fire if within the check window (± check_interval/2, approx 30s)
|
||||||
|
let diff = (*now - target_utc).num_seconds().abs();
|
||||||
|
return diff <= 30;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard 5-field cron: minute hour day_of_month month day_of_week
|
||||||
|
let parts: Vec<&str> = cron.split_whitespace().collect();
|
||||||
|
if parts.len() != 5 {
|
||||||
|
tracing::warn!("[Scheduler] Invalid cron expression (expected 5 fields): '{}'", cron);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let minute = now.minute() as i32;
|
||||||
|
let hour = now.hour() as i32;
|
||||||
|
let day = now.day() as i32;
|
||||||
|
let month = now.month() as i32;
|
||||||
|
let weekday = now.weekday().num_days_from_monday() as i32; // Mon=0..Sun=6
|
||||||
|
|
||||||
|
Self::cron_field_matches(parts[0], minute)
|
||||||
|
&& Self::cron_field_matches(parts[1], hour)
|
||||||
|
&& Self::cron_field_matches(parts[2], day)
|
||||||
|
&& Self::cron_field_matches(parts[3], month)
|
||||||
|
&& Self::cron_field_matches(parts[4], weekday)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a single cron field matches the current value
|
||||||
|
fn cron_field_matches(field: &str, value: i32) -> bool {
|
||||||
|
if field == "*" || field == "?" {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle step: */N
|
||||||
|
if let Some(step_str) = field.strip_prefix("*/") {
|
||||||
|
if let Ok(step) = step_str.parse::<i32>() {
|
||||||
|
if step > 0 {
|
||||||
|
return value % step == 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle range: N-M
|
||||||
|
if field.contains('-') {
|
||||||
|
let range_parts: Vec<&str> = field.split('-').collect();
|
||||||
|
if range_parts.len() == 2 {
|
||||||
|
if let (Ok(start), Ok(end)) = (range_parts[0].parse::<i32>(), range_parts[1].parse::<i32>()) {
|
||||||
|
return value >= start && value <= end;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle list: N,M,O
|
||||||
|
if field.contains(',') {
|
||||||
|
return field.split(',').any(|part| {
|
||||||
|
part.trim().parse::<i32>().map(|p| p == value).unwrap_or(false)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simple value
|
||||||
|
field.parse::<i32>().map(|p| p == value).unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check interval shorthand expressions
|
||||||
|
fn check_interval_shorthand(interval: &str, now: &chrono::DateTime<chrono::Utc>) -> bool {
|
||||||
|
let (num_str, unit) = if interval.ends_with('s') {
|
||||||
|
(&interval[..interval.len()-1], 's')
|
||||||
|
} else if interval.ends_with('m') {
|
||||||
|
(&interval[..interval.len()-1], 'm')
|
||||||
|
} else if interval.ends_with('h') {
|
||||||
|
(&interval[..interval.len()-1], 'h')
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
|
let num: i64 = match num_str.parse() {
|
||||||
|
Ok(n) => n,
|
||||||
|
Err(_) => return false,
|
||||||
|
};
|
||||||
|
|
||||||
|
if num <= 0 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
let interval_secs = match unit {
|
||||||
|
's' => num,
|
||||||
|
'm' => num * 60,
|
||||||
|
'h' => num * 3600,
|
||||||
|
_ => return false,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if current timestamp aligns with the interval
|
||||||
|
let timestamp = now.timestamp();
|
||||||
|
timestamp % interval_secs == 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use chrono::Timelike;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cron_field_wildcard() {
|
||||||
|
assert!(SchedulerService::cron_field_matches("*", 5));
|
||||||
|
assert!(SchedulerService::cron_field_matches("?", 5));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cron_field_exact() {
|
||||||
|
assert!(SchedulerService::cron_field_matches("5", 5));
|
||||||
|
assert!(!SchedulerService::cron_field_matches("5", 6));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cron_field_step() {
|
||||||
|
assert!(SchedulerService::cron_field_matches("*/5", 0));
|
||||||
|
assert!(SchedulerService::cron_field_matches("*/5", 5));
|
||||||
|
assert!(SchedulerService::cron_field_matches("*/5", 10));
|
||||||
|
assert!(!SchedulerService::cron_field_matches("*/5", 3));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cron_field_range() {
|
||||||
|
assert!(SchedulerService::cron_field_matches("1-5", 1));
|
||||||
|
assert!(SchedulerService::cron_field_matches("1-5", 3));
|
||||||
|
assert!(SchedulerService::cron_field_matches("1-5", 5));
|
||||||
|
assert!(!SchedulerService::cron_field_matches("1-5", 0));
|
||||||
|
assert!(!SchedulerService::cron_field_matches("1-5", 6));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cron_field_list() {
|
||||||
|
assert!(SchedulerService::cron_field_matches("1,3,5", 1));
|
||||||
|
assert!(SchedulerService::cron_field_matches("1,3,5", 3));
|
||||||
|
assert!(SchedulerService::cron_field_matches("1,3,5", 5));
|
||||||
|
assert!(!SchedulerService::cron_field_matches("1,3,5", 2));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_fire_every_minute() {
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
assert!(SchedulerService::should_fire_cron("every:1m", &now));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_fire_cron_wildcard() {
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
// Every minute match
|
||||||
|
assert!(SchedulerService::should_fire_cron(
|
||||||
|
&format!("{} * * * *", now.minute()),
|
||||||
|
&now,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_should_not_fire_cron() {
|
||||||
|
let now = chrono::Utc::now();
|
||||||
|
let wrong_minute = if now.minute() < 59 { now.minute() + 1 } else { 0 };
|
||||||
|
assert!(!SchedulerService::should_fire_cron(
|
||||||
|
&format!("{} * * * *", wrong_minute),
|
||||||
|
&now,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -49,8 +49,26 @@ CREATE TABLE IF NOT EXISTS schema_version (
|
|||||||
version INTEGER PRIMARY KEY
|
version INTEGER PRIMARY KEY
|
||||||
);
|
);
|
||||||
|
|
||||||
|
-- Hand execution runs table
|
||||||
|
CREATE TABLE IF NOT EXISTS hand_runs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
hand_name TEXT NOT NULL,
|
||||||
|
trigger_source TEXT NOT NULL,
|
||||||
|
params TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'pending',
|
||||||
|
result TEXT,
|
||||||
|
error TEXT,
|
||||||
|
duration_ms INTEGER,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
started_at TEXT,
|
||||||
|
completed_at TEXT
|
||||||
|
);
|
||||||
|
|
||||||
-- Indexes
|
-- Indexes
|
||||||
CREATE INDEX IF NOT EXISTS idx_sessions_agent ON sessions(agent_id);
|
CREATE INDEX IF NOT EXISTS idx_sessions_agent ON sessions(agent_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id);
|
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_kv_agent ON kv_store(agent_id);
|
CREATE INDEX IF NOT EXISTS idx_kv_agent ON kv_store(agent_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_hand_runs_hand ON hand_runs(hand_name);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_hand_runs_status ON hand_runs(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_hand_runs_created ON hand_runs(created_at);
|
||||||
"#;
|
"#;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
//! Memory store implementation
|
//! Memory store implementation
|
||||||
|
|
||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use zclaw_types::{AgentConfig, AgentId, SessionId, Message, Result, ZclawError};
|
use zclaw_types::{AgentConfig, AgentId, SessionId, Message, Result, ZclawError, HandRun, HandRunId, HandRunStatus, HandRunFilter};
|
||||||
|
|
||||||
/// Memory store for persisting ZCLAW data
|
/// Memory store for persisting ZCLAW data
|
||||||
pub struct MemoryStore {
|
pub struct MemoryStore {
|
||||||
@@ -283,6 +283,193 @@ impl MemoryStore {
|
|||||||
|
|
||||||
Ok(rows.into_iter().map(|(key,)| key).collect())
|
Ok(rows.into_iter().map(|(key,)| key).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === Hand Run Tracking ===
|
||||||
|
|
||||||
|
/// Save a new hand run record
|
||||||
|
pub async fn save_hand_run(&self, run: &HandRun) -> Result<()> {
|
||||||
|
let id = run.id.to_string();
|
||||||
|
let trigger_source = serde_json::to_string(&run.trigger_source)?;
|
||||||
|
let params = serde_json::to_string(&run.params)?;
|
||||||
|
let result = run.result.as_ref().map(|v| serde_json::to_string(v)).transpose()?;
|
||||||
|
let error = run.error.as_ref().map(|e| serde_json::to_string(e)).transpose()?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
INSERT INTO hand_runs (id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.bind(&run.hand_name)
|
||||||
|
.bind(&trigger_source)
|
||||||
|
.bind(¶ms)
|
||||||
|
.bind(run.status.to_string())
|
||||||
|
.bind(result.as_deref())
|
||||||
|
.bind(error.as_deref())
|
||||||
|
.bind(run.duration_ms.map(|d| d as i64))
|
||||||
|
.bind(&run.created_at)
|
||||||
|
.bind(run.started_at.as_deref())
|
||||||
|
.bind(run.completed_at.as_deref())
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update an existing hand run record
|
||||||
|
pub async fn update_hand_run(&self, run: &HandRun) -> Result<()> {
|
||||||
|
let id = run.id.to_string();
|
||||||
|
let trigger_source = serde_json::to_string(&run.trigger_source)?;
|
||||||
|
let params = serde_json::to_string(&run.params)?;
|
||||||
|
let result = run.result.as_ref().map(|v| serde_json::to_string(v)).transpose()?;
|
||||||
|
let error = run.error.as_ref().map(|e| serde_json::to_string(e)).transpose()?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
r#"
|
||||||
|
UPDATE hand_runs SET
|
||||||
|
hand_name = ?, trigger_source = ?, params = ?, status = ?,
|
||||||
|
result = ?, error = ?, duration_ms = ?,
|
||||||
|
started_at = ?, completed_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.bind(&run.hand_name)
|
||||||
|
.bind(&trigger_source)
|
||||||
|
.bind(¶ms)
|
||||||
|
.bind(run.status.to_string())
|
||||||
|
.bind(result.as_deref())
|
||||||
|
.bind(error.as_deref())
|
||||||
|
.bind(run.duration_ms.map(|d| d as i64))
|
||||||
|
.bind(run.started_at.as_deref())
|
||||||
|
.bind(run.completed_at.as_deref())
|
||||||
|
.bind(&id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a hand run by ID
|
||||||
|
pub async fn get_hand_run(&self, id: &HandRunId) -> Result<Option<HandRun>> {
|
||||||
|
let id_str = id.to_string();
|
||||||
|
|
||||||
|
let row = sqlx::query_as::<_, (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>)>(
|
||||||
|
"SELECT id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at FROM hand_runs WHERE id = ?"
|
||||||
|
)
|
||||||
|
.bind(&id_str)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||||
|
|
||||||
|
match row {
|
||||||
|
Some(r) => Ok(Some(Self::row_to_hand_run(r)?)),
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List hand runs with optional filter
|
||||||
|
pub async fn list_hand_runs(&self, filter: &HandRunFilter) -> Result<Vec<HandRun>> {
|
||||||
|
let mut query = String::from(
|
||||||
|
"SELECT id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at FROM hand_runs WHERE 1=1"
|
||||||
|
);
|
||||||
|
let mut bind_values: Vec<String> = Vec::new();
|
||||||
|
|
||||||
|
if let Some(ref hand_name) = filter.hand_name {
|
||||||
|
query.push_str(" AND hand_name = ?");
|
||||||
|
bind_values.push(hand_name.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref status) = filter.status {
|
||||||
|
query.push_str(" AND status = ?");
|
||||||
|
bind_values.push(status.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
query.push_str(" ORDER BY created_at DESC");
|
||||||
|
|
||||||
|
if let Some(limit) = filter.limit {
|
||||||
|
query.push_str(&format!(" LIMIT {}", limit));
|
||||||
|
}
|
||||||
|
if let Some(offset) = filter.offset {
|
||||||
|
query.push_str(&format!(" OFFSET {}", offset));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sql_query = sqlx::query_as::<_, (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>)>(&query);
|
||||||
|
|
||||||
|
for val in &bind_values {
|
||||||
|
sql_query = sql_query.bind(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
let rows = sql_query
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||||
|
|
||||||
|
rows.into_iter()
|
||||||
|
.map(|r| Self::row_to_hand_run(r))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Count hand runs matching filter
|
||||||
|
pub async fn count_hand_runs(&self, filter: &HandRunFilter) -> Result<u32> {
|
||||||
|
let mut query = String::from("SELECT COUNT(*) FROM hand_runs WHERE 1=1");
|
||||||
|
let mut bind_values: Vec<String> = Vec::new();
|
||||||
|
|
||||||
|
if let Some(ref hand_name) = filter.hand_name {
|
||||||
|
query.push_str(" AND hand_name = ?");
|
||||||
|
bind_values.push(hand_name.clone());
|
||||||
|
}
|
||||||
|
if let Some(ref status) = filter.status {
|
||||||
|
query.push_str(" AND status = ?");
|
||||||
|
bind_values.push(status.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut sql_query = sqlx::query_scalar::<_, i64>(&query);
|
||||||
|
for val in &bind_values {
|
||||||
|
sql_query = sql_query.bind(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
let count = sql_query
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||||
|
|
||||||
|
Ok(count as u32)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn row_to_hand_run(
|
||||||
|
row: (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>),
|
||||||
|
) -> Result<HandRun> {
|
||||||
|
let (id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at) = row;
|
||||||
|
|
||||||
|
let run_id: HandRunId = id.parse()
|
||||||
|
.map_err(|e| ZclawError::StorageError(format!("Invalid HandRunId: {}", e)))?;
|
||||||
|
let trigger: zclaw_types::TriggerSource = serde_json::from_str(&trigger_source)?;
|
||||||
|
let params_val: serde_json::Value = serde_json::from_str(¶ms)?;
|
||||||
|
let run_status: HandRunStatus = status.parse()
|
||||||
|
.map_err(|e| ZclawError::StorageError(e))?;
|
||||||
|
let result_val: Option<serde_json::Value> = result.map(|r| serde_json::from_str(&r)).transpose()?;
|
||||||
|
let error_val: Option<String> = error.as_ref()
|
||||||
|
.map(|e| serde_json::from_str::<String>(e))
|
||||||
|
.transpose()
|
||||||
|
.unwrap_or_else(|_| error.clone());
|
||||||
|
|
||||||
|
Ok(HandRun {
|
||||||
|
id: run_id,
|
||||||
|
hand_name,
|
||||||
|
trigger_source: trigger,
|
||||||
|
params: params_val,
|
||||||
|
status: run_status,
|
||||||
|
result: result_val,
|
||||||
|
error: error_val,
|
||||||
|
duration_ms: duration_ms.map(|d| d as u64),
|
||||||
|
created_at,
|
||||||
|
started_at,
|
||||||
|
completed_at,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -427,6 +427,28 @@ impl A2aRouter {
|
|||||||
pub fn agent_id(&self) -> &AgentId {
|
pub fn agent_id(&self) -> &AgentId {
|
||||||
&self.agent_id
|
&self.agent_id
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Discover agents that have a specific capability
|
||||||
|
pub async fn discover(&self, capability: &str) -> Result<Vec<A2aAgentProfile>> {
|
||||||
|
let cap_index = self.capability_index.read().await;
|
||||||
|
let profiles = self.profiles.read().await;
|
||||||
|
|
||||||
|
match cap_index.get(capability) {
|
||||||
|
Some(agent_ids) => {
|
||||||
|
let result: Vec<A2aAgentProfile> = agent_ids.iter()
|
||||||
|
.filter_map(|id| profiles.get(id).cloned())
|
||||||
|
.collect();
|
||||||
|
Ok(result)
|
||||||
|
}
|
||||||
|
None => Ok(Vec::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get all registered agent profiles
|
||||||
|
pub async fn list_profiles(&self) -> Vec<A2aAgentProfile> {
|
||||||
|
let profiles = self.profiles.read().await;
|
||||||
|
profiles.values().cloned().collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Basic A2A client implementation
|
/// Basic A2A client implementation
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
//! Optionally flushes old messages to the growth/memory system before discarding.
|
//! Optionally flushes old messages to the growth/memory system before discarding.
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::sync::atomic::{AtomicU64, Ordering};
|
||||||
use zclaw_types::{AgentId, Message, SessionId};
|
use zclaw_types::{AgentId, Message, SessionId};
|
||||||
|
|
||||||
use crate::driver::{CompletionRequest, ContentBlock, LlmDriver};
|
use crate::driver::{CompletionRequest, ContentBlock, LlmDriver};
|
||||||
@@ -40,9 +41,18 @@ pub fn estimate_tokens(text: &str) -> usize {
|
|||||||
{
|
{
|
||||||
// CJK ideographs — ~1.5 tokens
|
// CJK ideographs — ~1.5 tokens
|
||||||
tokens += 1.5;
|
tokens += 1.5;
|
||||||
|
} else if (0xAC00..=0xD7AF).contains(&code) || (0x1100..=0x11FF).contains(&code) {
|
||||||
|
// Korean Hangul syllables + Jamo — ~1.5 tokens
|
||||||
|
tokens += 1.5;
|
||||||
|
} else if (0x3040..=0x309F).contains(&code) || (0x30A0..=0x30FF).contains(&code) {
|
||||||
|
// Japanese Hiragana + Katakana — ~1.5 tokens
|
||||||
|
tokens += 1.5;
|
||||||
} else if (0x3000..=0x303F).contains(&code) || (0xFF00..=0xFFEF).contains(&code) {
|
} else if (0x3000..=0x303F).contains(&code) || (0xFF00..=0xFFEF).contains(&code) {
|
||||||
// CJK / fullwidth punctuation — ~1.0 token
|
// CJK / fullwidth punctuation — ~1.0 token
|
||||||
tokens += 1.0;
|
tokens += 1.0;
|
||||||
|
} else if (0x1F000..=0x1FAFF).contains(&code) || (0x2600..=0x27BF).contains(&code) {
|
||||||
|
// Emoji & Symbols — ~2.0 tokens
|
||||||
|
tokens += 2.0;
|
||||||
} else if char == ' ' || char == '\n' || char == '\t' {
|
} else if char == ' ' || char == '\n' || char == '\t' {
|
||||||
// whitespace
|
// whitespace
|
||||||
tokens += 0.25;
|
tokens += 0.25;
|
||||||
@@ -88,6 +98,54 @@ pub fn estimate_messages_tokens(messages: &[Message]) -> usize {
|
|||||||
total
|
total
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============================================================
|
||||||
|
// Calibration: adjust heuristic estimates using API feedback
|
||||||
|
// ============================================================
|
||||||
|
|
||||||
|
const F64_1_0_BITS: u64 = 4607182418800017408u64; // 1.0f64.to_bits()
|
||||||
|
|
||||||
|
/// Global calibration factor for token estimation (stored as f64 bits).
|
||||||
|
///
|
||||||
|
/// Updated via exponential moving average when API returns actual token counts.
|
||||||
|
/// Initial value is 1.0 (no adjustment).
|
||||||
|
static CALIBRATION_FACTOR_BITS: AtomicU64 = AtomicU64::new(F64_1_0_BITS);
|
||||||
|
|
||||||
|
/// Get the current calibration factor.
|
||||||
|
pub fn get_calibration_factor() -> f64 {
|
||||||
|
f64::from_bits(CALIBRATION_FACTOR_BITS.load(Ordering::Relaxed))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update calibration factor using exponential moving average.
|
||||||
|
///
|
||||||
|
/// Compares estimated tokens with actual tokens from API response:
|
||||||
|
/// - `ratio = actual / estimated` so underestimates push factor UP
|
||||||
|
/// - EMA: `new = current * 0.7 + ratio * 0.3`
|
||||||
|
/// - Clamped to [0.5, 2.0] to prevent runaway values
|
||||||
|
pub fn update_calibration(estimated: usize, actual: u32) {
|
||||||
|
if actual == 0 || estimated == 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let ratio = actual as f64 / estimated as f64;
|
||||||
|
let current = get_calibration_factor();
|
||||||
|
let new_factor = (current * 0.7 + ratio * 0.3).clamp(0.5, 2.0);
|
||||||
|
CALIBRATION_FACTOR_BITS.store(new_factor.to_bits(), Ordering::Relaxed);
|
||||||
|
tracing::debug!(
|
||||||
|
"[Compaction] Calibration: estimated={}, actual={}, ratio={:.2}, factor {:.2} → {:.2}",
|
||||||
|
estimated, actual, ratio, current, new_factor
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Estimate total tokens for messages with calibration applied.
|
||||||
|
fn estimate_messages_tokens_calibrated(messages: &[Message]) -> usize {
|
||||||
|
let raw = estimate_messages_tokens(messages);
|
||||||
|
let factor = get_calibration_factor();
|
||||||
|
if (factor - 1.0).abs() < f64::EPSILON {
|
||||||
|
raw
|
||||||
|
} else {
|
||||||
|
((raw as f64 * factor).ceil()) as usize
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Compact a message list by summarizing old messages and keeping recent ones.
|
/// Compact a message list by summarizing old messages and keeping recent ones.
|
||||||
///
|
///
|
||||||
/// When `messages.len() > keep_recent`, the oldest messages are summarized
|
/// When `messages.len() > keep_recent`, the oldest messages are summarized
|
||||||
@@ -134,7 +192,7 @@ pub fn compact_messages(messages: Vec<Message>, keep_recent: usize) -> (Vec<Mess
|
|||||||
///
|
///
|
||||||
/// Returns the (possibly compacted) message list.
|
/// Returns the (possibly compacted) message list.
|
||||||
pub fn maybe_compact(messages: Vec<Message>, threshold: usize) -> Vec<Message> {
|
pub fn maybe_compact(messages: Vec<Message>, threshold: usize) -> Vec<Message> {
|
||||||
let tokens = estimate_messages_tokens(&messages);
|
let tokens = estimate_messages_tokens_calibrated(&messages);
|
||||||
if tokens < threshold {
|
if tokens < threshold {
|
||||||
return messages;
|
return messages;
|
||||||
}
|
}
|
||||||
@@ -208,7 +266,7 @@ pub async fn maybe_compact_with_config(
|
|||||||
driver: Option<&Arc<dyn LlmDriver>>,
|
driver: Option<&Arc<dyn LlmDriver>>,
|
||||||
growth: Option<&GrowthIntegration>,
|
growth: Option<&GrowthIntegration>,
|
||||||
) -> CompactionOutcome {
|
) -> CompactionOutcome {
|
||||||
let tokens = estimate_messages_tokens(&messages);
|
let tokens = estimate_messages_tokens_calibrated(&messages);
|
||||||
if tokens < threshold {
|
if tokens < threshold {
|
||||||
return CompactionOutcome {
|
return CompactionOutcome {
|
||||||
messages,
|
messages,
|
||||||
@@ -475,10 +533,11 @@ fn generate_summary(messages: &[Message]) -> String {
|
|||||||
|
|
||||||
let summary = sections.join("\n");
|
let summary = sections.join("\n");
|
||||||
|
|
||||||
// Enforce max length
|
// Enforce max length (char-safe for CJK)
|
||||||
let max_chars = 800;
|
let max_chars = 800;
|
||||||
if summary.len() > max_chars {
|
if summary.chars().count() > max_chars {
|
||||||
format!("{}...\n(摘要已截断)", &summary[..max_chars])
|
let truncated: String = summary.chars().take(max_chars).collect();
|
||||||
|
format!("{}...\n(摘要已截断)", truncated)
|
||||||
} else {
|
} else {
|
||||||
summary
|
summary
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -130,7 +130,8 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
let api_key = self.api_key.expose_secret().to_string();
|
let api_key = self.api_key.expose_secret().to_string();
|
||||||
|
|
||||||
Box::pin(stream! {
|
Box::pin(stream! {
|
||||||
tracing::debug!("[OpenAiDriver:stream] Starting HTTP request...");
|
println!("[OpenAI:stream] POST to {}/chat/completions", base_url);
|
||||||
|
println!("[OpenAI:stream] Request model={}, stream={}", stream_request.model, stream_request.stream);
|
||||||
let response = match self.client
|
let response = match self.client
|
||||||
.post(format!("{}/chat/completions", base_url))
|
.post(format!("{}/chat/completions", base_url))
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
@@ -141,11 +142,11 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(r) => {
|
Ok(r) => {
|
||||||
tracing::debug!("[OpenAiDriver:stream] Got response, status: {}", r.status());
|
println!("[OpenAI:stream] Response status: {}, content-type: {:?}", r.status(), r.headers().get("content-type"));
|
||||||
r
|
r
|
||||||
},
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("[OpenAiDriver:stream] HTTP request failed: {:?}", e);
|
println!("[OpenAI:stream] HTTP request FAILED: {:?}", e);
|
||||||
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
|
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -154,6 +155,7 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let body = response.text().await.unwrap_or_default();
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
println!("[OpenAI:stream] API error {}: {}", status, &body[..body.len().min(500)]);
|
||||||
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -161,21 +163,45 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
let mut byte_stream = response.bytes_stream();
|
let mut byte_stream = response.bytes_stream();
|
||||||
let mut accumulated_tool_calls: std::collections::HashMap<String, (String, String)> = std::collections::HashMap::new();
|
let mut accumulated_tool_calls: std::collections::HashMap<String, (String, String)> = std::collections::HashMap::new();
|
||||||
let mut current_tool_id: Option<String> = None;
|
let mut current_tool_id: Option<String> = None;
|
||||||
|
let mut sse_event_count: usize = 0;
|
||||||
|
let mut raw_bytes_total: usize = 0;
|
||||||
|
|
||||||
while let Some(chunk_result) = byte_stream.next().await {
|
while let Some(chunk_result) = byte_stream.next().await {
|
||||||
let chunk = match chunk_result {
|
let chunk = match chunk_result {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
|
println!("[OpenAI:stream] Byte stream error: {:?}", e);
|
||||||
yield Err(ZclawError::LlmError(format!("Stream error: {}", e)));
|
yield Err(ZclawError::LlmError(format!("Stream error: {}", e)));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
raw_bytes_total += chunk.len();
|
||||||
let text = String::from_utf8_lossy(&chunk);
|
let text = String::from_utf8_lossy(&chunk);
|
||||||
|
// Log first 500 bytes of raw data for debugging SSE format
|
||||||
|
if raw_bytes_total <= 600 {
|
||||||
|
println!("[OpenAI:stream] RAW chunk ({} bytes): {:?}", text.len(), &text[..text.len().min(500)]);
|
||||||
|
}
|
||||||
for line in text.lines() {
|
for line in text.lines() {
|
||||||
if let Some(data) = line.strip_prefix("data: ") {
|
let trimmed = line.trim();
|
||||||
|
if trimmed.is_empty() || trimmed.starts_with(':') {
|
||||||
|
continue; // Skip empty lines and SSE comments
|
||||||
|
}
|
||||||
|
// Handle both "data: " (standard) and "data:" (no space)
|
||||||
|
let data = if let Some(d) = trimmed.strip_prefix("data: ") {
|
||||||
|
Some(d)
|
||||||
|
} else if let Some(d) = trimmed.strip_prefix("data:") {
|
||||||
|
Some(d.trim_start())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
if let Some(data) = data {
|
||||||
|
sse_event_count += 1;
|
||||||
|
if sse_event_count <= 3 || data == "[DONE]" {
|
||||||
|
println!("[OpenAI:stream] SSE #{}: {}", sse_event_count, &data[..data.len().min(300)]);
|
||||||
|
}
|
||||||
if data == "[DONE]" {
|
if data == "[DONE]" {
|
||||||
tracing::debug!("[OpenAI] Stream done, accumulated_tool_calls: {:?}", accumulated_tool_calls.len());
|
println!("[OpenAI:stream] Received [DONE], total SSE events: {}, raw bytes: {}", sse_event_count, raw_bytes_total);
|
||||||
|
|
||||||
// Emit ToolUseEnd for all accumulated tool calls (skip invalid ones with empty name)
|
// Emit ToolUseEnd for all accumulated tool calls (skip invalid ones with empty name)
|
||||||
for (id, (name, args)) in &accumulated_tool_calls {
|
for (id, (name, args)) in &accumulated_tool_calls {
|
||||||
@@ -216,10 +242,19 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
// Handle text content
|
// Handle text content
|
||||||
if let Some(content) = &delta.content {
|
if let Some(content) = &delta.content {
|
||||||
if !content.is_empty() {
|
if !content.is_empty() {
|
||||||
|
tracing::debug!("[OpenAI:stream] TextDelta: {} chars", content.len());
|
||||||
yield Ok(StreamChunk::TextDelta { delta: content.clone() });
|
yield Ok(StreamChunk::TextDelta { delta: content.clone() });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle reasoning_content (Kimi, Qwen, DeepSeek, GLM thinking)
|
||||||
|
if let Some(reasoning) = &delta.reasoning_content {
|
||||||
|
if !reasoning.is_empty() {
|
||||||
|
tracing::debug!("[OpenAI:stream] ThinkingDelta (reasoning_content): {} chars", reasoning.len());
|
||||||
|
yield Ok(StreamChunk::ThinkingDelta { delta: reasoning.clone() });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Handle tool calls
|
// Handle tool calls
|
||||||
if let Some(tool_calls) = &delta.tool_calls {
|
if let Some(tool_calls) = &delta.tool_calls {
|
||||||
tracing::trace!("[OpenAI] Received tool_calls delta: {:?}", tool_calls);
|
tracing::trace!("[OpenAI] Received tool_calls delta: {:?}", tool_calls);
|
||||||
@@ -284,6 +319,7 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
println!("[OpenAI:stream] Byte stream ended. Total: {} SSE events, {} raw bytes", sse_event_count, raw_bytes_total);
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -304,55 +340,122 @@ impl OpenAiDriver {
|
|||||||
request.system.clone()
|
request.system.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
let messages: Vec<OpenAiMessage> = request.messages
|
// Build messages with tool result truncation to prevent payload overflow.
|
||||||
.iter()
|
// Most LLM APIs have a 2-4MB HTTP payload limit.
|
||||||
.filter_map(|msg| match msg {
|
const MAX_TOOL_RESULT_BYTES: usize = 32_768; // 32KB per tool result
|
||||||
zclaw_types::Message::User { content } => Some(OpenAiMessage {
|
const MAX_PAYLOAD_BYTES: usize = 1_800_000; // 1.8MB (under 2MB API limit)
|
||||||
|
|
||||||
|
let mut messages: Vec<OpenAiMessage> = Vec::new();
|
||||||
|
let mut pending_tool_calls: Option<Vec<OpenAiToolCall>> = None;
|
||||||
|
let mut pending_content: Option<String> = None;
|
||||||
|
let mut pending_reasoning: Option<String> = None;
|
||||||
|
|
||||||
|
let flush_pending = |tc: &mut Option<Vec<OpenAiToolCall>>,
|
||||||
|
c: &mut Option<String>,
|
||||||
|
r: &mut Option<String>,
|
||||||
|
out: &mut Vec<OpenAiMessage>| {
|
||||||
|
let calls = tc.take();
|
||||||
|
let content = c.take();
|
||||||
|
let reasoning = r.take();
|
||||||
|
|
||||||
|
if let Some(calls) = calls {
|
||||||
|
if !calls.is_empty() {
|
||||||
|
// Merge assistant content + reasoning into the tool call message
|
||||||
|
out.push(OpenAiMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: content.filter(|s| !s.is_empty()),
|
||||||
|
reasoning_content: reasoning.filter(|s| !s.is_empty()),
|
||||||
|
tool_calls: Some(calls),
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// No tool calls — emit a plain assistant message
|
||||||
|
if content.is_some() || reasoning.is_some() {
|
||||||
|
out.push(OpenAiMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: content.filter(|s| !s.is_empty()),
|
||||||
|
reasoning_content: reasoning.filter(|s| !s.is_empty()),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for msg in &request.messages {
|
||||||
|
match msg {
|
||||||
|
zclaw_types::Message::User { content } => {
|
||||||
|
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
||||||
|
messages.push(OpenAiMessage {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: Some(content.clone()),
|
content: Some(content.clone()),
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}),
|
tool_call_id: None,
|
||||||
zclaw_types::Message::Assistant { content, thinking: _ } => Some(OpenAiMessage {
|
reasoning_content: None,
|
||||||
role: "assistant".to_string(),
|
});
|
||||||
content: Some(content.clone()),
|
}
|
||||||
tool_calls: None,
|
zclaw_types::Message::Assistant { content, thinking } => {
|
||||||
}),
|
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
||||||
zclaw_types::Message::System { content } => Some(OpenAiMessage {
|
// Don't push immediately — wait to see if next messages are ToolUse
|
||||||
|
pending_content = Some(content.clone());
|
||||||
|
pending_reasoning = thinking.clone();
|
||||||
|
}
|
||||||
|
zclaw_types::Message::System { content } => {
|
||||||
|
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
||||||
|
messages.push(OpenAiMessage {
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: Some(content.clone()),
|
content: Some(content.clone()),
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}),
|
tool_call_id: None,
|
||||||
|
reasoning_content: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
zclaw_types::Message::ToolUse { id, tool, input } => {
|
zclaw_types::Message::ToolUse { id, tool, input } => {
|
||||||
// Ensure arguments is always a valid JSON object, never null or invalid
|
// Accumulate tool calls — they'll be merged with the pending assistant message
|
||||||
let args = if input.is_null() {
|
let args = if input.is_null() {
|
||||||
"{}".to_string()
|
"{}".to_string()
|
||||||
} else {
|
} else {
|
||||||
serde_json::to_string(input).unwrap_or_else(|_| "{}".to_string())
|
serde_json::to_string(input).unwrap_or_else(|_| "{}".to_string())
|
||||||
};
|
};
|
||||||
Some(OpenAiMessage {
|
pending_tool_calls
|
||||||
role: "assistant".to_string(),
|
.get_or_insert_with(Vec::new)
|
||||||
content: None,
|
.push(OpenAiToolCall {
|
||||||
tool_calls: Some(vec![OpenAiToolCall {
|
|
||||||
id: id.clone(),
|
id: id.clone(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: tool.to_string(),
|
name: tool.to_string(),
|
||||||
arguments: args,
|
arguments: args,
|
||||||
},
|
},
|
||||||
}]),
|
});
|
||||||
})
|
|
||||||
}
|
}
|
||||||
zclaw_types::Message::ToolResult { tool_call_id: _, output, is_error, .. } => Some(OpenAiMessage {
|
zclaw_types::Message::ToolResult { tool_call_id, output, is_error, .. } => {
|
||||||
role: "tool".to_string(),
|
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
||||||
content: Some(if *is_error {
|
let content_str = if *is_error {
|
||||||
format!("Error: {}", output)
|
format!("Error: {}", output)
|
||||||
} else {
|
} else {
|
||||||
output.to_string()
|
output.to_string()
|
||||||
}),
|
};
|
||||||
|
// Truncate oversized tool results to prevent payload overflow
|
||||||
|
let truncated = if content_str.len() > MAX_TOOL_RESULT_BYTES {
|
||||||
|
let mut s = String::from(&content_str[..MAX_TOOL_RESULT_BYTES]);
|
||||||
|
s.push_str("\n\n... [内容已截断,原文过大]");
|
||||||
|
s
|
||||||
|
} else {
|
||||||
|
content_str
|
||||||
|
};
|
||||||
|
messages.push(OpenAiMessage {
|
||||||
|
role: "tool".to_string(),
|
||||||
|
content: Some(truncated),
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}),
|
tool_call_id: Some(tool_call_id.clone()),
|
||||||
})
|
reasoning_content: None,
|
||||||
.collect();
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Flush any remaining accumulated assistant content and/or tool calls
|
||||||
|
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
||||||
|
|
||||||
// Add system prompt if provided
|
// Add system prompt if provided
|
||||||
let mut messages = messages;
|
let mut messages = messages;
|
||||||
@@ -361,6 +464,8 @@ impl OpenAiDriver {
|
|||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: Some(system.clone()),
|
content: Some(system.clone()),
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
reasoning_content: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -376,7 +481,7 @@ impl OpenAiDriver {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
OpenAiRequest {
|
let api_request = OpenAiRequest {
|
||||||
model: request.model.clone(), // Use model ID directly without any transformation
|
model: request.model.clone(), // Use model ID directly without any transformation
|
||||||
messages,
|
messages,
|
||||||
max_tokens: request.max_tokens,
|
max_tokens: request.max_tokens,
|
||||||
@@ -384,7 +489,75 @@ impl OpenAiDriver {
|
|||||||
stop: if request.stop.is_empty() { None } else { Some(request.stop.clone()) },
|
stop: if request.stop.is_empty() { None } else { Some(request.stop.clone()) },
|
||||||
stream: request.stream,
|
stream: request.stream,
|
||||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||||
|
};
|
||||||
|
|
||||||
|
// Pre-send payload size validation
|
||||||
|
if let Ok(serialized) = serde_json::to_string(&api_request) {
|
||||||
|
if serialized.len() > MAX_PAYLOAD_BYTES {
|
||||||
|
tracing::warn!(
|
||||||
|
target: "openai_driver",
|
||||||
|
"Request payload too large: {} bytes (limit: {}), truncating messages",
|
||||||
|
serialized.len(),
|
||||||
|
MAX_PAYLOAD_BYTES
|
||||||
|
);
|
||||||
|
return Self::truncate_messages_to_fit(api_request, MAX_PAYLOAD_BYTES);
|
||||||
}
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
target: "openai_driver",
|
||||||
|
"Request payload size: {} bytes (limit: {})",
|
||||||
|
serialized.len(),
|
||||||
|
MAX_PAYLOAD_BYTES
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
api_request
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emergency truncation: drop oldest non-system messages until payload fits
|
||||||
|
fn truncate_messages_to_fit(mut request: OpenAiRequest, _max_bytes: usize) -> OpenAiRequest {
|
||||||
|
// Keep system message (if any) and last 4 non-system messages
|
||||||
|
let has_system = request.messages.first()
|
||||||
|
.map(|m| m.role == "system")
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let non_system: Vec<OpenAiMessage> = request.messages.into_iter()
|
||||||
|
.filter(|m| m.role != "system")
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Keep last N messages and truncate any remaining large tool results
|
||||||
|
let keep_count = 4.min(non_system.len());
|
||||||
|
let start = non_system.len() - keep_count;
|
||||||
|
let kept: Vec<OpenAiMessage> = non_system.into_iter()
|
||||||
|
.skip(start)
|
||||||
|
.map(|mut msg| {
|
||||||
|
// Additional per-message truncation for tool results
|
||||||
|
if msg.role == "tool" {
|
||||||
|
if let Some(ref content) = msg.content {
|
||||||
|
if content.len() > 16_384 {
|
||||||
|
let mut s = String::from(&content[..16_384]);
|
||||||
|
s.push_str("\n\n... [上下文压缩截断]");
|
||||||
|
msg.content = Some(s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
msg
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
if has_system {
|
||||||
|
messages.push(OpenAiMessage {
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: Some("You are a helpful AI assistant. (注意:对话历史已被压缩以适应上下文大小限制)".to_string()),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
reasoning_content: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
messages.extend(kept);
|
||||||
|
|
||||||
|
request.messages = messages;
|
||||||
|
request
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_response(&self, api_response: OpenAiResponse, model: String) -> CompletionResponse {
|
fn convert_response(&self, api_response: OpenAiResponse, model: String) -> CompletionResponse {
|
||||||
@@ -398,6 +571,7 @@ impl OpenAiDriver {
|
|||||||
// This is important because some providers return empty content with tool_calls
|
// This is important because some providers return empty content with tool_calls
|
||||||
let has_tool_calls = c.message.tool_calls.as_ref().map(|tc| !tc.is_empty()).unwrap_or(false);
|
let has_tool_calls = c.message.tool_calls.as_ref().map(|tc| !tc.is_empty()).unwrap_or(false);
|
||||||
let has_content = c.message.content.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
|
let has_content = c.message.content.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
|
||||||
|
let has_reasoning = c.message.reasoning_content.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
|
||||||
|
|
||||||
let blocks = if has_tool_calls {
|
let blocks = if has_tool_calls {
|
||||||
// Tool calls take priority
|
// Tool calls take priority
|
||||||
@@ -413,6 +587,11 @@ impl OpenAiDriver {
|
|||||||
let text = c.message.content.as_ref().unwrap();
|
let text = c.message.content.as_ref().unwrap();
|
||||||
tracing::debug!("[OpenAiDriver:convert_response] Using text content: {} chars", text.len());
|
tracing::debug!("[OpenAiDriver:convert_response] Using text content: {} chars", text.len());
|
||||||
vec![ContentBlock::Text { text: text.clone() }]
|
vec![ContentBlock::Text { text: text.clone() }]
|
||||||
|
} else if has_reasoning {
|
||||||
|
// Content empty but reasoning_content present (Kimi, Qwen, DeepSeek)
|
||||||
|
let reasoning = c.message.reasoning_content.as_ref().unwrap();
|
||||||
|
tracing::debug!("[OpenAiDriver:convert_response] Using reasoning_content: {} chars", reasoning.len());
|
||||||
|
vec![ContentBlock::Text { text: reasoning.clone() }]
|
||||||
} else {
|
} else {
|
||||||
// No content or tool_calls
|
// No content or tool_calls
|
||||||
tracing::debug!("[OpenAiDriver:convert_response] No content or tool_calls, using empty text");
|
tracing::debug!("[OpenAiDriver:convert_response] No content or tool_calls, using empty text");
|
||||||
@@ -594,6 +773,10 @@ struct OpenAiMessage {
|
|||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
tool_calls: Option<Vec<OpenAiToolCall>>,
|
tool_calls: Option<Vec<OpenAiToolCall>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_call_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
@@ -656,6 +839,8 @@ struct OpenAiResponseMessage {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
tool_calls: Option<Vec<OpenAiToolCallResponse>>,
|
tool_calls: Option<Vec<OpenAiToolCallResponse>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -705,6 +890,8 @@ struct OpenAiDelta {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
|
reasoning_content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
|
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,22 +4,14 @@
|
|||||||
//! enabling automatic memory retrieval before conversations and memory extraction
|
//! enabling automatic memory retrieval before conversations and memory extraction
|
||||||
//! after conversations.
|
//! after conversations.
|
||||||
//!
|
//!
|
||||||
//! # Usage
|
//! **Note (2026-03-27 audit)**: In the Tauri desktop deployment, this module is
|
||||||
|
//! NOT wired into the Kernel. The intelligence_hooks module in desktop/src-tauri
|
||||||
|
//! provides the same functionality (memory retrieval, heartbeat, reflection) via
|
||||||
|
//! direct VikingStorage calls. GrowthIntegration remains available for future
|
||||||
|
//! use (e.g., headless/server deployments where intelligence_hooks is not available).
|
||||||
//!
|
//!
|
||||||
//! ```rust,ignore
|
//! The `AgentLoop.growth` field defaults to `None` and the code gracefully falls
|
||||||
//! use zclaw_runtime::growth::GrowthIntegration;
|
//! through to normal behavior when not set.
|
||||||
//! use zclaw_growth::{VikingAdapter, MemoryExtractor, MemoryRetriever, PromptInjector};
|
|
||||||
//!
|
|
||||||
//! // Create growth integration
|
|
||||||
//! let viking = Arc::new(VikingAdapter::in_memory());
|
|
||||||
//! let growth = GrowthIntegration::new(viking);
|
|
||||||
//!
|
|
||||||
//! // Before conversation: enhance system prompt
|
|
||||||
//! let enhanced_prompt = growth.enhance_prompt(&agent_id, &base_prompt, &user_input).await?;
|
|
||||||
//!
|
|
||||||
//! // After conversation: extract and store memories
|
|
||||||
//! growth.process_conversation(&agent_id, &messages, session_id).await?;
|
|
||||||
//! ```
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use zclaw_growth::{
|
use zclaw_growth::{
|
||||||
|
|||||||
@@ -3,8 +3,10 @@
|
|||||||
//! LLM drivers, tool system, and agent loop implementation.
|
//! LLM drivers, tool system, and agent loop implementation.
|
||||||
|
|
||||||
/// Default User-Agent header sent with all outgoing HTTP requests.
|
/// Default User-Agent header sent with all outgoing HTTP requests.
|
||||||
/// Some LLM providers (e.g. Moonshot, Qwen, DashScope Coding Plan) reject requests without one.
|
/// Coding Plan providers (Kimi, Bailian/DashScope, Zhipu) validate the User-Agent against a
|
||||||
pub const USER_AGENT: &str = "ZCLAW/0.1.0";
|
/// whitelist of known Coding Agents (e.g. claude-code, kimi-cli, roo-code, kilo-code).
|
||||||
|
/// Must use the exact lowercase format to pass validation.
|
||||||
|
pub const USER_AGENT: &str = "claude-code/0.1.0";
|
||||||
|
|
||||||
pub mod driver;
|
pub mod driver;
|
||||||
pub mod tool;
|
pub mod tool;
|
||||||
|
|||||||
@@ -131,12 +131,30 @@ impl AgentLoop {
|
|||||||
|
|
||||||
/// Create tool context for tool execution
|
/// Create tool context for tool execution
|
||||||
fn create_tool_context(&self, session_id: SessionId) -> ToolContext {
|
fn create_tool_context(&self, session_id: SessionId) -> ToolContext {
|
||||||
|
// If no path_validator is configured, create a default one with user home as workspace.
|
||||||
|
// This allows file_read/file_write tools to work without explicit workspace config,
|
||||||
|
// while still restricting access to the user's home directory for security.
|
||||||
|
let path_validator = self.path_validator.clone().unwrap_or_else(|| {
|
||||||
|
let home = std::env::var("USERPROFILE")
|
||||||
|
.or_else(|_| std::env::var("HOME"))
|
||||||
|
.unwrap_or_else(|_| ".".to_string());
|
||||||
|
let home_path = std::path::PathBuf::from(&home);
|
||||||
|
tracing::info!(
|
||||||
|
"[AgentLoop] No path_validator configured, using user home as workspace: {}",
|
||||||
|
home_path.display()
|
||||||
|
);
|
||||||
|
PathValidator::new().with_workspace(home_path)
|
||||||
|
});
|
||||||
|
|
||||||
|
let working_dir = path_validator.workspace_root()
|
||||||
|
.map(|p| p.to_string_lossy().to_string());
|
||||||
|
|
||||||
ToolContext {
|
ToolContext {
|
||||||
agent_id: self.agent_id.clone(),
|
agent_id: self.agent_id.clone(),
|
||||||
working_directory: None,
|
working_directory: working_dir,
|
||||||
session_id: Some(session_id.to_string()),
|
session_id: Some(session_id.to_string()),
|
||||||
skill_executor: self.skill_executor.clone(),
|
skill_executor: self.skill_executor.clone(),
|
||||||
path_validator: self.path_validator.clone(),
|
path_validator: Some(path_validator),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -222,6 +240,14 @@ impl AgentLoop {
|
|||||||
total_input_tokens += response.input_tokens;
|
total_input_tokens += response.input_tokens;
|
||||||
total_output_tokens += response.output_tokens;
|
total_output_tokens += response.output_tokens;
|
||||||
|
|
||||||
|
// Calibrate token estimation on first iteration
|
||||||
|
if iterations == 1 {
|
||||||
|
compaction::update_calibration(
|
||||||
|
compaction::estimate_messages_tokens(&messages),
|
||||||
|
response.input_tokens,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Extract tool calls from response
|
// Extract tool calls from response
|
||||||
let tool_calls: Vec<(String, String, serde_json::Value)> = response.content.iter()
|
let tool_calls: Vec<(String, String, serde_json::Value)> = response.content.iter()
|
||||||
.filter_map(|block| match block {
|
.filter_map(|block| match block {
|
||||||
@@ -230,30 +256,49 @@ impl AgentLoop {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// If no tool calls, we have the final response
|
// Extract text and thinking separately
|
||||||
if tool_calls.is_empty() {
|
let text_parts: Vec<String> = response.content.iter()
|
||||||
// Extract text content
|
|
||||||
let text = response.content.iter()
|
|
||||||
.filter_map(|block| match block {
|
.filter_map(|block| match block {
|
||||||
ContentBlock::Text { text } => Some(text.clone()),
|
ContentBlock::Text { text } => Some(text.clone()),
|
||||||
ContentBlock::Thinking { thinking } => Some(format!("[思考] {}", thinking)),
|
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>()
|
.collect();
|
||||||
.join("\n");
|
let thinking_parts: Vec<String> = response.content.iter()
|
||||||
|
.filter_map(|block| match block {
|
||||||
|
ContentBlock::Thinking { thinking } => Some(thinking.clone()),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let text_content = text_parts.join("\n");
|
||||||
|
let thinking_content = if thinking_parts.is_empty() { None } else { Some(thinking_parts.join("")) };
|
||||||
|
|
||||||
// Save final assistant message
|
// If no tool calls, we have the final response
|
||||||
self.memory.append_message(&session_id, &Message::assistant(&text)).await?;
|
if tool_calls.is_empty() {
|
||||||
|
// Save final assistant message with thinking
|
||||||
|
let msg = if let Some(thinking) = &thinking_content {
|
||||||
|
Message::assistant_with_thinking(&text_content, thinking)
|
||||||
|
} else {
|
||||||
|
Message::assistant(&text_content)
|
||||||
|
};
|
||||||
|
self.memory.append_message(&session_id, &msg).await?;
|
||||||
|
|
||||||
break AgentLoopResult {
|
break AgentLoopResult {
|
||||||
response: text,
|
response: text_content,
|
||||||
input_tokens: total_input_tokens,
|
input_tokens: total_input_tokens,
|
||||||
output_tokens: total_output_tokens,
|
output_tokens: total_output_tokens,
|
||||||
iterations,
|
iterations,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// There are tool calls - add assistant message with tool calls to history
|
// There are tool calls - push assistant message with thinking before tool calls
|
||||||
|
// (required by Kimi and other thinking-enabled APIs)
|
||||||
|
let assistant_msg = if let Some(thinking) = &thinking_content {
|
||||||
|
Message::assistant_with_thinking(&text_content, thinking)
|
||||||
|
} else {
|
||||||
|
Message::assistant(&text_content)
|
||||||
|
};
|
||||||
|
messages.push(assistant_msg);
|
||||||
|
|
||||||
for (id, name, input) in &tool_calls {
|
for (id, name, input) in &tool_calls {
|
||||||
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
|
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
|
||||||
}
|
}
|
||||||
@@ -417,19 +462,29 @@ impl AgentLoop {
|
|||||||
let mut stream = driver.stream(request);
|
let mut stream = driver.stream(request);
|
||||||
let mut pending_tool_calls: Vec<(String, String, serde_json::Value)> = Vec::new();
|
let mut pending_tool_calls: Vec<(String, String, serde_json::Value)> = Vec::new();
|
||||||
let mut iteration_text = String::new();
|
let mut iteration_text = String::new();
|
||||||
|
let mut reasoning_text = String::new(); // Track reasoning separately for API requirement
|
||||||
|
|
||||||
// Process stream chunks
|
// Process stream chunks
|
||||||
tracing::debug!("[AgentLoop] Starting to process stream chunks");
|
tracing::debug!("[AgentLoop] Starting to process stream chunks");
|
||||||
|
let mut chunk_count: usize = 0;
|
||||||
|
let mut text_delta_count: usize = 0;
|
||||||
|
let mut thinking_delta_count: usize = 0;
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
Ok(chunk) => {
|
Ok(chunk) => {
|
||||||
|
chunk_count += 1;
|
||||||
match &chunk {
|
match &chunk {
|
||||||
StreamChunk::TextDelta { delta } => {
|
StreamChunk::TextDelta { delta } => {
|
||||||
|
text_delta_count += 1;
|
||||||
|
tracing::debug!("[AgentLoop] TextDelta #{}: {} chars", text_delta_count, delta.len());
|
||||||
iteration_text.push_str(delta);
|
iteration_text.push_str(delta);
|
||||||
let _ = tx.send(LoopEvent::Delta(delta.clone())).await;
|
let _ = tx.send(LoopEvent::Delta(delta.clone())).await;
|
||||||
}
|
}
|
||||||
StreamChunk::ThinkingDelta { delta } => {
|
StreamChunk::ThinkingDelta { delta } => {
|
||||||
let _ = tx.send(LoopEvent::Delta(format!("[思考] {}", delta))).await;
|
thinking_delta_count += 1;
|
||||||
|
tracing::debug!("[AgentLoop] ThinkingDelta #{}: {} chars", thinking_delta_count, delta.len());
|
||||||
|
// Accumulate reasoning separately — not mixed into iteration_text
|
||||||
|
reasoning_text.push_str(delta);
|
||||||
}
|
}
|
||||||
StreamChunk::ToolUseStart { id, name } => {
|
StreamChunk::ToolUseStart { id, name } => {
|
||||||
tracing::debug!("[AgentLoop] ToolUseStart: id={}, name={}", id, name);
|
tracing::debug!("[AgentLoop] ToolUseStart: id={}, name={}", id, name);
|
||||||
@@ -458,6 +513,13 @@ impl AgentLoop {
|
|||||||
tracing::debug!("[AgentLoop] Stream complete: input_tokens={}, output_tokens={}", it, ot);
|
tracing::debug!("[AgentLoop] Stream complete: input_tokens={}, output_tokens={}", it, ot);
|
||||||
total_input_tokens += *it;
|
total_input_tokens += *it;
|
||||||
total_output_tokens += *ot;
|
total_output_tokens += *ot;
|
||||||
|
// Calibrate token estimation on first iteration
|
||||||
|
if iteration == 1 {
|
||||||
|
compaction::update_calibration(
|
||||||
|
compaction::estimate_messages_tokens(&messages),
|
||||||
|
*it,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
StreamChunk::Error { message } => {
|
StreamChunk::Error { message } => {
|
||||||
tracing::error!("[AgentLoop] Stream error: {}", message);
|
tracing::error!("[AgentLoop] Stream error: {}", message);
|
||||||
@@ -471,16 +533,27 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tracing::debug!("[AgentLoop] Stream ended, pending_tool_calls count: {}", pending_tool_calls.len());
|
tracing::info!("[AgentLoop] Stream ended: {} total chunks (text={}, thinking={}, tools={}), iteration_text={} chars",
|
||||||
|
chunk_count, text_delta_count, thinking_delta_count, pending_tool_calls.len(),
|
||||||
|
iteration_text.len());
|
||||||
|
if iteration_text.is_empty() {
|
||||||
|
tracing::warn!("[AgentLoop] WARNING: iteration_text is EMPTY after {} chunks! text_delta={}, thinking_delta={}",
|
||||||
|
chunk_count, text_delta_count, thinking_delta_count);
|
||||||
|
}
|
||||||
|
|
||||||
// If no tool calls, we have the final response
|
// If no tool calls, we have the final response
|
||||||
if pending_tool_calls.is_empty() {
|
if pending_tool_calls.is_empty() {
|
||||||
tracing::debug!("[AgentLoop] No tool calls, returning final response");
|
tracing::info!("[AgentLoop] No tool calls, returning final response: {} chars (reasoning: {} chars)", iteration_text.len(), reasoning_text.len());
|
||||||
// Save final assistant message
|
// Save final assistant message with reasoning
|
||||||
let _ = memory.append_message(&session_id_clone, &Message::assistant(&iteration_text)).await;
|
if let Err(e) = memory.append_message(&session_id_clone, &Message::assistant_with_thinking(
|
||||||
|
&iteration_text,
|
||||||
|
&reasoning_text,
|
||||||
|
)).await {
|
||||||
|
tracing::warn!("[AgentLoop] Failed to save final assistant message: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
|
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
|
||||||
response: iteration_text,
|
response: iteration_text.clone(),
|
||||||
input_tokens: total_input_tokens,
|
input_tokens: total_input_tokens,
|
||||||
output_tokens: total_output_tokens,
|
output_tokens: total_output_tokens,
|
||||||
iterations: iteration,
|
iterations: iteration,
|
||||||
@@ -488,7 +561,13 @@ impl AgentLoop {
|
|||||||
break 'outer;
|
break 'outer;
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::debug!("[AgentLoop] Processing {} tool calls", pending_tool_calls.len());
|
tracing::debug!("[AgentLoop] Processing {} tool calls (reasoning: {} chars)", pending_tool_calls.len(), reasoning_text.len());
|
||||||
|
|
||||||
|
// Push assistant message with reasoning before tool calls (required by Kimi and other thinking-enabled APIs)
|
||||||
|
messages.push(Message::assistant_with_thinking(
|
||||||
|
&iteration_text,
|
||||||
|
&reasoning_text,
|
||||||
|
));
|
||||||
|
|
||||||
// There are tool calls - add to message history
|
// There are tool calls - add to message history
|
||||||
for (id, name, input) in &pending_tool_calls {
|
for (id, name, input) in &pending_tool_calls {
|
||||||
@@ -519,12 +598,21 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
LoopGuardResult::Allowed => {}
|
LoopGuardResult::Allowed => {}
|
||||||
}
|
}
|
||||||
|
// Use pre-resolved path_validator (already has default fallback from create_tool_context logic)
|
||||||
|
let pv = path_validator.clone().unwrap_or_else(|| {
|
||||||
|
let home = std::env::var("USERPROFILE")
|
||||||
|
.or_else(|_| std::env::var("HOME"))
|
||||||
|
.unwrap_or_else(|_| ".".to_string());
|
||||||
|
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
|
||||||
|
});
|
||||||
|
let working_dir = pv.workspace_root()
|
||||||
|
.map(|p| p.to_string_lossy().to_string());
|
||||||
let tool_context = ToolContext {
|
let tool_context = ToolContext {
|
||||||
agent_id: agent_id.clone(),
|
agent_id: agent_id.clone(),
|
||||||
working_directory: None,
|
working_directory: working_dir,
|
||||||
session_id: Some(session_id_clone.to_string()),
|
session_id: Some(session_id_clone.to_string()),
|
||||||
skill_executor: skill_executor.clone(),
|
skill_executor: skill_executor.clone(),
|
||||||
path_validator: path_validator.clone(),
|
path_validator: Some(pv),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
||||||
|
|||||||
@@ -160,6 +160,11 @@ impl PathValidator {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the workspace root directory
|
||||||
|
pub fn workspace_root(&self) -> Option<&PathBuf> {
|
||||||
|
self.workspace_root.as_ref()
|
||||||
|
}
|
||||||
|
|
||||||
/// Validate a path for read access
|
/// Validate a path for read access
|
||||||
pub fn validate_read(&self, path: &str) -> Result<PathBuf> {
|
pub fn validate_read(&self, path: &str) -> Result<PathBuf> {
|
||||||
let canonical = self.resolve_and_validate(path)?;
|
let canonical = self.resolve_and_validate(path)?;
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ path = "src/main.rs"
|
|||||||
zclaw-types = { workspace = true }
|
zclaw-types = { workspace = true }
|
||||||
|
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
|
tokio-stream = { workspace = true }
|
||||||
futures = { workspace = true }
|
futures = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
@@ -23,7 +24,6 @@ chrono = { workspace = true }
|
|||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
sqlx = { workspace = true }
|
sqlx = { workspace = true }
|
||||||
libsqlite3-sys = { workspace = true }
|
|
||||||
reqwest = { workspace = true }
|
reqwest = { workspace = true }
|
||||||
secrecy = { workspace = true }
|
secrecy = { workspace = true }
|
||||||
sha2 = { workspace = true }
|
sha2 = { workspace = true }
|
||||||
@@ -41,6 +41,9 @@ argon2 = { workspace = true }
|
|||||||
totp-rs = { workspace = true }
|
totp-rs = { workspace = true }
|
||||||
urlencoding = "2"
|
urlencoding = "2"
|
||||||
data-encoding = "2"
|
data-encoding = "2"
|
||||||
|
regex = "1"
|
||||||
|
aes-gcm = "0.10"
|
||||||
|
bytes = "1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ pub async fn get_account(
|
|||||||
service::get_account(&state.db, &id).await.map(Json)
|
service::get_account(&state.db, &id).await.map(Json)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// PUT /api/v1/accounts/:id (admin or self for limited fields)
|
/// PATCH /api/v1/accounts/:id (admin or self for limited fields)
|
||||||
pub async fn update_account(
|
pub async fn update_account(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Path(id): Path<String>,
|
Path(id): Path<String>,
|
||||||
@@ -80,12 +80,15 @@ pub async fn update_status(
|
|||||||
Ok(Json(serde_json::json!({"ok": true})))
|
Ok(Json(serde_json::json!({"ok": true})))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/tokens
|
/// GET /api/v1/tokens?page=1&page_size=20
|
||||||
pub async fn list_tokens(
|
pub async fn list_tokens(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<Vec<TokenInfo>>> {
|
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||||
service::list_api_tokens(&state.db, &ctx.account_id).await.map(Json)
|
) -> SaasResult<Json<PaginatedResponse<TokenInfo>>> {
|
||||||
|
let page = params.get("page").and_then(|v| v.parse().ok());
|
||||||
|
let page_size = params.get("page_size").and_then(|v| v.parse().ok());
|
||||||
|
service::list_api_tokens(&state.db, &ctx.account_id, page, page_size).await.map(Json)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// POST /api/v1/tokens
|
/// POST /api/v1/tokens
|
||||||
@@ -94,9 +97,24 @@ pub async fn create_token(
|
|||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Json(req): Json<CreateTokenRequest>,
|
Json(req): Json<CreateTokenRequest>,
|
||||||
) -> SaasResult<Json<TokenInfo>> {
|
) -> SaasResult<Json<TokenInfo>> {
|
||||||
let token = service::create_api_token(&state.db, &ctx.account_id, &req).await?;
|
// 权限校验: 创建的 token 不能超出创建者已有的权限
|
||||||
|
let allowed_permissions: Vec<String> = req.permissions
|
||||||
|
.into_iter()
|
||||||
|
.filter(|p| ctx.permissions.contains(p))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if allowed_permissions.is_empty() {
|
||||||
|
return Err(SaasError::InvalidInput("请求的权限均不被允许".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let filtered_req = CreateTokenRequest {
|
||||||
|
name: req.name,
|
||||||
|
permissions: allowed_permissions,
|
||||||
|
expires_days: req.expires_days,
|
||||||
|
};
|
||||||
|
let token = service::create_api_token(&state.db, &ctx.account_id, &filtered_req).await?;
|
||||||
log_operation(&state.db, &ctx.account_id, "token.create", "api_token", &token.id,
|
log_operation(&state.db, &ctx.account_id, "token.create", "api_token", &token.id,
|
||||||
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
|
Some(serde_json::json!({"name": &filtered_req.name})), ctx.client_ip.as_deref()).await?;
|
||||||
Ok(Json(token))
|
Ok(Json(token))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,18 +134,21 @@ pub async fn list_operation_logs(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
) -> SaasResult<Json<PaginatedResponse<serde_json::Value>>> {
|
||||||
require_admin(&ctx)?;
|
require_admin(&ctx)?;
|
||||||
let page: i64 = params.get("page").and_then(|v| v.parse().ok()).unwrap_or(1);
|
let page: u32 = params.get("page").and_then(|v| v.parse().ok()).unwrap_or(1).max(1);
|
||||||
let page_size: i64 = params.get("page_size").and_then(|v| v.parse().ok()).unwrap_or(50);
|
let page_size: u32 = params.get("page_size").and_then(|v| v.parse().ok()).unwrap_or(50).min(100);
|
||||||
let offset = (page - 1) * page_size;
|
let offset = ((page - 1) * page_size) as i64;
|
||||||
|
|
||||||
|
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM operation_logs")
|
||||||
|
.fetch_one(&state.db).await?;
|
||||||
|
|
||||||
let rows: Vec<(i64, Option<String>, String, Option<String>, Option<String>, Option<String>, Option<String>, String)> =
|
let rows: Vec<(i64, Option<String>, String, Option<String>, Option<String>, Option<String>, Option<String>, String)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at
|
"SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at
|
||||||
FROM operation_logs ORDER BY created_at DESC LIMIT ?1 OFFSET ?2"
|
FROM operation_logs ORDER BY created_at DESC LIMIT $1 OFFSET $2"
|
||||||
)
|
)
|
||||||
.bind(page_size)
|
.bind(page_size as i64)
|
||||||
.bind(offset)
|
.bind(offset)
|
||||||
.fetch_all(&state.db)
|
.fetch_all(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -141,7 +162,7 @@ pub async fn list_operation_logs(
|
|||||||
})
|
})
|
||||||
}).collect();
|
}).collect();
|
||||||
|
|
||||||
Ok(Json(items))
|
Ok(Json(PaginatedResponse { items, total, page, page_size }))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/stats/dashboard — 仪表盘聚合统计 (需要 admin 权限)
|
/// GET /api/v1/stats/dashboard — 仪表盘聚合统计 (需要 admin 权限)
|
||||||
@@ -151,32 +172,34 @@ pub async fn dashboard_stats(
|
|||||||
) -> SaasResult<Json<serde_json::Value>> {
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
require_admin(&ctx)?;
|
require_admin(&ctx)?;
|
||||||
|
|
||||||
let total_accounts: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM accounts")
|
// 查询 1: 账号 + Provider + Model 聚合 (一次查询)
|
||||||
.fetch_one(&state.db).await?;
|
let stats_row: (i64, i64, i64, i64) = sqlx::query_as(
|
||||||
let active_accounts: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM accounts WHERE status = 'active'")
|
"SELECT
|
||||||
.fetch_one(&state.db).await?;
|
(SELECT COUNT(*) FROM accounts) as total_accounts,
|
||||||
let tasks_today: (i64,) = sqlx::query_as(
|
(SELECT COUNT(*) FROM accounts WHERE status = 'active') as active_accounts,
|
||||||
"SELECT COUNT(*) FROM relay_tasks WHERE date(created_at) = date('now')"
|
(SELECT COUNT(*) FROM providers WHERE enabled = true) as active_providers,
|
||||||
).fetch_one(&state.db).await?;
|
(SELECT COUNT(*) FROM models WHERE enabled = true) as active_models"
|
||||||
let active_providers: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM providers WHERE enabled = 1")
|
|
||||||
.fetch_one(&state.db).await?;
|
|
||||||
let active_models: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM models WHERE enabled = 1")
|
|
||||||
.fetch_one(&state.db).await?;
|
|
||||||
let tokens_today_input: (i64,) = sqlx::query_as(
|
|
||||||
"SELECT COALESCE(SUM(input_tokens), 0) FROM usage_records WHERE date(created_at) = date('now')"
|
|
||||||
).fetch_one(&state.db).await?;
|
|
||||||
let tokens_today_output: (i64,) = sqlx::query_as(
|
|
||||||
"SELECT COALESCE(SUM(output_tokens), 0) FROM usage_records WHERE date(created_at) = date('now')"
|
|
||||||
).fetch_one(&state.db).await?;
|
).fetch_one(&state.db).await?;
|
||||||
|
let (total_accounts, active_accounts, active_providers, active_models) = stats_row;
|
||||||
|
|
||||||
|
// 查询 2: 今日中转统计 (一次查询)
|
||||||
|
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
|
||||||
|
let today_row: (i64, i64, i64) = sqlx::query_as(
|
||||||
|
"SELECT
|
||||||
|
(SELECT COUNT(*) FROM relay_tasks WHERE SUBSTRING(created_at, 1, 10) = $1) as tasks_today,
|
||||||
|
COALESCE((SELECT SUM(input_tokens) FROM usage_records WHERE SUBSTRING(created_at, 1, 10) = $1), 0) as tokens_input,
|
||||||
|
COALESCE((SELECT SUM(output_tokens) FROM usage_records WHERE SUBSTRING(created_at, 1, 10) = $1), 0) as tokens_output"
|
||||||
|
).bind(&today).fetch_one(&state.db).await?;
|
||||||
|
let (tasks_today, tokens_today_input, tokens_today_output) = today_row;
|
||||||
|
|
||||||
Ok(Json(serde_json::json!({
|
Ok(Json(serde_json::json!({
|
||||||
"total_accounts": total_accounts.0,
|
"total_accounts": total_accounts,
|
||||||
"active_accounts": active_accounts.0,
|
"active_accounts": active_accounts,
|
||||||
"tasks_today": tasks_today.0,
|
"tasks_today": tasks_today,
|
||||||
"active_providers": active_providers.0,
|
"active_providers": active_providers,
|
||||||
"active_models": active_models.0,
|
"active_models": active_models,
|
||||||
"tokens_today_input": tokens_today_input.0,
|
"tokens_today_input": tokens_today_input,
|
||||||
"tokens_today_output": tokens_today_output.0,
|
"tokens_today_output": tokens_today_output,
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,9 +224,9 @@ pub async fn register_device(
|
|||||||
// UPSERT: 已存在则更新 last_seen_at,不存在则插入
|
// UPSERT: 已存在则更新 last_seen_at,不存在则插入
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO devices (id, account_id, device_id, device_name, platform, app_version, last_seen_at, created_at)
|
"INSERT INTO devices (id, account_id, device_id, device_name, platform, app_version, last_seen_at, created_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?7)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $7)
|
||||||
ON CONFLICT(account_id, device_id) DO UPDATE SET
|
ON CONFLICT(account_id, device_id) DO UPDATE SET
|
||||||
device_name = ?4, platform = ?5, app_version = ?6, last_seen_at = ?7"
|
device_name = $4, platform = $5, app_version = $6, last_seen_at = $7"
|
||||||
)
|
)
|
||||||
.bind(&device_uuid)
|
.bind(&device_uuid)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
@@ -233,14 +256,32 @@ pub async fn device_heartbeat(
|
|||||||
.ok_or_else(|| SaasError::InvalidInput("缺少 device_id".into()))?;
|
.ok_or_else(|| SaasError::InvalidInput("缺少 device_id".into()))?;
|
||||||
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let result = sqlx::query(
|
|
||||||
"UPDATE devices SET last_seen_at = ?1 WHERE account_id = ?2 AND device_id = ?3"
|
// Also update platform/app_version if provided (supports client upgrades)
|
||||||
|
let platform = req.get("platform").and_then(|v| v.as_str());
|
||||||
|
let app_version = req.get("app_version").and_then(|v| v.as_str());
|
||||||
|
|
||||||
|
let result = if platform.is_some() || app_version.is_some() {
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE devices SET last_seen_at = $1, platform = COALESCE($4, platform), app_version = COALESCE($5, app_version) WHERE account_id = $2 AND device_id = $3"
|
||||||
|
)
|
||||||
|
.bind(&now)
|
||||||
|
.bind(&ctx.account_id)
|
||||||
|
.bind(device_id)
|
||||||
|
.bind(platform)
|
||||||
|
.bind(app_version)
|
||||||
|
.execute(&state.db)
|
||||||
|
.await?
|
||||||
|
} else {
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE devices SET last_seen_at = $1 WHERE account_id = $2 AND device_id = $3"
|
||||||
)
|
)
|
||||||
.bind(&now)
|
.bind(&now)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.bind(device_id)
|
.bind(device_id)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?
|
||||||
|
};
|
||||||
|
|
||||||
if result.rows_affected() == 0 {
|
if result.rows_affected() == 0 {
|
||||||
return Err(SaasError::NotFound("设备未注册".into()));
|
return Err(SaasError::NotFound("设备未注册".into()));
|
||||||
@@ -249,27 +290,13 @@ pub async fn device_heartbeat(
|
|||||||
Ok(Json(serde_json::json!({"ok": true})))
|
Ok(Json(serde_json::json!({"ok": true})))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/devices — 列出当前用户的设备
|
/// GET /api/v1/devices?page=1&page_size=20 — 列出当前用户的设备
|
||||||
pub async fn list_devices(
|
pub async fn list_devices(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||||
let rows: Vec<(String, String, Option<String>, Option<String>, Option<String>, String, String)> =
|
) -> SaasResult<Json<PaginatedResponse<serde_json::Value>>> {
|
||||||
sqlx::query_as(
|
let page = params.get("page").and_then(|v| v.parse().ok());
|
||||||
"SELECT id, device_id, device_name, platform, app_version, last_seen_at, created_at
|
let page_size = params.get("page_size").and_then(|v| v.parse().ok());
|
||||||
FROM devices WHERE account_id = ?1 ORDER BY last_seen_at DESC"
|
service::list_devices(&state.db, &ctx.account_id, page, page_size).await.map(Json)
|
||||||
)
|
|
||||||
.bind(&ctx.account_id)
|
|
||||||
.fetch_all(&state.db)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let items: Vec<serde_json::Value> = rows.into_iter().map(|r| {
|
|
||||||
serde_json::json!({
|
|
||||||
"id": r.0, "device_id": r.1,
|
|
||||||
"device_name": r.2, "platform": r.3, "app_version": r.4,
|
|
||||||
"last_seen_at": r.5, "created_at": r.6,
|
|
||||||
})
|
|
||||||
}).collect();
|
|
||||||
|
|
||||||
Ok(Json(items))
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ pub mod types;
|
|||||||
pub mod service;
|
pub mod service;
|
||||||
pub mod handlers;
|
pub mod handlers;
|
||||||
|
|
||||||
use axum::routing::{delete, get, patch, post, put};
|
use axum::routing::{delete, get, patch, post};
|
||||||
|
|
||||||
pub fn routes() -> axum::Router<crate::state::AppState> {
|
pub fn routes() -> axum::Router<crate::state::AppState> {
|
||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
.route("/api/v1/accounts", get(handlers::list_accounts))
|
.route("/api/v1/accounts", get(handlers::list_accounts))
|
||||||
.route("/api/v1/accounts/{id}", get(handlers::get_account))
|
.route("/api/v1/accounts/:id", get(handlers::get_account))
|
||||||
.route("/api/v1/accounts/{id}", put(handlers::update_account))
|
.route("/api/v1/accounts/:id", patch(handlers::update_account))
|
||||||
.route("/api/v1/accounts/{id}/status", patch(handlers::update_status))
|
.route("/api/v1/accounts/:id/status", patch(handlers::update_status))
|
||||||
.route("/api/v1/tokens", get(handlers::list_tokens))
|
.route("/api/v1/tokens", get(handlers::list_tokens))
|
||||||
.route("/api/v1/tokens", post(handlers::create_token))
|
.route("/api/v1/tokens", post(handlers::create_token))
|
||||||
.route("/api/v1/tokens/{id}", delete(handlers::revoke_token))
|
.route("/api/v1/tokens/:id", delete(handlers::revoke_token))
|
||||||
.route("/api/v1/logs/operations", get(handlers::list_operation_logs))
|
.route("/api/v1/logs/operations", get(handlers::list_operation_logs))
|
||||||
.route("/api/v1/stats/dashboard", get(handlers::dashboard_stats))
|
.route("/api/v1/stats/dashboard", get(handlers::dashboard_stats))
|
||||||
.route("/api/v1/devices", get(handlers::list_devices))
|
.route("/api/v1/devices", get(handlers::list_devices))
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
//! 账号管理业务逻辑
|
//! 账号管理业务逻辑
|
||||||
|
|
||||||
use sqlx::SqlitePool;
|
use sqlx::PgPool;
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
|
use crate::common::{PaginatedResponse, normalize_pagination};
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
|
|
||||||
pub async fn list_accounts(
|
pub async fn list_accounts(
|
||||||
db: &SqlitePool,
|
db: &PgPool,
|
||||||
query: &ListAccountsQuery,
|
query: &ListAccountsQuery,
|
||||||
) -> SaasResult<PaginatedResponse<serde_json::Value>> {
|
) -> SaasResult<PaginatedResponse<serde_json::Value>> {
|
||||||
let page = query.page.unwrap_or(1).max(1);
|
let page = query.page.unwrap_or(1).max(1);
|
||||||
@@ -14,17 +15,21 @@ pub async fn list_accounts(
|
|||||||
|
|
||||||
let mut where_clauses = Vec::new();
|
let mut where_clauses = Vec::new();
|
||||||
let mut params: Vec<String> = Vec::new();
|
let mut params: Vec<String> = Vec::new();
|
||||||
|
let mut param_idx = 1usize;
|
||||||
|
|
||||||
if let Some(role) = &query.role {
|
if let Some(role) = &query.role {
|
||||||
where_clauses.push("role = ?".to_string());
|
where_clauses.push(format!("role = ${}", param_idx));
|
||||||
|
param_idx += 1;
|
||||||
params.push(role.clone());
|
params.push(role.clone());
|
||||||
}
|
}
|
||||||
if let Some(status) = &query.status {
|
if let Some(status) = &query.status {
|
||||||
where_clauses.push("status = ?".to_string());
|
where_clauses.push(format!("status = ${}", param_idx));
|
||||||
|
param_idx += 1;
|
||||||
params.push(status.clone());
|
params.push(status.clone());
|
||||||
}
|
}
|
||||||
if let Some(search) = &query.search {
|
if let Some(search) = &query.search {
|
||||||
where_clauses.push("(username LIKE ? OR email LIKE ? OR display_name LIKE ?)".to_string());
|
where_clauses.push(format!("(username LIKE ${} OR email LIKE ${} OR display_name LIKE ${})", param_idx, param_idx + 1, param_idx + 2));
|
||||||
|
param_idx += 3;
|
||||||
let pattern = format!("%{}%", search);
|
let pattern = format!("%{}%", search);
|
||||||
params.push(pattern.clone());
|
params.push(pattern.clone());
|
||||||
params.push(pattern.clone());
|
params.push(pattern.clone());
|
||||||
@@ -44,10 +49,12 @@ pub async fn list_accounts(
|
|||||||
}
|
}
|
||||||
let total: i64 = count_query.fetch_one(db).await?;
|
let total: i64 = count_query.fetch_one(db).await?;
|
||||||
|
|
||||||
|
let limit_idx = param_idx;
|
||||||
|
let offset_idx = param_idx + 1;
|
||||||
let data_sql = format!(
|
let data_sql = format!(
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||||
FROM accounts {} ORDER BY created_at DESC LIMIT ? OFFSET ?",
|
FROM accounts {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
|
||||||
where_sql
|
where_sql, limit_idx, offset_idx
|
||||||
);
|
);
|
||||||
let mut data_query = sqlx::query_as::<_, (String, String, String, String, String, String, bool, Option<String>, String)>(&data_sql);
|
let mut data_query = sqlx::query_as::<_, (String, String, String, String, String, String, bool, Option<String>, String)>(&data_sql);
|
||||||
for p in ¶ms {
|
for p in ¶ms {
|
||||||
@@ -69,11 +76,11 @@ pub async fn list_accounts(
|
|||||||
Ok(PaginatedResponse { items, total, page, page_size })
|
Ok(PaginatedResponse { items, total, page, page_size })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_account(db: &SqlitePool, account_id: &str) -> SaasResult<serde_json::Value> {
|
pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json::Value> {
|
||||||
let row: Option<(String, String, String, String, String, String, bool, Option<String>, String)> =
|
let row: Option<(String, String, String, String, String, String, bool, Option<String>, String)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||||
FROM accounts WHERE id = ?1"
|
FROM accounts WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(account_id)
|
.bind(account_id)
|
||||||
.fetch_optional(db)
|
.fetch_optional(db)
|
||||||
@@ -90,28 +97,30 @@ pub async fn get_account(db: &SqlitePool, account_id: &str) -> SaasResult<serde_
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update_account(
|
pub async fn update_account(
|
||||||
db: &SqlitePool,
|
db: &PgPool,
|
||||||
account_id: &str,
|
account_id: &str,
|
||||||
req: &UpdateAccountRequest,
|
req: &UpdateAccountRequest,
|
||||||
) -> SaasResult<serde_json::Value> {
|
) -> SaasResult<serde_json::Value> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let mut updates = Vec::new();
|
let mut updates = Vec::new();
|
||||||
let mut params: Vec<String> = Vec::new();
|
let mut params: Vec<String> = Vec::new();
|
||||||
|
let mut param_idx = 1usize;
|
||||||
|
|
||||||
if let Some(ref v) = req.display_name { updates.push("display_name = ?"); params.push(v.clone()); }
|
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
|
||||||
if let Some(ref v) = req.email { updates.push("email = ?"); params.push(v.clone()); }
|
if let Some(ref v) = req.email { updates.push(format!("email = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
|
||||||
if let Some(ref v) = req.role { updates.push("role = ?"); params.push(v.clone()); }
|
if let Some(ref v) = req.role { updates.push(format!("role = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
|
||||||
if let Some(ref v) = req.avatar_url { updates.push("avatar_url = ?"); params.push(v.clone()); }
|
if let Some(ref v) = req.avatar_url { updates.push(format!("avatar_url = ${}", param_idx)); param_idx += 1; params.push(v.clone()); }
|
||||||
|
|
||||||
if updates.is_empty() {
|
if updates.is_empty() {
|
||||||
return get_account(db, account_id).await;
|
return get_account(db, account_id).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
updates.push("updated_at = ?");
|
updates.push(format!("updated_at = ${}", param_idx));
|
||||||
|
param_idx += 1;
|
||||||
params.push(now.clone());
|
params.push(now.clone());
|
||||||
params.push(account_id.to_string());
|
params.push(account_id.to_string());
|
||||||
|
|
||||||
let sql = format!("UPDATE accounts SET {} WHERE id = ?", updates.join(", "));
|
let sql = format!("UPDATE accounts SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||||
let mut query = sqlx::query(&sql);
|
let mut query = sqlx::query(&sql);
|
||||||
for p in ¶ms {
|
for p in ¶ms {
|
||||||
query = query.bind(p);
|
query = query.bind(p);
|
||||||
@@ -121,7 +130,7 @@ pub async fn update_account(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update_account_status(
|
pub async fn update_account_status(
|
||||||
db: &SqlitePool,
|
db: &PgPool,
|
||||||
account_id: &str,
|
account_id: &str,
|
||||||
status: &str,
|
status: &str,
|
||||||
) -> SaasResult<()> {
|
) -> SaasResult<()> {
|
||||||
@@ -130,7 +139,7 @@ pub async fn update_account_status(
|
|||||||
return Err(SaasError::InvalidInput(format!("无效状态: {},有效值: {:?}", status, valid)));
|
return Err(SaasError::InvalidInput(format!("无效状态: {},有效值: {:?}", status, valid)));
|
||||||
}
|
}
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let result = sqlx::query("UPDATE accounts SET status = ?1, updated_at = ?2 WHERE id = ?3")
|
let result = sqlx::query("UPDATE accounts SET status = $1, updated_at = $2 WHERE id = $3")
|
||||||
.bind(status).bind(&now).bind(account_id)
|
.bind(status).bind(&now).bind(account_id)
|
||||||
.execute(db).await?;
|
.execute(db).await?;
|
||||||
|
|
||||||
@@ -141,7 +150,7 @@ pub async fn update_account_status(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_api_token(
|
pub async fn create_api_token(
|
||||||
db: &SqlitePool,
|
db: &PgPool,
|
||||||
account_id: &str,
|
account_id: &str,
|
||||||
req: &CreateTokenRequest,
|
req: &CreateTokenRequest,
|
||||||
) -> SaasResult<TokenInfo> {
|
) -> SaasResult<TokenInfo> {
|
||||||
@@ -163,7 +172,7 @@ pub async fn create_api_token(
|
|||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO api_tokens (id, account_id, name, token_hash, token_prefix, permissions, created_at, expires_at)
|
"INSERT INTO api_tokens (id, account_id, name, token_hash, token_prefix, permissions, created_at, expires_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)"
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||||
)
|
)
|
||||||
.bind(&token_id)
|
.bind(&token_id)
|
||||||
.bind(account_id)
|
.bind(account_id)
|
||||||
@@ -189,28 +198,80 @@ pub async fn create_api_token(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn list_api_tokens(
|
pub async fn list_api_tokens(
|
||||||
db: &SqlitePool,
|
db: &PgPool,
|
||||||
account_id: &str,
|
account_id: &str,
|
||||||
) -> SaasResult<Vec<TokenInfo>> {
|
page: Option<u32>,
|
||||||
|
page_size: Option<u32>,
|
||||||
|
) -> SaasResult<PaginatedResponse<TokenInfo>> {
|
||||||
|
let (p, ps, offset) = normalize_pagination(page, page_size);
|
||||||
|
|
||||||
|
let total: (i64,) = sqlx::query_as(
|
||||||
|
"SELECT COUNT(*) FROM api_tokens WHERE account_id = $1 AND revoked_at IS NULL"
|
||||||
|
)
|
||||||
|
.bind(account_id)
|
||||||
|
.fetch_one(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let rows: Vec<(String, String, String, String, Option<String>, Option<String>, String)> =
|
let rows: Vec<(String, String, String, String, Option<String>, Option<String>, String)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, name, token_prefix, permissions, last_used_at, expires_at, created_at
|
"SELECT id, name, token_prefix, permissions, last_used_at, expires_at, created_at
|
||||||
FROM api_tokens WHERE account_id = ?1 AND revoked_at IS NULL ORDER BY created_at DESC"
|
FROM api_tokens WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||||||
)
|
)
|
||||||
.bind(account_id)
|
.bind(account_id)
|
||||||
|
.bind(ps as i64)
|
||||||
|
.bind(offset)
|
||||||
.fetch_all(db)
|
.fetch_all(db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(rows.into_iter().map(|(id, name, token_prefix, perms, last_used, expires, created)| {
|
let items = rows.into_iter().map(|(id, name, token_prefix, perms, last_used, expires, created)| {
|
||||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||||
TokenInfo { id, name, token_prefix, permissions, last_used_at: last_used, expires_at: expires, created_at: created, token: None, }
|
TokenInfo { id, name, token_prefix, permissions, last_used_at: last_used, expires_at: expires, created_at: created, token: None, }
|
||||||
}).collect())
|
}).collect();
|
||||||
|
|
||||||
|
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn revoke_api_token(db: &SqlitePool, token_id: &str, account_id: &str) -> SaasResult<()> {
|
pub async fn list_devices(
|
||||||
|
db: &PgPool,
|
||||||
|
account_id: &str,
|
||||||
|
page: Option<u32>,
|
||||||
|
page_size: Option<u32>,
|
||||||
|
) -> SaasResult<PaginatedResponse<serde_json::Value>> {
|
||||||
|
let (p, ps, offset) = normalize_pagination(page, page_size);
|
||||||
|
|
||||||
|
let total: (i64,) = sqlx::query_as(
|
||||||
|
"SELECT COUNT(*) FROM devices WHERE account_id = $1"
|
||||||
|
)
|
||||||
|
.bind(account_id)
|
||||||
|
.fetch_one(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let rows: Vec<(String, String, Option<String>, Option<String>, Option<String>, String, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, device_id, device_name, platform, app_version, last_seen_at, created_at
|
||||||
|
FROM devices WHERE account_id = $1 ORDER BY last_seen_at DESC LIMIT $2 OFFSET $3"
|
||||||
|
)
|
||||||
|
.bind(account_id)
|
||||||
|
.bind(ps as i64)
|
||||||
|
.bind(offset)
|
||||||
|
.fetch_all(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let items: Vec<serde_json::Value> = rows.into_iter().map(|r| {
|
||||||
|
serde_json::json!({
|
||||||
|
"id": r.0, "device_id": r.1,
|
||||||
|
"device_name": r.2, "platform": r.3, "app_version": r.4,
|
||||||
|
"last_seen_at": r.5, "created_at": r.6,
|
||||||
|
})
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn revoke_api_token(db: &PgPool, token_id: &str, account_id: &str) -> SaasResult<()> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let result = sqlx::query(
|
let result = sqlx::query(
|
||||||
"UPDATE api_tokens SET revoked_at = ?1 WHERE id = ?2 AND account_id = ?3 AND revoked_at IS NULL"
|
"UPDATE api_tokens SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
|
||||||
)
|
)
|
||||||
.bind(&now).bind(token_id).bind(account_id)
|
.bind(&now).bind(token_id).bind(account_id)
|
||||||
.execute(db).await?;
|
.execute(db).await?;
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// Re-export from common module
|
||||||
|
pub use crate::common::PaginatedResponse;
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct UpdateAccountRequest {
|
pub struct UpdateAccountRequest {
|
||||||
pub display_name: Option<String>,
|
pub display_name: Option<String>,
|
||||||
@@ -24,14 +27,6 @@ pub struct ListAccountsQuery {
|
|||||||
pub search: Option<String>,
|
pub search: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
|
||||||
pub struct PaginatedResponse<T: Serialize> {
|
|
||||||
pub items: Vec<T>,
|
|
||||||
pub total: i64,
|
|
||||||
pub page: u32,
|
|
||||||
pub page_size: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
pub struct CreateTokenRequest {
|
pub struct CreateTokenRequest {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
|
|||||||
104
crates/zclaw-saas/src/agent_template/handlers.rs
Normal file
104
crates/zclaw-saas/src/agent_template/handlers.rs
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
//! Agent 配置模板 HTTP 处理器
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Extension, Path, Query, State},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::error::SaasResult;
|
||||||
|
use crate::auth::types::AuthContext;
|
||||||
|
use crate::auth::handlers::{log_operation, check_permission};
|
||||||
|
use super::types::*;
|
||||||
|
use super::service;
|
||||||
|
|
||||||
|
/// GET /api/v1/agent-templates — 列出 Agent 模板
|
||||||
|
pub async fn list_templates(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Query(query): Query<AgentTemplateListQuery>,
|
||||||
|
) -> SaasResult<Json<crate::common::PaginatedResponse<AgentTemplateInfo>>> {
|
||||||
|
check_permission(&ctx, "model:read")?;
|
||||||
|
Ok(Json(service::list_templates(&state.db, &query).await?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /api/v1/agent-templates — 创建 Agent 模板
|
||||||
|
pub async fn create_template(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Json(req): Json<CreateAgentTemplateRequest>,
|
||||||
|
) -> SaasResult<Json<AgentTemplateInfo>> {
|
||||||
|
check_permission(&ctx, "model:manage")?;
|
||||||
|
|
||||||
|
let category = req.category.as_deref().unwrap_or("general");
|
||||||
|
let source = req.source.as_deref().unwrap_or("custom");
|
||||||
|
let visibility = req.visibility.as_deref().unwrap_or("public");
|
||||||
|
let tools = req.tools.as_deref().unwrap_or(&[]);
|
||||||
|
let capabilities = req.capabilities.as_deref().unwrap_or(&[]);
|
||||||
|
|
||||||
|
let result = service::create_template(
|
||||||
|
&state.db, &req.name, req.description.as_deref(),
|
||||||
|
category, source, req.model.as_deref(),
|
||||||
|
req.system_prompt.as_deref(),
|
||||||
|
tools, capabilities,
|
||||||
|
req.temperature, req.max_tokens, visibility,
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "agent_template.create", "agent_template", &result.id,
|
||||||
|
Some(serde_json::json!({"name": req.name})), ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/agent-templates/:id — 获取单个 Agent 模板
|
||||||
|
pub async fn get_template(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
) -> SaasResult<Json<AgentTemplateInfo>> {
|
||||||
|
check_permission(&ctx, "model:read")?;
|
||||||
|
Ok(Json(service::get_template(&state.db, &id).await?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /api/v1/agent-templates/:id — 更新 Agent 模板
|
||||||
|
pub async fn update_template(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
Json(req): Json<UpdateAgentTemplateRequest>,
|
||||||
|
) -> SaasResult<Json<AgentTemplateInfo>> {
|
||||||
|
check_permission(&ctx, "model:manage")?;
|
||||||
|
|
||||||
|
let result = service::update_template(
|
||||||
|
&state.db, &id,
|
||||||
|
req.description.as_deref(),
|
||||||
|
req.model.as_deref(),
|
||||||
|
req.system_prompt.as_deref(),
|
||||||
|
req.tools.as_deref(),
|
||||||
|
req.capabilities.as_deref(),
|
||||||
|
req.temperature,
|
||||||
|
req.max_tokens,
|
||||||
|
req.visibility.as_deref(),
|
||||||
|
req.status.as_deref(),
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "agent_template.update", "agent_template", &id,
|
||||||
|
None, ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DELETE /api/v1/agent-templates/:id — 归档 Agent 模板
|
||||||
|
pub async fn archive_template(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
) -> SaasResult<Json<AgentTemplateInfo>> {
|
||||||
|
check_permission(&ctx, "model:manage")?;
|
||||||
|
|
||||||
|
let result = service::archive_template(&state.db, &id).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "agent_template.archive", "agent_template", &id,
|
||||||
|
None, ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
|
}
|
||||||
17
crates/zclaw-saas/src/agent_template/mod.rs
Normal file
17
crates/zclaw-saas/src/agent_template/mod.rs
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
//! Agent 配置模板管理模块
|
||||||
|
|
||||||
|
pub mod types;
|
||||||
|
pub mod service;
|
||||||
|
pub mod handlers;
|
||||||
|
|
||||||
|
use axum::routing::{delete, get, post};
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
/// Agent 模板管理路由 (需要认证)
|
||||||
|
pub fn routes() -> axum::Router<AppState> {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/api/v1/agent-templates", get(handlers::list_templates).post(handlers::create_template))
|
||||||
|
.route("/api/v1/agent-templates/:id", get(handlers::get_template))
|
||||||
|
.route("/api/v1/agent-templates/:id", post(handlers::update_template))
|
||||||
|
.route("/api/v1/agent-templates/:id", delete(handlers::archive_template))
|
||||||
|
}
|
||||||
272
crates/zclaw-saas/src/agent_template/service.rs
Normal file
272
crates/zclaw-saas/src/agent_template/service.rs
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
//! Agent 配置模板业务逻辑
|
||||||
|
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use crate::error::{SaasError, SaasResult};
|
||||||
|
use super::types::*;
|
||||||
|
|
||||||
|
fn row_to_template(
|
||||||
|
row: (String, String, Option<String>, String, String, Option<String>, Option<String>,
|
||||||
|
String, String, Option<f64>, Option<i32>, String, String, i32, String, String),
|
||||||
|
) -> AgentTemplateInfo {
|
||||||
|
AgentTemplateInfo {
|
||||||
|
id: row.0, name: row.1, description: row.2, category: row.3, source: row.4,
|
||||||
|
model: row.5, system_prompt: row.6, tools: serde_json::from_str(&row.7).unwrap_or_default(),
|
||||||
|
capabilities: serde_json::from_str(&row.8).unwrap_or_default(),
|
||||||
|
temperature: row.9, max_tokens: row.10, visibility: row.11, status: row.12,
|
||||||
|
current_version: row.13, created_at: row.14, updated_at: row.15,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建 Agent 模板
|
||||||
|
pub async fn create_template(
|
||||||
|
db: &PgPool,
|
||||||
|
name: &str,
|
||||||
|
description: Option<&str>,
|
||||||
|
category: &str,
|
||||||
|
source: &str,
|
||||||
|
model: Option<&str>,
|
||||||
|
system_prompt: Option<&str>,
|
||||||
|
tools: &[String],
|
||||||
|
capabilities: &[String],
|
||||||
|
temperature: Option<f64>,
|
||||||
|
max_tokens: Option<i32>,
|
||||||
|
visibility: &str,
|
||||||
|
) -> SaasResult<AgentTemplateInfo> {
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let tools_json = serde_json::to_string(tools).unwrap_or_else(|_| "[]".to_string());
|
||||||
|
let caps_json = serde_json::to_string(capabilities).unwrap_or_else(|_| "[]".to_string());
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO agent_templates (id, name, description, category, source, model, system_prompt,
|
||||||
|
tools, capabilities, temperature, max_tokens, visibility, status, current_version, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, 'active', 1, $13, $13)"
|
||||||
|
)
|
||||||
|
.bind(&id).bind(name).bind(description).bind(category).bind(source)
|
||||||
|
.bind(model).bind(system_prompt).bind(&tools_json).bind(&caps_json)
|
||||||
|
.bind(temperature).bind(max_tokens).bind(visibility).bind(&now)
|
||||||
|
.execute(db).await.map_err(|e| {
|
||||||
|
if e.to_string().contains("unique") {
|
||||||
|
SaasError::AlreadyExists(format!("Agent 模板 '{}' 已存在", name))
|
||||||
|
} else {
|
||||||
|
SaasError::Database(e)
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
get_template(db, &id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取单个模板
|
||||||
|
pub async fn get_template(db: &PgPool, id: &str) -> SaasResult<AgentTemplateInfo> {
|
||||||
|
let row: Option<_> = sqlx::query_as(
|
||||||
|
"SELECT id, name, description, category, source, model, system_prompt,
|
||||||
|
tools, capabilities, temperature, max_tokens, visibility, status,
|
||||||
|
current_version, created_at, updated_at
|
||||||
|
FROM agent_templates WHERE id = $1"
|
||||||
|
).bind(id).fetch_optional(db).await?;
|
||||||
|
|
||||||
|
row.map(row_to_template)
|
||||||
|
.ok_or_else(|| SaasError::NotFound(format!("Agent 模板 {} 不存在", id)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 列出模板(分页 + 过滤)
|
||||||
|
/// 使用动态参数化查询,安全拼接 WHERE 条件。
|
||||||
|
pub async fn list_templates(
|
||||||
|
db: &PgPool,
|
||||||
|
query: &AgentTemplateListQuery,
|
||||||
|
) -> SaasResult<crate::common::PaginatedResponse<AgentTemplateInfo>> {
|
||||||
|
let page = query.page.unwrap_or(1).max(1);
|
||||||
|
let page_size = query.page_size.unwrap_or(20).min(100);
|
||||||
|
let offset = ((page - 1) * page_size) as i64;
|
||||||
|
|
||||||
|
// 动态构建参数化 WHERE 子句
|
||||||
|
let mut conditions: Vec<String> = vec!["1=1".to_string()];
|
||||||
|
let mut param_idx = 1u32;
|
||||||
|
let mut cat_bind: Option<String> = None;
|
||||||
|
let mut src_bind: Option<String> = None;
|
||||||
|
let mut vis_bind: Option<String> = None;
|
||||||
|
let mut st_bind: Option<String> = None;
|
||||||
|
|
||||||
|
if let Some(ref cat) = query.category {
|
||||||
|
param_idx += 1;
|
||||||
|
conditions.push(format!("category = ${}", param_idx));
|
||||||
|
cat_bind = Some(cat.clone());
|
||||||
|
}
|
||||||
|
if let Some(ref src) = query.source {
|
||||||
|
param_idx += 1;
|
||||||
|
conditions.push(format!("source = ${}", param_idx));
|
||||||
|
src_bind = Some(src.clone());
|
||||||
|
}
|
||||||
|
if let Some(ref vis) = query.visibility {
|
||||||
|
param_idx += 1;
|
||||||
|
conditions.push(format!("visibility = ${}", param_idx));
|
||||||
|
vis_bind = Some(vis.clone());
|
||||||
|
}
|
||||||
|
if let Some(ref st) = query.status {
|
||||||
|
param_idx += 1;
|
||||||
|
conditions.push(format!("status = ${}", param_idx));
|
||||||
|
st_bind = Some(st.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let where_clause = conditions.join(" AND ");
|
||||||
|
|
||||||
|
// COUNT 查询: WHERE 参数绑定 ($1..$N)
|
||||||
|
let count_idx = param_idx;
|
||||||
|
let count_sql = format!(
|
||||||
|
"SELECT COUNT(*) FROM agent_templates WHERE {}",
|
||||||
|
where_clause
|
||||||
|
);
|
||||||
|
let count_limit_idx = count_idx + 1;
|
||||||
|
let count_offset_idx = count_limit_idx + 1;
|
||||||
|
let data_sql = format!(
|
||||||
|
"SELECT id, name, description, category, source, model, system_prompt,
|
||||||
|
tools, capabilities, temperature, max_tokens, visibility, status,
|
||||||
|
current_version, created_at, updated_at
|
||||||
|
FROM agent_templates WHERE {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
|
||||||
|
where_clause, count_limit_idx, count_offset_idx
|
||||||
|
);
|
||||||
|
|
||||||
|
// 构建 COUNT 查询并绑定参数
|
||||||
|
let mut count_q = sqlx::query_scalar::<_, i64>(&count_sql);
|
||||||
|
if let Some(ref v) = cat_bind { count_q = count_q.bind(v); }
|
||||||
|
if let Some(ref v) = src_bind { count_q = count_q.bind(v); }
|
||||||
|
if let Some(ref v) = vis_bind { count_q = count_q.bind(v); }
|
||||||
|
if let Some(ref v) = st_bind { count_q = count_q.bind(v); }
|
||||||
|
let total: i64 = count_q.fetch_one(db).await?;
|
||||||
|
|
||||||
|
// 构建数据查询并绑定参数
|
||||||
|
let mut data_q = sqlx::query_as::<_, (
|
||||||
|
String, String, Option<String>, String, String, Option<String>, Option<String>,
|
||||||
|
String, String, Option<f64>, Option<i32>, String, String, i32, String, String
|
||||||
|
)>(&data_sql);
|
||||||
|
if let Some(ref v) = cat_bind { data_q = data_q.bind(v); }
|
||||||
|
if let Some(ref v) = src_bind { data_q = data_q.bind(v); }
|
||||||
|
if let Some(ref v) = vis_bind { data_q = data_q.bind(v); }
|
||||||
|
if let Some(ref v) = st_bind { data_q = data_q.bind(v); }
|
||||||
|
data_q = data_q.bind(page_size as i64).bind(offset);
|
||||||
|
|
||||||
|
let rows = data_q.fetch_all(db).await?;
|
||||||
|
let items = rows.into_iter().map(row_to_template).collect();
|
||||||
|
|
||||||
|
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 更新模板
|
||||||
|
/// 使用动态参数化查询,安全拼接 SET 子句。
|
||||||
|
pub async fn update_template(
|
||||||
|
db: &PgPool,
|
||||||
|
id: &str,
|
||||||
|
description: Option<&str>,
|
||||||
|
model: Option<&str>,
|
||||||
|
system_prompt: Option<&str>,
|
||||||
|
tools: Option<&[String]>,
|
||||||
|
capabilities: Option<&[String]>,
|
||||||
|
temperature: Option<f64>,
|
||||||
|
max_tokens: Option<i32>,
|
||||||
|
visibility: Option<&str>,
|
||||||
|
status: Option<&str>,
|
||||||
|
) -> SaasResult<AgentTemplateInfo> {
|
||||||
|
// 确认存在
|
||||||
|
get_template(db, id).await?;
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let mut set_clauses: Vec<String> = vec![];
|
||||||
|
let mut param_idx = 1u32;
|
||||||
|
|
||||||
|
// 收集需要绑定的值(按顺序)
|
||||||
|
let mut desc_val: Option<String> = None;
|
||||||
|
let mut model_val: Option<String> = None;
|
||||||
|
let mut sp_val: Option<String> = None;
|
||||||
|
let mut tools_val: Option<String> = None;
|
||||||
|
let mut caps_val: Option<String> = None;
|
||||||
|
let mut temp_val: Option<f64> = None;
|
||||||
|
let mut mt_val: Option<i32> = None;
|
||||||
|
let mut vis_val: Option<String> = None;
|
||||||
|
let mut st_val: Option<String> = None;
|
||||||
|
|
||||||
|
if let Some(desc) = description {
|
||||||
|
param_idx += 1;
|
||||||
|
set_clauses.push(format!("description = ${}", param_idx));
|
||||||
|
desc_val = Some(desc.to_string());
|
||||||
|
}
|
||||||
|
if let Some(m) = model {
|
||||||
|
param_idx += 1;
|
||||||
|
set_clauses.push(format!("model = ${}", param_idx));
|
||||||
|
model_val = Some(m.to_string());
|
||||||
|
}
|
||||||
|
if let Some(sp) = system_prompt {
|
||||||
|
param_idx += 1;
|
||||||
|
set_clauses.push(format!("system_prompt = ${}", param_idx));
|
||||||
|
sp_val = Some(sp.to_string());
|
||||||
|
}
|
||||||
|
if let Some(t) = tools {
|
||||||
|
let json = serde_json::to_string(t).unwrap_or_else(|_| "[]".to_string());
|
||||||
|
param_idx += 1;
|
||||||
|
set_clauses.push(format!("tools = ${}", param_idx));
|
||||||
|
tools_val = Some(json);
|
||||||
|
}
|
||||||
|
if let Some(c) = capabilities {
|
||||||
|
let json = serde_json::to_string(c).unwrap_or_else(|_| "[]".to_string());
|
||||||
|
param_idx += 1;
|
||||||
|
set_clauses.push(format!("capabilities = ${}", param_idx));
|
||||||
|
caps_val = Some(json);
|
||||||
|
}
|
||||||
|
if let Some(t) = temperature {
|
||||||
|
param_idx += 1;
|
||||||
|
set_clauses.push(format!("temperature = ${}", param_idx));
|
||||||
|
temp_val = Some(t);
|
||||||
|
}
|
||||||
|
if let Some(m) = max_tokens {
|
||||||
|
param_idx += 1;
|
||||||
|
set_clauses.push(format!("max_tokens = ${}", param_idx));
|
||||||
|
mt_val = Some(m);
|
||||||
|
}
|
||||||
|
if let Some(v) = visibility {
|
||||||
|
param_idx += 1;
|
||||||
|
set_clauses.push(format!("visibility = ${}", param_idx));
|
||||||
|
vis_val = Some(v.to_string());
|
||||||
|
}
|
||||||
|
if let Some(s) = status {
|
||||||
|
param_idx += 1;
|
||||||
|
set_clauses.push(format!("status = ${}", param_idx));
|
||||||
|
st_val = Some(s.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if set_clauses.is_empty() {
|
||||||
|
return get_template(db, id).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// updated_at
|
||||||
|
param_idx += 1;
|
||||||
|
set_clauses.push(format!("updated_at = ${}", param_idx));
|
||||||
|
|
||||||
|
// WHERE id = $N
|
||||||
|
let id_idx = param_idx + 1;
|
||||||
|
|
||||||
|
let sql = format!(
|
||||||
|
"UPDATE agent_templates SET {} WHERE id = ${}",
|
||||||
|
set_clauses.join(", "), id_idx
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut q = sqlx::query(&sql);
|
||||||
|
if let Some(ref v) = desc_val { q = q.bind(v); }
|
||||||
|
if let Some(ref v) = model_val { q = q.bind(v); }
|
||||||
|
if let Some(ref v) = sp_val { q = q.bind(v); }
|
||||||
|
if let Some(ref v) = tools_val { q = q.bind(v); }
|
||||||
|
if let Some(ref v) = caps_val { q = q.bind(v); }
|
||||||
|
if let Some(v) = temp_val { q = q.bind(v); }
|
||||||
|
if let Some(v) = mt_val { q = q.bind(v); }
|
||||||
|
if let Some(ref v) = vis_val { q = q.bind(v); }
|
||||||
|
if let Some(ref v) = st_val { q = q.bind(v); }
|
||||||
|
q = q.bind(&now);
|
||||||
|
q = q.bind(id);
|
||||||
|
|
||||||
|
q.execute(db).await?;
|
||||||
|
|
||||||
|
get_template(db, id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 归档模板
|
||||||
|
pub async fn archive_template(db: &PgPool, id: &str) -> SaasResult<AgentTemplateInfo> {
|
||||||
|
update_template(db, id, None, None, None, None, None, None, None, None, Some("archived")).await
|
||||||
|
}
|
||||||
65
crates/zclaw-saas/src/agent_template/types.rs
Normal file
65
crates/zclaw-saas/src/agent_template/types.rs
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
//! Agent 配置模板类型定义
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// --- Agent Template ---
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct AgentTemplateInfo {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub category: String,
|
||||||
|
pub source: String,
|
||||||
|
pub model: Option<String>,
|
||||||
|
pub system_prompt: Option<String>,
|
||||||
|
pub tools: Vec<String>,
|
||||||
|
pub capabilities: Vec<String>,
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
pub max_tokens: Option<i32>,
|
||||||
|
pub visibility: String,
|
||||||
|
pub status: String,
|
||||||
|
pub current_version: i32,
|
||||||
|
pub created_at: String,
|
||||||
|
pub updated_at: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct CreateAgentTemplateRequest {
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub category: Option<String>,
|
||||||
|
pub source: Option<String>,
|
||||||
|
pub model: Option<String>,
|
||||||
|
pub system_prompt: Option<String>,
|
||||||
|
pub tools: Option<Vec<String>>,
|
||||||
|
pub capabilities: Option<Vec<String>>,
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
pub max_tokens: Option<i32>,
|
||||||
|
pub visibility: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct UpdateAgentTemplateRequest {
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub model: Option<String>,
|
||||||
|
pub system_prompt: Option<String>,
|
||||||
|
pub tools: Option<Vec<String>>,
|
||||||
|
pub capabilities: Option<Vec<String>>,
|
||||||
|
pub temperature: Option<f64>,
|
||||||
|
pub max_tokens: Option<i32>,
|
||||||
|
pub visibility: Option<String>,
|
||||||
|
pub status: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- List ---
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct AgentTemplateListQuery {
|
||||||
|
pub category: Option<String>,
|
||||||
|
pub source: Option<String>,
|
||||||
|
pub visibility: Option<String>,
|
||||||
|
pub status: Option<String>,
|
||||||
|
pub page: Option<u32>,
|
||||||
|
pub page_size: Option<u32>,
|
||||||
|
}
|
||||||
@@ -6,26 +6,45 @@ use secrecy::ExposeSecret;
|
|||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
use super::{
|
use super::{
|
||||||
jwt::create_token,
|
jwt::{create_token, create_refresh_token, verify_token, verify_token_skip_expiry},
|
||||||
password::{hash_password, verify_password},
|
password::{hash_password, verify_password},
|
||||||
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic},
|
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic, RefreshRequest},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// POST /api/v1/auth/register
|
/// POST /api/v1/auth/register
|
||||||
|
/// 注册成功后自动签发 JWT,返回与 login 一致的 LoginResponse
|
||||||
pub async fn register(
|
pub async fn register(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
Json(req): Json<RegisterRequest>,
|
Json(req): Json<RegisterRequest>,
|
||||||
) -> SaasResult<(StatusCode, Json<AccountPublic>)> {
|
) -> SaasResult<(StatusCode, Json<LoginResponse>)> {
|
||||||
if req.username.len() < 3 {
|
if req.username.len() < 3 {
|
||||||
return Err(SaasError::InvalidInput("用户名至少 3 个字符".into()));
|
return Err(SaasError::InvalidInput("用户名至少 3 个字符".into()));
|
||||||
}
|
}
|
||||||
|
if req.username.len() > 32 {
|
||||||
|
return Err(SaasError::InvalidInput("用户名最多 32 个字符".into()));
|
||||||
|
}
|
||||||
|
let username_re = regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap();
|
||||||
|
if !username_re.is_match(&req.username) {
|
||||||
|
return Err(SaasError::InvalidInput("用户名只能包含字母、数字、下划线和连字符".into()));
|
||||||
|
}
|
||||||
|
if !req.email.contains('@') || !req.email.contains('.') {
|
||||||
|
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
|
||||||
|
}
|
||||||
if req.password.len() < 8 {
|
if req.password.len() < 8 {
|
||||||
return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
|
return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
|
||||||
}
|
}
|
||||||
|
if req.password.len() > 128 {
|
||||||
|
return Err(SaasError::InvalidInput("密码最多 128 个字符".into()));
|
||||||
|
}
|
||||||
|
if let Some(ref name) = req.display_name {
|
||||||
|
if name.len() > 64 {
|
||||||
|
return Err(SaasError::InvalidInput("显示名称最多 64 个字符".into()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let existing: Vec<(String,)> = sqlx::query_as(
|
let existing: Vec<(String,)> = sqlx::query_as(
|
||||||
"SELECT id FROM accounts WHERE username = ?1 OR email = ?2"
|
"SELECT id FROM accounts WHERE username = $1 OR email = $2"
|
||||||
)
|
)
|
||||||
.bind(&req.username)
|
.bind(&req.username)
|
||||||
.bind(&req.email)
|
.bind(&req.email)
|
||||||
@@ -44,7 +63,7 @@ pub async fn register(
|
|||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
|
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'active', ?7, ?7)"
|
VALUES ($1, $2, $3, $4, $5, $6, 'active', $7, $7)"
|
||||||
)
|
)
|
||||||
.bind(&account_id)
|
.bind(&account_id)
|
||||||
.bind(&req.username)
|
.bind(&req.username)
|
||||||
@@ -59,7 +78,30 @@ pub async fn register(
|
|||||||
let client_ip = addr.ip().to_string();
|
let client_ip = addr.ip().to_string();
|
||||||
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
|
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
|
||||||
|
|
||||||
Ok((StatusCode::CREATED, Json(AccountPublic {
|
// 注册成功后自动签发 JWT + Refresh Token
|
||||||
|
let permissions = get_role_permissions(&state.db, &role).await?;
|
||||||
|
let config = state.config.read().await;
|
||||||
|
let token = create_token(
|
||||||
|
&account_id, &role, permissions.clone(),
|
||||||
|
state.jwt_secret.expose_secret(),
|
||||||
|
config.auth.jwt_expiration_hours,
|
||||||
|
)?;
|
||||||
|
let refresh_token = create_refresh_token(
|
||||||
|
&account_id, &role, permissions,
|
||||||
|
state.jwt_secret.expose_secret(),
|
||||||
|
config.auth.refresh_token_hours,
|
||||||
|
)?;
|
||||||
|
drop(config);
|
||||||
|
|
||||||
|
store_refresh_token(
|
||||||
|
&state.db, &account_id, &refresh_token,
|
||||||
|
state.jwt_secret.expose_secret(), 168,
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
Ok((StatusCode::CREATED, Json(LoginResponse {
|
||||||
|
token,
|
||||||
|
refresh_token,
|
||||||
|
account: AccountPublic {
|
||||||
id: account_id,
|
id: account_id,
|
||||||
username: req.username,
|
username: req.username,
|
||||||
email: req.email,
|
email: req.email,
|
||||||
@@ -68,6 +110,7 @@ pub async fn register(
|
|||||||
status: "active".into(),
|
status: "active".into(),
|
||||||
totp_enabled: false,
|
totp_enabled: false,
|
||||||
created_at: now,
|
created_at: now,
|
||||||
|
},
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,7 +123,7 @@ pub async fn login(
|
|||||||
let row: Option<(String, String, String, String, String, String, bool, String)> =
|
let row: Option<(String, String, String, String, String, String, bool, String)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||||
FROM accounts WHERE username = ?1 OR email = ?1"
|
FROM accounts WHERE username = $1 OR email = $1"
|
||||||
)
|
)
|
||||||
.bind(&req.username)
|
.bind(&req.username)
|
||||||
.fetch_optional(&state.db)
|
.fetch_optional(&state.db)
|
||||||
@@ -94,7 +137,7 @@ pub async fn login(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let (password_hash,): (String,) = sqlx::query_as(
|
let (password_hash,): (String,) = sqlx::query_as(
|
||||||
"SELECT password_hash FROM accounts WHERE id = ?1"
|
"SELECT password_hash FROM accounts WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(&id)
|
.bind(&id)
|
||||||
.fetch_one(&state.db)
|
.fetch_one(&state.db)
|
||||||
@@ -110,7 +153,7 @@ pub async fn login(
|
|||||||
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
|
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
|
||||||
|
|
||||||
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
||||||
"SELECT totp_secret FROM accounts WHERE id = ?1"
|
"SELECT totp_secret FROM accounts WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(&id)
|
.bind(&id)
|
||||||
.fetch_one(&state.db)
|
.fetch_one(&state.db)
|
||||||
@@ -120,6 +163,12 @@ pub async fn login(
|
|||||||
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
|
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
// 解密 TOTP secret (兼容旧的明文格式)
|
||||||
|
let config = state.config.read().await;
|
||||||
|
let enc_key = config.totp_encryption_key()
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||||
|
let secret = super::totp::decrypt_totp_for_login(&secret, &enc_key)?;
|
||||||
|
|
||||||
if !super::totp::verify_totp_code(&secret, code) {
|
if !super::totp::verify_totp_code(&secret, code) {
|
||||||
return Err(SaasError::Totp("TOTP 码错误或已过期".into()));
|
return Err(SaasError::Totp("TOTP 码错误或已过期".into()));
|
||||||
}
|
}
|
||||||
@@ -132,16 +181,28 @@ pub async fn login(
|
|||||||
state.jwt_secret.expose_secret(),
|
state.jwt_secret.expose_secret(),
|
||||||
config.auth.jwt_expiration_hours,
|
config.auth.jwt_expiration_hours,
|
||||||
)?;
|
)?;
|
||||||
|
let refresh_token = create_refresh_token(
|
||||||
|
&id, &role, permissions,
|
||||||
|
state.jwt_secret.expose_secret(),
|
||||||
|
config.auth.refresh_token_hours,
|
||||||
|
)?;
|
||||||
|
drop(config);
|
||||||
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
sqlx::query("UPDATE accounts SET last_login_at = ?1 WHERE id = ?2")
|
sqlx::query("UPDATE accounts SET last_login_at = $1 WHERE id = $2")
|
||||||
.bind(&now).bind(&id)
|
.bind(&now).bind(&id)
|
||||||
.execute(&state.db).await?;
|
.execute(&state.db).await?;
|
||||||
let client_ip = addr.ip().to_string();
|
let client_ip = addr.ip().to_string();
|
||||||
log_operation(&state.db, &id, "account.login", "account", &id, None, Some(&client_ip)).await?;
|
log_operation(&state.db, &id, "account.login", "account", &id, None, Some(&client_ip)).await?;
|
||||||
|
|
||||||
|
store_refresh_token(
|
||||||
|
&state.db, &id, &refresh_token,
|
||||||
|
state.jwt_secret.expose_secret(), 168,
|
||||||
|
).await?;
|
||||||
|
|
||||||
Ok(Json(LoginResponse {
|
Ok(Json(LoginResponse {
|
||||||
token,
|
token,
|
||||||
|
refresh_token,
|
||||||
account: AccountPublic {
|
account: AccountPublic {
|
||||||
id, username, email, display_name, role, status, totp_enabled, created_at,
|
id, username, email, display_name, role, status, totp_enabled, created_at,
|
||||||
},
|
},
|
||||||
@@ -149,17 +210,92 @@ pub async fn login(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// POST /api/v1/auth/refresh
|
/// POST /api/v1/auth/refresh
|
||||||
|
/// 使用 refresh_token 换取新的 access + refresh token 对
|
||||||
|
/// refresh_token 一次性使用,使用后立即失效
|
||||||
pub async fn refresh(
|
pub async fn refresh(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
Json(req): Json<RefreshRequest>,
|
||||||
) -> SaasResult<Json<serde_json::Value>> {
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
|
// 1. 验证 refresh token 签名 (跳过过期检查,但有 7 天窗口限制)
|
||||||
|
let claims = verify_token_skip_expiry(&req.refresh_token, state.jwt_secret.expose_secret())?;
|
||||||
|
|
||||||
|
// 2. 确认是 refresh 类型 token
|
||||||
|
if claims.token_type != "refresh" {
|
||||||
|
return Err(SaasError::AuthError("无效的 refresh token".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let jti = claims.jti.as_deref()
|
||||||
|
.ok_or_else(|| SaasError::AuthError("refresh token 缺少 jti".into()))?;
|
||||||
|
|
||||||
|
// 3. 从 DB 查找 refresh token,确保未被使用
|
||||||
|
let row: Option<(String,)> = sqlx::query_as(
|
||||||
|
"SELECT account_id FROM refresh_tokens WHERE jti = $1 AND used_at IS NULL AND expires_at > $2"
|
||||||
|
)
|
||||||
|
.bind(jti)
|
||||||
|
.bind(&chrono::Utc::now().to_rfc3339())
|
||||||
|
.fetch_optional(&state.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let token_account_id = row
|
||||||
|
.ok_or_else(|| SaasError::AuthError("refresh token 已使用、已过期或不存在".into()))?
|
||||||
|
.0;
|
||||||
|
|
||||||
|
// 4. 验证 token 中的 account_id 与 DB 中的一致
|
||||||
|
if token_account_id != claims.sub {
|
||||||
|
return Err(SaasError::AuthError("refresh token 账号不匹配".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 标记旧 refresh token 为已使用 (一次性)
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
sqlx::query("UPDATE refresh_tokens SET used_at = $1 WHERE jti = $2")
|
||||||
|
.bind(&now).bind(jti)
|
||||||
|
.execute(&state.db).await?;
|
||||||
|
|
||||||
|
// 6. 获取最新角色权限
|
||||||
|
let (role,): (String,) = sqlx::query_as(
|
||||||
|
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
|
||||||
|
)
|
||||||
|
.bind(&claims.sub)
|
||||||
|
.fetch_optional(&state.db)
|
||||||
|
.await?
|
||||||
|
.ok_or_else(|| SaasError::AuthError("账号不存在或已禁用".into()))?;
|
||||||
|
|
||||||
|
let permissions = get_role_permissions(&state.db, &role).await?;
|
||||||
|
|
||||||
|
// 7. 创建新的 access token + refresh token
|
||||||
let config = state.config.read().await;
|
let config = state.config.read().await;
|
||||||
let token = create_token(
|
let new_access = create_token(
|
||||||
&ctx.account_id, &ctx.role, ctx.permissions.clone(),
|
&claims.sub, &role, permissions.clone(),
|
||||||
state.jwt_secret.expose_secret(),
|
state.jwt_secret.expose_secret(),
|
||||||
config.auth.jwt_expiration_hours,
|
config.auth.jwt_expiration_hours,
|
||||||
)?;
|
)?;
|
||||||
Ok(Json(serde_json::json!({ "token": token })))
|
let new_refresh = create_refresh_token(
|
||||||
|
&claims.sub, &role, permissions.clone(),
|
||||||
|
state.jwt_secret.expose_secret(),
|
||||||
|
config.auth.refresh_token_hours,
|
||||||
|
)?;
|
||||||
|
drop(config);
|
||||||
|
|
||||||
|
// 8. 存储新 refresh token 到 DB
|
||||||
|
let new_claims = verify_token(&new_refresh, state.jwt_secret.expose_secret())?;
|
||||||
|
let new_jti = new_claims.jti.unwrap_or_default();
|
||||||
|
let new_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let refresh_expires = (chrono::Utc::now() + chrono::Duration::hours(168)).to_rfc3339();
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO refresh_tokens (id, account_id, jti, token_hash, expires_at, created_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
)
|
||||||
|
.bind(&new_id).bind(&claims.sub).bind(&new_jti)
|
||||||
|
.bind(sha256_hex(&new_refresh)).bind(&refresh_expires).bind(&now)
|
||||||
|
.execute(&state.db).await?;
|
||||||
|
|
||||||
|
// 9. 清理过期/已使用的 refresh tokens (异步, 不阻塞)
|
||||||
|
cleanup_expired_refresh_tokens(&state.db).await?;
|
||||||
|
|
||||||
|
Ok(Json(serde_json::json!({
|
||||||
|
"token": new_access,
|
||||||
|
"refresh_token": new_refresh,
|
||||||
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
|
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
|
||||||
@@ -170,7 +306,7 @@ pub async fn me(
|
|||||||
let row: Option<(String, String, String, String, String, String, bool, String)> =
|
let row: Option<(String, String, String, String, String, String, bool, String)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||||
FROM accounts WHERE id = ?1"
|
FROM accounts WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.fetch_optional(&state.db)
|
.fetch_optional(&state.db)
|
||||||
@@ -196,7 +332,7 @@ pub async fn change_password(
|
|||||||
|
|
||||||
// 获取当前密码哈希
|
// 获取当前密码哈希
|
||||||
let (password_hash,): (String,) = sqlx::query_as(
|
let (password_hash,): (String,) = sqlx::query_as(
|
||||||
"SELECT password_hash FROM accounts WHERE id = ?1"
|
"SELECT password_hash FROM accounts WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.fetch_one(&state.db)
|
.fetch_one(&state.db)
|
||||||
@@ -210,7 +346,7 @@ pub async fn change_password(
|
|||||||
// 更新密码
|
// 更新密码
|
||||||
let new_hash = hash_password(&req.new_password)?;
|
let new_hash = hash_password(&req.new_password)?;
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
sqlx::query("UPDATE accounts SET password_hash = ?1, updated_at = ?2 WHERE id = ?3")
|
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2 WHERE id = $3")
|
||||||
.bind(&new_hash)
|
.bind(&new_hash)
|
||||||
.bind(&now)
|
.bind(&now)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
@@ -223,9 +359,9 @@ pub async fn change_password(
|
|||||||
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
|
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult<Vec<String>> {
|
pub(crate) async fn get_role_permissions(db: &sqlx::PgPool, role: &str) -> SaasResult<Vec<String>> {
|
||||||
let row: Option<(String,)> = sqlx::query_as(
|
let row: Option<(String,)> = sqlx::query_as(
|
||||||
"SELECT permissions FROM roles WHERE id = ?1"
|
"SELECT permissions FROM roles WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(role)
|
.bind(role)
|
||||||
.fetch_optional(db)
|
.fetch_optional(db)
|
||||||
@@ -252,7 +388,7 @@ pub fn check_permission(ctx: &AuthContext, permission: &str) -> SaasResult<()> {
|
|||||||
|
|
||||||
/// 记录操作日志
|
/// 记录操作日志
|
||||||
pub async fn log_operation(
|
pub async fn log_operation(
|
||||||
db: &sqlx::SqlitePool,
|
db: &sqlx::PgPool,
|
||||||
account_id: &str,
|
account_id: &str,
|
||||||
action: &str,
|
action: &str,
|
||||||
target_type: &str,
|
target_type: &str,
|
||||||
@@ -263,7 +399,7 @@ pub async fn log_operation(
|
|||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
|
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)"
|
VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||||
)
|
)
|
||||||
.bind(account_id)
|
.bind(account_id)
|
||||||
.bind(action)
|
.bind(action)
|
||||||
@@ -276,3 +412,45 @@ pub async fn log_operation(
|
|||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 存储 refresh token 到 DB
|
||||||
|
async fn store_refresh_token(
|
||||||
|
db: &sqlx::PgPool,
|
||||||
|
account_id: &str,
|
||||||
|
refresh_token: &str,
|
||||||
|
secret: &str,
|
||||||
|
refresh_hours: i64,
|
||||||
|
) -> SaasResult<()> {
|
||||||
|
let claims = verify_token(refresh_token, secret)?;
|
||||||
|
let jti = claims.jti.unwrap_or_default();
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let expires_at = (chrono::Utc::now() + chrono::Duration::hours(refresh_hours)).to_rfc3339();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO refresh_tokens (id, account_id, jti, token_hash, expires_at, created_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
)
|
||||||
|
.bind(&id).bind(account_id).bind(&jti)
|
||||||
|
.bind(sha256_hex(refresh_token)).bind(&expires_at).bind(&now)
|
||||||
|
.execute(db).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 清理过期和已使用的 refresh tokens
|
||||||
|
async fn cleanup_expired_refresh_tokens(db: &sqlx::PgPool) -> SaasResult<()> {
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
// 删除过期超过 30 天的已使用 token (减少 DB 膨胀)
|
||||||
|
sqlx::query(
|
||||||
|
"DELETE FROM refresh_tokens WHERE (used_at IS NOT NULL AND used_at < $1) OR (expires_at < $1)"
|
||||||
|
)
|
||||||
|
.bind(&now)
|
||||||
|
.execute(db).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// SHA-256 hex digest
|
||||||
|
fn sha256_hex(input: &str) -> String {
|
||||||
|
use sha2::{Sha256, Digest};
|
||||||
|
hex::encode(Sha256::digest(input.as_bytes()))
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,27 +9,52 @@ use crate::error::SaasResult;
|
|||||||
/// JWT Claims
|
/// JWT Claims
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct Claims {
|
pub struct Claims {
|
||||||
|
/// JWT ID — 唯一标识,用于 token 追踪和吊销
|
||||||
|
pub jti: Option<String>,
|
||||||
pub sub: String,
|
pub sub: String,
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub permissions: Vec<String>,
|
pub permissions: Vec<String>,
|
||||||
|
/// token 类型: "access" 或 "refresh"
|
||||||
|
#[serde(default = "default_token_type")]
|
||||||
|
pub token_type: String,
|
||||||
pub iat: i64,
|
pub iat: i64,
|
||||||
pub exp: i64,
|
pub exp: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_token_type() -> String {
|
||||||
|
"access".to_string()
|
||||||
|
}
|
||||||
|
|
||||||
impl Claims {
|
impl Claims {
|
||||||
pub fn new(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self {
|
pub fn new_access(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
Self {
|
Self {
|
||||||
|
jti: Some(uuid::Uuid::new_v4().to_string()),
|
||||||
sub: account_id.to_string(),
|
sub: account_id.to_string(),
|
||||||
role: role.to_string(),
|
role: role.to_string(),
|
||||||
permissions,
|
permissions,
|
||||||
|
token_type: "access".to_string(),
|
||||||
iat: now.timestamp(),
|
iat: now.timestamp(),
|
||||||
exp: (now + Duration::hours(expiration_hours)).timestamp(),
|
exp: (now + Duration::hours(expiration_hours)).timestamp(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 创建 refresh token claims (有效期更长,用于一次性刷新)
|
||||||
|
pub fn new_refresh(account_id: &str, role: &str, permissions: Vec<String>, refresh_hours: i64) -> Self {
|
||||||
|
let now = Utc::now();
|
||||||
|
Self {
|
||||||
|
jti: Some(uuid::Uuid::new_v4().to_string()),
|
||||||
|
sub: account_id.to_string(),
|
||||||
|
role: role.to_string(),
|
||||||
|
permissions,
|
||||||
|
token_type: "refresh".to_string(),
|
||||||
|
iat: now.timestamp(),
|
||||||
|
exp: (now + Duration::hours(refresh_hours)).timestamp(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 创建 JWT Token
|
/// 创建 Access JWT Token
|
||||||
pub fn create_token(
|
pub fn create_token(
|
||||||
account_id: &str,
|
account_id: &str,
|
||||||
role: &str,
|
role: &str,
|
||||||
@@ -37,7 +62,24 @@ pub fn create_token(
|
|||||||
secret: &str,
|
secret: &str,
|
||||||
expiration_hours: i64,
|
expiration_hours: i64,
|
||||||
) -> SaasResult<String> {
|
) -> SaasResult<String> {
|
||||||
let claims = Claims::new(account_id, role, permissions, expiration_hours);
|
let claims = Claims::new_access(account_id, role, permissions, expiration_hours);
|
||||||
|
let token = encode(
|
||||||
|
&Header::default(),
|
||||||
|
&claims,
|
||||||
|
&EncodingKey::from_secret(secret.as_bytes()),
|
||||||
|
)?;
|
||||||
|
Ok(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建 Refresh JWT Token (独立 jti,有效期更长)
|
||||||
|
pub fn create_refresh_token(
|
||||||
|
account_id: &str,
|
||||||
|
role: &str,
|
||||||
|
permissions: Vec<String>,
|
||||||
|
secret: &str,
|
||||||
|
refresh_hours: i64,
|
||||||
|
) -> SaasResult<String> {
|
||||||
|
let claims = Claims::new_refresh(account_id, role, permissions, refresh_hours);
|
||||||
let token = encode(
|
let token = encode(
|
||||||
&Header::default(),
|
&Header::default(),
|
||||||
&claims,
|
&claims,
|
||||||
@@ -56,6 +98,52 @@ pub fn verify_token(token: &str, secret: &str) -> SaasResult<Claims> {
|
|||||||
Ok(token_data.claims)
|
Ok(token_data.claims)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 验证 JWT Token 但跳过过期检查(仅用于 refresh token 刷新)
|
||||||
|
/// 限制: 原始 token 的 iat 必须在 7 天内
|
||||||
|
pub fn verify_token_skip_expiry(token: &str, secret: &str) -> SaasResult<Claims> {
|
||||||
|
let mut validation = Validation::default();
|
||||||
|
validation.validate_exp = false;
|
||||||
|
let token_data = decode::<Claims>(
|
||||||
|
token,
|
||||||
|
&DecodingKey::from_secret(secret.as_bytes()),
|
||||||
|
&validation,
|
||||||
|
)?;
|
||||||
|
let claims = &token_data.claims;
|
||||||
|
|
||||||
|
// 限制刷新窗口: token 签发时间必须在 7 天内
|
||||||
|
let now = Utc::now().timestamp();
|
||||||
|
let max_refresh_window = 7 * 24 * 3600; // 7 天
|
||||||
|
if now - claims.iat > max_refresh_window {
|
||||||
|
return Err(jsonwebtoken::errors::Error::from(
|
||||||
|
jsonwebtoken::errors::ErrorKind::ExpiredSignature
|
||||||
|
).into());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(token_data.claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Token 对: access token + refresh token
|
||||||
|
#[derive(Debug, serde::Serialize)]
|
||||||
|
pub struct TokenPair {
|
||||||
|
pub access_token: String,
|
||||||
|
pub refresh_token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 创建 access + refresh token 对
|
||||||
|
pub fn create_token_pair(
|
||||||
|
account_id: &str,
|
||||||
|
role: &str,
|
||||||
|
permissions: Vec<String>,
|
||||||
|
secret: &str,
|
||||||
|
access_hours: i64,
|
||||||
|
refresh_hours: i64,
|
||||||
|
) -> SaasResult<TokenPair> {
|
||||||
|
Ok(TokenPair {
|
||||||
|
access_token: create_token(account_id, role, permissions.clone(), secret, access_hours)?,
|
||||||
|
refresh_token: create_refresh_token(account_id, role, permissions, secret, refresh_hours)?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -74,6 +162,8 @@ mod tests {
|
|||||||
assert_eq!(claims.sub, "account-123");
|
assert_eq!(claims.sub, "account-123");
|
||||||
assert_eq!(claims.role, "admin");
|
assert_eq!(claims.role, "admin");
|
||||||
assert_eq!(claims.permissions, vec!["model:read"]);
|
assert_eq!(claims.permissions, vec!["model:read"]);
|
||||||
|
assert!(claims.jti.is_some());
|
||||||
|
assert_eq!(claims.token_type, "access");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -88,4 +178,17 @@ mod tests {
|
|||||||
let result = verify_token(&token, "wrong-secret");
|
let result = verify_token(&token, "wrong-secret");
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_refresh_token_has_different_jti() {
|
||||||
|
let access = create_token("acct-1", "user", vec![], TEST_SECRET, 1).unwrap();
|
||||||
|
let refresh = create_refresh_token("acct-1", "user", vec![], TEST_SECRET, 168).unwrap();
|
||||||
|
|
||||||
|
let access_claims = verify_token(&access, TEST_SECRET).unwrap();
|
||||||
|
let refresh_claims = verify_token(&refresh, TEST_SECRET).unwrap();
|
||||||
|
|
||||||
|
assert_ne!(access_claims.jti, refresh_claims.jti);
|
||||||
|
assert_eq!(access_claims.token_type, "access");
|
||||||
|
assert_eq!(refresh_claims.token_type, "refresh");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
|||||||
|
|
||||||
let row: Option<(String, Option<String>, String)> = sqlx::query_as(
|
let row: Option<(String, Option<String>, String)> = sqlx::query_as(
|
||||||
"SELECT account_id, expires_at, permissions FROM api_tokens
|
"SELECT account_id, expires_at, permissions FROM api_tokens
|
||||||
WHERE token_hash = ?1 AND revoked_at IS NULL"
|
WHERE token_hash = $1 AND revoked_at IS NULL"
|
||||||
)
|
)
|
||||||
.bind(&token_hash)
|
.bind(&token_hash)
|
||||||
.fetch_optional(&state.db)
|
.fetch_optional(&state.db)
|
||||||
@@ -50,7 +50,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
|||||||
|
|
||||||
// 查询关联账号的角色
|
// 查询关联账号的角色
|
||||||
let (role,): (String,) = sqlx::query_as(
|
let (role,): (String,) = sqlx::query_as(
|
||||||
"SELECT role FROM accounts WHERE id = ?1 AND status = 'active'"
|
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
|
||||||
)
|
)
|
||||||
.bind(&account_id)
|
.bind(&account_id)
|
||||||
.fetch_optional(&state.db)
|
.fetch_optional(&state.db)
|
||||||
@@ -71,7 +71,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
|||||||
let db = state.db.clone();
|
let db = state.db.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = ?1 WHERE token_hash = ?2")
|
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
|
||||||
.bind(&now).bind(&token_hash)
|
.bind(&now).bind(&token_hash)
|
||||||
.execute(&db).await;
|
.execute(&db).await;
|
||||||
});
|
});
|
||||||
@@ -121,7 +121,8 @@ pub async fn auth_middleware(
|
|||||||
verify_api_token(&state, token, client_ip.clone()).await
|
verify_api_token(&state, token, client_ip.clone()).await
|
||||||
} else {
|
} else {
|
||||||
// JWT 路径
|
// JWT 路径
|
||||||
jwt::verify_token(token, state.jwt_secret.expose_secret())
|
let verify_result = jwt::verify_token(token, state.jwt_secret.expose_secret());
|
||||||
|
verify_result
|
||||||
.map(|claims| AuthContext {
|
.map(|claims| AuthContext {
|
||||||
account_id: claims.sub,
|
account_id: claims.sub,
|
||||||
role: claims.role,
|
role: claims.role,
|
||||||
@@ -153,6 +154,7 @@ pub fn routes() -> axum::Router<AppState> {
|
|||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
.route("/api/v1/auth/register", post(handlers::register))
|
.route("/api/v1/auth/register", post(handlers::register))
|
||||||
.route("/api/v1/auth/login", post(handlers::login))
|
.route("/api/v1/auth/login", post(handlers::login))
|
||||||
|
.route("/api/v1/auth/refresh", post(handlers::refresh))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 需要认证的路由
|
/// 需要认证的路由
|
||||||
@@ -160,7 +162,6 @@ pub fn protected_routes() -> axum::Router<AppState> {
|
|||||||
use axum::routing::{get, post, put};
|
use axum::routing::{get, post, put};
|
||||||
|
|
||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
.route("/api/v1/auth/refresh", post(handlers::refresh))
|
|
||||||
.route("/api/v1/auth/me", get(handlers::me))
|
.route("/api/v1/auth/me", get(handlers::me))
|
||||||
.route("/api/v1/auth/password", put(handlers::change_password))
|
.route("/api/v1/auth/password", put(handlers::change_password))
|
||||||
.route("/api/v1/auth/totp/setup", post(totp::setup_totp))
|
.route("/api/v1/auth/totp/setup", post(totp::setup_totp))
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ use crate::state::AppState;
|
|||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
use crate::auth::types::AuthContext;
|
use crate::auth::types::AuthContext;
|
||||||
use crate::auth::handlers::log_operation;
|
use crate::auth::handlers::log_operation;
|
||||||
|
use crate::crypto;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// TOTP 设置响应
|
/// TOTP 设置响应
|
||||||
@@ -46,6 +47,21 @@ fn base32_decode(data: &str) -> Option<Vec<u8>> {
|
|||||||
data_encoding::BASE32.decode(data.as_bytes()).ok()
|
data_encoding::BASE32.decode(data.as_bytes()).ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 加密 TOTP secret (AES-256-GCM,随机 nonce)
|
||||||
|
/// 存储格式: enc:<base64(nonce||ciphertext)>
|
||||||
|
/// 委托给 crypto::encrypt_value 统一加密
|
||||||
|
fn encrypt_totp_secret(plaintext: &str, key: &[u8; 32]) -> Result<String, SaasError> {
|
||||||
|
crate::crypto::encrypt_value(plaintext, key)
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 解密 TOTP secret (仅支持新格式: 随机 nonce)
|
||||||
|
/// 旧的固定 nonce 格式应通过启动时迁移转换。
|
||||||
|
fn decrypt_totp_secret(encrypted: &str, key: &[u8; 32]) -> Result<String, SaasError> {
|
||||||
|
crate::crypto::decrypt_value(encrypted, key)
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))
|
||||||
|
}
|
||||||
|
|
||||||
/// 生成 TOTP 密钥并返回 otpauth URI
|
/// 生成 TOTP 密钥并返回 otpauth URI
|
||||||
pub fn generate_totp_secret(issuer: &str, account_name: &str) -> TotpSetupResponse {
|
pub fn generate_totp_secret(issuer: &str, account_name: &str) -> TotpSetupResponse {
|
||||||
let secret = generate_random_secret();
|
let secret = generate_random_secret();
|
||||||
@@ -94,7 +110,7 @@ pub async fn setup_totp(
|
|||||||
) -> SaasResult<Json<TotpSetupResponse>> {
|
) -> SaasResult<Json<TotpSetupResponse>> {
|
||||||
// 如果已启用 TOTP,先清除旧密钥
|
// 如果已启用 TOTP,先清除旧密钥
|
||||||
let (username,): (String,) = sqlx::query_as(
|
let (username,): (String,) = sqlx::query_as(
|
||||||
"SELECT username FROM accounts WHERE id = ?1"
|
"SELECT username FROM accounts WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.fetch_one(&state.db)
|
.fetch_one(&state.db)
|
||||||
@@ -103,9 +119,13 @@ pub async fn setup_totp(
|
|||||||
let config = state.config.read().await;
|
let config = state.config.read().await;
|
||||||
let setup = generate_totp_secret(&config.auth.totp_issuer, &username);
|
let setup = generate_totp_secret(&config.auth.totp_issuer, &username);
|
||||||
|
|
||||||
// 存储密钥 (但不启用,需要 /verify 确认)
|
// 加密后存储密钥 (但不启用,需要 /verify 确认)
|
||||||
sqlx::query("UPDATE accounts SET totp_secret = ?1 WHERE id = ?2")
|
let enc_key = config.totp_encryption_key()
|
||||||
.bind(&setup.secret)
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||||
|
let encrypted_secret = encrypt_totp_secret(&setup.secret, &enc_key)?;
|
||||||
|
|
||||||
|
sqlx::query("UPDATE accounts SET totp_secret = $1 WHERE id = $2")
|
||||||
|
.bind(&encrypted_secret)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -130,23 +150,42 @@ pub async fn verify_totp(
|
|||||||
|
|
||||||
// 获取存储的密钥
|
// 获取存储的密钥
|
||||||
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
||||||
"SELECT totp_secret FROM accounts WHERE id = ?1"
|
"SELECT totp_secret FROM accounts WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.fetch_one(&state.db)
|
.fetch_one(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let secret = totp_secret.ok_or_else(|| {
|
let encrypted_secret = totp_secret.ok_or_else(|| {
|
||||||
SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into())
|
SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
// 解密 secret (兼容旧的明文格式)
|
||||||
|
let config = state.config.read().await;
|
||||||
|
let enc_key = config.totp_encryption_key()
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||||
|
let secret = if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
|
||||||
|
decrypt_totp_secret(&encrypted_secret, &enc_key)?
|
||||||
|
} else {
|
||||||
|
// 旧格式: 明文存储,需要迁移
|
||||||
|
encrypted_secret.clone()
|
||||||
|
};
|
||||||
|
|
||||||
if !verify_totp_code(&secret, code) {
|
if !verify_totp_code(&secret, code) {
|
||||||
return Err(SaasError::Totp("TOTP 码验证失败".into()));
|
return Err(SaasError::Totp("TOTP 码验证失败".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证成功 → 启用 TOTP
|
// 验证成功 → 启用 TOTP,同时确保密钥已加密
|
||||||
|
let final_secret = if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
|
||||||
|
encrypted_secret
|
||||||
|
} else {
|
||||||
|
// 迁移: 加密旧明文密钥
|
||||||
|
encrypt_totp_secret(&secret, &enc_key)?
|
||||||
|
};
|
||||||
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
sqlx::query("UPDATE accounts SET totp_enabled = 1, updated_at = ?1 WHERE id = ?2")
|
sqlx::query("UPDATE accounts SET totp_enabled = true, totp_secret = $1, updated_at = $2 WHERE id = $3")
|
||||||
|
.bind(&final_secret)
|
||||||
.bind(&now)
|
.bind(&now)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
@@ -167,7 +206,7 @@ pub async fn disable_totp(
|
|||||||
) -> SaasResult<Json<serde_json::Value>> {
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
// 验证密码
|
// 验证密码
|
||||||
let (password_hash,): (String,) = sqlx::query_as(
|
let (password_hash,): (String,) = sqlx::query_as(
|
||||||
"SELECT password_hash FROM accounts WHERE id = ?1"
|
"SELECT password_hash FROM accounts WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.fetch_one(&state.db)
|
.fetch_one(&state.db)
|
||||||
@@ -179,7 +218,7 @@ pub async fn disable_totp(
|
|||||||
|
|
||||||
// 清除 TOTP
|
// 清除 TOTP
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
sqlx::query("UPDATE accounts SET totp_enabled = 0, totp_secret = NULL, updated_at = ?1 WHERE id = ?2")
|
sqlx::query("UPDATE accounts SET totp_enabled = false, totp_secret = NULL, updated_at = $1 WHERE id = $2")
|
||||||
.bind(&now)
|
.bind(&now)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
@@ -190,3 +229,14 @@ pub async fn disable_totp(
|
|||||||
|
|
||||||
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"})))
|
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"})))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 解密 TOTP secret (供 login handler 使用)
|
||||||
|
/// 返回解密后的明文 secret
|
||||||
|
pub fn decrypt_totp_for_login(encrypted_secret: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
|
||||||
|
if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
|
||||||
|
decrypt_totp_secret(encrypted_secret, enc_key)
|
||||||
|
} else {
|
||||||
|
// 兼容旧的明文格式
|
||||||
|
Ok(encrypted_secret.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ pub struct LoginRequest {
|
|||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
pub struct LoginResponse {
|
pub struct LoginResponse {
|
||||||
pub token: String,
|
pub token: String,
|
||||||
|
pub refresh_token: String,
|
||||||
pub account: AccountPublic,
|
pub account: AccountPublic,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,3 +55,9 @@ pub struct AuthContext {
|
|||||||
pub permissions: Vec<String>,
|
pub permissions: Vec<String>,
|
||||||
pub client_ip: Option<String>,
|
pub client_ip: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Token 刷新请求
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct RefreshRequest {
|
||||||
|
pub refresh_token: String,
|
||||||
|
}
|
||||||
|
|||||||
51
crates/zclaw-saas/src/common.rs
Normal file
51
crates/zclaw-saas/src/common.rs
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
//! 公共类型和工具函数
|
||||||
|
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
/// 分页响应通用包装
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct PaginatedResponse<T: Serialize> {
|
||||||
|
pub items: Vec<T>,
|
||||||
|
pub total: i64,
|
||||||
|
pub page: u32,
|
||||||
|
pub page_size: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 分页上限
|
||||||
|
pub const MAX_PAGE_SIZE: u32 = 100;
|
||||||
|
|
||||||
|
/// 默认分页大小
|
||||||
|
pub const DEFAULT_PAGE_SIZE: u32 = 20;
|
||||||
|
|
||||||
|
/// 规范化分页参数,返回 (page, page_size, offset)
|
||||||
|
pub fn normalize_pagination(page: Option<u32>, page_size: Option<u32>) -> (u32, u32, i64) {
|
||||||
|
let p = page.unwrap_or(1).max(1);
|
||||||
|
let ps = page_size.unwrap_or(DEFAULT_PAGE_SIZE).min(MAX_PAGE_SIZE).max(1);
|
||||||
|
let offset = ((p - 1) * ps) as i64;
|
||||||
|
(p, ps, offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_pagination_defaults() {
|
||||||
|
let (page, size, offset) = normalize_pagination(None, None);
|
||||||
|
assert_eq!(page, 1);
|
||||||
|
assert_eq!(size, DEFAULT_PAGE_SIZE);
|
||||||
|
assert_eq!(offset, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_pagination_clamp() {
|
||||||
|
let (page, size, offset) = normalize_pagination(None, Some(999));
|
||||||
|
assert_eq!(size, MAX_PAGE_SIZE);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_normalize_pagination_offset() {
|
||||||
|
let (page, size, offset) = normalize_pagination(Some(3), Some(10));
|
||||||
|
assert_eq!(offset, 20);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,7 +2,8 @@
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use secrecy::SecretString;
|
use secrecy::{ExposeSecret, SecretString};
|
||||||
|
use sha2::Digest;
|
||||||
|
|
||||||
/// SaaS 服务器完整配置
|
/// SaaS 服务器完整配置
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
@@ -40,6 +41,9 @@ pub struct AuthConfig {
|
|||||||
pub jwt_expiration_hours: i64,
|
pub jwt_expiration_hours: i64,
|
||||||
#[serde(default = "default_totp_issuer")]
|
#[serde(default = "default_totp_issuer")]
|
||||||
pub totp_issuer: String,
|
pub totp_issuer: String,
|
||||||
|
/// Refresh Token 有效期 (小时), 默认 168 小时 = 7 天
|
||||||
|
#[serde(default = "default_refresh_hours")]
|
||||||
|
pub refresh_token_hours: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 中转服务配置
|
/// 中转服务配置
|
||||||
@@ -59,9 +63,10 @@ pub struct RelayConfig {
|
|||||||
|
|
||||||
fn default_host() -> String { "0.0.0.0".into() }
|
fn default_host() -> String { "0.0.0.0".into() }
|
||||||
fn default_port() -> u16 { 8080 }
|
fn default_port() -> u16 { 8080 }
|
||||||
fn default_db_url() -> String { "sqlite:./saas-data.db".into() }
|
fn default_db_url() -> String { "postgres://localhost:5432/zclaw".into() }
|
||||||
fn default_jwt_hours() -> i64 { 24 }
|
fn default_jwt_hours() -> i64 { 24 }
|
||||||
fn default_totp_issuer() -> String { "ZCLAW SaaS".into() }
|
fn default_totp_issuer() -> String { "ZCLAW SaaS".into() }
|
||||||
|
fn default_refresh_hours() -> i64 { 168 }
|
||||||
fn default_max_queue() -> usize { 1000 }
|
fn default_max_queue() -> usize { 1000 }
|
||||||
fn default_max_concurrent() -> usize { 5 }
|
fn default_max_concurrent() -> usize { 5 }
|
||||||
fn default_batch_window() -> u64 { 50 }
|
fn default_batch_window() -> u64 { 50 }
|
||||||
@@ -124,6 +129,7 @@ impl Default for AuthConfig {
|
|||||||
Self {
|
Self {
|
||||||
jwt_expiration_hours: default_jwt_hours(),
|
jwt_expiration_hours: default_jwt_hours(),
|
||||||
totp_issuer: default_totp_issuer(),
|
totp_issuer: default_totp_issuer(),
|
||||||
|
refresh_token_hours: default_refresh_hours(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -147,7 +153,7 @@ impl SaaSConfig {
|
|||||||
.map(PathBuf::from)
|
.map(PathBuf::from)
|
||||||
.unwrap_or_else(|_| PathBuf::from("saas-config.toml"));
|
.unwrap_or_else(|_| PathBuf::from("saas-config.toml"));
|
||||||
|
|
||||||
let config = if config_path.exists() {
|
let mut config = if config_path.exists() {
|
||||||
let content = std::fs::read_to_string(&config_path)?;
|
let content = std::fs::read_to_string(&config_path)?;
|
||||||
toml::from_str(&content)?
|
toml::from_str(&content)?
|
||||||
} else {
|
} else {
|
||||||
@@ -155,6 +161,11 @@ impl SaaSConfig {
|
|||||||
SaaSConfig::default()
|
SaaSConfig::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// 环境变量覆盖数据库 URL (避免在配置文件中存储密码)
|
||||||
|
if let Ok(db_url) = std::env::var("ZCLAW_DATABASE_URL") {
|
||||||
|
config.database.url = db_url;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,4 +192,47 @@ impl SaaSConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 获取 API Key 加密密钥 (复用 TOTP 加密密钥)
|
||||||
|
pub fn api_key_encryption_key(&self) -> anyhow::Result<[u8; 32]> {
|
||||||
|
self.totp_encryption_key()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取 TOTP 加密密钥 (AES-256-GCM, 32 字节)
|
||||||
|
/// 从 ZCLAW_TOTP_ENCRYPTION_KEY 环境变量加载 (hex 编码的 64 字符)
|
||||||
|
/// 开发环境使用默认值 (不安全)
|
||||||
|
pub fn totp_encryption_key(&self) -> anyhow::Result<[u8; 32]> {
|
||||||
|
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||||
|
.map(|v| v == "true" || v == "1")
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
match std::env::var("ZCLAW_TOTP_ENCRYPTION_KEY") {
|
||||||
|
Ok(hex_key) => {
|
||||||
|
if hex_key.len() != 64 {
|
||||||
|
anyhow::bail!("ZCLAW_TOTP_ENCRYPTION_KEY 必须是 64 个十六进制字符 (32 字节)");
|
||||||
|
}
|
||||||
|
let mut key = [0u8; 32];
|
||||||
|
for i in 0..32 {
|
||||||
|
key[i] = u8::from_str_radix(&hex_key[i*2..i*2+2], 16)
|
||||||
|
.map_err(|_| anyhow::anyhow!("ZCLAW_TOTP_ENCRYPTION_KEY 包含无效的十六进制字符"))?;
|
||||||
|
}
|
||||||
|
Ok(key)
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
if is_dev {
|
||||||
|
tracing::warn!("ZCLAW_TOTP_ENCRYPTION_KEY not set, using development default (INSECURE)");
|
||||||
|
// 开发环境使用固定密钥
|
||||||
|
let mut key = [0u8; 32];
|
||||||
|
key.copy_from_slice(b"zclaw-dev-totp-encrypt-key-32b!x");
|
||||||
|
Ok(key)
|
||||||
|
} else {
|
||||||
|
// 生产环境: 使用 JWT 密钥的 SHA-256 哈希作为加密密钥
|
||||||
|
tracing::warn!("ZCLAW_TOTP_ENCRYPTION_KEY not set, deriving from JWT secret");
|
||||||
|
let jwt = self.jwt_secret()?;
|
||||||
|
let hash = sha2::Sha256::digest(jwt.expose_secret().as_bytes());
|
||||||
|
Ok(hash.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
103
crates/zclaw-saas/src/crypto.rs
Normal file
103
crates/zclaw-saas/src/crypto.rs
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
//! 通用加密工具 (AES-256-GCM)
|
||||||
|
//!
|
||||||
|
//! 提供 API Key、TOTP secret 等敏感数据的加密/解密。
|
||||||
|
//! 存储格式: `enc:<base64(nonce(12 bytes) || ciphertext)>`
|
||||||
|
|
||||||
|
use aes_gcm::aead::{Aead, KeyInit, OsRng};
|
||||||
|
use aes_gcm::aead::rand_core::RngCore;
|
||||||
|
use aes_gcm::{Aes256Gcm, Nonce};
|
||||||
|
use crate::error::{SaasError, SaasResult};
|
||||||
|
|
||||||
|
/// 加密值的前缀标识
|
||||||
|
pub const ENCRYPTED_PREFIX: &str = "enc:";
|
||||||
|
|
||||||
|
/// AES-256-GCM nonce 长度 (12 字节)
|
||||||
|
const NONCE_SIZE: usize = 12;
|
||||||
|
|
||||||
|
/// 加密明文值 (AES-256-GCM, 随机 nonce)
|
||||||
|
///
|
||||||
|
/// 返回格式: `enc:<base64(nonce(12 bytes) || ciphertext)>`
|
||||||
|
/// 每次加密使用随机 nonce,相同明文产生不同密文。
|
||||||
|
pub fn encrypt_value(plaintext: &str, key: &[u8; 32]) -> SaasResult<String> {
|
||||||
|
let cipher = Aes256Gcm::new_from_slice(key)
|
||||||
|
.map_err(|e| SaasError::Encryption(format!("加密初始化失败: {}", e)))?;
|
||||||
|
|
||||||
|
let mut nonce_bytes = [0u8; NONCE_SIZE];
|
||||||
|
OsRng.fill_bytes(&mut nonce_bytes);
|
||||||
|
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||||
|
|
||||||
|
let ciphertext = cipher.encrypt(nonce, plaintext.as_bytes())
|
||||||
|
.map_err(|e| SaasError::Encryption(format!("加密失败: {}", e)))?;
|
||||||
|
|
||||||
|
let mut combined = nonce_bytes.to_vec();
|
||||||
|
combined.extend_from_slice(&ciphertext);
|
||||||
|
|
||||||
|
Ok(format!("{}{}", ENCRYPTED_PREFIX, data_encoding::BASE64.encode(&combined)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 解密 `enc:` 前缀的加密值
|
||||||
|
///
|
||||||
|
/// 仅支持新格式 (随机 nonce),不支持旧格式 (固定 nonce)。
|
||||||
|
/// 旧格式数据应通过一次性迁移函数转换。
|
||||||
|
pub fn decrypt_value(encrypted: &str, key: &[u8; 32]) -> SaasResult<String> {
|
||||||
|
let encoded = encrypted.strip_prefix(ENCRYPTED_PREFIX)
|
||||||
|
.ok_or_else(|| SaasError::Encryption("加密值格式无效 (缺少 enc: 前缀)".into()))?;
|
||||||
|
|
||||||
|
let raw = data_encoding::BASE64.decode(encoded.as_bytes())
|
||||||
|
.map_err(|_| SaasError::Encryption("加密值 Base64 解码失败".into()))?;
|
||||||
|
|
||||||
|
if raw.len() <= NONCE_SIZE {
|
||||||
|
return Err(SaasError::Encryption("加密值数据不完整".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let cipher = Aes256Gcm::new_from_slice(key)
|
||||||
|
.map_err(|e| SaasError::Encryption(format!("解密初始化失败: {}", e)))?;
|
||||||
|
|
||||||
|
let (nonce_bytes, ciphertext) = raw.split_at(NONCE_SIZE);
|
||||||
|
let nonce = Nonce::from_slice(nonce_bytes);
|
||||||
|
|
||||||
|
let plaintext = cipher.decrypt(nonce, ciphertext)
|
||||||
|
.map_err(|_| SaasError::Encryption("解密失败 (密钥可能已变更)".into()))?;
|
||||||
|
|
||||||
|
String::from_utf8(plaintext)
|
||||||
|
.map_err(|_| SaasError::Encryption("解密后数据无效 UTF-8".into()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 检查值是否已加密 (以 `enc:` 开头)
|
||||||
|
pub fn is_encrypted(value: &str) -> bool {
|
||||||
|
value.starts_with(ENCRYPTED_PREFIX)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 批量迁移: 将旧的固定 nonce 加密值重新加密为随机 nonce 格式
|
||||||
|
///
|
||||||
|
/// 输入为旧格式 (固定 nonce `zclaw_totp_nce`) 加密的 base64 数据,
|
||||||
|
/// 输出为新格式 `enc:<base64(random_nonce || ciphertext)>`。
|
||||||
|
pub fn re_encrypt_from_legacy(legacy_base64: &str, legacy_key: &[u8; 32], new_key: &[u8; 32]) -> SaasResult<String> {
|
||||||
|
// 先用旧 nonce 解密
|
||||||
|
let cipher = Aes256Gcm::new_from_slice(legacy_key)
|
||||||
|
.map_err(|e| SaasError::Encryption(format!("解密初始化失败: {}", e)))?;
|
||||||
|
|
||||||
|
let raw = data_encoding::BASE64.decode(legacy_base64.as_bytes())
|
||||||
|
.or_else(|_| data_encoding::BASE32.decode(legacy_base64.as_bytes()))
|
||||||
|
.map_err(|_| SaasError::Encryption("旧格式 Base64/Base32 解码失败".into()))?;
|
||||||
|
|
||||||
|
// 尝试新格式 (前 12 字节为 nonce)
|
||||||
|
if raw.len() > NONCE_SIZE {
|
||||||
|
let (nonce_bytes, ciphertext) = raw.split_at(NONCE_SIZE);
|
||||||
|
let nonce = Nonce::from_slice(nonce_bytes);
|
||||||
|
if let Ok(plaintext_bytes) = cipher.decrypt(nonce, ciphertext) {
|
||||||
|
let plaintext = String::from_utf8(plaintext_bytes)
|
||||||
|
.map_err(|_| SaasError::Encryption("旧格式解密后数据无效".into()))?;
|
||||||
|
return encrypt_value(&plaintext, new_key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 回退到旧格式: 固定 nonce
|
||||||
|
let legacy_nonce = Nonce::from_slice(b"zclaw_totp_nce");
|
||||||
|
let plaintext_bytes = cipher.decrypt(legacy_nonce, raw.as_ref())
|
||||||
|
.map_err(|_| SaasError::Encryption("旧格式解密失败".into()))?;
|
||||||
|
let plaintext = String::from_utf8(plaintext_bytes)
|
||||||
|
.map_err(|_| SaasError::Encryption("旧格式解密后数据无效".into()))?;
|
||||||
|
|
||||||
|
encrypt_value(&plaintext, new_key)
|
||||||
|
}
|
||||||
@@ -1,9 +1,10 @@
|
|||||||
//! 数据库初始化与 Schema
|
//! 数据库初始化与 Schema (PostgreSQL)
|
||||||
|
|
||||||
use sqlx::SqlitePool;
|
use sqlx::postgres::PgPoolOptions;
|
||||||
|
use sqlx::PgPool;
|
||||||
use crate::error::SaasResult;
|
use crate::error::SaasResult;
|
||||||
|
|
||||||
const SCHEMA_VERSION: i32 = 1;
|
const SCHEMA_VERSION: i32 = 4;
|
||||||
|
|
||||||
const SCHEMA_SQL: &str = r#"
|
const SCHEMA_SQL: &str = r#"
|
||||||
CREATE TABLE IF NOT EXISTS saas_schema_version (
|
CREATE TABLE IF NOT EXISTS saas_schema_version (
|
||||||
@@ -20,7 +21,7 @@ CREATE TABLE IF NOT EXISTS accounts (
|
|||||||
role TEXT NOT NULL DEFAULT 'user',
|
role TEXT NOT NULL DEFAULT 'user',
|
||||||
status TEXT NOT NULL DEFAULT 'active',
|
status TEXT NOT NULL DEFAULT 'active',
|
||||||
totp_secret TEXT,
|
totp_secret TEXT,
|
||||||
totp_enabled INTEGER NOT NULL DEFAULT 0,
|
totp_enabled BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
last_login_at TEXT,
|
last_login_at TEXT,
|
||||||
created_at TEXT NOT NULL,
|
created_at TEXT NOT NULL,
|
||||||
updated_at TEXT NOT NULL
|
updated_at TEXT NOT NULL
|
||||||
@@ -49,7 +50,7 @@ CREATE TABLE IF NOT EXISTS roles (
|
|||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
description TEXT,
|
description TEXT,
|
||||||
permissions TEXT NOT NULL DEFAULT '[]',
|
permissions TEXT NOT NULL DEFAULT '[]',
|
||||||
is_system INTEGER NOT NULL DEFAULT 0,
|
is_system BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
created_at TEXT NOT NULL,
|
created_at TEXT NOT NULL,
|
||||||
updated_at TEXT NOT NULL
|
updated_at TEXT NOT NULL
|
||||||
);
|
);
|
||||||
@@ -64,7 +65,7 @@ CREATE TABLE IF NOT EXISTS permission_templates (
|
|||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS operation_logs (
|
CREATE TABLE IF NOT EXISTS operation_logs (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
account_id TEXT,
|
account_id TEXT,
|
||||||
action TEXT NOT NULL,
|
action TEXT NOT NULL,
|
||||||
target_type TEXT,
|
target_type TEXT,
|
||||||
@@ -84,7 +85,7 @@ CREATE TABLE IF NOT EXISTS providers (
|
|||||||
api_key TEXT,
|
api_key TEXT,
|
||||||
base_url TEXT NOT NULL,
|
base_url TEXT NOT NULL,
|
||||||
api_protocol TEXT NOT NULL DEFAULT 'openai',
|
api_protocol TEXT NOT NULL DEFAULT 'openai',
|
||||||
enabled INTEGER NOT NULL DEFAULT 1,
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
rate_limit_rpm INTEGER,
|
rate_limit_rpm INTEGER,
|
||||||
rate_limit_tpm INTEGER,
|
rate_limit_tpm INTEGER,
|
||||||
config_json TEXT DEFAULT '{}',
|
config_json TEXT DEFAULT '{}',
|
||||||
@@ -97,13 +98,13 @@ CREATE TABLE IF NOT EXISTS models (
|
|||||||
provider_id TEXT NOT NULL,
|
provider_id TEXT NOT NULL,
|
||||||
model_id TEXT NOT NULL,
|
model_id TEXT NOT NULL,
|
||||||
alias TEXT NOT NULL,
|
alias TEXT NOT NULL,
|
||||||
context_window INTEGER NOT NULL DEFAULT 8192,
|
context_window BIGINT NOT NULL DEFAULT 8192,
|
||||||
max_output_tokens INTEGER NOT NULL DEFAULT 4096,
|
max_output_tokens BIGINT NOT NULL DEFAULT 4096,
|
||||||
supports_streaming INTEGER NOT NULL DEFAULT 1,
|
supports_streaming BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
supports_vision INTEGER NOT NULL DEFAULT 0,
|
supports_vision BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
enabled INTEGER NOT NULL DEFAULT 1,
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
pricing_input REAL DEFAULT 0,
|
pricing_input DOUBLE PRECISION DEFAULT 0,
|
||||||
pricing_output REAL DEFAULT 0,
|
pricing_output DOUBLE PRECISION DEFAULT 0,
|
||||||
created_at TEXT NOT NULL,
|
created_at TEXT NOT NULL,
|
||||||
updated_at TEXT NOT NULL,
|
updated_at TEXT NOT NULL,
|
||||||
UNIQUE(provider_id, model_id),
|
UNIQUE(provider_id, model_id),
|
||||||
@@ -118,7 +119,7 @@ CREATE TABLE IF NOT EXISTS account_api_keys (
|
|||||||
key_value TEXT NOT NULL,
|
key_value TEXT NOT NULL,
|
||||||
key_label TEXT,
|
key_label TEXT,
|
||||||
permissions TEXT NOT NULL DEFAULT '[]',
|
permissions TEXT NOT NULL DEFAULT '[]',
|
||||||
enabled INTEGER NOT NULL DEFAULT 1,
|
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
last_used_at TEXT,
|
last_used_at TEXT,
|
||||||
created_at TEXT NOT NULL,
|
created_at TEXT NOT NULL,
|
||||||
updated_at TEXT NOT NULL,
|
updated_at TEXT NOT NULL,
|
||||||
@@ -129,7 +130,7 @@ CREATE TABLE IF NOT EXISTS account_api_keys (
|
|||||||
CREATE INDEX IF NOT EXISTS idx_account_api_keys_account ON account_api_keys(account_id);
|
CREATE INDEX IF NOT EXISTS idx_account_api_keys_account ON account_api_keys(account_id);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS usage_records (
|
CREATE TABLE IF NOT EXISTS usage_records (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
account_id TEXT NOT NULL,
|
account_id TEXT NOT NULL,
|
||||||
provider_id TEXT NOT NULL,
|
provider_id TEXT NOT NULL,
|
||||||
model_id TEXT NOT NULL,
|
model_id TEXT NOT NULL,
|
||||||
@@ -176,7 +177,7 @@ CREATE TABLE IF NOT EXISTS config_items (
|
|||||||
default_value TEXT,
|
default_value TEXT,
|
||||||
source TEXT NOT NULL DEFAULT 'local',
|
source TEXT NOT NULL DEFAULT 'local',
|
||||||
description TEXT,
|
description TEXT,
|
||||||
requires_restart INTEGER NOT NULL DEFAULT 0,
|
requires_restart BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
created_at TEXT NOT NULL,
|
created_at TEXT NOT NULL,
|
||||||
updated_at TEXT NOT NULL,
|
updated_at TEXT NOT NULL,
|
||||||
UNIQUE(category, key_path)
|
UNIQUE(category, key_path)
|
||||||
@@ -184,7 +185,7 @@ CREATE TABLE IF NOT EXISTS config_items (
|
|||||||
CREATE INDEX IF NOT EXISTS idx_config_category ON config_items(category);
|
CREATE INDEX IF NOT EXISTS idx_config_category ON config_items(category);
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS config_sync_log (
|
CREATE TABLE IF NOT EXISTS config_sync_log (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
account_id TEXT NOT NULL,
|
account_id TEXT NOT NULL,
|
||||||
client_fingerprint TEXT NOT NULL,
|
client_fingerprint TEXT NOT NULL,
|
||||||
action TEXT NOT NULL,
|
action TEXT NOT NULL,
|
||||||
@@ -210,84 +211,232 @@ CREATE TABLE IF NOT EXISTS devices (
|
|||||||
CREATE INDEX IF NOT EXISTS idx_devices_account ON devices(account_id);
|
CREATE INDEX IF NOT EXISTS idx_devices_account ON devices(account_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_devices_device_id ON devices(device_id);
|
CREATE INDEX IF NOT EXISTS idx_devices_device_id ON devices(device_id);
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_devices_unique ON devices(account_id, device_id);
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_devices_unique ON devices(account_id, device_id);
|
||||||
|
|
||||||
|
-- 提示词模板主表
|
||||||
|
CREATE TABLE IF NOT EXISTS prompt_templates (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL UNIQUE,
|
||||||
|
category TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
source TEXT NOT NULL DEFAULT 'builtin',
|
||||||
|
current_version INTEGER NOT NULL DEFAULT 1,
|
||||||
|
status TEXT NOT NULL DEFAULT 'active',
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_prompt_status ON prompt_templates(status);
|
||||||
|
|
||||||
|
-- 提示词版本表(不可变)
|
||||||
|
CREATE TABLE IF NOT EXISTS prompt_versions (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
template_id TEXT NOT NULL,
|
||||||
|
version INTEGER NOT NULL,
|
||||||
|
system_prompt TEXT,
|
||||||
|
user_prompt_template TEXT,
|
||||||
|
variables TEXT NOT NULL DEFAULT '[]',
|
||||||
|
changelog TEXT,
|
||||||
|
min_app_version TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
UNIQUE(template_id, version)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_prompt_ver_template ON prompt_versions(template_id);
|
||||||
|
|
||||||
|
-- 客户端提示词同步状态
|
||||||
|
CREATE TABLE IF NOT EXISTS prompt_sync_status (
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
template_id TEXT NOT NULL,
|
||||||
|
synced_version INTEGER NOT NULL,
|
||||||
|
synced_at TEXT NOT NULL,
|
||||||
|
PRIMARY KEY(device_id, template_id)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Provider Key Pool 表
|
||||||
|
CREATE TABLE IF NOT EXISTS provider_keys (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
provider_id TEXT NOT NULL,
|
||||||
|
key_label TEXT NOT NULL,
|
||||||
|
key_value TEXT NOT NULL,
|
||||||
|
priority INTEGER NOT NULL DEFAULT 0,
|
||||||
|
max_rpm INTEGER,
|
||||||
|
max_tpm INTEGER,
|
||||||
|
quota_reset_interval TEXT,
|
||||||
|
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
|
last_429_at TEXT,
|
||||||
|
cooldown_until TEXT,
|
||||||
|
total_requests BIGINT NOT NULL DEFAULT 0,
|
||||||
|
total_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_pkeys_provider ON provider_keys(provider_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_pkeys_active ON provider_keys(provider_id, is_active);
|
||||||
|
|
||||||
|
-- Key 使用量滑动窗口
|
||||||
|
CREATE TABLE IF NOT EXISTS key_usage_window (
|
||||||
|
key_id TEXT NOT NULL,
|
||||||
|
window_minute TEXT NOT NULL,
|
||||||
|
request_count INTEGER NOT NULL DEFAULT 0,
|
||||||
|
token_count BIGINT NOT NULL DEFAULT 0,
|
||||||
|
PRIMARY KEY(key_id, window_minute)
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Agent 配置模板表
|
||||||
|
CREATE TABLE IF NOT EXISTS agent_templates (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
description TEXT,
|
||||||
|
category TEXT NOT NULL DEFAULT 'general',
|
||||||
|
source TEXT NOT NULL DEFAULT 'builtin',
|
||||||
|
model TEXT,
|
||||||
|
system_prompt TEXT,
|
||||||
|
tools TEXT NOT NULL DEFAULT '[]'::text,
|
||||||
|
capabilities TEXT NOT NULL DEFAULT '[]'::text,
|
||||||
|
temperature DOUBLE PRECISION,
|
||||||
|
max_tokens INTEGER,
|
||||||
|
visibility TEXT NOT NULL DEFAULT 'public',
|
||||||
|
status TEXT NOT NULL DEFAULT 'active',
|
||||||
|
current_version INTEGER NOT NULL DEFAULT 1,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_agent_tmpl_status ON agent_templates(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_agent_tmpl_visibility ON agent_templates(visibility);
|
||||||
|
|
||||||
|
-- 桌面端遥测上报表(Token 用量统计,无内容)
|
||||||
|
CREATE TABLE IF NOT EXISTS telemetry_reports (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
account_id TEXT NOT NULL,
|
||||||
|
device_id TEXT NOT NULL,
|
||||||
|
app_version TEXT,
|
||||||
|
model_id TEXT NOT NULL,
|
||||||
|
input_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
output_tokens BIGINT NOT NULL DEFAULT 0,
|
||||||
|
latency_ms INTEGER,
|
||||||
|
success BOOLEAN NOT NULL DEFAULT TRUE,
|
||||||
|
error_type TEXT,
|
||||||
|
connection_mode TEXT NOT NULL DEFAULT 'tauri',
|
||||||
|
reported_at TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_telemetry_account ON telemetry_reports(account_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_telemetry_time ON telemetry_reports(reported_at);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_telemetry_model ON telemetry_reports(model_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_telemetry_day ON telemetry_reports((SUBSTRING(reported_at, 1, 10)));
|
||||||
|
|
||||||
|
-- Refresh Token 存储 (一次性使用, JWT jti 追踪)
|
||||||
|
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
account_id TEXT NOT NULL,
|
||||||
|
jti TEXT NOT NULL UNIQUE,
|
||||||
|
token_hash TEXT NOT NULL,
|
||||||
|
expires_at TEXT NOT NULL,
|
||||||
|
used_at TEXT,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_refresh_account ON refresh_tokens(account_id);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_refresh_jti ON refresh_tokens(jti);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_refresh_expires ON refresh_tokens(expires_at);
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
const SEED_ROLES: &str = r#"
|
const SEED_ROLES: &str = r#"
|
||||||
INSERT OR IGNORE INTO roles (id, name, description, permissions, is_system, created_at, updated_at)
|
INSERT INTO roles (id, name, description, permissions, is_system, created_at, updated_at)
|
||||||
VALUES
|
VALUES
|
||||||
('super_admin', '超级管理员', '拥有所有权限', '["admin:full","account:admin","provider:manage","model:manage","relay:admin","config:write"]', 1, datetime('now'), datetime('now')),
|
('super_admin', '超级管理员', '拥有所有权限', '["admin:full","account:admin","provider:manage","model:manage","relay:admin","config:write","prompt:read","prompt:write","prompt:publish","prompt:admin"]', TRUE, '2026-01-01T00:00:00+00:00', '2026-01-01T00:00:00+00:00'),
|
||||||
('admin', '管理员', '管理账号和配置', '["account:read","account:admin","provider:manage","model:read","model:manage","relay:use","relay:admin","config:read","config:write"]', 1, datetime('now'), datetime('now')),
|
('admin', '管理员', '管理账号和配置', '["account:read","account:admin","provider:manage","model:read","model:manage","relay:use","relay:admin","config:read","config:write","prompt:read","prompt:write","prompt:publish"]', TRUE, '2026-01-01T00:00:00+00:00', '2026-01-01T00:00:00+00:00'),
|
||||||
('user', '普通用户', '基础使用权限', '["model:read","relay:use","config:read"]', 1, datetime('now'), datetime('now'));
|
('user', '普通用户', '基础使用权限', '["model:read","relay:use","config:read","prompt:read"]', TRUE, '2026-01-01T00:00:00+00:00', '2026-01-01T00:00:00+00:00')
|
||||||
|
ON CONFLICT (id) DO NOTHING;
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
/// 初始化数据库
|
/// 初始化数据库
|
||||||
pub async fn init_db(database_url: &str) -> SaasResult<SqlitePool> {
|
pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
|
||||||
if database_url.starts_with("sqlite:") {
|
let pool = PgPoolOptions::new()
|
||||||
let path_part = database_url.strip_prefix("sqlite:").unwrap_or("");
|
.max_connections(20)
|
||||||
if path_part != ":memory:" {
|
.min_connections(2)
|
||||||
if let Some(parent) = std::path::Path::new(path_part).parent() {
|
.acquire_timeout(std::time::Duration::from_secs(5))
|
||||||
if !parent.as_os_str().is_empty() && !parent.exists() {
|
.idle_timeout(std::time::Duration::from_secs(600))
|
||||||
std::fs::create_dir_all(parent)?;
|
.connect(database_url)
|
||||||
}
|
.await?;
|
||||||
}
|
|
||||||
|
// PostgreSQL 不支持在一个 prepared statement 中执行多条 SQL
|
||||||
|
// 需要逐条执行
|
||||||
|
for stmt in SCHEMA_SQL.split(';') {
|
||||||
|
let trimmed = stmt.trim();
|
||||||
|
if !trimmed.is_empty() {
|
||||||
|
sqlx::query(trimmed).execute(&pool).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let pool = SqlitePool::connect(database_url).await?;
|
sqlx::query("INSERT INTO saas_schema_version (version) VALUES ($1) ON CONFLICT DO NOTHING")
|
||||||
sqlx::query("PRAGMA journal_mode=WAL;")
|
|
||||||
.execute(&pool)
|
|
||||||
.await?;
|
|
||||||
sqlx::query(SCHEMA_SQL).execute(&pool).await?;
|
|
||||||
sqlx::query("INSERT OR IGNORE INTO saas_schema_version (version) VALUES (?1)")
|
|
||||||
.bind(SCHEMA_VERSION)
|
.bind(SCHEMA_VERSION)
|
||||||
.execute(&pool)
|
.execute(&pool)
|
||||||
.await?;
|
.await?;
|
||||||
sqlx::query(SEED_ROLES).execute(&pool).await?;
|
|
||||||
|
for stmt in SEED_ROLES.split(';') {
|
||||||
|
let trimmed = stmt.trim();
|
||||||
|
if !trimmed.is_empty() {
|
||||||
|
sqlx::query(trimmed).execute(&pool).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
seed_admin_account(&pool).await?;
|
seed_admin_account(&pool).await?;
|
||||||
|
seed_builtin_prompts(&pool).await?;
|
||||||
tracing::info!("Database initialized (schema v{})", SCHEMA_VERSION);
|
tracing::info!("Database initialized (schema v{})", SCHEMA_VERSION);
|
||||||
Ok(pool)
|
Ok(pool)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 创建内存数据库 (测试用)
|
|
||||||
pub async fn init_memory_db() -> SaasResult<SqlitePool> {
|
|
||||||
let pool = SqlitePool::connect("sqlite::memory:").await?;
|
|
||||||
sqlx::query(SCHEMA_SQL).execute(&pool).await?;
|
|
||||||
sqlx::query("INSERT OR IGNORE INTO saas_schema_version (version) VALUES (?1)")
|
|
||||||
.bind(SCHEMA_VERSION)
|
|
||||||
.execute(&pool)
|
|
||||||
.await?;
|
|
||||||
sqlx::query(SEED_ROLES).execute(&pool).await?;
|
|
||||||
Ok(pool)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 如果 accounts 表为空且环境变量已设置,自动创建 super_admin 账号
|
/// 如果 accounts 表为空且环境变量已设置,自动创建 super_admin 账号
|
||||||
async fn seed_admin_account(pool: &SqlitePool) -> SaasResult<()> {
|
/// 或者更新现有 admin 用户的角色为 super_admin
|
||||||
let has_accounts: (bool,) = sqlx::query_as(
|
async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
|
||||||
"SELECT EXISTS(SELECT 1 FROM accounts LIMIT 1) as has"
|
|
||||||
)
|
|
||||||
.fetch_one(pool)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if has_accounts.0 {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
let admin_username = std::env::var("ZCLAW_ADMIN_USERNAME")
|
let admin_username = std::env::var("ZCLAW_ADMIN_USERNAME")
|
||||||
.unwrap_or_else(|_| "admin".to_string());
|
.unwrap_or_else(|_| "admin".to_string());
|
||||||
|
|
||||||
|
// 检查是否设置了管理员密码
|
||||||
let admin_password = match std::env::var("ZCLAW_ADMIN_PASSWORD") {
|
let admin_password = match std::env::var("ZCLAW_ADMIN_PASSWORD") {
|
||||||
Ok(pwd) => pwd,
|
Ok(pwd) => pwd,
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
tracing::warn!(
|
// 没有设置密码,尝试更新现有 admin 用户的角色
|
||||||
"accounts 表为空但未设置 ZCLAW_ADMIN_PASSWORD 环境变量。\
|
let result = sqlx::query(
|
||||||
请通过 POST /api/v1/auth/register 注册首个用户,然后手动将其 role 改为 super_admin。\
|
"UPDATE accounts SET role = 'super_admin' WHERE username = $1 AND role != 'super_admin'"
|
||||||
或设置 ZCLAW_ADMIN_USERNAME 和 ZCLAW_ADMIN_PASSWORD 环境变量后重启服务。"
|
)
|
||||||
);
|
.bind(&admin_username)
|
||||||
|
.execute(pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if result.rows_affected() > 0 {
|
||||||
|
tracing::info!("已将用户 {} 的角色更新为 super_admin", admin_username);
|
||||||
|
}
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::auth::password::hash_password;
|
// 检查 admin 用户是否已存在
|
||||||
|
let existing: Option<(String,)> = sqlx::query_as(
|
||||||
|
"SELECT id FROM accounts WHERE username = $1"
|
||||||
|
)
|
||||||
|
.bind(&admin_username)
|
||||||
|
.fetch_optional(pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some((account_id,)) = existing {
|
||||||
|
// 更新现有用户的密码和角色
|
||||||
|
use crate::auth::password::hash_password;
|
||||||
|
let password_hash = hash_password(&admin_password)?;
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE accounts SET password_hash = $1, role = 'super_admin', updated_at = $2 WHERE id = $3"
|
||||||
|
)
|
||||||
|
.bind(&password_hash)
|
||||||
|
.bind(&now)
|
||||||
|
.bind(&account_id)
|
||||||
|
.execute(pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
tracing::info!("已更新用户 {} 的密码和角色为 super_admin", admin_username);
|
||||||
|
} else {
|
||||||
|
// 创建新的 super_admin 账号
|
||||||
|
use crate::auth::password::hash_password;
|
||||||
let password_hash = hash_password(&admin_password)?;
|
let password_hash = hash_password(&admin_password)?;
|
||||||
let account_id = uuid::Uuid::new_v4().to_string();
|
let account_id = uuid::Uuid::new_v4().to_string();
|
||||||
let email = format!("{}@zclaw.local", admin_username);
|
let email = format!("{}@zclaw.local", admin_username);
|
||||||
@@ -295,7 +444,7 @@ async fn seed_admin_account(pool: &SqlitePool) -> SaasResult<()> {
|
|||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
|
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, 'super_admin', 'active', ?6, ?6)"
|
VALUES ($1, $2, $3, $4, $5, 'super_admin', 'active', $6, $6)"
|
||||||
)
|
)
|
||||||
.bind(&account_id)
|
.bind(&account_id)
|
||||||
.bind(&admin_username)
|
.bind(&admin_username)
|
||||||
@@ -306,44 +455,74 @@ async fn seed_admin_account(pool: &SqlitePool) -> SaasResult<()> {
|
|||||||
.execute(pool)
|
.execute(pool)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
tracing::info!(
|
tracing::info!("自动创建 super_admin 账号: username={}, email={}", admin_username, email);
|
||||||
"自动创建 super_admin 账号: username={}, email={}", admin_username, email
|
}
|
||||||
);
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 种子化内置提示词模板(仅当表为空时)
|
||||||
|
async fn seed_builtin_prompts(pool: &PgPool) -> SaasResult<()> {
|
||||||
|
let count: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM prompt_templates")
|
||||||
|
.fetch_one(pool).await?;
|
||||||
|
|
||||||
|
if count.0 > 0 {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
|
// reflection 提示词
|
||||||
|
let reflection_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let reflection_ver_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO prompt_templates (id, name, category, description, source, current_version, status, created_at, updated_at)
|
||||||
|
VALUES ($1, 'reflection', 'builtin_system', 'Agent 自我反思引擎', 'builtin', 1, 'active', $2, $2)"
|
||||||
|
).bind(&reflection_id).bind(&now).execute(pool).await?;
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO prompt_versions (id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at)
|
||||||
|
VALUES ($1, $2, 1, $3, $4, '[]', '初始版本', NULL, $5)"
|
||||||
|
).bind(&reflection_ver_id).bind(&reflection_id)
|
||||||
|
.bind("你是一个 AI Agent 的自我反思引擎。分析最近的对话历史,识别行为模式,并生成改进建议。\n\n输出 JSON 格式:\n{\n \"patterns\": [\n {\n \"observation\": \"观察到的模式描述\",\n \"frequency\": 数字,\n \"sentiment\": \"positive/negative/neutral\",\n \"evidence\": [\"证据1\", \"证据2\"]\n }\n ],\n \"improvements\": [\n {\n \"area\": \"改进领域\",\n \"suggestion\": \"具体建议\",\n \"priority\": \"high/medium/low\"\n }\n ],\n \"identityProposals\": []\n}")
|
||||||
|
.bind("分析以下对话历史,进行自我反思:\n\n{{context}}\n\n请识别行为模式(积极和消极),并提供具体的改进建议。")
|
||||||
|
.bind(&now).execute(pool).await?;
|
||||||
|
|
||||||
|
// compaction 提示词
|
||||||
|
let compaction_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let compaction_ver_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO prompt_templates (id, name, category, description, source, current_version, status, created_at, updated_at)
|
||||||
|
VALUES ($1, 'compaction', 'builtin_compaction', '对话上下文压缩', 'builtin', 1, 'active', $2, $2)"
|
||||||
|
).bind(&compaction_id).bind(&now).execute(pool).await?;
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO prompt_versions (id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at)
|
||||||
|
VALUES ($1, $2, 1, $3, $4, '[]', '初始版本', NULL, $5)"
|
||||||
|
).bind(&compaction_ver_id).bind(&compaction_id)
|
||||||
|
.bind("你是一个对话摘要专家。将长对话压缩为简洁的摘要,保留关键信息。\n\n要求:\n1. 保留所有重要决策和结论\n2. 保留用户偏好和约束\n3. 保留未完成的任务\n4. 保持时间顺序\n5. 摘要应能在后续对话中替代原始内容")
|
||||||
|
.bind("请将以下对话压缩为简洁摘要,保留关键信息:\n\n{{messages}}")
|
||||||
|
.bind(&now).execute(pool).await?;
|
||||||
|
|
||||||
|
// extraction 提示词
|
||||||
|
let extraction_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let extraction_ver_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO prompt_templates (id, name, category, description, source, current_version, status, created_at, updated_at)
|
||||||
|
VALUES ($1, 'extraction', 'builtin_extraction', '记忆提取引擎', 'builtin', 1, 'active', $2, $2)"
|
||||||
|
).bind(&extraction_id).bind(&now).execute(pool).await?;
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO prompt_versions (id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at)
|
||||||
|
VALUES ($1, $2, 1, $3, $4, '[]', '初始版本', NULL, $5)"
|
||||||
|
).bind(&extraction_ver_id).bind(&extraction_id)
|
||||||
|
.bind("你是一个记忆提取专家。从对话中提取值得长期记住的信息。\n\n提取类型:\n- fact: 用户告知的事实(如\"我的公司叫XXX\")\n- preference: 用户的偏好(如\"我喜欢简洁的回答\")\n- lesson: 本次对话的经验教训\n- task: 未完成的任务或承诺\n\n输出 JSON 数组:\n[\n {\n \"content\": \"记忆内容\",\n \"type\": \"fact/preference/lesson/task\",\n \"importance\": 1-10,\n \"tags\": [\"标签1\", \"标签2\"]\n }\n]")
|
||||||
|
.bind("从以下对话中提取值得长期记住的信息:\n\n{{conversation}}\n\n如果没有值得记忆的内容,返回空数组 []。")
|
||||||
|
.bind(&now).execute(pool).await?;
|
||||||
|
|
||||||
|
tracing::info!("Seeded 3 builtin prompt templates (reflection, compaction, extraction)");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
// PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容
|
||||||
|
// 集成测试见 tests/integration_test.rs
|
||||||
#[tokio::test]
|
|
||||||
async fn test_init_memory_db() {
|
|
||||||
let pool = init_memory_db().await.unwrap();
|
|
||||||
let roles: Vec<(String,)> = sqlx::query_as(
|
|
||||||
"SELECT id FROM roles WHERE is_system = 1"
|
|
||||||
)
|
|
||||||
.fetch_all(&pool)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(roles.len(), 3);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_schema_tables_exist() {
|
|
||||||
let pool = init_memory_db().await.unwrap();
|
|
||||||
let tables = [
|
|
||||||
"accounts", "api_tokens", "roles", "permission_templates",
|
|
||||||
"operation_logs", "providers", "models", "account_api_keys",
|
|
||||||
"usage_records", "relay_tasks", "config_items", "config_sync_log", "devices",
|
|
||||||
];
|
|
||||||
for table in tables {
|
|
||||||
let count: (i64,) = sqlx::query_as(&format!(
|
|
||||||
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='{}'", table
|
|
||||||
))
|
|
||||||
.fetch_one(&pool)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
assert_eq!(count.0, 1, "Table {} should exist", table);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
//!
|
//!
|
||||||
//! 独立的 SaaS 后端服务,提供账号权限管理、模型配置、请求中转和配置迁移。
|
//! 独立的 SaaS 后端服务,提供账号权限管理、模型配置、请求中转和配置迁移。
|
||||||
|
|
||||||
|
pub mod common;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod crypto;
|
||||||
pub mod db;
|
pub mod db;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod middleware;
|
pub mod middleware;
|
||||||
@@ -13,3 +15,7 @@ pub mod account;
|
|||||||
pub mod model_config;
|
pub mod model_config;
|
||||||
pub mod relay;
|
pub mod relay;
|
||||||
pub mod migration;
|
pub mod migration;
|
||||||
|
pub mod role;
|
||||||
|
pub mod prompt;
|
||||||
|
pub mod agent_template;
|
||||||
|
pub mod telemetry;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
//! ZCLAW SaaS 服务入口
|
//! ZCLAW SaaS 服务入口
|
||||||
|
|
||||||
|
use axum::extract::State;
|
||||||
use tracing::info;
|
use tracing::info;
|
||||||
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
|
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
|
||||||
|
|
||||||
@@ -19,7 +20,11 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
info!("Database initialized");
|
info!("Database initialized");
|
||||||
|
|
||||||
let state = AppState::new(db, config.clone())?;
|
let state = AppState::new(db, config.clone())?;
|
||||||
let app = build_router(state);
|
|
||||||
|
// 后台定时任务
|
||||||
|
spawn_background_tasks(state.clone());
|
||||||
|
|
||||||
|
let app = build_router(state).await;
|
||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind(format!("{}:{}", config.server.host, config.server.port))
|
let listener = tokio::net::TcpListener::bind(format!("{}:{}", config.server.host, config.server.port))
|
||||||
.await?;
|
.await?;
|
||||||
@@ -29,14 +34,68 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_router(state: AppState) -> axum::Router {
|
async fn health_handler(State(state): State<AppState>) -> axum::Json<serde_json::Value> {
|
||||||
|
let db_healthy = sqlx::query_scalar::<_, i32>("SELECT 1")
|
||||||
|
.fetch_one(&state.db)
|
||||||
|
.await
|
||||||
|
.is_ok();
|
||||||
|
|
||||||
|
let status = if db_healthy { "healthy" } else { "degraded" };
|
||||||
|
let _code = if db_healthy { 200 } else { 503 };
|
||||||
|
|
||||||
|
axum::Json(serde_json::json!({
|
||||||
|
"status": status,
|
||||||
|
"database": db_healthy,
|
||||||
|
"timestamp": chrono::Utc::now().to_rfc3339(),
|
||||||
|
"version": env!("CARGO_PKG_VERSION"),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 启动后台定时任务
|
||||||
|
fn spawn_background_tasks(state: AppState) {
|
||||||
|
// 每 5 分钟清理过期的限流条目
|
||||||
|
let rate_limit_state = state.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
rate_limit_state.cleanup_rate_limit_entries();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// 每 24 小时清理 90 天未活跃的设备
|
||||||
|
// 注意: last_seen_at 为 TEXT 类型,使用 rfc3339 字符串比较(字典序等价于时间序)
|
||||||
|
let cleanup_state = state.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut interval = tokio::time::interval(std::time::Duration::from_secs(86400));
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
let cutoff = (chrono::Utc::now() - chrono::Duration::days(90)).to_rfc3339();
|
||||||
|
match sqlx::query("DELETE FROM devices WHERE last_seen_at < $1")
|
||||||
|
.bind(&cutoff)
|
||||||
|
.execute(&cleanup_state.db)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(result) if result.rows_affected() > 0 => {
|
||||||
|
info!("Cleaned up {} stale devices", result.rows_affected());
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Failed to cleanup stale devices: {}", e);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn build_router(state: AppState) -> axum::Router {
|
||||||
use axum::middleware;
|
use axum::middleware;
|
||||||
use tower_http::cors::{Any, CorsLayer};
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
use tower_http::trace::TraceLayer;
|
use tower_http::trace::TraceLayer;
|
||||||
|
|
||||||
use axum::http::HeaderValue;
|
use axum::http::HeaderValue;
|
||||||
let cors = {
|
let cors = {
|
||||||
let config = state.config.blocking_read();
|
let config = state.config.read().await;
|
||||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||||
.map(|v| v == "true" || v == "1")
|
.map(|v| v == "true" || v == "1")
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
@@ -56,18 +115,42 @@ fn build_router(state: AppState) -> axum::Router {
|
|||||||
.collect();
|
.collect();
|
||||||
CorsLayer::new()
|
CorsLayer::new()
|
||||||
.allow_origin(origins)
|
.allow_origin(origins)
|
||||||
.allow_methods(Any)
|
.allow_methods([
|
||||||
.allow_headers(Any)
|
axum::http::Method::GET,
|
||||||
|
axum::http::Method::POST,
|
||||||
|
axum::http::Method::PUT,
|
||||||
|
axum::http::Method::PATCH,
|
||||||
|
axum::http::Method::DELETE,
|
||||||
|
axum::http::Method::OPTIONS,
|
||||||
|
])
|
||||||
|
.allow_headers([
|
||||||
|
axum::http::header::AUTHORIZATION,
|
||||||
|
axum::http::header::CONTENT_TYPE,
|
||||||
|
axum::http::HeaderName::from_static("x-request-id"),
|
||||||
|
])
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let public_routes = zclaw_saas::auth::routes();
|
let public_routes = zclaw_saas::auth::routes()
|
||||||
|
.route("/api/health", axum::routing::get(health_handler));
|
||||||
|
|
||||||
let protected_routes = zclaw_saas::auth::protected_routes()
|
let protected_routes = zclaw_saas::auth::protected_routes()
|
||||||
.merge(zclaw_saas::account::routes())
|
.merge(zclaw_saas::account::routes())
|
||||||
.merge(zclaw_saas::model_config::routes())
|
.merge(zclaw_saas::model_config::routes())
|
||||||
.merge(zclaw_saas::relay::routes())
|
.merge(zclaw_saas::relay::routes())
|
||||||
.merge(zclaw_saas::migration::routes())
|
.merge(zclaw_saas::migration::routes())
|
||||||
|
.merge(zclaw_saas::role::routes())
|
||||||
|
.merge(zclaw_saas::prompt::routes())
|
||||||
|
.merge(zclaw_saas::agent_template::routes())
|
||||||
|
.merge(zclaw_saas::telemetry::routes())
|
||||||
|
.layer(middleware::from_fn_with_state(
|
||||||
|
state.clone(),
|
||||||
|
zclaw_saas::middleware::api_version_middleware,
|
||||||
|
))
|
||||||
|
.layer(middleware::from_fn_with_state(
|
||||||
|
state.clone(),
|
||||||
|
zclaw_saas::middleware::request_id_middleware,
|
||||||
|
))
|
||||||
.layer(middleware::from_fn_with_state(
|
.layer(middleware::from_fn_with_state(
|
||||||
state.clone(),
|
state.clone(),
|
||||||
zclaw_saas::middleware::rate_limit_middleware,
|
zclaw_saas::middleware::rate_limit_middleware,
|
||||||
|
|||||||
@@ -1,81 +1,83 @@
|
|||||||
//! 通用中间件
|
//! 中间件模块
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Request, State},
|
body::Body,
|
||||||
http::StatusCode,
|
extract::State,
|
||||||
|
http::{HeaderValue, Request, Response},
|
||||||
middleware::Next,
|
middleware::Next,
|
||||||
response::{IntoResponse, Response},
|
response::IntoResponse,
|
||||||
};
|
};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
|
use crate::error::SaasError;
|
||||||
|
use crate::auth::types::AuthContext;
|
||||||
|
|
||||||
/// 滑动窗口速率限制中间件
|
/// 请求 ID 追踪中间件
|
||||||
///
|
/// 为每个请求生成唯一 ID,便于日志追踪
|
||||||
/// 按 account_id (从 AuthContext 提取) 做 per-minute 限流。
|
pub async fn request_id_middleware(
|
||||||
/// 超限时返回 429 Too Many Requests + Retry-After header。
|
State(_state): State<AppState>,
|
||||||
|
mut req: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Response<Body> {
|
||||||
|
let request_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
|
||||||
|
req.extensions_mut().insert(request_id.clone());
|
||||||
|
|
||||||
|
let mut response = next.run(req).await;
|
||||||
|
|
||||||
|
if let Ok(value) = HeaderValue::from_str(&request_id) {
|
||||||
|
response.headers_mut().insert("X-Request-ID", value);
|
||||||
|
}
|
||||||
|
|
||||||
|
response
|
||||||
|
}
|
||||||
|
|
||||||
|
/// API 版本控制中间件
|
||||||
|
/// 在响应头中添加版本信息
|
||||||
|
pub async fn api_version_middleware(
|
||||||
|
State(_state): State<AppState>,
|
||||||
|
req: Request<Body>,
|
||||||
|
next: Next,
|
||||||
|
) -> Response<Body> {
|
||||||
|
let mut response = next.run(req).await;
|
||||||
|
|
||||||
|
response.headers_mut().insert("X-API-Version", HeaderValue::from_static("1.0.0"));
|
||||||
|
response.headers_mut().insert("X-API-Deprecated", HeaderValue::from_static("false"));
|
||||||
|
|
||||||
|
response
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 速率限制中间件
|
||||||
|
/// 基于账号的请求频率限制
|
||||||
pub async fn rate_limit_middleware(
|
pub async fn rate_limit_middleware(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
req: Request,
|
req: Request<Body>,
|
||||||
next: Next,
|
next: Next,
|
||||||
) -> Response {
|
) -> Response<Body> {
|
||||||
// 从 AuthContext 提取 account_id(由 auth_middleware 在此之前注入)
|
let account_id = req.extensions()
|
||||||
let account_id = req
|
.get::<AuthContext>()
|
||||||
.extensions()
|
.map(|ctx| ctx.account_id.clone())
|
||||||
.get::<crate::auth::types::AuthContext>()
|
.unwrap_or_else(|| "anonymous".to_string());
|
||||||
.map(|ctx| ctx.account_id.clone());
|
|
||||||
|
|
||||||
let account_id = match account_id {
|
|
||||||
Some(id) => id,
|
|
||||||
None => return next.run(req).await,
|
|
||||||
};
|
|
||||||
|
|
||||||
let config = state.config.read().await;
|
let config = state.config.read().await;
|
||||||
let rpm = config.rate_limit.requests_per_minute as u64;
|
let rate_limit = config.rate_limit.requests_per_minute as usize;
|
||||||
let burst = config.rate_limit.burst as u64;
|
|
||||||
let max_requests = rpm + burst;
|
let key = format!("rate_limit:{}", account_id);
|
||||||
drop(config);
|
|
||||||
|
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let window_start = now - std::time::Duration::from_secs(60);
|
let window_start = now - std::time::Duration::from_secs(60);
|
||||||
|
|
||||||
// 滑动窗口: 清理过期条目 + 计数
|
let mut entries = state.rate_limit_entries.entry(key).or_insert_with(Vec::new);
|
||||||
let current_count = {
|
entries.retain(|&time| time > window_start);
|
||||||
let mut entries = state.rate_limit_entries.entry(account_id.clone()).or_default();
|
|
||||||
entries.retain(|&ts| ts > window_start);
|
if entries.len() >= rate_limit {
|
||||||
let count = entries.len() as u64;
|
return SaasError::RateLimited(format!(
|
||||||
if count < max_requests {
|
"请求频率超限,每分钟最多 {} 次请求",
|
||||||
|
rate_limit
|
||||||
|
)).into_response();
|
||||||
|
}
|
||||||
|
|
||||||
entries.push(now);
|
entries.push(now);
|
||||||
0 // 未超限
|
|
||||||
} else {
|
|
||||||
count
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if current_count >= max_requests {
|
|
||||||
// 计算最早条目的过期时间作为 Retry-After
|
|
||||||
let retry_after = if let Some(mut entries) = state.rate_limit_entries.get_mut(&account_id) {
|
|
||||||
entries.sort();
|
|
||||||
let earliest = *entries.first().unwrap_or(&now);
|
|
||||||
let elapsed = now.duration_since(earliest).as_secs();
|
|
||||||
60u64.saturating_sub(elapsed)
|
|
||||||
} else {
|
|
||||||
60
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
StatusCode::TOO_MANY_REQUESTS,
|
|
||||||
[
|
|
||||||
("Retry-After", retry_after.to_string()),
|
|
||||||
("Content-Type", "application/json".to_string()),
|
|
||||||
],
|
|
||||||
axum::Json(serde_json::json!({
|
|
||||||
"error": "RATE_LIMITED",
|
|
||||||
"message": format!("请求过于频繁,请在 {} 秒后重试", retry_after),
|
|
||||||
})),
|
|
||||||
)
|
|
||||||
.into_response();
|
|
||||||
}
|
|
||||||
|
|
||||||
next.run(req).await
|
next.run(req).await
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,16 +7,23 @@ use axum::{
|
|||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
use crate::error::SaasResult;
|
use crate::error::SaasResult;
|
||||||
use crate::auth::types::AuthContext;
|
use crate::auth::types::AuthContext;
|
||||||
use crate::auth::handlers::check_permission;
|
use crate::auth::handlers::{check_permission, log_operation};
|
||||||
|
use crate::common::PaginatedResponse;
|
||||||
use super::{types::*, service};
|
use super::{types::*, service};
|
||||||
|
|
||||||
/// GET /api/v1/config/items?category=xxx&source=xxx
|
/// GET /api/v1/config/items?category=xxx&source=xxx&page=1&page_size=20
|
||||||
pub async fn list_config_items(
|
pub async fn list_config_items(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Query(query): Query<ConfigQuery>,
|
Query(query): Query<ConfigQuery>,
|
||||||
_ctx: Extension<AuthContext>,
|
_ctx: Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<Vec<ConfigItemInfo>>> {
|
) -> SaasResult<Json<PaginatedResponse<ConfigItemInfo>>> {
|
||||||
service::list_config_items(&state.db, &query).await.map(Json)
|
let filter_query = ConfigQuery {
|
||||||
|
category: query.category.clone(),
|
||||||
|
source: query.source.clone(),
|
||||||
|
page: None,
|
||||||
|
page_size: None,
|
||||||
|
};
|
||||||
|
service::list_config_items(&state.db, &filter_query, query.page, query.page_size).await.map(Json)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/config/items/:id
|
/// GET /api/v1/config/items/:id
|
||||||
@@ -36,10 +43,11 @@ pub async fn create_config_item(
|
|||||||
) -> SaasResult<(StatusCode, Json<ConfigItemInfo>)> {
|
) -> SaasResult<(StatusCode, Json<ConfigItemInfo>)> {
|
||||||
check_permission(&ctx, "config:write")?;
|
check_permission(&ctx, "config:write")?;
|
||||||
let item = service::create_config_item(&state.db, &req).await?;
|
let item = service::create_config_item(&state.db, &req).await?;
|
||||||
|
log_operation(&state.db, &ctx.account_id, "config.create", "config_item", &item.id, None, ctx.client_ip.as_deref()).await?;
|
||||||
Ok((StatusCode::CREATED, Json(item)))
|
Ok((StatusCode::CREATED, Json(item)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// PUT /api/v1/config/items/:id (admin only)
|
/// PATCH /api/v1/config/items/:id (admin only)
|
||||||
pub async fn update_config_item(
|
pub async fn update_config_item(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Path(id): Path<String>,
|
Path(id): Path<String>,
|
||||||
@@ -47,7 +55,9 @@ pub async fn update_config_item(
|
|||||||
Json(req): Json<UpdateConfigItemRequest>,
|
Json(req): Json<UpdateConfigItemRequest>,
|
||||||
) -> SaasResult<Json<ConfigItemInfo>> {
|
) -> SaasResult<Json<ConfigItemInfo>> {
|
||||||
check_permission(&ctx, "config:write")?;
|
check_permission(&ctx, "config:write")?;
|
||||||
service::update_config_item(&state.db, &id, &req).await.map(Json)
|
let item = service::update_config_item(&state.db, &id, &req).await?;
|
||||||
|
log_operation(&state.db, &ctx.account_id, "config.update", "config_item", &id, None, ctx.client_ip.as_deref()).await?;
|
||||||
|
Ok(Json(item))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// DELETE /api/v1/config/items/:id (admin only)
|
/// DELETE /api/v1/config/items/:id (admin only)
|
||||||
@@ -58,6 +68,7 @@ pub async fn delete_config_item(
|
|||||||
) -> SaasResult<Json<serde_json::Value>> {
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
check_permission(&ctx, "config:write")?;
|
check_permission(&ctx, "config:write")?;
|
||||||
service::delete_config_item(&state.db, &id).await?;
|
service::delete_config_item(&state.db, &id).await?;
|
||||||
|
log_operation(&state.db, &ctx.account_id, "config.delete", "config_item", &id, None, ctx.client_ip.as_deref()).await?;
|
||||||
Ok(Json(serde_json::json!({"ok": true})))
|
Ok(Json(serde_json::json!({"ok": true})))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -76,32 +87,95 @@ pub async fn seed_config(
|
|||||||
) -> SaasResult<Json<serde_json::Value>> {
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
check_permission(&ctx, "config:write")?;
|
check_permission(&ctx, "config:write")?;
|
||||||
let count = service::seed_default_config_items(&state.db).await?;
|
let count = service::seed_default_config_items(&state.db).await?;
|
||||||
|
log_operation(&state.db, &ctx.account_id, "config.seed", "config_item", "batch", Some(serde_json::json!({"count": count})), ctx.client_ip.as_deref()).await?;
|
||||||
Ok(Json(serde_json::json!({"created": count})))
|
Ok(Json(serde_json::json!({"created": count})))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// POST /api/v1/config/sync
|
/// POST /api/v1/config/sync (需要 config:write 权限)
|
||||||
pub async fn sync_config(
|
pub async fn sync_config(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Json(req): Json<SyncConfigRequest>,
|
Json(req): Json<SyncConfigRequest>,
|
||||||
) -> SaasResult<Json<super::service::ConfigSyncResult>> {
|
) -> SaasResult<Json<super::service::ConfigSyncResult>> {
|
||||||
super::service::sync_config(&state.db, &ctx.account_id, &req).await.map(Json)
|
// 权限检查:仅 config:write 可推送配置
|
||||||
|
check_permission(&ctx, "config:write")?;
|
||||||
|
|
||||||
|
let result = super::service::sync_config(&state.db, &ctx.account_id, &req).await?;
|
||||||
|
|
||||||
|
// 审计日志
|
||||||
|
log_operation(
|
||||||
|
&state.db,
|
||||||
|
&ctx.account_id,
|
||||||
|
"config.sync",
|
||||||
|
"config",
|
||||||
|
"batch",
|
||||||
|
Some(serde_json::json!({
|
||||||
|
"client_fingerprint": req.client_fingerprint,
|
||||||
|
"action": req.action,
|
||||||
|
"config_count": req.config_keys.len(),
|
||||||
|
})),
|
||||||
|
ctx.client_ip.as_deref(),
|
||||||
|
).await.ok();
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// POST /api/v1/config/diff
|
/// POST /api/v1/config/diff
|
||||||
/// 计算客户端与 SaaS 端的配置差异 (不修改数据)
|
/// 计算客户端与 SaaS 端的配置差异 (不修改数据)
|
||||||
pub async fn config_diff(
|
pub async fn config_diff(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(_ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Json(req): Json<SyncConfigRequest>,
|
Json(req): Json<SyncConfigRequest>,
|
||||||
) -> SaasResult<Json<ConfigDiffResponse>> {
|
) -> SaasResult<Json<ConfigDiffResponse>> {
|
||||||
|
// diff 操作虽然不修改数据,但涉及敏感配置信息,仍需认证用户
|
||||||
service::compute_config_diff(&state.db, &req).await.map(Json)
|
service::compute_config_diff(&state.db, &req).await.map(Json)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/config/sync-logs
|
/// GET /api/v1/config/sync-logs?page=1&page_size=20
|
||||||
pub async fn list_sync_logs(
|
pub async fn list_sync_logs(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<Vec<ConfigSyncLogInfo>>> {
|
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||||
service::list_sync_logs(&state.db, &ctx.account_id).await.map(Json)
|
) -> SaasResult<Json<crate::common::PaginatedResponse<ConfigSyncLogInfo>>> {
|
||||||
|
let page: u32 = params.get("page").and_then(|v| v.parse().ok()).unwrap_or(1).max(1);
|
||||||
|
let page_size: u32 = params.get("page_size").and_then(|v| v.parse().ok()).unwrap_or(20).min(100);
|
||||||
|
service::list_sync_logs(&state.db, &ctx.account_id, page, page_size).await.map(Json)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/config/pull?since=2026-03-28T00:00:00Z
|
||||||
|
/// 批量拉取配置(供桌面端启动时一次性拉取)
|
||||||
|
/// 返回扁平的 key-value map,可选 since 参数过滤仅返回该时间之后更新的配置
|
||||||
|
pub async fn pull_config(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
_ctx: Extension<AuthContext>,
|
||||||
|
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||||
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
|
let since = params.get("since").cloned();
|
||||||
|
let items = service::fetch_all_config_items(
|
||||||
|
&state.db,
|
||||||
|
&ConfigQuery { category: None, source: None, page: None, page_size: None },
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
let mut configs: Vec<serde_json::Value> = Vec::new();
|
||||||
|
for item in items {
|
||||||
|
// 如果指定了 since,只返回 updated_at > since 的配置
|
||||||
|
if let Some(ref since_val) = since {
|
||||||
|
if item.updated_at <= *since_val {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
configs.push(serde_json::json!({
|
||||||
|
"key": item.key_path,
|
||||||
|
"category": item.category,
|
||||||
|
"value": item.current_value,
|
||||||
|
"value_type": item.value_type,
|
||||||
|
"default": item.default_value,
|
||||||
|
"updated_at": item.updated_at,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Json(serde_json::json!({
|
||||||
|
"configs": configs,
|
||||||
|
"pulled_at": chrono::Utc::now().to_rfc3339(),
|
||||||
|
})))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,10 +11,11 @@ use crate::state::AppState;
|
|||||||
pub fn routes() -> axum::Router<AppState> {
|
pub fn routes() -> axum::Router<AppState> {
|
||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
.route("/api/v1/config/items", get(handlers::list_config_items).post(handlers::create_config_item))
|
.route("/api/v1/config/items", get(handlers::list_config_items).post(handlers::create_config_item))
|
||||||
.route("/api/v1/config/items/{id}", get(handlers::get_config_item).put(handlers::update_config_item).delete(handlers::delete_config_item))
|
.route("/api/v1/config/items/:id", get(handlers::get_config_item).put(handlers::update_config_item).delete(handlers::delete_config_item))
|
||||||
.route("/api/v1/config/analysis", get(handlers::analyze_config))
|
.route("/api/v1/config/analysis", get(handlers::analyze_config))
|
||||||
.route("/api/v1/config/seed", post(handlers::seed_config))
|
.route("/api/v1/config/seed", post(handlers::seed_config))
|
||||||
.route("/api/v1/config/sync", post(handlers::sync_config))
|
.route("/api/v1/config/sync", post(handlers::sync_config))
|
||||||
.route("/api/v1/config/diff", post(handlers::config_diff))
|
.route("/api/v1/config/diff", post(handlers::config_diff))
|
||||||
.route("/api/v1/config/sync-logs", get(handlers::list_sync_logs))
|
.route("/api/v1/config/sync-logs", get(handlers::list_sync_logs))
|
||||||
|
.route("/api/v1/config/pull", get(handlers::pull_config))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +1,29 @@
|
|||||||
//! 配置迁移业务逻辑
|
//! 配置迁移业务逻辑
|
||||||
|
|
||||||
use sqlx::SqlitePool;
|
use sqlx::PgPool;
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
|
use crate::common::{PaginatedResponse, normalize_pagination};
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
// ============ Config Items ============
|
// ============ Config Items ============
|
||||||
|
|
||||||
pub async fn list_config_items(
|
/// Fetch all config items matching the query (internal use, no pagination).
|
||||||
db: &SqlitePool, query: &ConfigQuery,
|
pub(crate) async fn fetch_all_config_items(
|
||||||
|
db: &PgPool, query: &ConfigQuery,
|
||||||
) -> SaasResult<Vec<ConfigItemInfo>> {
|
) -> SaasResult<Vec<ConfigItemInfo>> {
|
||||||
let sql = match (&query.category, &query.source) {
|
let sql = match (&query.category, &query.source) {
|
||||||
(Some(_), Some(_)) => {
|
(Some(_), Some(_)) => {
|
||||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||||
FROM config_items WHERE category = ?1 AND source = ?2 ORDER BY category, key_path"
|
FROM config_items WHERE category = $1 AND source = $2 ORDER BY category, key_path"
|
||||||
}
|
}
|
||||||
(Some(_), None) => {
|
(Some(_), None) => {
|
||||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||||
FROM config_items WHERE category = ?1 ORDER BY key_path"
|
FROM config_items WHERE category = $1 ORDER BY key_path"
|
||||||
}
|
}
|
||||||
(None, Some(_)) => {
|
(None, Some(_)) => {
|
||||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||||
FROM config_items WHERE source = ?1 ORDER BY category, key_path"
|
FROM config_items WHERE source = $1 ORDER BY category, key_path"
|
||||||
}
|
}
|
||||||
(None, None) => {
|
(None, None) => {
|
||||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||||
@@ -44,11 +46,58 @@ pub async fn list_config_items(
|
|||||||
}).collect())
|
}).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<ConfigItemInfo> {
|
/// Paginated list of config items (HTTP handler entry point).
|
||||||
|
pub async fn list_config_items(
|
||||||
|
db: &PgPool, query: &ConfigQuery,
|
||||||
|
page: Option<u32>, page_size: Option<u32>,
|
||||||
|
) -> SaasResult<PaginatedResponse<ConfigItemInfo>> {
|
||||||
|
let (p, ps, offset) = normalize_pagination(page, page_size);
|
||||||
|
|
||||||
|
// Build WHERE clause for count and data queries
|
||||||
|
let (where_clause, has_category, has_source) = match (&query.category, &query.source) {
|
||||||
|
(Some(_), Some(_)) => ("WHERE category = $1 AND source = $2", true, true),
|
||||||
|
(Some(_), None) => ("WHERE category = $1", true, false),
|
||||||
|
(None, Some(_)) => ("WHERE source = $1", false, true),
|
||||||
|
(None, None) => ("", false, false),
|
||||||
|
};
|
||||||
|
|
||||||
|
let count_sql = format!("SELECT COUNT(*) FROM config_items {}", where_clause);
|
||||||
|
let data_sql = format!(
|
||||||
|
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||||
|
FROM config_items {} ORDER BY category, key_path LIMIT {} OFFSET {}",
|
||||||
|
where_clause, "$p", "$o"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Determine param indices for LIMIT/OFFSET based on filter params
|
||||||
|
let (limit_idx, offset_idx) = match (has_category, has_source) {
|
||||||
|
(true, true) => ("$3", "$4"),
|
||||||
|
(true, false) | (false, true) => ("$2", "$3"),
|
||||||
|
(false, false) => ("$1", "$2"),
|
||||||
|
};
|
||||||
|
let data_sql = data_sql.replace("$p", limit_idx).replace("$o", offset_idx);
|
||||||
|
|
||||||
|
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
|
||||||
|
if has_category { count_query = count_query.bind(&query.category); }
|
||||||
|
if has_source { count_query = count_query.bind(&query.source); }
|
||||||
|
let total: i64 = count_query.fetch_one(db).await?;
|
||||||
|
|
||||||
|
let mut data_query = sqlx::query_as::<_, (String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)>(&data_sql);
|
||||||
|
if has_category { data_query = data_query.bind(&query.category); }
|
||||||
|
if has_source { data_query = data_query.bind(&query.source); }
|
||||||
|
let rows = data_query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||||
|
|
||||||
|
let items = rows.into_iter().map(|(id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)| {
|
||||||
|
ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at }
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
Ok(PaginatedResponse { items, total, page: p, page_size: ps })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_config_item(db: &PgPool, item_id: &str) -> SaasResult<ConfigItemInfo> {
|
||||||
let row: Option<(String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)> =
|
let row: Option<(String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
|
||||||
FROM config_items WHERE id = ?1"
|
FROM config_items WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(item_id)
|
.bind(item_id)
|
||||||
.fetch_optional(db)
|
.fetch_optional(db)
|
||||||
@@ -61,7 +110,7 @@ pub async fn get_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<Confi
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_config_item(
|
pub async fn create_config_item(
|
||||||
db: &SqlitePool, req: &CreateConfigItemRequest,
|
db: &PgPool, req: &CreateConfigItemRequest,
|
||||||
) -> SaasResult<ConfigItemInfo> {
|
) -> SaasResult<ConfigItemInfo> {
|
||||||
let id = uuid::Uuid::new_v4().to_string();
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
@@ -70,7 +119,7 @@ pub async fn create_config_item(
|
|||||||
|
|
||||||
// 检查唯一性
|
// 检查唯一性
|
||||||
let existing: Option<(String,)> = sqlx::query_as(
|
let existing: Option<(String,)> = sqlx::query_as(
|
||||||
"SELECT id FROM config_items WHERE category = ?1 AND key_path = ?2"
|
"SELECT id FROM config_items WHERE category = $1 AND key_path = $2"
|
||||||
)
|
)
|
||||||
.bind(&req.category).bind(&req.key_path)
|
.bind(&req.category).bind(&req.key_path)
|
||||||
.fetch_optional(db).await?;
|
.fetch_optional(db).await?;
|
||||||
@@ -83,7 +132,7 @@ pub async fn create_config_item(
|
|||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
|
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?10)"
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $10)"
|
||||||
)
|
)
|
||||||
.bind(&id).bind(&req.category).bind(&req.key_path).bind(&req.value_type)
|
.bind(&id).bind(&req.category).bind(&req.key_path).bind(&req.value_type)
|
||||||
.bind(&req.current_value).bind(&req.default_value).bind(source)
|
.bind(&req.current_value).bind(&req.default_value).bind(source)
|
||||||
@@ -94,25 +143,27 @@ pub async fn create_config_item(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update_config_item(
|
pub async fn update_config_item(
|
||||||
db: &SqlitePool, item_id: &str, req: &UpdateConfigItemRequest,
|
db: &PgPool, item_id: &str, req: &UpdateConfigItemRequest,
|
||||||
) -> SaasResult<ConfigItemInfo> {
|
) -> SaasResult<ConfigItemInfo> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let mut updates = Vec::new();
|
let mut updates = Vec::new();
|
||||||
let mut params: Vec<String> = Vec::new();
|
let mut params: Vec<String> = Vec::new();
|
||||||
|
let mut param_idx = 1usize;
|
||||||
|
|
||||||
if let Some(ref v) = req.current_value { updates.push("current_value = ?"); params.push(v.clone()); }
|
if let Some(ref v) = req.current_value { updates.push(format!("current_value = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||||
if let Some(ref v) = req.source { updates.push("source = ?"); params.push(v.clone()); }
|
if let Some(ref v) = req.source { updates.push(format!("source = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||||
if let Some(ref v) = req.description { updates.push("description = ?"); params.push(v.clone()); }
|
if let Some(ref v) = req.description { updates.push(format!("description = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||||
|
|
||||||
if updates.is_empty() {
|
if updates.is_empty() {
|
||||||
return get_config_item(db, item_id).await;
|
return get_config_item(db, item_id).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
updates.push("updated_at = ?");
|
updates.push(format!("updated_at = ${}", param_idx));
|
||||||
params.push(now);
|
params.push(now);
|
||||||
|
param_idx += 1;
|
||||||
params.push(item_id.to_string());
|
params.push(item_id.to_string());
|
||||||
|
|
||||||
let sql = format!("UPDATE config_items SET {} WHERE id = ?", updates.join(", "));
|
let sql = format!("UPDATE config_items SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||||
let mut query = sqlx::query(&sql);
|
let mut query = sqlx::query(&sql);
|
||||||
for p in ¶ms {
|
for p in ¶ms {
|
||||||
query = query.bind(p);
|
query = query.bind(p);
|
||||||
@@ -122,8 +173,8 @@ pub async fn update_config_item(
|
|||||||
get_config_item(db, item_id).await
|
get_config_item(db, item_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn delete_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<()> {
|
pub async fn delete_config_item(db: &PgPool, item_id: &str) -> SaasResult<()> {
|
||||||
let result = sqlx::query("DELETE FROM config_items WHERE id = ?1")
|
let result = sqlx::query("DELETE FROM config_items WHERE id = $1")
|
||||||
.bind(item_id).execute(db).await?;
|
.bind(item_id).execute(db).await?;
|
||||||
if result.rows_affected() == 0 {
|
if result.rows_affected() == 0 {
|
||||||
return Err(SaasError::NotFound(format!("配置项 {} 不存在", item_id)));
|
return Err(SaasError::NotFound(format!("配置项 {} 不存在", item_id)));
|
||||||
@@ -133,8 +184,8 @@ pub async fn delete_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<()
|
|||||||
|
|
||||||
// ============ Config Analysis ============
|
// ============ Config Analysis ============
|
||||||
|
|
||||||
pub async fn analyze_config(db: &SqlitePool) -> SaasResult<ConfigAnalysis> {
|
pub async fn analyze_config(db: &PgPool) -> SaasResult<ConfigAnalysis> {
|
||||||
let items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
|
let items = fetch_all_config_items(db, &ConfigQuery { category: None, source: None, page: None, page_size: None }).await?;
|
||||||
|
|
||||||
let mut categories: std::collections::HashMap<String, (i64, i64)> = std::collections::HashMap::new();
|
let mut categories: std::collections::HashMap<String, (i64, i64)> = std::collections::HashMap::new();
|
||||||
for item in &items {
|
for item in &items {
|
||||||
@@ -157,7 +208,7 @@ pub async fn analyze_config(db: &SqlitePool) -> SaasResult<ConfigAnalysis> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 种子默认配置项
|
/// 种子默认配置项
|
||||||
pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
|
pub async fn seed_default_config_items(db: &PgPool) -> SaasResult<usize> {
|
||||||
let defaults = [
|
let defaults = [
|
||||||
("server", "server.host", "string", Some("127.0.0.1"), Some("127.0.0.1"), "服务器监听地址"),
|
("server", "server.host", "string", Some("127.0.0.1"), Some("127.0.0.1"), "服务器监听地址"),
|
||||||
("server", "server.port", "integer", Some("4200"), Some("4200"), "服务器端口"),
|
("server", "server.port", "integer", Some("4200"), Some("4200"), "服务器端口"),
|
||||||
@@ -172,6 +223,19 @@ pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
|
|||||||
("llm", "llm.default_provider", "string", Some("zhipu"), Some("zhipu"), "默认 LLM Provider"),
|
("llm", "llm.default_provider", "string", Some("zhipu"), Some("zhipu"), "默认 LLM Provider"),
|
||||||
("llm", "llm.temperature", "float", Some("0.7"), Some("0.7"), "默认温度"),
|
("llm", "llm.temperature", "float", Some("0.7"), Some("0.7"), "默认温度"),
|
||||||
("llm", "llm.max_tokens", "integer", Some("4096"), Some("4096"), "默认最大 token 数"),
|
("llm", "llm.max_tokens", "integer", Some("4096"), Some("4096"), "默认最大 token 数"),
|
||||||
|
// 安全策略配置
|
||||||
|
("security", "security.autonomy_level", "string", Some("standard"), Some("standard"), "自主级别: minimal/standard/full"),
|
||||||
|
("security", "security.max_tokens_per_request", "integer", Some("32768"), Some("32768"), "单次请求最大 Token 数"),
|
||||||
|
("security", "security.shell_enabled", "boolean", Some("true"), Some("true"), "是否启用 Shell 工具"),
|
||||||
|
("security", "security.shell_whitelist", "array", Some("[]"), Some("[]"), "Shell 命令白名单 (空=全部禁止)"),
|
||||||
|
("security", "security.file_write_enabled", "boolean", Some("true"), Some("true"), "是否允许文件写入"),
|
||||||
|
("security", "security.network_access_enabled", "boolean", Some("true"), Some("true"), "是否允许网络访问"),
|
||||||
|
("security", "security.browser_enabled", "boolean", Some("true"), Some("true"), "是否启用浏览器自动化"),
|
||||||
|
("security", "security.max_concurrent_tasks", "integer", Some("3"), Some("3"), "最大并发自主任务数"),
|
||||||
|
("security", "security.approval_required", "boolean", Some("false"), Some("false"), "高风险操作是否需要审批"),
|
||||||
|
("security", "security.content_filter_enabled", "boolean", Some("true"), Some("true"), "是否启用内容过滤"),
|
||||||
|
("security", "security.audit_log_enabled", "boolean", Some("true"), Some("true"), "是否启用审计日志"),
|
||||||
|
("security", "security.audit_log_max_entries", "integer", Some("500"), Some("500"), "审计日志最大条目数"),
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut created = 0;
|
let mut created = 0;
|
||||||
@@ -179,7 +243,7 @@ pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
|
|||||||
|
|
||||||
for (category, key_path, value_type, default_value, current_value, description) in defaults {
|
for (category, key_path, value_type, default_value, current_value, description) in defaults {
|
||||||
let existing: Option<(String,)> = sqlx::query_as(
|
let existing: Option<(String,)> = sqlx::query_as(
|
||||||
"SELECT id FROM config_items WHERE category = ?1 AND key_path = ?2"
|
"SELECT id FROM config_items WHERE category = $1 AND key_path = $2"
|
||||||
)
|
)
|
||||||
.bind(category).bind(key_path)
|
.bind(category).bind(key_path)
|
||||||
.fetch_optional(db)
|
.fetch_optional(db)
|
||||||
@@ -189,7 +253,7 @@ pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
|
|||||||
let id = uuid::Uuid::new_v4().to_string();
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
|
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'local', ?7, 0, ?8, ?8)"
|
VALUES ($1, $2, $3, $4, $5, $6, 'local', $7, false, $8, $8)"
|
||||||
)
|
)
|
||||||
.bind(&id).bind(category).bind(key_path).bind(value_type)
|
.bind(&id).bind(category).bind(key_path).bind(value_type)
|
||||||
.bind(current_value).bind(default_value).bind(description).bind(&now)
|
.bind(current_value).bind(default_value).bind(description).bind(&now)
|
||||||
@@ -206,9 +270,9 @@ pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
|
|||||||
|
|
||||||
/// 计算客户端与 SaaS 端的配置差异
|
/// 计算客户端与 SaaS 端的配置差异
|
||||||
pub async fn compute_config_diff(
|
pub async fn compute_config_diff(
|
||||||
db: &SqlitePool, req: &SyncConfigRequest,
|
db: &PgPool, req: &SyncConfigRequest,
|
||||||
) -> SaasResult<ConfigDiffResponse> {
|
) -> SaasResult<ConfigDiffResponse> {
|
||||||
let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
|
let saas_items = fetch_all_config_items(db, &ConfigQuery { category: None, source: None, page: None, page_size: None }).await?;
|
||||||
|
|
||||||
let mut items = Vec::new();
|
let mut items = Vec::new();
|
||||||
let mut conflicts = 0usize;
|
let mut conflicts = 0usize;
|
||||||
@@ -248,17 +312,18 @@ pub async fn compute_config_diff(
|
|||||||
|
|
||||||
/// 执行配置同步 (实际写入 config_items)
|
/// 执行配置同步 (实际写入 config_items)
|
||||||
pub async fn sync_config(
|
pub async fn sync_config(
|
||||||
db: &SqlitePool, account_id: &str, req: &SyncConfigRequest,
|
db: &PgPool, account_id: &str, req: &SyncConfigRequest,
|
||||||
) -> SaasResult<ConfigSyncResult> {
|
) -> SaasResult<ConfigSyncResult> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let config_keys_str = serde_json::to_string(&req.config_keys)?;
|
let config_keys_str = serde_json::to_string(&req.config_keys)?;
|
||||||
let client_values_str = Some(serde_json::to_string(&req.client_values)?);
|
let client_values_str = Some(serde_json::to_string(&req.client_values)?);
|
||||||
|
|
||||||
// 获取 SaaS 端的配置值
|
// 获取 SaaS 端的配置值
|
||||||
let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
|
let saas_items = fetch_all_config_items(db, &ConfigQuery { category: None, source: None, page: None, page_size: None }).await?;
|
||||||
let mut updated = 0i64;
|
let mut updated = 0i64;
|
||||||
let created = 0i64;
|
let mut created = 0i64;
|
||||||
let mut skipped = 0i64;
|
let mut skipped = 0i64;
|
||||||
|
let mut conflicts: Vec<String> = Vec::new();
|
||||||
|
|
||||||
for key in &req.config_keys {
|
for key in &req.config_keys {
|
||||||
let client_val = req.client_values.get(key)
|
let client_val = req.client_values.get(key)
|
||||||
@@ -269,26 +334,55 @@ pub async fn sync_config(
|
|||||||
|
|
||||||
match req.action.as_str() {
|
match req.action.as_str() {
|
||||||
"push" => {
|
"push" => {
|
||||||
// 客户端推送 → 覆盖 SaaS 值
|
// 客户端推送 → 覆盖 SaaS 值 (带 CAS 保护)
|
||||||
if let Some(val) = &client_val {
|
if let Some(val) = &client_val {
|
||||||
if let Some(item) = saas_item {
|
if let Some(item) = saas_item {
|
||||||
// 更新已有配置项
|
// CAS: 如果客户端提供了该 key 的 timestamp,做乐观锁
|
||||||
sqlx::query("UPDATE config_items SET current_value = ?1, source = 'local', updated_at = ?2 WHERE id = ?3")
|
if let Some(ref client_ts) = req.client_timestamps.get(key) {
|
||||||
|
let result = sqlx::query(
|
||||||
|
"UPDATE config_items SET current_value = $1, source = 'local', updated_at = $2 WHERE id = $3 AND updated_at = $4"
|
||||||
|
)
|
||||||
|
.bind(val).bind(&now).bind(&item.id).bind(client_ts)
|
||||||
|
.execute(db).await?;
|
||||||
|
if result.rows_affected() == 0 {
|
||||||
|
// SaaS 端已被修改 → 跳过,记录冲突
|
||||||
|
tracing::warn!(
|
||||||
|
"[ConfigSync] CAS conflict for key '{}': client_ts={}, saas_ts={}",
|
||||||
|
key, client_ts, item.updated_at
|
||||||
|
);
|
||||||
|
conflicts.push(key.clone());
|
||||||
|
skipped += 1;
|
||||||
|
} else {
|
||||||
|
updated += 1;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 无 CAS timestamp → 无条件覆盖 (向后兼容)
|
||||||
|
sqlx::query("UPDATE config_items SET current_value = $1, source = 'local', updated_at = $2 WHERE id = $3")
|
||||||
.bind(val).bind(&now).bind(&item.id)
|
.bind(val).bind(&now).bind(&item.id)
|
||||||
.execute(db).await?;
|
.execute(db).await?;
|
||||||
updated += 1;
|
updated += 1;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// 推送时如果 SaaS 不存在该 key,记录跳过
|
// 推送时 SaaS 不存在该 key → 创建新配置项
|
||||||
skipped += 1;
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let parts: Vec<&str> = key.splitn(2, '.').collect();
|
||||||
|
let category = parts.first().unwrap_or(&"general").to_string();
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, 'string', $4, $4, 'local', '客户端推送', false, $5, $5)"
|
||||||
|
)
|
||||||
|
.bind(&id).bind(&category).bind(key).bind(val).bind(&now)
|
||||||
|
.execute(db).await?;
|
||||||
|
created += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"merge" => {
|
"merge" => {
|
||||||
// 合并: 客户端有值且 SaaS 无值 → 创建; 都有值 → SaaS 优先保留
|
// 合并: 客户端有值且 SaaS 无值 → 填入; 都有值 → SaaS 优先保留
|
||||||
if let Some(val) = &client_val {
|
if let Some(val) = &client_val {
|
||||||
if let Some(item) = saas_item {
|
if let Some(item) = saas_item {
|
||||||
if item.current_value.is_none() || item.current_value.as_deref() == Some("") {
|
if item.current_value.is_none() || item.current_value.as_deref() == Some("") {
|
||||||
sqlx::query("UPDATE config_items SET current_value = ?1, source = 'local', updated_at = ?2 WHERE id = ?3")
|
sqlx::query("UPDATE config_items SET current_value = $1, source = 'local', updated_at = $2 WHERE id = $3")
|
||||||
.bind(val).bind(&now).bind(&item.id)
|
.bind(val).bind(&now).bind(&item.id)
|
||||||
.execute(db).await?;
|
.execute(db).await?;
|
||||||
updated += 1;
|
updated += 1;
|
||||||
@@ -296,11 +390,12 @@ pub async fn sync_config(
|
|||||||
// 冲突: SaaS 有值 → 保留 SaaS 值
|
// 冲突: SaaS 有值 → 保留 SaaS 值
|
||||||
skipped += 1;
|
skipped += 1;
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
// 客户端有但 SaaS 完全没有的 key → 不自动创建 (需要管理员先创建)
|
// 客户端有但 SaaS 完全没有的 key → 不自动创建 (需要管理员先创建)
|
||||||
skipped += 1;
|
skipped += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
// 默认: 记录日志但不修改 (向后兼容旧行为)
|
// 默认: 记录日志但不修改 (向后兼容旧行为)
|
||||||
}
|
}
|
||||||
@@ -323,7 +418,7 @@ pub async fn sync_config(
|
|||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO config_sync_log (account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)
|
"INSERT INTO config_sync_log (account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)"
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||||
)
|
)
|
||||||
.bind(account_id).bind(&req.client_fingerprint)
|
.bind(account_id).bind(&req.client_fingerprint)
|
||||||
.bind(&req.action).bind(&config_keys_str).bind(&client_values_str)
|
.bind(&req.action).bind(&config_keys_str).bind(&client_values_str)
|
||||||
@@ -331,7 +426,7 @@ pub async fn sync_config(
|
|||||||
.execute(db)
|
.execute(db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(ConfigSyncResult { updated, created, skipped })
|
Ok(ConfigSyncResult { updated, created, skipped, conflicts })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 同步结果
|
/// 同步结果
|
||||||
@@ -340,21 +435,36 @@ pub struct ConfigSyncResult {
|
|||||||
pub updated: i64,
|
pub updated: i64,
|
||||||
pub created: i64,
|
pub created: i64,
|
||||||
pub skipped: i64,
|
pub skipped: i64,
|
||||||
|
/// Keys skipped due to CAS conflict (SaaS was modified after client read)
|
||||||
|
pub conflicts: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn list_sync_logs(
|
pub async fn list_sync_logs(
|
||||||
db: &SqlitePool, account_id: &str,
|
db: &PgPool, account_id: &str, page: u32, page_size: u32,
|
||||||
) -> SaasResult<Vec<ConfigSyncLogInfo>> {
|
) -> SaasResult<crate::common::PaginatedResponse<ConfigSyncLogInfo>> {
|
||||||
|
let offset = ((page - 1) * page_size) as i64;
|
||||||
|
|
||||||
|
let total: (i64,) = sqlx::query_as(
|
||||||
|
"SELECT COUNT(*) FROM config_sync_log WHERE account_id = $1"
|
||||||
|
)
|
||||||
|
.bind(account_id)
|
||||||
|
.fetch_one(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
let rows: Vec<(i64, String, String, String, String, Option<String>, Option<String>, Option<String>, String)> =
|
let rows: Vec<(i64, String, String, String, String, Option<String>, Option<String>, Option<String>, String)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at
|
"SELECT id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at
|
||||||
FROM config_sync_log WHERE account_id = ?1 ORDER BY created_at DESC LIMIT 50"
|
FROM config_sync_log WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||||||
)
|
)
|
||||||
.bind(account_id)
|
.bind(account_id)
|
||||||
|
.bind(page_size as i64)
|
||||||
|
.bind(offset)
|
||||||
.fetch_all(db)
|
.fetch_all(db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
Ok(rows.into_iter().map(|(id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)| {
|
let items = rows.into_iter().map(|(id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)| {
|
||||||
ConfigSyncLogInfo { id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at }
|
ConfigSyncLogInfo { id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at }
|
||||||
}).collect())
|
}).collect();
|
||||||
|
|
||||||
|
Ok(crate::common::PaginatedResponse { items, total: total.0, page, page_size })
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -77,6 +77,11 @@ pub struct SyncConfigRequest {
|
|||||||
pub action: String,
|
pub action: String,
|
||||||
pub config_keys: Vec<String>,
|
pub config_keys: Vec<String>,
|
||||||
pub client_values: serde_json::Value,
|
pub client_values: serde_json::Value,
|
||||||
|
/// Client-side timestamps per key for optimistic locking (push CAS).
|
||||||
|
/// Maps `key_path` → `updated_at` as seen by client before this push.
|
||||||
|
/// Keys present here get `WHERE updated_at = $ts` on UPDATE; absent keys use unconditional overwrite.
|
||||||
|
#[serde(default)]
|
||||||
|
pub client_timestamps: std::collections::HashMap<String, String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_sync_action() -> String { "push".to_string() }
|
fn default_sync_action() -> String { "push".to_string() }
|
||||||
@@ -103,4 +108,6 @@ pub struct ConfigDiffResponse {
|
|||||||
pub struct ConfigQuery {
|
pub struct ConfigQuery {
|
||||||
pub category: Option<String>,
|
pub category: Option<String>,
|
||||||
pub source: Option<String>,
|
pub source: Option<String>,
|
||||||
|
pub page: Option<u32>,
|
||||||
|
pub page_size: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,19 +5,24 @@ use axum::{
|
|||||||
http::StatusCode, Json,
|
http::StatusCode, Json,
|
||||||
};
|
};
|
||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
use crate::error::SaasResult;
|
use crate::error::{SaasResult, SaasError};
|
||||||
use crate::auth::types::AuthContext;
|
use crate::auth::types::AuthContext;
|
||||||
use crate::auth::handlers::{log_operation, check_permission};
|
use crate::auth::handlers::{log_operation, check_permission};
|
||||||
|
use crate::common::PaginatedResponse;
|
||||||
use super::{types::*, service};
|
use super::{types::*, service};
|
||||||
|
|
||||||
// ============ Providers ============
|
// ============ Providers ============
|
||||||
|
|
||||||
/// GET /api/v1/providers
|
/// GET /api/v1/providers?enabled=true&page=1&page_size=20
|
||||||
pub async fn list_providers(
|
pub async fn list_providers(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
_ctx: Extension<AuthContext>,
|
_ctx: Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<Vec<ProviderInfo>>> {
|
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||||
service::list_providers(&state.db).await.map(Json)
|
) -> SaasResult<Json<PaginatedResponse<ProviderInfo>>> {
|
||||||
|
let page = params.get("page").and_then(|v| v.parse().ok());
|
||||||
|
let page_size = params.get("page_size").and_then(|v| v.parse().ok());
|
||||||
|
let enabled_filter = params.get("enabled").and_then(|v| v.parse().ok());
|
||||||
|
service::list_providers(&state.db, page, page_size, enabled_filter).await.map(Json)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/providers/:id
|
/// GET /api/v1/providers/:id
|
||||||
@@ -36,13 +41,17 @@ pub async fn create_provider(
|
|||||||
Json(req): Json<CreateProviderRequest>,
|
Json(req): Json<CreateProviderRequest>,
|
||||||
) -> SaasResult<(StatusCode, Json<ProviderInfo>)> {
|
) -> SaasResult<(StatusCode, Json<ProviderInfo>)> {
|
||||||
check_permission(&ctx, "provider:manage")?;
|
check_permission(&ctx, "provider:manage")?;
|
||||||
let provider = service::create_provider(&state.db, &req).await?;
|
let config = state.config.read().await;
|
||||||
|
let enc_key = config.api_key_encryption_key()
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||||
|
drop(config);
|
||||||
|
let provider = service::create_provider(&state.db, &req, &enc_key).await?;
|
||||||
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
|
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
|
||||||
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
|
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
|
||||||
Ok((StatusCode::CREATED, Json(provider)))
|
Ok((StatusCode::CREATED, Json(provider)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// PUT /api/v1/providers/:id (admin only)
|
/// PATCH /api/v1/providers/:id (admin only)
|
||||||
pub async fn update_provider(
|
pub async fn update_provider(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Path(id): Path<String>,
|
Path(id): Path<String>,
|
||||||
@@ -50,7 +59,11 @@ pub async fn update_provider(
|
|||||||
Json(req): Json<UpdateProviderRequest>,
|
Json(req): Json<UpdateProviderRequest>,
|
||||||
) -> SaasResult<Json<ProviderInfo>> {
|
) -> SaasResult<Json<ProviderInfo>> {
|
||||||
check_permission(&ctx, "provider:manage")?;
|
check_permission(&ctx, "provider:manage")?;
|
||||||
let provider = service::update_provider(&state.db, &id, &req).await?;
|
let config = state.config.read().await;
|
||||||
|
let enc_key = config.api_key_encryption_key()
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||||
|
drop(config);
|
||||||
|
let provider = service::update_provider(&state.db, &id, &req, &enc_key).await?;
|
||||||
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?;
|
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?;
|
||||||
Ok(Json(provider))
|
Ok(Json(provider))
|
||||||
}
|
}
|
||||||
@@ -69,14 +82,16 @@ pub async fn delete_provider(
|
|||||||
|
|
||||||
// ============ Models ============
|
// ============ Models ============
|
||||||
|
|
||||||
/// GET /api/v1/models?provider_id=xxx
|
/// GET /api/v1/models?provider_id=xxx&page=1&page_size=20
|
||||||
pub async fn list_models(
|
pub async fn list_models(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||||
_ctx: Extension<AuthContext>,
|
_ctx: Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<Vec<ModelInfo>>> {
|
) -> SaasResult<Json<PaginatedResponse<ModelInfo>>> {
|
||||||
let provider_id = params.get("provider_id").map(|s| s.as_str());
|
let provider_id = params.get("provider_id").map(|s| s.as_str());
|
||||||
service::list_models(&state.db, provider_id).await.map(Json)
|
let page = params.get("page").and_then(|v| v.parse().ok());
|
||||||
|
let page_size = params.get("page_size").and_then(|v| v.parse().ok());
|
||||||
|
service::list_models(&state.db, provider_id, page, page_size).await.map(Json)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/models/:id
|
/// GET /api/v1/models/:id
|
||||||
@@ -101,7 +116,7 @@ pub async fn create_model(
|
|||||||
Ok((StatusCode::CREATED, Json(model)))
|
Ok((StatusCode::CREATED, Json(model)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// PUT /api/v1/models/:id (admin only)
|
/// PATCH /api/v1/models/:id (admin only)
|
||||||
pub async fn update_model(
|
pub async fn update_model(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Path(id): Path<String>,
|
Path(id): Path<String>,
|
||||||
@@ -128,14 +143,16 @@ pub async fn delete_model(
|
|||||||
|
|
||||||
// ============ Account API Keys ============
|
// ============ Account API Keys ============
|
||||||
|
|
||||||
/// GET /api/v1/keys?provider_id=xxx
|
/// GET /api/v1/keys?provider_id=xxx&page=1&page_size=20
|
||||||
pub async fn list_api_keys(
|
pub async fn list_api_keys(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||||
) -> SaasResult<Json<Vec<AccountApiKeyInfo>>> {
|
) -> SaasResult<Json<PaginatedResponse<AccountApiKeyInfo>>> {
|
||||||
let provider_id = params.get("provider_id").map(|s| s.as_str());
|
let provider_id = params.get("provider_id").map(|s| s.as_str());
|
||||||
service::list_account_api_keys(&state.db, &ctx.account_id, provider_id).await.map(Json)
|
let page = params.get("page").and_then(|v| v.parse().ok());
|
||||||
|
let page_size = params.get("page_size").and_then(|v| v.parse().ok());
|
||||||
|
service::list_account_api_keys(&state.db, &ctx.account_id, provider_id, page, page_size).await.map(Json)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// POST /api/v1/keys
|
/// POST /api/v1/keys
|
||||||
@@ -144,7 +161,11 @@ pub async fn create_api_key(
|
|||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Json(req): Json<CreateAccountApiKeyRequest>,
|
Json(req): Json<CreateAccountApiKeyRequest>,
|
||||||
) -> SaasResult<(StatusCode, Json<AccountApiKeyInfo>)> {
|
) -> SaasResult<(StatusCode, Json<AccountApiKeyInfo>)> {
|
||||||
let key = service::create_account_api_key(&state.db, &ctx.account_id, &req).await?;
|
let config = state.config.read().await;
|
||||||
|
let enc_key = config.api_key_encryption_key()
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||||
|
drop(config);
|
||||||
|
let key = service::create_account_api_key(&state.db, &ctx.account_id, &req, &enc_key).await?;
|
||||||
log_operation(&state.db, &ctx.account_id, "api_key.create", "api_key", &key.id,
|
log_operation(&state.db, &ctx.account_id, "api_key.create", "api_key", &key.id,
|
||||||
Some(serde_json::json!({"provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?;
|
Some(serde_json::json!({"provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?;
|
||||||
Ok((StatusCode::CREATED, Json(key)))
|
Ok((StatusCode::CREATED, Json(key)))
|
||||||
@@ -157,7 +178,11 @@ pub async fn rotate_api_key(
|
|||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Json(req): Json<RotateApiKeyRequest>,
|
Json(req): Json<RotateApiKeyRequest>,
|
||||||
) -> SaasResult<Json<serde_json::Value>> {
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
service::rotate_account_api_key(&state.db, &id, &ctx.account_id, &req.new_key_value).await?;
|
let config = state.config.read().await;
|
||||||
|
let enc_key = config.api_key_encryption_key()
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||||
|
drop(config);
|
||||||
|
service::rotate_account_api_key(&state.db, &id, &ctx.account_id, &req.new_key_value, &enc_key).await?;
|
||||||
log_operation(&state.db, &ctx.account_id, "api_key.rotate", "api_key", &id, None, ctx.client_ip.as_deref()).await?;
|
log_operation(&state.db, &ctx.account_id, "api_key.rotate", "api_key", &id, None, ctx.client_ip.as_deref()).await?;
|
||||||
Ok(Json(serde_json::json!({"ok": true})))
|
Ok(Json(serde_json::json!({"ok": true})))
|
||||||
}
|
}
|
||||||
@@ -189,6 +214,6 @@ pub async fn list_provider_models(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Path(provider_id): Path<String>,
|
Path(provider_id): Path<String>,
|
||||||
_ctx: Extension<AuthContext>,
|
_ctx: Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<Vec<ModelInfo>>> {
|
) -> SaasResult<Json<PaginatedResponse<ModelInfo>>> {
|
||||||
service::list_models(&state.db, Some(&provider_id)).await.map(Json)
|
service::list_models(&state.db, Some(&provider_id), None, None).await.map(Json)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,15 +12,15 @@ pub fn routes() -> axum::Router<AppState> {
|
|||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
// Providers
|
// Providers
|
||||||
.route("/api/v1/providers", get(handlers::list_providers).post(handlers::create_provider))
|
.route("/api/v1/providers", get(handlers::list_providers).post(handlers::create_provider))
|
||||||
.route("/api/v1/providers/{id}", get(handlers::get_provider).put(handlers::update_provider).delete(handlers::delete_provider))
|
.route("/api/v1/providers/:id", get(handlers::get_provider).patch(handlers::update_provider).delete(handlers::delete_provider))
|
||||||
.route("/api/v1/providers/{id}/models", get(handlers::list_provider_models))
|
.route("/api/v1/providers/:id/models", get(handlers::list_provider_models))
|
||||||
// Models
|
// Models
|
||||||
.route("/api/v1/models", get(handlers::list_models).post(handlers::create_model))
|
.route("/api/v1/models", get(handlers::list_models).post(handlers::create_model))
|
||||||
.route("/api/v1/models/{id}", get(handlers::get_model).put(handlers::update_model).delete(handlers::delete_model))
|
.route("/api/v1/models/:id", get(handlers::get_model).patch(handlers::update_model).delete(handlers::delete_model))
|
||||||
// Account API Keys
|
// Account API Keys
|
||||||
.route("/api/v1/keys", get(handlers::list_api_keys).post(handlers::create_api_key))
|
.route("/api/v1/keys", get(handlers::list_api_keys).post(handlers::create_api_key))
|
||||||
.route("/api/v1/keys/{id}", delete(handlers::revoke_api_key))
|
.route("/api/v1/keys/:id", delete(handlers::revoke_api_key))
|
||||||
.route("/api/v1/keys/{id}/rotate", post(handlers::rotate_api_key))
|
.route("/api/v1/keys/:id/rotate", post(handlers::rotate_api_key))
|
||||||
// Usage
|
// Usage
|
||||||
.route("/api/v1/usage", get(handlers::get_usage))
|
.route("/api/v1/usage", get(handlers::get_usage))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,30 +1,61 @@
|
|||||||
//! 模型配置业务逻辑
|
//! 模型配置业务逻辑
|
||||||
|
|
||||||
use sqlx::SqlitePool;
|
use sqlx::{PgPool, Row};
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
|
use crate::common::{PaginatedResponse, normalize_pagination};
|
||||||
|
use crate::crypto;
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
|
|
||||||
// ============ Providers ============
|
// ============ Providers ============
|
||||||
|
|
||||||
pub async fn list_providers(db: &SqlitePool) -> SaasResult<Vec<ProviderInfo>> {
|
pub async fn list_providers(
|
||||||
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
|
db: &PgPool, page: Option<u32>, page_size: Option<u32>, enabled_filter: Option<bool>,
|
||||||
sqlx::query_as(
|
) -> SaasResult<PaginatedResponse<ProviderInfo>> {
|
||||||
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
let (p, ps, offset) = normalize_pagination(page, page_size);
|
||||||
FROM providers ORDER BY name"
|
|
||||||
)
|
|
||||||
.fetch_all(db)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
|
let (count_sql, data_sql) = if enabled_filter.is_some() {
|
||||||
|
(
|
||||||
|
"SELECT COUNT(*) FROM providers WHERE enabled = $1",
|
||||||
|
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
||||||
|
FROM providers WHERE enabled = $1 ORDER BY name LIMIT $2 OFFSET $3",
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
"SELECT COUNT(*) FROM providers",
|
||||||
|
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
||||||
|
FROM providers ORDER BY name LIMIT $1 OFFSET $2",
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let total: (i64,) = if let Some(en) = enabled_filter {
|
||||||
|
sqlx::query_as(count_sql).bind(en).fetch_one(db).await?
|
||||||
|
} else {
|
||||||
|
sqlx::query_as(count_sql).fetch_one(db).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
|
||||||
|
if let Some(en) = enabled_filter {
|
||||||
|
sqlx::query_as(data_sql)
|
||||||
|
.bind(en).bind(ps as i64).bind(offset)
|
||||||
|
.fetch_all(db).await?
|
||||||
|
} else {
|
||||||
|
sqlx::query_as(data_sql)
|
||||||
|
.bind(ps as i64).bind(offset)
|
||||||
|
.fetch_all(db).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
let items = rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
|
||||||
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at }
|
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at }
|
||||||
}).collect())
|
}).collect();
|
||||||
|
|
||||||
|
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<ProviderInfo> {
|
pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
|
||||||
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
|
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
|
||||||
FROM providers WHERE id = ?1"
|
FROM providers WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(provider_id)
|
.bind(provider_id)
|
||||||
.fetch_optional(db)
|
.fetch_optional(db)
|
||||||
@@ -36,22 +67,33 @@ pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<Prov
|
|||||||
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at })
|
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> SaasResult<ProviderInfo> {
|
pub async fn create_provider(db: &PgPool, req: &CreateProviderRequest, enc_key: &[u8; 32]) -> SaasResult<ProviderInfo> {
|
||||||
let id = uuid::Uuid::new_v4().to_string();
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
// 检查名称唯一性
|
// 检查名称唯一性
|
||||||
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = ?1")
|
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1")
|
||||||
.bind(&req.name).fetch_optional(db).await?;
|
.bind(&req.name).fetch_optional(db).await?;
|
||||||
if existing.is_some() {
|
if existing.is_some() {
|
||||||
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
|
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 加密 API Key 后存储
|
||||||
|
let encrypted_api_key = if let Some(ref key) = req.api_key {
|
||||||
|
if key.is_empty() {
|
||||||
|
String::new()
|
||||||
|
} else {
|
||||||
|
crypto::encrypt_value(key, enc_key)?
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
String::new()
|
||||||
|
};
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
|
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?8, ?9, ?9)"
|
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)"
|
||||||
)
|
)
|
||||||
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.api_key)
|
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key)
|
||||||
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
|
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
|
||||||
.execute(db).await?;
|
.execute(db).await?;
|
||||||
|
|
||||||
@@ -59,29 +101,34 @@ pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> Sa
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update_provider(
|
pub async fn update_provider(
|
||||||
db: &SqlitePool, provider_id: &str, req: &UpdateProviderRequest,
|
db: &PgPool, provider_id: &str, req: &UpdateProviderRequest, enc_key: &[u8; 32],
|
||||||
) -> SaasResult<ProviderInfo> {
|
) -> SaasResult<ProviderInfo> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let mut updates = Vec::new();
|
let mut updates = Vec::new();
|
||||||
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
|
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
|
||||||
|
let mut param_idx = 1;
|
||||||
|
|
||||||
if let Some(ref v) = req.display_name { updates.push("display_name = ?"); params.push(Box::new(v.clone())); }
|
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||||
if let Some(ref v) = req.base_url { updates.push("base_url = ?"); params.push(Box::new(v.clone())); }
|
if let Some(ref v) = req.base_url { updates.push(format!("base_url = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||||
if let Some(ref v) = req.api_protocol { updates.push("api_protocol = ?"); params.push(Box::new(v.clone())); }
|
if let Some(ref v) = req.api_protocol { updates.push(format!("api_protocol = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||||
if let Some(ref v) = req.api_key { updates.push("api_key = ?"); params.push(Box::new(v.clone())); }
|
if let Some(ref v) = req.api_key {
|
||||||
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); }
|
let encrypted = if v.is_empty() { String::new() } else { crypto::encrypt_value(v, enc_key)? };
|
||||||
if let Some(v) = req.rate_limit_rpm { updates.push("rate_limit_rpm = ?"); params.push(Box::new(v)); }
|
updates.push(format!("api_key = ${}", param_idx)); params.push(Box::new(encrypted)); param_idx += 1;
|
||||||
if let Some(v) = req.rate_limit_tpm { updates.push("rate_limit_tpm = ?"); params.push(Box::new(v)); }
|
}
|
||||||
|
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||||
|
if let Some(v) = req.rate_limit_rpm { updates.push(format!("rate_limit_rpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||||
|
if let Some(v) = req.rate_limit_tpm { updates.push(format!("rate_limit_tpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||||
|
|
||||||
if updates.is_empty() {
|
if updates.is_empty() {
|
||||||
return get_provider(db, provider_id).await;
|
return get_provider(db, provider_id).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
updates.push("updated_at = ?");
|
updates.push(format!("updated_at = ${}", param_idx));
|
||||||
params.push(Box::new(now.clone()));
|
params.push(Box::new(now.clone()));
|
||||||
|
param_idx += 1;
|
||||||
params.push(Box::new(provider_id.to_string()));
|
params.push(Box::new(provider_id.to_string()));
|
||||||
|
|
||||||
let sql = format!("UPDATE providers SET {} WHERE id = ?", updates.join(", "));
|
let sql = format!("UPDATE providers SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||||
let mut query = sqlx::query(&sql);
|
let mut query = sqlx::query(&sql);
|
||||||
for p in ¶ms {
|
for p in ¶ms {
|
||||||
query = query.bind(format!("{}", p));
|
query = query.bind(format!("{}", p));
|
||||||
@@ -91,8 +138,8 @@ pub async fn update_provider(
|
|||||||
get_provider(db, provider_id).await
|
get_provider(db, provider_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<()> {
|
pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> {
|
||||||
let result = sqlx::query("DELETE FROM providers WHERE id = ?1")
|
let result = sqlx::query("DELETE FROM providers WHERE id = $1")
|
||||||
.bind(provider_id).execute(db).await?;
|
.bind(provider_id).execute(db).await?;
|
||||||
|
|
||||||
if result.rows_affected() == 0 {
|
if result.rows_affected() == 0 {
|
||||||
@@ -103,27 +150,45 @@ pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<(
|
|||||||
|
|
||||||
// ============ Models ============
|
// ============ Models ============
|
||||||
|
|
||||||
pub async fn list_models(db: &SqlitePool, provider_id: Option<&str>) -> SaasResult<Vec<ModelInfo>> {
|
pub async fn list_models(
|
||||||
let sql = if provider_id.is_some() {
|
db: &PgPool, provider_id: Option<&str>, page: Option<u32>, page_size: Option<u32>,
|
||||||
|
) -> SaasResult<PaginatedResponse<ModelInfo>> {
|
||||||
|
let (p, ps, offset) = normalize_pagination(page, page_size);
|
||||||
|
|
||||||
|
let (count_sql, data_sql) = if provider_id.is_some() {
|
||||||
|
(
|
||||||
|
"SELECT COUNT(*) FROM models WHERE provider_id = $1",
|
||||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
||||||
FROM models WHERE provider_id = ?1 ORDER BY alias"
|
FROM models WHERE provider_id = $1 ORDER BY alias LIMIT $2 OFFSET $3",
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
|
(
|
||||||
|
"SELECT COUNT(*) FROM models",
|
||||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
||||||
FROM models ORDER BY provider_id, alias"
|
FROM models ORDER BY provider_id, alias LIMIT $1 OFFSET $2",
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(sql);
|
let total: (i64,) = if let Some(pid) = provider_id {
|
||||||
|
sqlx::query_as(count_sql).bind(pid).fetch_one(db).await?
|
||||||
|
} else {
|
||||||
|
sqlx::query_as(count_sql).fetch_one(db).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(data_sql);
|
||||||
if let Some(pid) = provider_id {
|
if let Some(pid) = provider_id {
|
||||||
query = query.bind(pid);
|
query = query.bind(pid);
|
||||||
}
|
}
|
||||||
|
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||||
|
|
||||||
let rows = query.fetch_all(db).await?;
|
let items = rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
|
||||||
Ok(rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
|
|
||||||
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at }
|
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at }
|
||||||
}).collect())
|
}).collect();
|
||||||
|
|
||||||
|
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
|
pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
|
||||||
// 验证 provider 存在
|
// 验证 provider 存在
|
||||||
let provider = get_provider(db, &req.provider_id).await?;
|
let provider = get_provider(db, &req.provider_id).await?;
|
||||||
|
|
||||||
@@ -132,7 +197,7 @@ pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResu
|
|||||||
|
|
||||||
// 检查 model 唯一性
|
// 检查 model 唯一性
|
||||||
let existing: Option<(String,)> = sqlx::query_as(
|
let existing: Option<(String,)> = sqlx::query_as(
|
||||||
"SELECT id FROM models WHERE provider_id = ?1 AND model_id = ?2"
|
"SELECT id FROM models WHERE provider_id = $1 AND model_id = $2"
|
||||||
)
|
)
|
||||||
.bind(&req.provider_id).bind(&req.model_id)
|
.bind(&req.provider_id).bind(&req.model_id)
|
||||||
.fetch_optional(db).await?;
|
.fetch_optional(db).await?;
|
||||||
@@ -152,7 +217,7 @@ pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResu
|
|||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
|
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 1, ?9, ?10, ?11, ?11)"
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)"
|
||||||
)
|
)
|
||||||
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
|
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
|
||||||
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
|
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
|
||||||
@@ -161,11 +226,11 @@ pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResu
|
|||||||
get_model(db, &id).await
|
get_model(db, &id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult<ModelInfo> {
|
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
|
||||||
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)> =
|
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
|
||||||
FROM models WHERE id = ?1"
|
FROM models WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(model_id)
|
.bind(model_id)
|
||||||
.fetch_optional(db)
|
.fetch_optional(db)
|
||||||
@@ -178,30 +243,32 @@ pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult<ModelInfo>
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update_model(
|
pub async fn update_model(
|
||||||
db: &SqlitePool, model_id: &str, req: &UpdateModelRequest,
|
db: &PgPool, model_id: &str, req: &UpdateModelRequest,
|
||||||
) -> SaasResult<ModelInfo> {
|
) -> SaasResult<ModelInfo> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let mut updates = Vec::new();
|
let mut updates = Vec::new();
|
||||||
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
|
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
|
||||||
|
let mut param_idx = 1;
|
||||||
|
|
||||||
if let Some(ref v) = req.alias { updates.push("alias = ?"); params.push(Box::new(v.clone())); }
|
if let Some(ref v) = req.alias { updates.push(format!("alias = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
|
||||||
if let Some(v) = req.context_window { updates.push("context_window = ?"); params.push(Box::new(v)); }
|
if let Some(v) = req.context_window { updates.push(format!("context_window = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||||
if let Some(v) = req.max_output_tokens { updates.push("max_output_tokens = ?"); params.push(Box::new(v)); }
|
if let Some(v) = req.max_output_tokens { updates.push(format!("max_output_tokens = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||||
if let Some(v) = req.supports_streaming { updates.push("supports_streaming = ?"); params.push(Box::new(v)); }
|
if let Some(v) = req.supports_streaming { updates.push(format!("supports_streaming = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||||
if let Some(v) = req.supports_vision { updates.push("supports_vision = ?"); params.push(Box::new(v)); }
|
if let Some(v) = req.supports_vision { updates.push(format!("supports_vision = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||||
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); }
|
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||||
if let Some(v) = req.pricing_input { updates.push("pricing_input = ?"); params.push(Box::new(v)); }
|
if let Some(v) = req.pricing_input { updates.push(format!("pricing_input = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||||
if let Some(v) = req.pricing_output { updates.push("pricing_output = ?"); params.push(Box::new(v)); }
|
if let Some(v) = req.pricing_output { updates.push(format!("pricing_output = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
|
||||||
|
|
||||||
if updates.is_empty() {
|
if updates.is_empty() {
|
||||||
return get_model(db, model_id).await;
|
return get_model(db, model_id).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
updates.push("updated_at = ?");
|
updates.push(format!("updated_at = ${}", param_idx));
|
||||||
params.push(Box::new(now.clone()));
|
params.push(Box::new(now.clone()));
|
||||||
|
param_idx += 1;
|
||||||
params.push(Box::new(model_id.to_string()));
|
params.push(Box::new(model_id.to_string()));
|
||||||
|
|
||||||
let sql = format!("UPDATE models SET {} WHERE id = ?", updates.join(", "));
|
let sql = format!("UPDATE models SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||||
let mut query = sqlx::query(&sql);
|
let mut query = sqlx::query(&sql);
|
||||||
for p in ¶ms {
|
for p in ¶ms {
|
||||||
query = query.bind(format!("{}", p));
|
query = query.bind(format!("{}", p));
|
||||||
@@ -211,8 +278,8 @@ pub async fn update_model(
|
|||||||
get_model(db, model_id).await
|
get_model(db, model_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> {
|
pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> {
|
||||||
let result = sqlx::query("DELETE FROM models WHERE id = ?1")
|
let result = sqlx::query("DELETE FROM models WHERE id = $1")
|
||||||
.bind(model_id).execute(db).await?;
|
.bind(model_id).execute(db).await?;
|
||||||
|
|
||||||
if result.rows_affected() == 0 {
|
if result.rows_affected() == 0 {
|
||||||
@@ -224,32 +291,52 @@ pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> {
|
|||||||
// ============ Account API Keys ============
|
// ============ Account API Keys ============
|
||||||
|
|
||||||
pub async fn list_account_api_keys(
|
pub async fn list_account_api_keys(
|
||||||
db: &SqlitePool, account_id: &str, provider_id: Option<&str>,
|
db: &PgPool, account_id: &str, provider_id: Option<&str>,
|
||||||
) -> SaasResult<Vec<AccountApiKeyInfo>> {
|
page: Option<u32>, page_size: Option<u32>,
|
||||||
let sql = if provider_id.is_some() {
|
) -> SaasResult<PaginatedResponse<AccountApiKeyInfo>> {
|
||||||
|
let (p, ps, offset) = normalize_pagination(page, page_size);
|
||||||
|
|
||||||
|
// Build COUNT and data queries based on whether provider_id is provided
|
||||||
|
let (count_sql, data_sql) = if provider_id.is_some() {
|
||||||
|
(
|
||||||
|
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL",
|
||||||
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
|
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
|
||||||
FROM account_api_keys WHERE account_id = ?1 AND provider_id = ?2 AND revoked_at IS NULL ORDER BY created_at DESC"
|
FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $3 OFFSET $4",
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
|
(
|
||||||
|
"SELECT COUNT(*) FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL",
|
||||||
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
|
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
|
||||||
FROM account_api_keys WHERE account_id = ?1 AND revoked_at IS NULL ORDER BY created_at DESC"
|
FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3",
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<String>, String, String)>(sql)
|
let total: (i64,) = if provider_id.is_some() {
|
||||||
|
let mut q = sqlx::query_as(count_sql).bind(account_id);
|
||||||
|
if let Some(pid) = provider_id { q = q.bind(pid); }
|
||||||
|
q.fetch_one(db).await?
|
||||||
|
} else {
|
||||||
|
sqlx::query_as(count_sql).bind(account_id).fetch_one(db).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<String>, String, String)>(data_sql)
|
||||||
.bind(account_id);
|
.bind(account_id);
|
||||||
if let Some(pid) = provider_id {
|
if let Some(pid) = provider_id {
|
||||||
query = query.bind(pid);
|
query = query.bind(pid);
|
||||||
}
|
}
|
||||||
|
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||||
|
|
||||||
let rows = query.fetch_all(db).await?;
|
let items = rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
|
||||||
Ok(rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
|
|
||||||
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||||
let masked = mask_api_key(&key_value);
|
let masked = mask_api_key(&key_value);
|
||||||
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used, created_at, masked_key: masked }
|
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used, created_at, masked_key: masked }
|
||||||
}).collect())
|
}).collect();
|
||||||
|
|
||||||
|
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn create_account_api_key(
|
pub async fn create_account_api_key(
|
||||||
db: &SqlitePool, account_id: &str, req: &CreateAccountApiKeyRequest,
|
db: &PgPool, account_id: &str, req: &CreateAccountApiKeyRequest, enc_key: &[u8; 32],
|
||||||
) -> SaasResult<AccountApiKeyInfo> {
|
) -> SaasResult<AccountApiKeyInfo> {
|
||||||
// 验证 provider 存在
|
// 验证 provider 存在
|
||||||
get_provider(db, &req.provider_id).await?;
|
get_provider(db, &req.provider_id).await?;
|
||||||
@@ -258,11 +345,14 @@ pub async fn create_account_api_key(
|
|||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let permissions = serde_json::to_string(&req.permissions)?;
|
let permissions = serde_json::to_string(&req.permissions)?;
|
||||||
|
|
||||||
|
// 加密 key_value 后存储
|
||||||
|
let encrypted_key_value = crypto::encrypt_value(&req.key_value, enc_key)?;
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
|
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?7)"
|
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)"
|
||||||
)
|
)
|
||||||
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&req.key_value)
|
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value)
|
||||||
.bind(&req.key_label).bind(&permissions).bind(&now)
|
.bind(&req.key_label).bind(&permissions).bind(&now)
|
||||||
.execute(db).await?;
|
.execute(db).await?;
|
||||||
|
|
||||||
@@ -275,13 +365,14 @@ pub async fn create_account_api_key(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn rotate_account_api_key(
|
pub async fn rotate_account_api_key(
|
||||||
db: &SqlitePool, key_id: &str, account_id: &str, new_key_value: &str,
|
db: &PgPool, key_id: &str, account_id: &str, new_key_value: &str, enc_key: &[u8; 32],
|
||||||
) -> SaasResult<()> {
|
) -> SaasResult<()> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let encrypted_value = crypto::encrypt_value(new_key_value, enc_key)?;
|
||||||
let result = sqlx::query(
|
let result = sqlx::query(
|
||||||
"UPDATE account_api_keys SET key_value = ?1, updated_at = ?2 WHERE id = ?3 AND account_id = ?4 AND revoked_at IS NULL"
|
"UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL"
|
||||||
)
|
)
|
||||||
.bind(new_key_value).bind(&now).bind(key_id).bind(account_id)
|
.bind(&encrypted_value).bind(&now).bind(key_id).bind(account_id)
|
||||||
.execute(db).await?;
|
.execute(db).await?;
|
||||||
|
|
||||||
if result.rows_affected() == 0 {
|
if result.rows_affected() == 0 {
|
||||||
@@ -291,11 +382,11 @@ pub async fn rotate_account_api_key(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn revoke_account_api_key(
|
pub async fn revoke_account_api_key(
|
||||||
db: &SqlitePool, key_id: &str, account_id: &str,
|
db: &PgPool, key_id: &str, account_id: &str,
|
||||||
) -> SaasResult<()> {
|
) -> SaasResult<()> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
let result = sqlx::query(
|
let result = sqlx::query(
|
||||||
"UPDATE account_api_keys SET revoked_at = ?1 WHERE id = ?2 AND account_id = ?3 AND revoked_at IS NULL"
|
"UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
|
||||||
)
|
)
|
||||||
.bind(&now).bind(key_id).bind(account_id)
|
.bind(&now).bind(key_id).bind(account_id)
|
||||||
.execute(db).await?;
|
.execute(db).await?;
|
||||||
@@ -309,25 +400,30 @@ pub async fn revoke_account_api_key(
|
|||||||
// ============ Usage Statistics ============
|
// ============ Usage Statistics ============
|
||||||
|
|
||||||
pub async fn get_usage_stats(
|
pub async fn get_usage_stats(
|
||||||
db: &SqlitePool, account_id: &str, query: &UsageQuery,
|
db: &PgPool, account_id: &str, query: &UsageQuery,
|
||||||
) -> SaasResult<UsageStats> {
|
) -> SaasResult<UsageStats> {
|
||||||
let mut where_clauses = vec!["account_id = ?".to_string()];
|
let mut param_idx = 1;
|
||||||
|
let mut where_clauses = vec![format!("account_id = ${}", param_idx)];
|
||||||
let mut params: Vec<String> = vec![account_id.to_string()];
|
let mut params: Vec<String> = vec![account_id.to_string()];
|
||||||
|
param_idx += 1;
|
||||||
|
|
||||||
if let Some(ref from) = query.from {
|
if let Some(ref from) = query.from {
|
||||||
where_clauses.push("created_at >= ?".to_string());
|
where_clauses.push(format!("created_at >= ${}", param_idx));
|
||||||
params.push(from.clone());
|
params.push(from.clone());
|
||||||
|
param_idx += 1;
|
||||||
}
|
}
|
||||||
if let Some(ref to) = query.to {
|
if let Some(ref to) = query.to {
|
||||||
where_clauses.push("created_at <= ?".to_string());
|
where_clauses.push(format!("created_at <= ${}", param_idx));
|
||||||
params.push(to.clone());
|
params.push(to.clone());
|
||||||
|
param_idx += 1;
|
||||||
}
|
}
|
||||||
if let Some(ref pid) = query.provider_id {
|
if let Some(ref pid) = query.provider_id {
|
||||||
where_clauses.push("provider_id = ?".to_string());
|
where_clauses.push(format!("provider_id = ${}", param_idx));
|
||||||
params.push(pid.clone());
|
params.push(pid.clone());
|
||||||
|
param_idx += 1;
|
||||||
}
|
}
|
||||||
if let Some(ref mid) = query.model_id {
|
if let Some(ref mid) = query.model_id {
|
||||||
where_clauses.push("model_id = ?".to_string());
|
where_clauses.push(format!("model_id = ${}", param_idx));
|
||||||
params.push(mid.clone());
|
params.push(mid.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,18 +431,21 @@ pub async fn get_usage_stats(
|
|||||||
|
|
||||||
// 总量统计
|
// 总量统计
|
||||||
let total_sql = format!(
|
let total_sql = format!(
|
||||||
"SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
"SELECT COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||||
FROM usage_records WHERE {}", where_sql
|
FROM usage_records WHERE {}", where_sql
|
||||||
);
|
);
|
||||||
let mut total_query = sqlx::query_as::<_, (i64, i64, i64)>(&total_sql);
|
let mut total_query = sqlx::query(&total_sql);
|
||||||
for p in ¶ms {
|
for p in ¶ms {
|
||||||
total_query = total_query.bind(p);
|
total_query = total_query.bind(p);
|
||||||
}
|
}
|
||||||
let (total_requests, total_input, total_output) = total_query.fetch_one(db).await?;
|
let row = total_query.fetch_one(db).await?;
|
||||||
|
let total_requests: i64 = row.try_get(0).unwrap_or(0);
|
||||||
|
let total_input: i64 = row.try_get(1).unwrap_or(0);
|
||||||
|
let total_output: i64 = row.try_get(2).unwrap_or(0);
|
||||||
|
|
||||||
// 按模型统计
|
// 按模型统计
|
||||||
let by_model_sql = format!(
|
let by_model_sql = format!(
|
||||||
"SELECT provider_id, model_id, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
"SELECT provider_id, model_id, COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||||
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20",
|
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20",
|
||||||
where_sql
|
where_sql
|
||||||
);
|
);
|
||||||
@@ -360,21 +459,27 @@ pub async fn get_usage_stats(
|
|||||||
ModelUsage { provider_id, model_id, request_count: count, input_tokens: input, output_tokens: output }
|
ModelUsage { provider_id, model_id, request_count: count, input_tokens: input, output_tokens: output }
|
||||||
}).collect();
|
}).collect();
|
||||||
|
|
||||||
// 按天统计 (最近 30 天)
|
// 按天统计 (使用 days 参数或默认 30 天)
|
||||||
let from_30d = (chrono::Utc::now() - chrono::Duration::days(30)).to_rfc3339();
|
let days = query.days.unwrap_or(30).min(365).max(1) as i64;
|
||||||
|
let from_days = (chrono::Utc::now() - chrono::Duration::days(days)).format("%Y-%m-%d").to_string() + "T00:00:00Z";
|
||||||
let daily_sql = format!(
|
let daily_sql = format!(
|
||||||
"SELECT DATE(created_at) as day, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
"SELECT SUBSTRING(created_at, 1, 10) as day, COUNT(*)::bigint, COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
|
||||||
FROM usage_records WHERE account_id = ?1 AND created_at >= ?2
|
FROM usage_records WHERE account_id = $1 AND created_at >= $2
|
||||||
GROUP BY DATE(created_at) ORDER BY day DESC LIMIT 30"
|
GROUP BY SUBSTRING(created_at, 1, 10) ORDER BY day DESC LIMIT $3"
|
||||||
);
|
);
|
||||||
let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql)
|
let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql)
|
||||||
.bind(account_id).bind(&from_30d)
|
.bind(account_id).bind(&from_days).bind(days as i32)
|
||||||
.fetch_all(db).await?;
|
.fetch_all(db).await?;
|
||||||
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
|
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
|
||||||
.map(|(date, count, input, output)| {
|
.map(|(date, count, input, output)| {
|
||||||
DailyUsage { date, request_count: count, input_tokens: input, output_tokens: output }
|
DailyUsage { date, request_count: count, input_tokens: input, output_tokens: output }
|
||||||
}).collect();
|
}).collect();
|
||||||
|
|
||||||
|
// 按 group_by 过滤返回
|
||||||
|
let group_by = query.group_by.as_deref();
|
||||||
|
let by_model = if group_by == Some("model") || group_by.is_none() { by_model } else { vec![] };
|
||||||
|
let by_day = if group_by == Some("day") || group_by.is_none() { by_day } else { vec![] };
|
||||||
|
|
||||||
Ok(UsageStats {
|
Ok(UsageStats {
|
||||||
total_requests,
|
total_requests,
|
||||||
total_input_tokens: total_input,
|
total_input_tokens: total_input,
|
||||||
@@ -385,14 +490,14 @@ pub async fn get_usage_stats(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn record_usage(
|
pub async fn record_usage(
|
||||||
db: &SqlitePool, account_id: &str, provider_id: &str, model_id: &str,
|
db: &PgPool, account_id: &str, provider_id: &str, model_id: &str,
|
||||||
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
|
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
|
||||||
status: &str, error_message: Option<&str>,
|
status: &str, error_message: Option<&str>,
|
||||||
) -> SaasResult<()> {
|
) -> SaasResult<()> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
|
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)"
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||||
)
|
)
|
||||||
.bind(account_id).bind(provider_id).bind(model_id)
|
.bind(account_id).bind(provider_id).bind(model_id)
|
||||||
.bind(input_tokens).bind(output_tokens).bind(latency_ms)
|
.bind(input_tokens).bind(output_tokens).bind(latency_ms)
|
||||||
|
|||||||
@@ -149,6 +149,10 @@ pub struct UsageQuery {
|
|||||||
pub to: Option<String>,
|
pub to: Option<String>,
|
||||||
pub provider_id: Option<String>,
|
pub provider_id: Option<String>,
|
||||||
pub model_id: Option<String>,
|
pub model_id: Option<String>,
|
||||||
|
/// 聚合维度: "day" 或 "model"。不传则返回完整 UsageStats
|
||||||
|
pub group_by: Option<String>,
|
||||||
|
/// 最近 N 天
|
||||||
|
pub days: Option<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- Seed Data ---
|
// --- Seed Data ---
|
||||||
|
|||||||
173
crates/zclaw-saas/src/prompt/handlers.rs
Normal file
173
crates/zclaw-saas/src/prompt/handlers.rs
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
//! 提示词模板 HTTP 处理器
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Extension, Path, Query, State},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::error::SaasResult;
|
||||||
|
use crate::auth::types::AuthContext;
|
||||||
|
use crate::auth::handlers::{log_operation, check_permission};
|
||||||
|
use super::types::*;
|
||||||
|
use super::service;
|
||||||
|
|
||||||
|
/// GET /api/v1/prompts/check — OTA 批量检查更新
|
||||||
|
pub async fn check_updates(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Json(req): Json<PromptCheckRequest>,
|
||||||
|
) -> SaasResult<Json<PromptCheckResponse>> {
|
||||||
|
let result = service::check_updates(&state.db, &req.device_id, &req.versions).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "prompt.check", "prompt", &req.device_id,
|
||||||
|
Some(serde_json::json!({"updates_count": result.updates.len()})), ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/prompts — 列表全部模板
|
||||||
|
pub async fn list_prompts(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Query(query): Query<PromptListQuery>,
|
||||||
|
) -> SaasResult<Json<crate::common::PaginatedResponse<PromptTemplateInfo>>> {
|
||||||
|
check_permission(&ctx, "prompt:read")?;
|
||||||
|
Ok(Json(service::list_templates(&state.db, &query).await?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /api/v1/prompts — 创建提示词模板
|
||||||
|
pub async fn create_prompt(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Json(req): Json<CreatePromptRequest>,
|
||||||
|
) -> SaasResult<Json<PromptTemplateInfo>> {
|
||||||
|
check_permission(&ctx, "prompt:write")?;
|
||||||
|
|
||||||
|
let source = req.source.as_deref().unwrap_or("custom");
|
||||||
|
let result = service::create_template(
|
||||||
|
&state.db, &req.name, &req.category, req.description.as_deref(),
|
||||||
|
source, &req.system_prompt,
|
||||||
|
req.user_prompt_template.as_deref(),
|
||||||
|
req.variables.clone(),
|
||||||
|
req.min_app_version.as_deref(),
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "prompt.create", "prompt", &result.id,
|
||||||
|
Some(serde_json::json!({"name": req.name})), ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/prompts/{name} — 获取模板(按名称)
|
||||||
|
pub async fn get_prompt(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path(name): Path<String>,
|
||||||
|
) -> SaasResult<Json<PromptTemplateInfo>> {
|
||||||
|
check_permission(&ctx, "prompt:read")?;
|
||||||
|
Ok(Json(service::get_template_by_name(&state.db, &name).await?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// PUT /api/v1/prompts/{name} — 更新模板元数据
|
||||||
|
pub async fn update_prompt(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path(name): Path<String>,
|
||||||
|
Json(req): Json<UpdatePromptRequest>,
|
||||||
|
) -> SaasResult<Json<PromptTemplateInfo>> {
|
||||||
|
check_permission(&ctx, "prompt:write")?;
|
||||||
|
|
||||||
|
let tmpl = service::get_template_by_name(&state.db, &name).await?;
|
||||||
|
let result = service::update_template(
|
||||||
|
&state.db, &tmpl.id,
|
||||||
|
req.description.as_deref(),
|
||||||
|
req.status.as_deref(),
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "prompt.update", "prompt", &tmpl.id,
|
||||||
|
Some(serde_json::json!({"name": name})), ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DELETE /api/v1/prompts/{name} — 归档模板
|
||||||
|
pub async fn archive_prompt(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path(name): Path<String>,
|
||||||
|
) -> SaasResult<Json<PromptTemplateInfo>> {
|
||||||
|
check_permission(&ctx, "prompt:admin")?;
|
||||||
|
|
||||||
|
let tmpl = service::get_template_by_name(&state.db, &name).await?;
|
||||||
|
let result = service::update_template(&state.db, &tmpl.id, None, Some("archived")).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "prompt.archive", "prompt", &tmpl.id, None, ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/prompts/{name}/versions — 查看版本历史
|
||||||
|
pub async fn list_versions(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path(name): Path<String>,
|
||||||
|
) -> SaasResult<Json<Vec<PromptVersionInfo>>> {
|
||||||
|
check_permission(&ctx, "prompt:read")?;
|
||||||
|
|
||||||
|
let tmpl = service::get_template_by_name(&state.db, &name).await?;
|
||||||
|
Ok(Json(service::list_versions(&state.db, &tmpl.id).await?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/prompts/{name}/versions/{version} — 获取特定版本
|
||||||
|
pub async fn get_version(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path((name, _version)): Path<(String, i32)>,
|
||||||
|
) -> SaasResult<Json<PromptVersionInfo>> {
|
||||||
|
check_permission(&ctx, "prompt:read")?;
|
||||||
|
|
||||||
|
let _tmpl = service::get_template_by_name(&state.db, &name).await?;
|
||||||
|
Ok(Json(service::get_current_version(&state.db, &name).await?))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /api/v1/prompts/{name}/versions — 发布新版本
|
||||||
|
pub async fn create_version(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path(name): Path<String>,
|
||||||
|
Json(req): Json<CreateVersionRequest>,
|
||||||
|
) -> SaasResult<Json<PromptVersionInfo>> {
|
||||||
|
check_permission(&ctx, "prompt:write")?;
|
||||||
|
|
||||||
|
let tmpl = service::get_template_by_name(&state.db, &name).await?;
|
||||||
|
let result = service::create_version(
|
||||||
|
&state.db, &tmpl.id,
|
||||||
|
&req.system_prompt,
|
||||||
|
req.user_prompt_template.as_deref(),
|
||||||
|
req.variables.clone(),
|
||||||
|
req.changelog.as_deref(),
|
||||||
|
req.min_app_version.as_deref(),
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "prompt.publish_version", "prompt", &tmpl.id,
|
||||||
|
Some(serde_json::json!({"name": name, "version": result.version})), ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /api/v1/prompts/{name}/rollback/{version} — 回退到指定版本
|
||||||
|
pub async fn rollback_version(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path((name, version)): Path<(String, i32)>,
|
||||||
|
) -> SaasResult<Json<PromptTemplateInfo>> {
|
||||||
|
check_permission(&ctx, "prompt:admin")?;
|
||||||
|
|
||||||
|
let tmpl = service::get_template_by_name(&state.db, &name).await?;
|
||||||
|
let result = service::rollback_version(&state.db, &tmpl.id, version).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "prompt.rollback", "prompt", &tmpl.id,
|
||||||
|
Some(serde_json::json!({"name": name, "target_version": version})), ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
|
}
|
||||||
19
crates/zclaw-saas/src/prompt/mod.rs
Normal file
19
crates/zclaw-saas/src/prompt/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
//! 提示词模板管理模块
|
||||||
|
|
||||||
|
pub mod types;
|
||||||
|
pub mod service;
|
||||||
|
pub mod handlers;
|
||||||
|
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
/// 提示词管理路由 (需要认证)
|
||||||
|
pub fn routes() -> axum::Router<AppState> {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/api/v1/prompts/check", post(handlers::check_updates))
|
||||||
|
.route("/api/v1/prompts", get(handlers::list_prompts).post(handlers::create_prompt))
|
||||||
|
.route("/api/v1/prompts/:name", get(handlers::get_prompt).put(handlers::update_prompt).delete(handlers::archive_prompt))
|
||||||
|
.route("/api/v1/prompts/:name/versions", get(handlers::list_versions).post(handlers::create_version))
|
||||||
|
.route("/api/v1/prompts/:name/versions/:version", get(handlers::get_version))
|
||||||
|
.route("/api/v1/prompts/:name/rollback/:version", post(handlers::rollback_version))
|
||||||
|
}
|
||||||
323
crates/zclaw-saas/src/prompt/service.rs
Normal file
323
crates/zclaw-saas/src/prompt/service.rs
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
//! 提示词模板服务层
|
||||||
|
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use crate::error::{SaasError, SaasResult};
|
||||||
|
use crate::common::PaginatedResponse;
|
||||||
|
use crate::common::normalize_pagination;
|
||||||
|
use super::types::*;
|
||||||
|
|
||||||
|
/// 创建提示词模板 + 初始版本
|
||||||
|
pub async fn create_template(
|
||||||
|
db: &PgPool,
|
||||||
|
name: &str,
|
||||||
|
category: &str,
|
||||||
|
description: Option<&str>,
|
||||||
|
source: &str,
|
||||||
|
system_prompt: &str,
|
||||||
|
user_prompt_template: Option<&str>,
|
||||||
|
variables: Option<serde_json::Value>,
|
||||||
|
min_app_version: Option<&str>,
|
||||||
|
) -> SaasResult<PromptTemplateInfo> {
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let version_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let vars_json = variables.unwrap_or(serde_json::json!([])).to_string();
|
||||||
|
|
||||||
|
// 插入模板
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO prompt_templates (id, name, category, description, source, current_version, status, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, 1, 'active', $6, $6)"
|
||||||
|
)
|
||||||
|
.bind(&id).bind(name).bind(category).bind(description).bind(source).bind(&now)
|
||||||
|
.execute(db).await.map_err(|e| {
|
||||||
|
if e.to_string().contains("unique") {
|
||||||
|
SaasError::AlreadyExists(format!("提示词模板 '{}' 已存在", name))
|
||||||
|
} else {
|
||||||
|
SaasError::Database(e)
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// 插入 v1
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO prompt_versions (id, template_id, version, system_prompt, user_prompt_template, variables, min_app_version, created_at)
|
||||||
|
VALUES ($1, $2, 1, $3, $4, $5, $6, $7)"
|
||||||
|
)
|
||||||
|
.bind(&version_id).bind(&id).bind(system_prompt).bind(user_prompt_template).bind(&vars_json).bind(min_app_version).bind(&now)
|
||||||
|
.execute(db).await?;
|
||||||
|
|
||||||
|
get_template(db, &id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取单个模板
|
||||||
|
pub async fn get_template(db: &PgPool, id: &str) -> SaasResult<PromptTemplateInfo> {
|
||||||
|
let row: Option<(String, String, String, Option<String>, String, i32, String, String, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at
|
||||||
|
FROM prompt_templates WHERE id = $1"
|
||||||
|
).bind(id).fetch_optional(db).await?;
|
||||||
|
|
||||||
|
let (id, name, category, description, source, current_version, status, created_at, updated_at) =
|
||||||
|
row.ok_or_else(|| SaasError::NotFound(format!("提示词模板 {} 不存在", id)))?;
|
||||||
|
|
||||||
|
Ok(PromptTemplateInfo { id, name, category, description, source, current_version, status, created_at, updated_at })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按名称获取模板
|
||||||
|
pub async fn get_template_by_name(db: &PgPool, name: &str) -> SaasResult<PromptTemplateInfo> {
|
||||||
|
let row: Option<(String, String, String, Option<String>, String, i32, String, String, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at
|
||||||
|
FROM prompt_templates WHERE name = $1"
|
||||||
|
).bind(name).fetch_optional(db).await?;
|
||||||
|
|
||||||
|
let (id, name, category, description, source, current_version, status, created_at, updated_at) =
|
||||||
|
row.ok_or_else(|| SaasError::NotFound(format!("提示词模板 '{}' 不存在", name)))?;
|
||||||
|
|
||||||
|
Ok(PromptTemplateInfo { id, name, category, description, source, current_version, status, created_at, updated_at })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 列表模板
|
||||||
|
pub async fn list_templates(
|
||||||
|
db: &PgPool,
|
||||||
|
query: &PromptListQuery,
|
||||||
|
) -> SaasResult<PaginatedResponse<PromptTemplateInfo>> {
|
||||||
|
let (page, page_size, offset) = normalize_pagination(query.page, query.page_size);
|
||||||
|
|
||||||
|
let mut where_clauses = vec!["1=1".to_string()];
|
||||||
|
let mut count_sql = String::from("SELECT COUNT(*) FROM prompt_templates WHERE ");
|
||||||
|
let mut data_sql = String::from(
|
||||||
|
"SELECT id, name, category, description, source, current_version, status, created_at, updated_at
|
||||||
|
FROM prompt_templates WHERE "
|
||||||
|
);
|
||||||
|
|
||||||
|
if let Some(ref cat) = query.category {
|
||||||
|
where_clauses.push(format!("category = '{}'", cat.replace('\'', "''")));
|
||||||
|
}
|
||||||
|
if let Some(ref src) = query.source {
|
||||||
|
where_clauses.push(format!("source = '{}'", src.replace('\'', "''")));
|
||||||
|
}
|
||||||
|
if let Some(ref st) = query.status {
|
||||||
|
where_clauses.push(format!("status = '{}'", st.replace('\'', "''")));
|
||||||
|
}
|
||||||
|
|
||||||
|
let where_clause = where_clauses.join(" AND ");
|
||||||
|
count_sql.push_str(&where_clause);
|
||||||
|
data_sql.push_str(&where_clause);
|
||||||
|
data_sql.push_str(&format!(" ORDER BY updated_at DESC LIMIT {} OFFSET {}", page_size, offset));
|
||||||
|
|
||||||
|
let total: i64 = sqlx::query_scalar(&count_sql).fetch_one(db).await?;
|
||||||
|
|
||||||
|
let rows: Vec<(String, String, String, Option<String>, String, i32, String, String, String)> =
|
||||||
|
sqlx::query_as(&data_sql).fetch_all(db).await?;
|
||||||
|
|
||||||
|
let items: Vec<PromptTemplateInfo> = rows.into_iter().map(|(id, name, category, description, source, current_version, status, created_at, updated_at)| {
|
||||||
|
PromptTemplateInfo { id, name, category, description, source, current_version, status, created_at, updated_at }
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
Ok(PaginatedResponse { items, total, page, page_size })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 更新模板元数据(不修改内容)
|
||||||
|
pub async fn update_template(
|
||||||
|
db: &PgPool,
|
||||||
|
id: &str,
|
||||||
|
description: Option<&str>,
|
||||||
|
status: Option<&str>,
|
||||||
|
) -> SaasResult<PromptTemplateInfo> {
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
|
if let Some(desc) = description {
|
||||||
|
sqlx::query("UPDATE prompt_templates SET description = $1, updated_at = $2 WHERE id = $3")
|
||||||
|
.bind(desc).bind(&now).bind(id).execute(db).await?;
|
||||||
|
}
|
||||||
|
if let Some(st) = status {
|
||||||
|
let valid = ["active", "deprecated", "archived"];
|
||||||
|
if !valid.contains(&st) {
|
||||||
|
return Err(SaasError::InvalidInput(format!("无效状态: {},允许: {}", st, valid.join(", "))));
|
||||||
|
}
|
||||||
|
sqlx::query("UPDATE prompt_templates SET status = $1, updated_at = $2 WHERE id = $3")
|
||||||
|
.bind(st).bind(&now).bind(id).execute(db).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
get_template(db, id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 发布新版本
|
||||||
|
pub async fn create_version(
|
||||||
|
db: &PgPool,
|
||||||
|
template_id: &str,
|
||||||
|
system_prompt: &str,
|
||||||
|
user_prompt_template: Option<&str>,
|
||||||
|
variables: Option<serde_json::Value>,
|
||||||
|
changelog: Option<&str>,
|
||||||
|
min_app_version: Option<&str>,
|
||||||
|
) -> SaasResult<PromptVersionInfo> {
|
||||||
|
let tmpl = get_template(db, template_id).await?;
|
||||||
|
|
||||||
|
let new_version = tmpl.current_version + 1;
|
||||||
|
let version_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let vars_json = variables.unwrap_or(serde_json::json!([])).to_string();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO prompt_versions (id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||||
|
)
|
||||||
|
.bind(&version_id).bind(template_id).bind(new_version)
|
||||||
|
.bind(system_prompt).bind(user_prompt_template).bind(&vars_json).bind(changelog).bind(min_app_version).bind(&now)
|
||||||
|
.execute(db).await?;
|
||||||
|
|
||||||
|
// 更新模板的 current_version
|
||||||
|
sqlx::query("UPDATE prompt_templates SET current_version = $1, updated_at = $2 WHERE id = $3")
|
||||||
|
.bind(new_version).bind(&now).bind(template_id)
|
||||||
|
.execute(db).await?;
|
||||||
|
|
||||||
|
get_version(db, &version_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取特定版本
|
||||||
|
pub async fn get_version(db: &PgPool, version_id: &str) -> SaasResult<PromptVersionInfo> {
|
||||||
|
let row: Option<(String, String, i32, String, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at
|
||||||
|
FROM prompt_versions WHERE id = $1"
|
||||||
|
).bind(version_id).fetch_optional(db).await?;
|
||||||
|
|
||||||
|
let (id, template_id, version, system_prompt, user_prompt_template, variables_str, changelog, min_app_version, created_at) =
|
||||||
|
row.ok_or_else(|| SaasError::NotFound(format!("提示词版本 {} 不存在", version_id)))?;
|
||||||
|
|
||||||
|
let variables: serde_json::Value = serde_json::from_str(&variables_str).unwrap_or(serde_json::json!([]));
|
||||||
|
|
||||||
|
Ok(PromptVersionInfo { id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取模板的当前版本内容
|
||||||
|
pub async fn get_current_version(db: &PgPool, template_name: &str) -> SaasResult<PromptVersionInfo> {
|
||||||
|
let tmpl = get_template_by_name(db, template_name).await?;
|
||||||
|
|
||||||
|
let row: Option<(String, String, i32, String, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at
|
||||||
|
FROM prompt_versions WHERE template_id = $1 AND version = $2"
|
||||||
|
).bind(&tmpl.id).bind(tmpl.current_version).fetch_optional(db).await?;
|
||||||
|
|
||||||
|
let (id, template_id, version, system_prompt, user_prompt_template, variables_str, changelog, min_app_version, created_at) =
|
||||||
|
row.ok_or_else(|| SaasError::NotFound(format!("提示词 '{}' 的版本 {} 不存在", template_name, tmpl.current_version)))?;
|
||||||
|
|
||||||
|
let variables: serde_json::Value = serde_json::from_str(&variables_str).unwrap_or(serde_json::json!([]));
|
||||||
|
|
||||||
|
Ok(PromptVersionInfo { id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 列出模板的所有版本
|
||||||
|
pub async fn list_versions(
|
||||||
|
db: &PgPool,
|
||||||
|
template_id: &str,
|
||||||
|
) -> SaasResult<Vec<PromptVersionInfo>> {
|
||||||
|
let rows: Vec<(String, String, i32, String, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at
|
||||||
|
FROM prompt_versions WHERE template_id = $1 ORDER BY version DESC"
|
||||||
|
).bind(template_id).fetch_all(db).await?;
|
||||||
|
|
||||||
|
Ok(rows.into_iter().map(|(id, template_id, version, system_prompt, user_prompt_template, variables_str, changelog, min_app_version, created_at)| {
|
||||||
|
let variables = serde_json::from_str(&variables_str).unwrap_or(serde_json::json!([]));
|
||||||
|
PromptVersionInfo { id, template_id, version, system_prompt, user_prompt_template, variables, changelog, min_app_version, created_at }
|
||||||
|
}).collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 回退到指定版本
|
||||||
|
pub async fn rollback_version(
|
||||||
|
db: &PgPool,
|
||||||
|
template_id: &str,
|
||||||
|
target_version: i32,
|
||||||
|
) -> SaasResult<PromptTemplateInfo> {
|
||||||
|
// 验证目标版本存在
|
||||||
|
let exists: (bool,) = sqlx::query_as(
|
||||||
|
"SELECT EXISTS(SELECT 1 FROM prompt_versions WHERE template_id = $1 AND version = $2)"
|
||||||
|
).bind(template_id).bind(target_version).fetch_one(db).await?;
|
||||||
|
|
||||||
|
if !exists.0 {
|
||||||
|
return Err(SaasError::NotFound(format!("版本 {} 不存在", target_version)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
sqlx::query("UPDATE prompt_templates SET current_version = $1, updated_at = $2 WHERE id = $3")
|
||||||
|
.bind(target_version).bind(&now).bind(template_id)
|
||||||
|
.execute(db).await?;
|
||||||
|
|
||||||
|
get_template(db, template_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// OTA 批量检查更新
|
||||||
|
pub async fn check_updates(
|
||||||
|
db: &PgPool,
|
||||||
|
device_id: &str,
|
||||||
|
client_versions: &std::collections::HashMap<String, i32>,
|
||||||
|
) -> SaasResult<PromptCheckResponse> {
|
||||||
|
let mut updates = Vec::new();
|
||||||
|
|
||||||
|
for (name, client_ver) in client_versions {
|
||||||
|
let tmpl = match get_template_by_name(db, name).await {
|
||||||
|
Ok(t) if t.status == "active" => t,
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
if tmpl.current_version > *client_ver {
|
||||||
|
// 获取最新版本内容
|
||||||
|
if let Ok(ver) = get_current_version(db, name).await {
|
||||||
|
updates.push(PromptUpdatePayload {
|
||||||
|
name: tmpl.name.clone(),
|
||||||
|
version: ver.version,
|
||||||
|
system_prompt: ver.system_prompt,
|
||||||
|
user_prompt_template: ver.user_prompt_template,
|
||||||
|
variables: ver.variables,
|
||||||
|
source: tmpl.source.clone(),
|
||||||
|
min_app_version: ver.min_app_version,
|
||||||
|
changelog: ver.changelog,
|
||||||
|
});
|
||||||
|
|
||||||
|
// 更新同步状态
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO prompt_sync_status (device_id, template_id, synced_version, synced_at)
|
||||||
|
VALUES ($1, $2, $3, $4)
|
||||||
|
ON CONFLICT (device_id, template_id) DO UPDATE SET synced_version = $3, synced_at = $4"
|
||||||
|
)
|
||||||
|
.bind(device_id).bind(&tmpl.id).bind(ver.version).bind(&now)
|
||||||
|
.execute(db).await.ok(); // 非关键,忽略错误
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 首次连接:返回所有 active 的 builtin 模板
|
||||||
|
if client_versions.is_empty() {
|
||||||
|
let all_templates = list_templates(db, &PromptListQuery {
|
||||||
|
source: Some("builtin".into()),
|
||||||
|
status: Some("active".into()),
|
||||||
|
category: None,
|
||||||
|
page: None,
|
||||||
|
page_size: Some(100),
|
||||||
|
}).await?;
|
||||||
|
|
||||||
|
for tmpl in &all_templates.items {
|
||||||
|
if let Ok(ver) = get_current_version(db, &tmpl.name).await {
|
||||||
|
updates.push(PromptUpdatePayload {
|
||||||
|
name: tmpl.name.clone(),
|
||||||
|
version: ver.version,
|
||||||
|
system_prompt: ver.system_prompt,
|
||||||
|
user_prompt_template: ver.user_prompt_template,
|
||||||
|
variables: ver.variables,
|
||||||
|
source: tmpl.source.clone(),
|
||||||
|
min_app_version: ver.min_app_version,
|
||||||
|
changelog: ver.changelog,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(PromptCheckResponse {
|
||||||
|
updates,
|
||||||
|
server_time: chrono::Utc::now().to_rfc3339(),
|
||||||
|
})
|
||||||
|
}
|
||||||
97
crates/zclaw-saas/src/prompt/types.rs
Normal file
97
crates/zclaw-saas/src/prompt/types.rs
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
//! 提示词模板类型定义
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// --- Prompt Template ---
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PromptTemplateInfo {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub category: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub source: String,
|
||||||
|
pub current_version: i32,
|
||||||
|
pub status: String,
|
||||||
|
pub created_at: String,
|
||||||
|
pub updated_at: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct CreatePromptRequest {
|
||||||
|
pub name: String,
|
||||||
|
pub category: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub source: Option<String>,
|
||||||
|
pub system_prompt: String,
|
||||||
|
pub user_prompt_template: Option<String>,
|
||||||
|
pub variables: Option<serde_json::Value>,
|
||||||
|
pub min_app_version: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct UpdatePromptRequest {
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub status: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- Prompt Version ---
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PromptVersionInfo {
|
||||||
|
pub id: String,
|
||||||
|
pub template_id: String,
|
||||||
|
pub version: i32,
|
||||||
|
pub system_prompt: String,
|
||||||
|
pub user_prompt_template: Option<String>,
|
||||||
|
pub variables: serde_json::Value,
|
||||||
|
pub changelog: Option<String>,
|
||||||
|
pub min_app_version: Option<String>,
|
||||||
|
pub created_at: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct CreateVersionRequest {
|
||||||
|
pub system_prompt: String,
|
||||||
|
pub user_prompt_template: Option<String>,
|
||||||
|
pub variables: Option<serde_json::Value>,
|
||||||
|
pub changelog: Option<String>,
|
||||||
|
pub min_app_version: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- OTA Check ---
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct PromptCheckRequest {
|
||||||
|
pub device_id: String,
|
||||||
|
pub versions: std::collections::HashMap<String, i32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct PromptCheckResponse {
|
||||||
|
pub updates: Vec<PromptUpdatePayload>,
|
||||||
|
pub server_time: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct PromptUpdatePayload {
|
||||||
|
pub name: String,
|
||||||
|
pub version: i32,
|
||||||
|
pub system_prompt: String,
|
||||||
|
pub user_prompt_template: Option<String>,
|
||||||
|
pub variables: serde_json::Value,
|
||||||
|
pub source: String,
|
||||||
|
pub min_app_version: Option<String>,
|
||||||
|
pub changelog: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- List ---
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct PromptListQuery {
|
||||||
|
pub category: Option<String>,
|
||||||
|
pub source: Option<String>,
|
||||||
|
pub status: Option<String>,
|
||||||
|
pub page: Option<u32>,
|
||||||
|
pub page_size: Option<u32>,
|
||||||
|
}
|
||||||
@@ -23,6 +23,22 @@ pub async fn chat_completions(
|
|||||||
) -> SaasResult<Response> {
|
) -> SaasResult<Response> {
|
||||||
check_permission(&ctx, "relay:use")?;
|
check_permission(&ctx, "relay:use")?;
|
||||||
|
|
||||||
|
// 队列容量检查:防止过载
|
||||||
|
let config = state.config.read().await;
|
||||||
|
let queued_count: i64 = sqlx::query_scalar(
|
||||||
|
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
|
||||||
|
)
|
||||||
|
.bind(&ctx.account_id)
|
||||||
|
.fetch_one(&state.db)
|
||||||
|
.await
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
if queued_count >= config.relay.max_queue_size as i64 {
|
||||||
|
return Err(SaasError::RateLimited(
|
||||||
|
format!("队列已满 ({} 个任务排队中),请稍后重试", queued_count)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
let model_name = req.get("model")
|
let model_name = req.get("model")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
|
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
|
||||||
@@ -32,7 +48,7 @@ pub async fn chat_completions(
|
|||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
// 查找 model 对应的 provider
|
// 查找 model 对应的 provider
|
||||||
let models = model_service::list_models(&state.db, None).await?;
|
let models = model_service::list_models(&state.db, None, None, None).await?.items;
|
||||||
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
|
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled)
|
||||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||||
|
|
||||||
@@ -42,15 +58,6 @@ pub async fn chat_completions(
|
|||||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 provider 的 API key (从数据库直接查询)
|
|
||||||
let provider_api_key: Option<String> = sqlx::query_scalar(
|
|
||||||
"SELECT api_key FROM providers WHERE id = ?1"
|
|
||||||
)
|
|
||||||
.bind(&target_model.provider_id)
|
|
||||||
.fetch_optional(&state.db)
|
|
||||||
.await?
|
|
||||||
.flatten();
|
|
||||||
|
|
||||||
let request_body = serde_json::to_string(&req)?;
|
let request_body = serde_json::to_string(&req)?;
|
||||||
|
|
||||||
// 创建中转任务
|
// 创建中转任务
|
||||||
@@ -64,27 +71,22 @@ pub async fn chat_completions(
|
|||||||
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
|
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
|
||||||
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
|
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
// 执行中转 (带重试)
|
// 获取加密密钥用于解密 API Key
|
||||||
|
let enc_key = config.api_key_encryption_key()
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||||
|
|
||||||
|
// 执行中转 (Key Pool 自动选择 + 429 轮转)
|
||||||
let response = service::execute_relay(
|
let response = service::execute_relay(
|
||||||
&state.db, &task.id, &provider.base_url,
|
&state.db, &task.id, &target_model.provider_id,
|
||||||
provider_api_key.as_deref(), &request_body, stream,
|
&provider.base_url, &request_body, stream,
|
||||||
config.relay.max_attempts,
|
config.relay.max_attempts,
|
||||||
config.relay.retry_delay_ms,
|
config.relay.retry_delay_ms,
|
||||||
|
&enc_key,
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
match response {
|
match response {
|
||||||
Ok(service::RelayResponse::Json(body)) => {
|
Ok(service::RelayResponse::Json(body)) => {
|
||||||
// 记录用量
|
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
|
||||||
let parsed: serde_json::Value = serde_json::from_str(&body).unwrap_or_default();
|
|
||||||
let input_tokens = parsed.get("usage")
|
|
||||||
.and_then(|u| u.get("prompt_tokens"))
|
|
||||||
.and_then(|v| v.as_i64())
|
|
||||||
.unwrap_or(0);
|
|
||||||
let output_tokens = parsed.get("usage")
|
|
||||||
.and_then(|u| u.get("completion_tokens"))
|
|
||||||
.and_then(|v| v.as_i64())
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
model_service::record_usage(
|
model_service::record_usage(
|
||||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||||
&target_model.model_id, input_tokens, output_tokens,
|
&target_model.model_id, input_tokens, output_tokens,
|
||||||
@@ -94,13 +96,14 @@ pub async fn chat_completions(
|
|||||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
||||||
}
|
}
|
||||||
Ok(service::RelayResponse::Sse(body)) => {
|
Ok(service::RelayResponse::Sse(body)) => {
|
||||||
|
// SSE 流的 usage 统计在 service 层异步处理
|
||||||
|
// 这里先记录一个占位记录,实际值会在流结束后更新
|
||||||
model_service::record_usage(
|
model_service::record_usage(
|
||||||
&state.db, &ctx.account_id, &target_model.provider_id,
|
&state.db, &ctx.account_id, &target_model.provider_id,
|
||||||
&target_model.model_id, 0, 0,
|
&target_model.model_id, 0, 0,
|
||||||
None, "success", None,
|
None, "streaming", None,
|
||||||
).await?;
|
).await?;
|
||||||
|
|
||||||
// 流式响应: 直接转发 axum::body::Body
|
|
||||||
let response = axum::response::Response::builder()
|
let response = axum::response::Response::builder()
|
||||||
.status(StatusCode::OK)
|
.status(StatusCode::OK)
|
||||||
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
|
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
|
||||||
@@ -126,7 +129,7 @@ pub async fn list_tasks(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Query(query): Query<RelayTaskQuery>,
|
Query(query): Query<RelayTaskQuery>,
|
||||||
) -> SaasResult<Json<Vec<RelayTaskInfo>>> {
|
) -> SaasResult<Json<crate::common::PaginatedResponse<RelayTaskInfo>>> {
|
||||||
service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json)
|
service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,11 +153,11 @@ pub async fn list_available_models(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
_ctx: Extension<AuthContext>,
|
_ctx: Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||||
let providers = model_service::list_providers(&state.db).await?;
|
let providers = model_service::list_providers(&state.db, None, None, None).await?.items;
|
||||||
let enabled_provider_ids: std::collections::HashSet<String> =
|
let enabled_provider_ids: std::collections::HashSet<String> =
|
||||||
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
|
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
|
||||||
|
|
||||||
let models = model_service::list_models(&state.db, None).await?;
|
let models = model_service::list_models(&state.db, None, None, None).await?.items;
|
||||||
let available: Vec<serde_json::Value> = models.into_iter()
|
let available: Vec<serde_json::Value> = models.into_iter()
|
||||||
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
|
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
|
||||||
.map(|m| {
|
.map(|m| {
|
||||||
@@ -191,17 +194,10 @@ pub async fn retry_task(
|
|||||||
|
|
||||||
// 获取 provider 信息
|
// 获取 provider 信息
|
||||||
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
|
||||||
let provider_api_key: Option<String> = sqlx::query_scalar(
|
|
||||||
"SELECT api_key FROM providers WHERE id = ?1"
|
|
||||||
)
|
|
||||||
.bind(&task.provider_id)
|
|
||||||
.fetch_optional(&state.db)
|
|
||||||
.await?
|
|
||||||
.flatten();
|
|
||||||
|
|
||||||
// 读取原始请求体
|
// 读取原始请求体
|
||||||
let request_body: Option<String> = sqlx::query_scalar(
|
let request_body: Option<String> = sqlx::query_scalar(
|
||||||
"SELECT request_body FROM relay_tasks WHERE id = ?1"
|
"SELECT request_body FROM relay_tasks WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(&id)
|
.bind(&id)
|
||||||
.fetch_optional(&state.db)
|
.fetch_optional(&state.db)
|
||||||
@@ -219,23 +215,27 @@ pub async fn retry_task(
|
|||||||
let max_attempts = task.max_attempts as u32;
|
let max_attempts = task.max_attempts as u32;
|
||||||
let config = state.config.read().await;
|
let config = state.config.read().await;
|
||||||
let base_delay_ms = config.relay.retry_delay_ms;
|
let base_delay_ms = config.relay.retry_delay_ms;
|
||||||
|
let enc_key = config.api_key_encryption_key()
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
||||||
|
|
||||||
// 重置任务状态为 queued 以允许新的 processing
|
// 重置任务状态为 queued 以允许新的 processing
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = ?1"
|
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(&id)
|
.bind(&id)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// 异步执行重试
|
// 异步执行重试 (Key Pool 自动选择)
|
||||||
let db = state.db.clone();
|
let db = state.db.clone();
|
||||||
let task_id = id.clone();
|
let task_id = id.clone();
|
||||||
|
let provider_id = task.provider_id.clone();
|
||||||
|
let base_url = provider.base_url.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
match service::execute_relay(
|
match service::execute_relay(
|
||||||
&db, &task_id, &provider.base_url,
|
&db, &task_id, &provider_id,
|
||||||
provider_api_key.as_deref(), &body, stream,
|
&base_url, &body, stream,
|
||||||
max_attempts, base_delay_ms,
|
max_attempts, base_delay_ms, &enc_key,
|
||||||
).await {
|
).await {
|
||||||
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
|
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
|
||||||
Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e),
|
Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e),
|
||||||
@@ -247,3 +247,96 @@ pub async fn retry_task(
|
|||||||
|
|
||||||
Ok(Json(serde_json::json!({"ok": true, "task_id": id})))
|
Ok(Json(serde_json::json!({"ok": true, "task_id": id})))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============ Key Pool 管理 (admin only) ============
|
||||||
|
|
||||||
|
/// GET /api/v1/providers/:provider_id/keys
|
||||||
|
pub async fn list_provider_keys(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path(provider_id): Path<String>,
|
||||||
|
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||||
|
check_permission(&ctx, "provider:manage")?;
|
||||||
|
let keys = super::key_pool::list_provider_keys(&state.db, &provider_id).await?;
|
||||||
|
Ok(Json(keys))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /api/v1/providers/:provider_id/keys
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
pub struct AddKeyRequest {
|
||||||
|
pub key_label: String,
|
||||||
|
pub key_value: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub priority: i32,
|
||||||
|
pub max_rpm: Option<i64>,
|
||||||
|
pub max_tpm: Option<i64>,
|
||||||
|
pub quota_reset_interval: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_provider_key(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path(provider_id): Path<String>,
|
||||||
|
Json(req): Json<AddKeyRequest>,
|
||||||
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
|
check_permission(&ctx, "provider:manage")?;
|
||||||
|
|
||||||
|
if req.key_label.trim().is_empty() {
|
||||||
|
return Err(SaasError::InvalidInput("key_label 不能为空".into()));
|
||||||
|
}
|
||||||
|
if req.key_value.trim().is_empty() {
|
||||||
|
return Err(SaasError::InvalidInput("key_value 不能为空".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let key_id = super::key_pool::add_provider_key(
|
||||||
|
&state.db, &provider_id, &req.key_label, &req.key_value,
|
||||||
|
req.priority, req.max_rpm, req.max_tpm,
|
||||||
|
req.quota_reset_interval.as_deref(),
|
||||||
|
).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "provider_key.add", "provider_key", &key_id,
|
||||||
|
Some(serde_json::json!({"provider_id": provider_id, "label": req.key_label})),
|
||||||
|
ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(serde_json::json!({"ok": true, "key_id": key_id})))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// PUT /api/v1/providers/:provider_id/keys/:key_id/toggle
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
pub struct ToggleKeyRequest {
|
||||||
|
pub active: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn toggle_provider_key(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path((provider_id, key_id)): Path<(String, String)>,
|
||||||
|
Json(req): Json<ToggleKeyRequest>,
|
||||||
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
|
check_permission(&ctx, "provider:manage")?;
|
||||||
|
|
||||||
|
super::key_pool::toggle_key_active(&state.db, &key_id, req.active).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "provider_key.toggle", "provider_key", &key_id,
|
||||||
|
Some(serde_json::json!({"provider_id": provider_id, "active": req.active})),
|
||||||
|
ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(serde_json::json!({"ok": true})))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DELETE /api/v1/providers/:provider_id/keys/:key_id
|
||||||
|
pub async fn delete_provider_key(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Path((provider_id, key_id)): Path<(String, String)>,
|
||||||
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
|
check_permission(&ctx, "provider:manage")?;
|
||||||
|
|
||||||
|
super::key_pool::delete_provider_key(&state.db, &key_id).await?;
|
||||||
|
|
||||||
|
log_operation(&state.db, &ctx.account_id, "provider_key.delete", "provider_key", &key_id,
|
||||||
|
Some(serde_json::json!({"provider_id": provider_id})),
|
||||||
|
ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
|
Ok(Json(serde_json::json!({"ok": true})))
|
||||||
|
}
|
||||||
|
|||||||
320
crates/zclaw-saas/src/relay/key_pool.rs
Normal file
320
crates/zclaw-saas/src/relay/key_pool.rs
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
//! Provider Key Pool 服务
|
||||||
|
//!
|
||||||
|
//! 管理 provider 的多个 API Key,实现智能轮转绕过限额。
|
||||||
|
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use crate::error::{SaasError, SaasResult};
|
||||||
|
use crate::crypto;
|
||||||
|
|
||||||
|
/// 解密 key_value (如果已加密),否则原样返回
|
||||||
|
fn decrypt_key_value(encrypted: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
|
||||||
|
if crypto::is_encrypted(encrypted) {
|
||||||
|
crypto::decrypt_value(encrypted, enc_key)
|
||||||
|
.map_err(|e| SaasError::Internal(e.to_string()))
|
||||||
|
} else {
|
||||||
|
// 兼容旧的明文格式
|
||||||
|
Ok(encrypted.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Key Pool 中的可用 Key
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct PoolKey {
|
||||||
|
pub id: String,
|
||||||
|
pub key_value: String,
|
||||||
|
pub priority: i32,
|
||||||
|
pub max_rpm: Option<i64>,
|
||||||
|
pub max_tpm: Option<i64>,
|
||||||
|
pub quota_reset_interval: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Key 选择结果
|
||||||
|
pub struct KeySelection {
|
||||||
|
pub key: PoolKey,
|
||||||
|
pub key_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 从 provider 的 Key Pool 中选择最佳可用 Key
|
||||||
|
pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32]) -> SaasResult<KeySelection> {
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
||||||
|
|
||||||
|
// 获取所有活跃 Key
|
||||||
|
let rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<String>)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, key_value, priority, max_rpm, max_tpm, quota_reset_interval
|
||||||
|
FROM provider_keys
|
||||||
|
WHERE provider_id = $1 AND is_active = TRUE AND (cooldown_until IS NULL OR cooldown_until <= $2)
|
||||||
|
ORDER BY priority ASC"
|
||||||
|
).bind(provider_id).bind(&now).fetch_all(db).await?;
|
||||||
|
|
||||||
|
if rows.is_empty() {
|
||||||
|
// 检查是否有冷却中的 Key,返回预计等待时间
|
||||||
|
let cooldown_row: Option<(String,)> = sqlx::query_as(
|
||||||
|
"SELECT cooldown_until FROM provider_keys
|
||||||
|
WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until > $2
|
||||||
|
ORDER BY cooldown_until ASC
|
||||||
|
LIMIT 1"
|
||||||
|
).bind(provider_id).bind(&now).fetch_optional(db).await?;
|
||||||
|
|
||||||
|
if let Some((earliest,)) = cooldown_row {
|
||||||
|
// 尝试解析时间差
|
||||||
|
let wait_secs = parse_cooldown_remaining(&earliest, &now);
|
||||||
|
return Err(SaasError::RateLimited(
|
||||||
|
format!("所有 Key 均在冷却中,预计 {} 秒后可用", wait_secs)
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 provider 级别的单 Key
|
||||||
|
let provider_key: Option<String> = sqlx::query_scalar(
|
||||||
|
"SELECT api_key FROM providers WHERE id = $1"
|
||||||
|
).bind(provider_id).fetch_optional(db).await?.flatten();
|
||||||
|
|
||||||
|
if let Some(key) = provider_key {
|
||||||
|
let decrypted = decrypt_key_value(&key, enc_key)?;
|
||||||
|
return Ok(KeySelection {
|
||||||
|
key: PoolKey {
|
||||||
|
id: "provider-fallback".to_string(),
|
||||||
|
key_value: decrypted,
|
||||||
|
priority: 0,
|
||||||
|
max_rpm: None,
|
||||||
|
max_tpm: None,
|
||||||
|
quota_reset_interval: None,
|
||||||
|
},
|
||||||
|
key_id: "provider-fallback".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
return Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查滑动窗口使用量
|
||||||
|
for (id, key_value, priority, max_rpm, max_tpm, quota_reset_interval) in rows {
|
||||||
|
// 检查 RPM 限额
|
||||||
|
if let Some(rpm_limit) = max_rpm {
|
||||||
|
if rpm_limit > 0 {
|
||||||
|
let window: Option<(i64,)> = sqlx::query_as(
|
||||||
|
"SELECT COALESCE(SUM(request_count), 0) FROM key_usage_window
|
||||||
|
WHERE key_id = $1 AND window_minute = $2"
|
||||||
|
).bind(&id).bind(¤t_minute).fetch_optional(db).await?;
|
||||||
|
|
||||||
|
if let Some((count,)) = window {
|
||||||
|
if count >= rpm_limit {
|
||||||
|
tracing::debug!("Key {} hit RPM limit ({}/{})", id, count, rpm_limit);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 TPM 限额
|
||||||
|
if let Some(tpm_limit) = max_tpm {
|
||||||
|
if tpm_limit > 0 {
|
||||||
|
let window: Option<(i64,)> = sqlx::query_as(
|
||||||
|
"SELECT COALESCE(SUM(token_count), 0) FROM key_usage_window
|
||||||
|
WHERE key_id = $1 AND window_minute = $2"
|
||||||
|
).bind(&id).bind(¤t_minute).fetch_optional(db).await?;
|
||||||
|
|
||||||
|
if let Some((tokens,)) = window {
|
||||||
|
if tokens >= tpm_limit {
|
||||||
|
tracing::debug!("Key {} hit TPM limit ({}/{})", id, tokens, tpm_limit);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 此 Key 可用 — 解密 key_value
|
||||||
|
let decrypted_kv = decrypt_key_value(&key_value, enc_key)?;
|
||||||
|
return Ok(KeySelection {
|
||||||
|
key: PoolKey {
|
||||||
|
id: id.clone(),
|
||||||
|
key_value: decrypted_kv,
|
||||||
|
priority,
|
||||||
|
max_rpm,
|
||||||
|
max_tpm,
|
||||||
|
quota_reset_interval,
|
||||||
|
},
|
||||||
|
key_id: id,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// 所有 Key 都超限,回退到 provider 单 Key
|
||||||
|
let provider_key: Option<String> = sqlx::query_scalar(
|
||||||
|
"SELECT api_key FROM providers WHERE id = $1"
|
||||||
|
).bind(provider_id).fetch_optional(db).await?.flatten();
|
||||||
|
|
||||||
|
if let Some(key) = provider_key {
|
||||||
|
let decrypted = decrypt_key_value(&key, enc_key)?;
|
||||||
|
return Ok(KeySelection {
|
||||||
|
key: PoolKey {
|
||||||
|
id: "provider-fallback".to_string(),
|
||||||
|
key_value: decrypted,
|
||||||
|
priority: 0,
|
||||||
|
max_rpm: None,
|
||||||
|
max_tpm: None,
|
||||||
|
quota_reset_interval: None,
|
||||||
|
},
|
||||||
|
key_id: "provider-fallback".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(SaasError::RateLimited(
|
||||||
|
format!("Provider {} 所有 Key 均已达限额", provider_id)
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 记录 Key 使用量(滑动窗口)
|
||||||
|
pub async fn record_key_usage(
|
||||||
|
db: &PgPool,
|
||||||
|
key_id: &str,
|
||||||
|
tokens: Option<i64>,
|
||||||
|
) -> SaasResult<()> {
|
||||||
|
let current_minute = chrono::Utc::now().format("%Y-%m-%dT%H:%M").to_string();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO key_usage_window (key_id, window_minute, request_count, token_count)
|
||||||
|
VALUES ($1, $2, 1, $3)
|
||||||
|
ON CONFLICT (key_id, window_minute) DO UPDATE
|
||||||
|
SET request_count = key_usage_window.request_count + 1,
|
||||||
|
token_count = key_usage_window.token_count + $3"
|
||||||
|
)
|
||||||
|
.bind(key_id).bind(¤t_minute).bind(tokens.unwrap_or(0))
|
||||||
|
.execute(db).await?;
|
||||||
|
|
||||||
|
// 更新 Key 的累计统计
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE provider_keys SET total_requests = total_requests + 1, total_tokens = total_tokens + COALESCE($1, 0), updated_at = $2
|
||||||
|
WHERE id = $3"
|
||||||
|
)
|
||||||
|
.bind(tokens).bind(&chrono::Utc::now().to_rfc3339()).bind(key_id)
|
||||||
|
.execute(db).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 标记 Key 收到 429,设置冷却期
|
||||||
|
pub async fn mark_key_429(
|
||||||
|
db: &PgPool,
|
||||||
|
key_id: &str,
|
||||||
|
retry_after_seconds: Option<u64>,
|
||||||
|
) -> SaasResult<()> {
|
||||||
|
let cooldown = if let Some(secs) = retry_after_seconds {
|
||||||
|
(chrono::Utc::now() + chrono::Duration::seconds(secs as i64)).to_rfc3339()
|
||||||
|
} else {
|
||||||
|
// 默认 5 分钟冷却
|
||||||
|
(chrono::Utc::now() + chrono::Duration::minutes(5)).to_rfc3339()
|
||||||
|
};
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE provider_keys SET last_429_at = $1, cooldown_until = $2, updated_at = $3
|
||||||
|
WHERE id = $4"
|
||||||
|
)
|
||||||
|
.bind(&now).bind(&cooldown).bind(&now).bind(key_id)
|
||||||
|
.execute(db).await?;
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
"Key {} 收到 429,冷却至 {}",
|
||||||
|
key_id,
|
||||||
|
cooldown
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 获取 provider 的所有 Key(管理用)
|
||||||
|
pub async fn list_provider_keys(
|
||||||
|
db: &PgPool,
|
||||||
|
provider_id: &str,
|
||||||
|
) -> SaasResult<Vec<serde_json::Value>> {
|
||||||
|
let rows: Vec<(String, String, String, i32, Option<i64>, Option<i64>, Option<String>, bool, Option<String>, Option<String>, i64, i64, String, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, provider_id, key_label, priority, max_rpm, max_tpm, quota_reset_interval, is_active,
|
||||||
|
last_429_at, cooldown_until, total_requests, total_tokens, created_at, updated_at
|
||||||
|
FROM provider_keys WHERE provider_id = $1 ORDER BY priority ASC"
|
||||||
|
).bind(provider_id).fetch_all(db).await?;
|
||||||
|
|
||||||
|
Ok(rows.into_iter().map(|r| {
|
||||||
|
serde_json::json!({
|
||||||
|
"id": r.0,
|
||||||
|
"provider_id": r.1,
|
||||||
|
"key_label": r.2,
|
||||||
|
"priority": r.3,
|
||||||
|
"max_rpm": r.4,
|
||||||
|
"max_tpm": r.5,
|
||||||
|
"quota_reset_interval": r.6,
|
||||||
|
"is_active": r.7,
|
||||||
|
"last_429_at": r.8,
|
||||||
|
"cooldown_until": r.9,
|
||||||
|
"total_requests": r.10,
|
||||||
|
"total_tokens": r.11,
|
||||||
|
"created_at": r.12,
|
||||||
|
"updated_at": r.13,
|
||||||
|
})
|
||||||
|
}).collect())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 添加 Key 到 Pool
|
||||||
|
pub async fn add_provider_key(
|
||||||
|
db: &PgPool,
|
||||||
|
provider_id: &str,
|
||||||
|
key_label: &str,
|
||||||
|
key_value: &str,
|
||||||
|
priority: i32,
|
||||||
|
max_rpm: Option<i64>,
|
||||||
|
max_tpm: Option<i64>,
|
||||||
|
quota_reset_interval: Option<&str>,
|
||||||
|
) -> SaasResult<String> {
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO provider_keys (id, provider_id, key_label, key_value, priority, max_rpm, max_tpm, quota_reset_interval, is_active, total_requests, total_tokens, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, TRUE, 0, 0, $9, $9)"
|
||||||
|
)
|
||||||
|
.bind(&id).bind(provider_id).bind(key_label).bind(key_value)
|
||||||
|
.bind(priority).bind(max_rpm).bind(max_tpm).bind(quota_reset_interval).bind(&now)
|
||||||
|
.execute(db).await?;
|
||||||
|
|
||||||
|
tracing::info!("Added key '{}' to provider {}", key_label, provider_id);
|
||||||
|
Ok(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 切换 Key 活跃状态
|
||||||
|
pub async fn toggle_key_active(
|
||||||
|
db: &PgPool,
|
||||||
|
key_id: &str,
|
||||||
|
active: bool,
|
||||||
|
) -> SaasResult<()> {
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE provider_keys SET is_active = $1, updated_at = $2 WHERE id = $3"
|
||||||
|
).bind(active).bind(&now).bind(key_id).execute(db).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 删除 Key
|
||||||
|
pub async fn delete_provider_key(
|
||||||
|
db: &PgPool,
|
||||||
|
key_id: &str,
|
||||||
|
) -> SaasResult<()> {
|
||||||
|
sqlx::query("DELETE FROM provider_keys WHERE id = $1")
|
||||||
|
.bind(key_id).execute(db).await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 解析冷却剩余时间(秒)
|
||||||
|
fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 {
|
||||||
|
let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until);
|
||||||
|
let current = chrono::DateTime::parse_from_rfc3339(now);
|
||||||
|
|
||||||
|
match (cooldown, current) {
|
||||||
|
(Ok(c), Ok(n)) => {
|
||||||
|
let diff = c.signed_duration_since(n);
|
||||||
|
diff.num_seconds().max(0)
|
||||||
|
}
|
||||||
|
_ => 300, // 默认 5 分钟
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,16 +3,23 @@
|
|||||||
pub mod types;
|
pub mod types;
|
||||||
pub mod service;
|
pub mod service;
|
||||||
pub mod handlers;
|
pub mod handlers;
|
||||||
|
pub mod key_pool;
|
||||||
|
|
||||||
use axum::routing::{get, post};
|
use axum::routing::{delete, get, post, put};
|
||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
|
|
||||||
/// 中转服务路由 (需要认证)
|
/// 中转服务路由 (需要认证)
|
||||||
pub fn routes() -> axum::Router<AppState> {
|
pub fn routes() -> axum::Router<AppState> {
|
||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
|
// Relay 核心端点
|
||||||
.route("/api/v1/relay/chat/completions", post(handlers::chat_completions))
|
.route("/api/v1/relay/chat/completions", post(handlers::chat_completions))
|
||||||
.route("/api/v1/relay/tasks", get(handlers::list_tasks))
|
.route("/api/v1/relay/tasks", get(handlers::list_tasks))
|
||||||
.route("/api/v1/relay/tasks/{id}", get(handlers::get_task))
|
.route("/api/v1/relay/tasks/:id", get(handlers::get_task))
|
||||||
.route("/api/v1/relay/tasks/{id}/retry", post(handlers::retry_task))
|
.route("/api/v1/relay/tasks/:id/retry", post(handlers::retry_task))
|
||||||
.route("/api/v1/relay/models", get(handlers::list_available_models))
|
.route("/api/v1/relay/models", get(handlers::list_available_models))
|
||||||
|
// Key Pool 管理 (admin only)
|
||||||
|
.route("/api/v1/providers/:provider_id/keys", get(handlers::list_provider_keys))
|
||||||
|
.route("/api/v1/providers/:provider_id/keys", post(handlers::add_provider_key))
|
||||||
|
.route("/api/v1/providers/:provider_id/keys/:key_id/toggle", put(handlers::toggle_provider_key))
|
||||||
|
.route("/api/v1/providers/:provider_id/keys/:key_id", delete(handlers::delete_provider_key))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
//! 中转服务核心逻辑
|
//! 中转服务核心逻辑
|
||||||
|
|
||||||
use sqlx::SqlitePool;
|
use sqlx::PgPool;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::sync::Mutex;
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
@@ -18,7 +20,7 @@ fn is_retryable_error(e: &reqwest::Error) -> bool {
|
|||||||
// ============ Relay Task Management ============
|
// ============ Relay Task Management ============
|
||||||
|
|
||||||
pub async fn create_relay_task(
|
pub async fn create_relay_task(
|
||||||
db: &SqlitePool,
|
db: &PgPool,
|
||||||
account_id: &str,
|
account_id: &str,
|
||||||
provider_id: &str,
|
provider_id: &str,
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
@@ -33,7 +35,7 @@ pub async fn create_relay_task(
|
|||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, ?8, ?9, ?9)"
|
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)"
|
||||||
)
|
)
|
||||||
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
||||||
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
|
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
|
||||||
@@ -42,11 +44,11 @@ pub async fn create_relay_task(
|
|||||||
get_relay_task(db, &id).await
|
get_relay_task(db, &id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||||||
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
|
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||||
FROM relay_tasks WHERE id = ?1"
|
FROM relay_tasks WHERE id = $1"
|
||||||
)
|
)
|
||||||
.bind(task_id)
|
.bind(task_id)
|
||||||
.fetch_optional(db)
|
.fetch_optional(db)
|
||||||
@@ -63,50 +65,62 @@ pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayT
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn list_relay_tasks(
|
pub async fn list_relay_tasks(
|
||||||
db: &SqlitePool, account_id: &str, query: &RelayTaskQuery,
|
db: &PgPool, account_id: &str, query: &RelayTaskQuery,
|
||||||
) -> SaasResult<Vec<RelayTaskInfo>> {
|
) -> SaasResult<crate::common::PaginatedResponse<RelayTaskInfo>> {
|
||||||
let page = query.page.unwrap_or(1).max(1);
|
let page = query.page.unwrap_or(1).max(1) as u32;
|
||||||
let page_size = query.page_size.unwrap_or(20).min(100);
|
let page_size = query.page_size.unwrap_or(20).min(100) as u32;
|
||||||
let offset = (page - 1) * page_size;
|
let offset = ((page - 1) * page_size) as i64;
|
||||||
|
|
||||||
let sql = if query.status.is_some() {
|
let (count_sql, data_sql) = if query.status.is_some() {
|
||||||
|
(
|
||||||
|
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status = $2",
|
||||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||||
FROM relay_tasks WHERE account_id = ?1 AND status = ?2 ORDER BY created_at DESC LIMIT ?3 OFFSET ?4"
|
FROM relay_tasks WHERE account_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT $3 OFFSET $4"
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
|
(
|
||||||
|
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1",
|
||||||
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
|
||||||
FROM relay_tasks WHERE account_id = ?1 ORDER BY created_at DESC LIMIT ?2 OFFSET ?3"
|
FROM relay_tasks WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(sql)
|
let total: i64 = if query.status.is_some() {
|
||||||
|
sqlx::query_scalar(count_sql).bind(account_id).bind(query.status.as_ref().unwrap()).fetch_one(db).await?
|
||||||
|
} else {
|
||||||
|
sqlx::query_scalar(count_sql).bind(account_id).fetch_one(db).await?
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(data_sql)
|
||||||
.bind(account_id);
|
.bind(account_id);
|
||||||
|
|
||||||
if let Some(ref status) = query.status {
|
if let Some(ref status) = query.status {
|
||||||
query_builder = query_builder.bind(status);
|
query_builder = query_builder.bind(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
query_builder = query_builder.bind(page_size).bind(offset);
|
let rows = query_builder.bind(page_size as i64).bind(offset).fetch_all(db).await?;
|
||||||
|
let items: Vec<RelayTaskInfo> = rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
|
||||||
let rows = query_builder.fetch_all(db).await?;
|
|
||||||
Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
|
|
||||||
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
|
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at }
|
||||||
}).collect())
|
}).collect();
|
||||||
|
|
||||||
|
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn update_task_status(
|
pub async fn update_task_status(
|
||||||
db: &SqlitePool, task_id: &str, status: &str,
|
db: &PgPool, task_id: &str, status: &str,
|
||||||
input_tokens: Option<i64>, output_tokens: Option<i64>,
|
input_tokens: Option<i64>, output_tokens: Option<i64>,
|
||||||
error_message: Option<&str>,
|
error_message: Option<&str>,
|
||||||
) -> SaasResult<()> {
|
) -> SaasResult<()> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
let update_sql = match status {
|
let update_sql = match status {
|
||||||
"processing" => "started_at = ?1, status = 'processing', attempt_count = attempt_count + 1",
|
"processing" => "started_at = $1, status = 'processing', attempt_count = attempt_count + 1",
|
||||||
"completed" => "completed_at = ?1, status = 'completed', input_tokens = COALESCE(?2, input_tokens), output_tokens = COALESCE(?3, output_tokens)",
|
"completed" => "completed_at = $1, status = 'completed', input_tokens = COALESCE($2, input_tokens), output_tokens = COALESCE($3, output_tokens)",
|
||||||
"failed" => "completed_at = ?1, status = 'failed', error_message = ?2",
|
"failed" => "completed_at = $1, status = 'failed', error_message = $2",
|
||||||
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
|
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
|
||||||
};
|
};
|
||||||
|
|
||||||
let sql = format!("UPDATE relay_tasks SET {} WHERE id = ?4", update_sql);
|
let sql = format!("UPDATE relay_tasks SET {} WHERE id = $4", update_sql);
|
||||||
|
|
||||||
let mut query = sqlx::query(&sql).bind(&now);
|
let mut query = sqlx::query(&sql).bind(&now);
|
||||||
if status == "completed" {
|
if status == "completed" {
|
||||||
@@ -123,15 +137,43 @@ pub async fn update_task_status(
|
|||||||
|
|
||||||
// ============ Relay Execution ============
|
// ============ Relay Execution ============
|
||||||
|
|
||||||
|
/// SSE 流中的 usage 信息捕获器
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
struct SseUsageCapture {
|
||||||
|
input_tokens: i64,
|
||||||
|
output_tokens: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SseUsageCapture {
|
||||||
|
fn parse_sse_line(&mut self, line: &str) {
|
||||||
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
|
if data == "[DONE]" {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(data) {
|
||||||
|
if let Some(usage) = parsed.get("usage") {
|
||||||
|
if let Some(input) = usage.get("prompt_tokens").and_then(|v| v.as_i64()) {
|
||||||
|
self.input_tokens = input;
|
||||||
|
}
|
||||||
|
if let Some(output) = usage.get("completion_tokens").and_then(|v| v.as_i64()) {
|
||||||
|
self.output_tokens = output;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub async fn execute_relay(
|
pub async fn execute_relay(
|
||||||
db: &SqlitePool,
|
db: &PgPool,
|
||||||
task_id: &str,
|
task_id: &str,
|
||||||
|
provider_id: &str,
|
||||||
provider_base_url: &str,
|
provider_base_url: &str,
|
||||||
provider_api_key: Option<&str>,
|
|
||||||
request_body: &str,
|
request_body: &str,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
max_attempts: u32,
|
max_attempts: u32,
|
||||||
base_delay_ms: u64,
|
base_delay_ms: u64,
|
||||||
|
enc_key: &[u8; 32],
|
||||||
) -> SaasResult<RelayResponse> {
|
) -> SaasResult<RelayResponse> {
|
||||||
validate_provider_url(provider_base_url)?;
|
validate_provider_url(provider_base_url)?;
|
||||||
|
|
||||||
@@ -144,17 +186,47 @@ pub async fn execute_relay(
|
|||||||
|
|
||||||
let max_attempts = max_attempts.max(1).min(5);
|
let max_attempts = max_attempts.max(1).min(5);
|
||||||
|
|
||||||
|
// Key Pool 轮转状态
|
||||||
|
let mut current_key_id: Option<String> = None;
|
||||||
|
let mut current_api_key: Option<String> = None;
|
||||||
|
|
||||||
for attempt in 0..max_attempts {
|
for attempt in 0..max_attempts {
|
||||||
let is_first = attempt == 0;
|
let is_first = attempt == 0;
|
||||||
if is_first {
|
if is_first {
|
||||||
update_task_status(db, task_id, "processing", None, None, None).await?;
|
update_task_status(db, task_id, "processing", None, None, None).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 首次或 429 后需要重新选择 Key
|
||||||
|
if current_key_id.is_none() {
|
||||||
|
match super::key_pool::select_best_key(db, provider_id, enc_key).await {
|
||||||
|
Ok(selection) => {
|
||||||
|
let key_id = selection.key_id.clone();
|
||||||
|
let key_value = selection.key.key_value.clone();
|
||||||
|
tracing::debug!(
|
||||||
|
"Relay task {} 选择 Key {} (attempt {})",
|
||||||
|
task_id, key_id, attempt + 1
|
||||||
|
);
|
||||||
|
current_key_id = Some(key_id);
|
||||||
|
current_api_key = Some(key_value);
|
||||||
|
}
|
||||||
|
Err(SaasError::RateLimited(msg)) => {
|
||||||
|
// 所有 Key 均在冷却中
|
||||||
|
let err_msg = format!("Key Pool 耗尽: {}", msg);
|
||||||
|
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||||
|
return Err(SaasError::RateLimited(msg));
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let key_id = current_key_id.as_ref().unwrap().clone();
|
||||||
|
let api_key = current_api_key.clone();
|
||||||
|
|
||||||
let mut req_builder = client.post(&url)
|
let mut req_builder = client.post(&url)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.body(request_body.to_string());
|
.body(request_body.to_string());
|
||||||
|
|
||||||
if let Some(key) = provider_api_key {
|
if let Some(ref key) = api_key {
|
||||||
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
|
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,31 +234,128 @@ pub async fn execute_relay(
|
|||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(resp) if resp.status().is_success() => {
|
Ok(resp) if resp.status().is_success() => {
|
||||||
// 成功
|
|
||||||
if stream {
|
if stream {
|
||||||
let byte_stream = resp.bytes_stream()
|
let usage_capture = Arc::new(Mutex::new(SseUsageCapture::default()));
|
||||||
.map(|result| result.map_err(std::io::Error::other));
|
let usage_capture_clone = usage_capture.clone();
|
||||||
let body = axum::body::Body::from_stream(byte_stream);
|
let db_clone = db.clone();
|
||||||
update_task_status(db, task_id, "completed", None, None, None).await?;
|
let task_id_clone = task_id.to_string();
|
||||||
|
let key_id_for_spawn = key_id.clone();
|
||||||
|
|
||||||
|
// Bounded channel for backpressure: 128 chunks (~128KB) buffer.
|
||||||
|
// If the client reads slowly, the upstream is signaled via
|
||||||
|
// backpressure instead of growing memory indefinitely.
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel::<Result<bytes::Bytes, std::io::Error>>(128);
|
||||||
|
|
||||||
|
// Spawn a task to consume the upstream stream and forward through the bounded channel
|
||||||
|
tokio::spawn(async move {
|
||||||
|
use futures::StreamExt;
|
||||||
|
let mut upstream = resp.bytes_stream();
|
||||||
|
while let Some(chunk_result) = upstream.next().await {
|
||||||
|
match chunk_result {
|
||||||
|
Ok(chunk) => {
|
||||||
|
// Parse SSE lines for usage tracking
|
||||||
|
if let Ok(text) = std::str::from_utf8(&chunk) {
|
||||||
|
if let Ok(mut capture) = usage_capture_clone.lock() {
|
||||||
|
for line in text.lines() {
|
||||||
|
capture.parse_sse_line(line);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Forward to bounded channel — if full, this applies backpressure
|
||||||
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
|
tracing::debug!("SSE relay: client disconnected, stopping upstream");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(std::io::Error::other(e))).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert mpsc::Receiver into a Body stream
|
||||||
|
let body_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||||
|
let body = axum::body::Body::from_stream(body_stream);
|
||||||
|
|
||||||
|
// SSE 流结束后异步记录 usage + Key 使用量
|
||||||
|
tokio::spawn(async move {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
|
||||||
|
let (input, output) = match usage_capture.lock() {
|
||||||
|
Ok(capture) => (
|
||||||
|
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
|
||||||
|
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
|
||||||
|
),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Usage capture lock poisoned: {}", e);
|
||||||
|
(None, None)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// 记录任务状态
|
||||||
|
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
|
||||||
|
tracing::warn!("Failed to update task status after SSE stream: {}", e);
|
||||||
|
}
|
||||||
|
// 记录 Key 使用量
|
||||||
|
let total_tokens = input.unwrap_or(0) + output.unwrap_or(0);
|
||||||
|
if let Err(e) = super::key_pool::record_key_usage(&db_clone, &key_id_for_spawn, Some(total_tokens)).await {
|
||||||
|
tracing::warn!("Failed to record key usage: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
return Ok(RelayResponse::Sse(body));
|
return Ok(RelayResponse::Sse(body));
|
||||||
} else {
|
} else {
|
||||||
let body = resp.text().await.unwrap_or_default();
|
let body = resp.text().await.unwrap_or_default();
|
||||||
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
let (input_tokens, output_tokens) = extract_token_usage(&body);
|
||||||
update_task_status(db, task_id, "completed",
|
update_task_status(db, task_id, "completed",
|
||||||
Some(input_tokens), Some(output_tokens), None).await?;
|
Some(input_tokens), Some(output_tokens), None).await?;
|
||||||
|
// 记录 Key 使用量
|
||||||
|
let _ = super::key_pool::record_key_usage(
|
||||||
|
db, &key_id, Some(input_tokens + output_tokens),
|
||||||
|
).await;
|
||||||
return Ok(RelayResponse::Json(body));
|
return Ok(RelayResponse::Json(body));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(resp) => {
|
Ok(resp) => {
|
||||||
let status = resp.status().as_u16();
|
let status = resp.status().as_u16();
|
||||||
|
if status == 429 {
|
||||||
|
// 解析 Retry-After header
|
||||||
|
let retry_after = resp.headers()
|
||||||
|
.get("retry-after")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.and_then(|v| v.parse::<u64>().ok());
|
||||||
|
|
||||||
|
// 标记 Key 为 429 冷却
|
||||||
|
if let Err(e) = super::key_pool::mark_key_429(db, &key_id, retry_after).await {
|
||||||
|
tracing::warn!("Failed to mark key 429: {}", e);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 强制下次迭代重新选择 Key
|
||||||
|
current_key_id = None;
|
||||||
|
current_api_key = None;
|
||||||
|
|
||||||
|
if attempt + 1 >= max_attempts {
|
||||||
|
let err_msg = format!(
|
||||||
|
"Key Pool 轮转耗尽 ({} attempts),所有 Key 均被限流",
|
||||||
|
max_attempts
|
||||||
|
);
|
||||||
|
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||||
|
return Err(SaasError::RateLimited(err_msg));
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
"Relay task {} 收到 429,Key {} 已标记冷却 (attempt {}/{})",
|
||||||
|
task_id, key_id, attempt + 1, max_attempts
|
||||||
|
);
|
||||||
|
// 429 时立即切换 Key 重试,不做退避延迟
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
|
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
|
||||||
// 4xx 客户端错误或已达最大重试次数 → 立即失败
|
|
||||||
let body = resp.text().await.unwrap_or_default();
|
let body = resp.text().await.unwrap_or_default();
|
||||||
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]);
|
||||||
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
|
||||||
return Err(SaasError::Relay(err_msg));
|
return Err(SaasError::Relay(err_msg));
|
||||||
}
|
}
|
||||||
// 可重试的服务端错误 → 继续循环
|
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
"Relay task {} 可重试错误 HTTP {} (attempt {}/{})",
|
"Relay task {} 可重试错误 HTTP {} (attempt {}/{})",
|
||||||
task_id, status, attempt + 1, max_attempts
|
task_id, status, attempt + 1, max_attempts
|
||||||
@@ -205,12 +374,11 @@ pub async fn execute_relay(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 指数退避: base_delay * 2^attempt
|
// 非 429 错误使用指数退避
|
||||||
let delay_ms = base_delay_ms * (1 << attempt);
|
let delay_ms = base_delay_ms * (1 << attempt);
|
||||||
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
|
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 理论上不会到达 (循环内已处理),但满足编译器
|
|
||||||
Err(SaasError::Relay("重试次数已耗尽".into()))
|
Err(SaasError::Relay("重试次数已耗尽".into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -228,6 +396,7 @@ fn hash_request(body: &str) -> String {
|
|||||||
hex::encode(Sha256::digest(body.as_bytes()))
|
hex::encode(Sha256::digest(body.as_bytes()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 从 JSON 响应中提取 token 使用量
|
||||||
fn extract_token_usage(body: &str) -> (i64, i64) {
|
fn extract_token_usage(body: &str) -> (i64, i64) {
|
||||||
let parsed: serde_json::Value = match serde_json::from_str(body) {
|
let parsed: serde_json::Value = match serde_json::from_str(body) {
|
||||||
Ok(v) => v,
|
Ok(v) => v,
|
||||||
@@ -247,6 +416,11 @@ fn extract_token_usage(body: &str) -> (i64, i64) {
|
|||||||
(input, output)
|
(input, output)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 从 JSON 响应中提取 token 使用量 (公开版本)
|
||||||
|
pub fn extract_token_usage_from_json(body: &str) -> (i64, i64) {
|
||||||
|
extract_token_usage(body)
|
||||||
|
}
|
||||||
|
|
||||||
/// SSRF 防护: 验证 provider URL 不指向内网
|
/// SSRF 防护: 验证 provider URL 不指向内网
|
||||||
fn validate_provider_url(url: &str) -> SaasResult<()> {
|
fn validate_provider_url(url: &str) -> SaasResult<()> {
|
||||||
let parsed: url::Url = url.parse().map_err(|_| {
|
let parsed: url::Url = url.parse().map_err(|_| {
|
||||||
@@ -274,6 +448,9 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
|
|||||||
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
|
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// 去除 IPv6 方括号
|
||||||
|
let host = host.trim_start_matches('[').trim_end_matches(']');
|
||||||
|
|
||||||
// 精确匹配的阻止列表
|
// 精确匹配的阻止列表
|
||||||
let blocked_exact = [
|
let blocked_exact = [
|
||||||
"127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1",
|
"127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1",
|
||||||
@@ -292,16 +469,39 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
|
||||||
|
if host.parse::<u64>().is_ok() {
|
||||||
|
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// 阻止十六进制/八进制 IP 混淆 (如 0x7f000001, 0177.0.0.1)
|
||||||
|
if host.chars().all(|c| c.is_ascii_hexdigit() || c == '.' || c == ':' || c == 'x' || c == 'X') {
|
||||||
|
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
|
||||||
|
}
|
||||||
|
|
||||||
// 阻止 IPv4 私有网段 (通过解析 IP)
|
// 阻止 IPv4 私有网段 (通过解析 IP)
|
||||||
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
|
||||||
if is_private_ip(&ip) {
|
if is_private_ip(&ip) {
|
||||||
return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host)));
|
return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host)));
|
||||||
}
|
}
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
|
// 对域名做 DNS 解析,检查解析结果是否指向内网
|
||||||
if host.parse::<u64>().is_ok() {
|
let addr_str: String = format!("{}:0", host);
|
||||||
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
|
match std::net::ToSocketAddrs::to_socket_addrs(&addr_str) {
|
||||||
|
Ok(addrs) => {
|
||||||
|
for sockaddr in addrs {
|
||||||
|
if is_private_ip(&sockaddr.ip()) {
|
||||||
|
return Err(SaasError::InvalidInput(
|
||||||
|
"provider URL 域名解析到内网地址".into()
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
// DNS 解析失败,可能是无效域名,不阻止请求
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|||||||
130
crates/zclaw-saas/src/role/handlers.rs
Normal file
130
crates/zclaw-saas/src/role/handlers.rs
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
//! 角色管理 HTTP 处理器
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Extension, Json, Path, State},
|
||||||
|
http::StatusCode,
|
||||||
|
};
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::error::SaasResult;
|
||||||
|
use crate::auth::types::AuthContext;
|
||||||
|
use crate::auth::handlers::{check_permission, log_operation};
|
||||||
|
use super::{types::*, service};
|
||||||
|
|
||||||
|
fn require_admin(ctx: &AuthContext) -> SaasResult<()> {
|
||||||
|
check_permission(ctx, "account:admin")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/roles
|
||||||
|
pub async fn list_roles(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
) -> SaasResult<Json<Vec<RoleInfo>>> {
|
||||||
|
check_permission(&ctx, "account:read")?;
|
||||||
|
service::list_roles(&state.db).await.map(Json)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/roles/:id
|
||||||
|
pub async fn get_role(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
) -> SaasResult<Json<RoleInfo>> {
|
||||||
|
check_permission(&ctx, "account:read")?;
|
||||||
|
service::get_role(&state.db, &id).await.map(Json)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /api/v1/roles
|
||||||
|
pub async fn create_role(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Json(req): Json<CreateRoleRequest>,
|
||||||
|
) -> SaasResult<(StatusCode, Json<RoleInfo>)> {
|
||||||
|
require_admin(&ctx)?;
|
||||||
|
let role = service::create_role(&state.db, &req).await?;
|
||||||
|
log_operation(&state.db, &ctx.account_id, "role.create", "role", &role.id,
|
||||||
|
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
|
||||||
|
Ok((StatusCode::CREATED, Json(role)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// PUT /api/v1/roles/:id
|
||||||
|
pub async fn update_role(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Json(req): Json<UpdateRoleRequest>,
|
||||||
|
) -> SaasResult<Json<RoleInfo>> {
|
||||||
|
require_admin(&ctx)?;
|
||||||
|
let role = service::update_role(&state.db, &id, &req).await?;
|
||||||
|
log_operation(&state.db, &ctx.account_id, "role.update", "role", &id, None, ctx.client_ip.as_deref()).await?;
|
||||||
|
Ok(Json(role))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DELETE /api/v1/roles/:id
|
||||||
|
pub async fn delete_role(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
|
require_admin(&ctx)?;
|
||||||
|
service::delete_role(&state.db, &id).await?;
|
||||||
|
log_operation(&state.db, &ctx.account_id, "role.delete", "role", &id, None, ctx.client_ip.as_deref()).await?;
|
||||||
|
Ok(Json(serde_json::json!({"ok": true})))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/permission-templates
|
||||||
|
pub async fn list_templates(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
) -> SaasResult<Json<Vec<PermissionTemplate>>> {
|
||||||
|
check_permission(&ctx, "account:read")?;
|
||||||
|
service::list_templates(&state.db).await.map(Json)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/permission-templates/:id
|
||||||
|
pub async fn get_template(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
) -> SaasResult<Json<PermissionTemplate>> {
|
||||||
|
check_permission(&ctx, "account:read")?;
|
||||||
|
service::get_template(&state.db, &id).await.map(Json)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /api/v1/permission-templates
|
||||||
|
pub async fn create_template(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Json(req): Json<CreateTemplateRequest>,
|
||||||
|
) -> SaasResult<(StatusCode, Json<PermissionTemplate>)> {
|
||||||
|
require_admin(&ctx)?;
|
||||||
|
let template = service::create_template(&state.db, &req).await?;
|
||||||
|
log_operation(&state.db, &ctx.account_id, "template.create", "permission_template", &template.id,
|
||||||
|
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
|
||||||
|
Ok((StatusCode::CREATED, Json(template)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// DELETE /api/v1/permission-templates/:id
|
||||||
|
pub async fn delete_template(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
|
require_admin(&ctx)?;
|
||||||
|
service::delete_template(&state.db, &id).await?;
|
||||||
|
log_operation(&state.db, &ctx.account_id, "template.delete", "permission_template", &id, None, ctx.client_ip.as_deref()).await?;
|
||||||
|
Ok(Json(serde_json::json!({"ok": true})))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /api/v1/permission-templates/:id/apply
|
||||||
|
pub async fn apply_template(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Json(req): Json<ApplyTemplateRequest>,
|
||||||
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
|
require_admin(&ctx)?;
|
||||||
|
let count = service::apply_template_to_accounts(&state.db, &id, &req.account_ids).await?;
|
||||||
|
log_operation(&state.db, &ctx.account_id, "template.apply", "permission_template", &id,
|
||||||
|
Some(serde_json::json!({"accounts": &req.account_ids, "applied_count": count})), ctx.client_ip.as_deref()).await?;
|
||||||
|
Ok(Json(serde_json::json!({"ok": true, "applied_count": count})))
|
||||||
|
}
|
||||||
34
crates/zclaw-saas/src/role/handlers_ext.rs
Normal file
34
crates/zclaw-saas/src/role/handlers_ext.rs
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
//! 角色管理模块
|
||||||
|
//! handlers_ext - 获取角色权限列表(公开 API)
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Extension, Path, State},
|
||||||
|
http::StatusCode,
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::error::SaasResult;
|
||||||
|
use crate::auth::types::AuthContext;
|
||||||
|
use crate::auth::handlers::check_permission;
|
||||||
|
use super::{types::*, service};
|
||||||
|
|
||||||
|
use crate::role::handlers_ext;
|
||||||
|
|
||||||
|
/// GET /api/v1/roles/:id/permissions - 公开 API,无需登录验证
|
||||||
|
pub async fn get_role_permissions(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<String>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
) -> SaasResult<Json<Vec<String>>> {
|
||||||
|
check_permission(&ctx, "account:read")?;
|
||||||
|
|
||||||
|
let row: Option<(String,)> = sqlx::query_as(
|
||||||
|
"SELECT permissions FROM roles WHERE id = $1"
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.fetch_optional(&state.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
|
||||||
|
Ok(permissions)
|
||||||
|
}
|
||||||
31
crates/zclaw-saas/src/role/mod.rs
Normal file
31
crates/zclaw-saas/src/role/mod.rs
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
//! 角色管理模块
|
||||||
|
|
||||||
|
pub mod types;
|
||||||
|
pub mod service;
|
||||||
|
pub mod handlers;
|
||||||
|
|
||||||
|
pub mod handlers_ext;
|
||||||
|
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router<AppState> {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/api/v1/roles", get(handlers::list_roles).post(handlers::create_role))
|
||||||
|
.route("/api/v1/roles/:id", get(handlers::get_role).put(handlers::update_role).delete(handlers::delete_role))
|
||||||
|
.route("/api/v1/permission-templates", get(handlers::list_templates).post(handlers::create_template))
|
||||||
|
.route("/api/v1/permission-templates/:id", get(handlers::get_template).delete(handlers::delete_template))
|
||||||
|
.route("/api/v1/permission-templates/:id/apply", post(handlers::apply_template))
|
||||||
|
.route("/api/v1/roles/:id/permissions", get(handlers::get_role_permissions))
|
||||||
|
handlers
|
||||||
|
}use axum::routing::{get, post};
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
pub fn routes() -> axum::Router<AppState> {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/api/v1/roles", get(handlers::list_roles).post(handlers::create_role))
|
||||||
|
.route("/api/v1/roles/:id", get(handlers::get_role).put(handlers::update_role).delete(handlers::delete_role))
|
||||||
|
.route("/api/v1/permission-templates", get(handlers::list_templates).post(handlers::create_template))
|
||||||
|
.route("/api/v1/permission-templates/:id", get(handlers::get_template).delete(handlers::delete_template))
|
||||||
|
.route("/api/v1/permission-templates/:id/apply", post(handlers::apply_template))
|
||||||
|
}
|
||||||
238
crates/zclaw-saas/src/role/service.rs
Normal file
238
crates/zclaw-saas/src/role/service.rs
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
//! 角色管理业务逻辑
|
||||||
|
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use crate::error::{SaasError, SaasResult};
|
||||||
|
use super::types::*;
|
||||||
|
|
||||||
|
pub async fn list_roles(db: &PgPool) -> SaasResult<Vec<RoleInfo>> {
|
||||||
|
let rows: Vec<(String, String, Option<String>, String, bool, String, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, name, description, permissions, is_system, created_at, updated_at
|
||||||
|
FROM roles ORDER BY
|
||||||
|
CASE id
|
||||||
|
WHEN 'super_admin' THEN 1
|
||||||
|
WHEN 'admin' THEN 2
|
||||||
|
WHEN 'user' THEN 3
|
||||||
|
ELSE 4
|
||||||
|
END"
|
||||||
|
)
|
||||||
|
.fetch_all(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let roles = rows.into_iter().map(|(id, name, description, perms, is_system, created_at, updated_at)| {
|
||||||
|
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||||
|
RoleInfo { id, name, description, permissions, is_system, created_at, updated_at }
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
Ok(roles)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_role(db: &PgPool, role_id: &str) -> SaasResult<RoleInfo> {
|
||||||
|
let row: Option<(String, String, Option<String>, String, bool, String, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, name, description, permissions, is_system, created_at, updated_at
|
||||||
|
FROM roles WHERE id = $1"
|
||||||
|
)
|
||||||
|
.bind(role_id)
|
||||||
|
.fetch_optional(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let (id, name, description, perms, is_system, created_at, updated_at) =
|
||||||
|
row.ok_or_else(|| SaasError::NotFound(format!("角色 {} 不存在", role_id)))?;
|
||||||
|
|
||||||
|
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||||
|
Ok(RoleInfo { id, name, description, permissions, is_system, created_at, updated_at })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_role(db: &PgPool, req: &CreateRoleRequest) -> SaasResult<RoleInfo> {
|
||||||
|
let existing: Option<(String,)> = sqlx::query_as(
|
||||||
|
"SELECT id FROM roles WHERE id = $1"
|
||||||
|
)
|
||||||
|
.bind(&req.id)
|
||||||
|
.fetch_optional(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if existing.is_some() {
|
||||||
|
return Err(SaasError::AlreadyExists(format!("角色 {} 已存在", req.id)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let permissions = serde_json::to_string(&req.permissions)?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO roles (id, name, description, permissions, is_system, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, false, $5, $5)"
|
||||||
|
)
|
||||||
|
.bind(&req.id)
|
||||||
|
.bind(&req.name)
|
||||||
|
.bind(&req.description)
|
||||||
|
.bind(&permissions)
|
||||||
|
.bind(&now)
|
||||||
|
.execute(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(RoleInfo {
|
||||||
|
id: req.id.clone(),
|
||||||
|
name: req.name.clone(),
|
||||||
|
description: req.description.clone(),
|
||||||
|
permissions: req.permissions.clone(),
|
||||||
|
is_system: false,
|
||||||
|
created_at: now.clone(),
|
||||||
|
updated_at: now,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn update_role(db: &PgPool, role_id: &str, req: &UpdateRoleRequest) -> SaasResult<RoleInfo> {
|
||||||
|
let existing = get_role(db, role_id).await?;
|
||||||
|
|
||||||
|
if existing.is_system {
|
||||||
|
return Err(SaasError::Forbidden("系统角色不可修改".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let name = req.name.as_ref().unwrap_or(&existing.name);
|
||||||
|
let description = req.description.as_ref().or(existing.description.as_ref());
|
||||||
|
let permissions = req.permissions.as_ref().unwrap_or(&existing.permissions);
|
||||||
|
let permissions_json = serde_json::to_string(permissions)?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"UPDATE roles SET name = $1, description = $2, permissions = $3, updated_at = $4 WHERE id = $5"
|
||||||
|
)
|
||||||
|
.bind(name)
|
||||||
|
.bind(description)
|
||||||
|
.bind(&permissions_json)
|
||||||
|
.bind(&now)
|
||||||
|
.bind(role_id)
|
||||||
|
.execute(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(RoleInfo {
|
||||||
|
id: role_id.to_string(),
|
||||||
|
name: name.clone(),
|
||||||
|
description: description.cloned(),
|
||||||
|
permissions: permissions.clone(),
|
||||||
|
is_system: false,
|
||||||
|
created_at: existing.created_at,
|
||||||
|
updated_at: now,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_role(db: &PgPool, role_id: &str) -> SaasResult<()> {
|
||||||
|
let existing = get_role(db, role_id).await?;
|
||||||
|
|
||||||
|
if existing.is_system {
|
||||||
|
return Err(SaasError::Forbidden("系统角色不可删除".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = sqlx::query("DELETE FROM roles WHERE id = $1 AND is_system = false")
|
||||||
|
.bind(role_id)
|
||||||
|
.execute(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if result.rows_affected() == 0 {
|
||||||
|
return Err(SaasError::NotFound(format!("角色 {} 不存在", role_id)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list_templates(db: &PgPool) -> SaasResult<Vec<PermissionTemplate>> {
|
||||||
|
let rows: Vec<(String, String, Option<String>, String, String, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, name, description, permissions, created_at, updated_at
|
||||||
|
FROM permission_templates ORDER BY created_at DESC"
|
||||||
|
)
|
||||||
|
.fetch_all(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let templates = rows.into_iter().map(|(id, name, description, perms, created_at, updated_at)| {
|
||||||
|
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||||
|
PermissionTemplate { id, name, description, permissions, created_at, updated_at }
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
Ok(templates)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_template(db: &PgPool, template_id: &str) -> SaasResult<PermissionTemplate> {
|
||||||
|
let row: Option<(String, String, Option<String>, String, String, String)> =
|
||||||
|
sqlx::query_as(
|
||||||
|
"SELECT id, name, description, permissions, created_at, updated_at
|
||||||
|
FROM permission_templates WHERE id = $1"
|
||||||
|
)
|
||||||
|
.bind(template_id)
|
||||||
|
.fetch_optional(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let (id, name, description, perms, created_at, updated_at) =
|
||||||
|
row.ok_or_else(|| SaasError::NotFound(format!("权限模板 {} 不存在", template_id)))?;
|
||||||
|
|
||||||
|
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||||
|
Ok(PermissionTemplate { id, name, description, permissions, created_at, updated_at })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_template(db: &PgPool, req: &CreateTemplateRequest) -> SaasResult<PermissionTemplate> {
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
let permissions = serde_json::to_string(&req.permissions)?;
|
||||||
|
|
||||||
|
sqlx::query(
|
||||||
|
"INSERT INTO permission_templates (id, name, description, permissions, created_at, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $5)"
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.bind(&req.name)
|
||||||
|
.bind(&req.description)
|
||||||
|
.bind(&permissions)
|
||||||
|
.bind(&now)
|
||||||
|
.execute(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(PermissionTemplate {
|
||||||
|
id,
|
||||||
|
name: req.name.clone(),
|
||||||
|
description: req.description.clone(),
|
||||||
|
permissions: req.permissions.clone(),
|
||||||
|
created_at: now.clone(),
|
||||||
|
updated_at: now,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_template(db: &PgPool, template_id: &str) -> SaasResult<()> {
|
||||||
|
let result = sqlx::query("DELETE FROM permission_templates WHERE id = $1")
|
||||||
|
.bind(template_id)
|
||||||
|
.execute(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if result.rows_affected() == 0 {
|
||||||
|
return Err(SaasError::NotFound(format!("权限模板 {} 不存在", template_id)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn apply_template_to_accounts(
|
||||||
|
db: &PgPool,
|
||||||
|
template_id: &str,
|
||||||
|
account_ids: &[String],
|
||||||
|
) -> SaasResult<usize> {
|
||||||
|
let template = get_template(db, template_id).await?;
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
|
let mut success_count = 0;
|
||||||
|
for account_id in account_ids {
|
||||||
|
let result = sqlx::query(
|
||||||
|
"UPDATE accounts SET role = $1, updated_at = $2 WHERE id = $3"
|
||||||
|
)
|
||||||
|
.bind(&template.id)
|
||||||
|
.bind(&now)
|
||||||
|
.bind(account_id)
|
||||||
|
.execute(db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if result.rows_affected() > 0 {
|
||||||
|
success_count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(success_count)
|
||||||
|
}
|
||||||
51
crates/zclaw-saas/src/role/types.rs
Normal file
51
crates/zclaw-saas/src/role/types.rs
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
//! 角色管理类型定义
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct RoleInfo {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub permissions: Vec<String>,
|
||||||
|
pub is_system: bool,
|
||||||
|
pub created_at: String,
|
||||||
|
pub updated_at: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct CreateRoleRequest {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub permissions: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct UpdateRoleRequest {
|
||||||
|
pub name: Option<String>,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub permissions: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct PermissionTemplate {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub permissions: Vec<String>,
|
||||||
|
pub created_at: String,
|
||||||
|
pub updated_at: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct CreateTemplateRequest {
|
||||||
|
pub name: String,
|
||||||
|
pub description: Option<String>,
|
||||||
|
pub permissions: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ApplyTemplateRequest {
|
||||||
|
pub account_ids: Vec<String>,
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
//! 应用状态
|
//! 应用状态
|
||||||
|
|
||||||
use sqlx::SqlitePool;
|
use sqlx::PgPool;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
@@ -10,7 +10,7 @@ use crate::config::SaaSConfig;
|
|||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
/// 数据库连接池
|
/// 数据库连接池
|
||||||
pub db: SqlitePool,
|
pub db: PgPool,
|
||||||
/// 服务器配置 (可热更新)
|
/// 服务器配置 (可热更新)
|
||||||
pub config: Arc<RwLock<SaaSConfig>>,
|
pub config: Arc<RwLock<SaaSConfig>>,
|
||||||
/// JWT 密钥
|
/// JWT 密钥
|
||||||
@@ -20,7 +20,7 @@ pub struct AppState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppState {
|
||||||
pub fn new(db: SqlitePool, config: SaaSConfig) -> anyhow::Result<Self> {
|
pub fn new(db: PgPool, config: SaaSConfig) -> anyhow::Result<Self> {
|
||||||
let jwt_secret = config.jwt_secret()?;
|
let jwt_secret = config.jwt_secret()?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
db,
|
db,
|
||||||
@@ -29,4 +29,13 @@ impl AppState {
|
|||||||
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
|
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// 清理过期的限流条目 (60 秒窗口外的记录)
|
||||||
|
pub fn cleanup_rate_limit_entries(&self) {
|
||||||
|
let window_start = Instant::now() - std::time::Duration::from_secs(60);
|
||||||
|
self.rate_limit_entries.retain(|_, entries| {
|
||||||
|
entries.retain(|&ts| ts > window_start);
|
||||||
|
!entries.is_empty()
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
114
crates/zclaw-saas/src/telemetry/handlers.rs
Normal file
114
crates/zclaw-saas/src/telemetry/handlers.rs
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
//! 遥测 API 处理器
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::{Extension, Query, State},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use crate::error::SaasResult;
|
||||||
|
use crate::auth::types::AuthContext;
|
||||||
|
use crate::auth::handlers::log_operation;
|
||||||
|
use crate::state::AppState;
|
||||||
|
use super::types::*;
|
||||||
|
|
||||||
|
/// POST /api/v1/telemetry/report
|
||||||
|
///
|
||||||
|
/// 接收桌面端上报的 Token 用量统计(无内容,仅计数)。
|
||||||
|
/// 桌面端定期批量上报本地 LLM 调用的用量数据。
|
||||||
|
pub async fn report_telemetry(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Json(req): Json<TelemetryReportRequest>,
|
||||||
|
) -> SaasResult<Json<TelemetryReportResponse>> {
|
||||||
|
// 限制单次上报条目数(防止滥用)
|
||||||
|
let entries = if req.entries.len() > 500 {
|
||||||
|
&req.entries[..500]
|
||||||
|
} else {
|
||||||
|
&req.entries
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = super::service::ingest_telemetry(
|
||||||
|
&state.db,
|
||||||
|
&ctx.account_id,
|
||||||
|
&req.device_id,
|
||||||
|
&req.app_version,
|
||||||
|
entries,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// 审计日志:记录遥测上报事件
|
||||||
|
log_operation(
|
||||||
|
&state.db,
|
||||||
|
&ctx.account_id,
|
||||||
|
"telemetry.report",
|
||||||
|
"telemetry",
|
||||||
|
&req.device_id,
|
||||||
|
Some(serde_json::json!({"entry_count": entries.len(), "app_version": req.app_version})),
|
||||||
|
ctx.client_ip.as_deref(),
|
||||||
|
).await.ok(); // 非阻塞:日志写入失败不影响上报
|
||||||
|
|
||||||
|
Ok(Json(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/telemetry/stats
|
||||||
|
///
|
||||||
|
/// 按模型聚合用量统计(当前用户)
|
||||||
|
pub async fn get_telemetry_stats(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Query(query): Query<TelemetryStatsQuery>,
|
||||||
|
) -> SaasResult<Json<Vec<ModelUsageStat>>> {
|
||||||
|
let stats = super::service::get_model_stats(
|
||||||
|
&state.db,
|
||||||
|
&ctx.account_id,
|
||||||
|
&query,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Json(stats))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// GET /api/v1/telemetry/daily
|
||||||
|
///
|
||||||
|
/// 按天聚合用量统计(当前用户)
|
||||||
|
pub async fn get_daily_stats(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Query(query): Query<TelemetryStatsQuery>,
|
||||||
|
) -> SaasResult<Json<Vec<DailyUsageStat>>> {
|
||||||
|
let stats = super::service::get_daily_stats(
|
||||||
|
&state.db,
|
||||||
|
&ctx.account_id,
|
||||||
|
&query,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Json(stats))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /api/v1/telemetry/audit
|
||||||
|
///
|
||||||
|
/// 接收桌面端上报的审计日志摘要(仅操作类型和计数,无具体内容)。
|
||||||
|
pub async fn report_audit_summary(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Extension(ctx): Extension<AuthContext>,
|
||||||
|
Json(req): Json<AuditSummaryRequest>,
|
||||||
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
|
let entries = if req.entries.len() > 200 {
|
||||||
|
&req.entries[..200]
|
||||||
|
} else {
|
||||||
|
&req.entries
|
||||||
|
};
|
||||||
|
|
||||||
|
let written = super::service::ingest_audit_summary(
|
||||||
|
&state.db,
|
||||||
|
&ctx.account_id,
|
||||||
|
&req.device_id,
|
||||||
|
entries,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Json(serde_json::json!({
|
||||||
|
"accepted": written,
|
||||||
|
"total": entries.len(),
|
||||||
|
})))
|
||||||
|
}
|
||||||
20
crates/zclaw-saas/src/telemetry/mod.rs
Normal file
20
crates/zclaw-saas/src/telemetry/mod.rs
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
//! 使用量遥测模块
|
||||||
|
//!
|
||||||
|
//! 接收桌面端上报的本地 LLM 调用 Token 用量统计(无内容),
|
||||||
|
//! 提供聚合查询 API 供 Admin 面板使用。
|
||||||
|
|
||||||
|
pub mod types;
|
||||||
|
pub mod service;
|
||||||
|
pub mod handlers;
|
||||||
|
|
||||||
|
use axum::routing::{get, post};
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
/// 遥测路由 (需要认证)
|
||||||
|
pub fn routes() -> axum::Router<AppState> {
|
||||||
|
axum::Router::new()
|
||||||
|
.route("/api/v1/telemetry/report", post(handlers::report_telemetry))
|
||||||
|
.route("/api/v1/telemetry/stats", get(handlers::get_telemetry_stats))
|
||||||
|
.route("/api/v1/telemetry/daily", get(handlers::get_daily_stats))
|
||||||
|
.route("/api/v1/telemetry/audit", post(handlers::report_audit_summary))
|
||||||
|
}
|
||||||
226
crates/zclaw-saas/src/telemetry/service.rs
Normal file
226
crates/zclaw-saas/src/telemetry/service.rs
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
//! 遥测服务逻辑
|
||||||
|
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use crate::error::SaasResult;
|
||||||
|
use super::types::*;
|
||||||
|
|
||||||
|
/// 批量写入遥测记录
|
||||||
|
pub async fn ingest_telemetry(
|
||||||
|
db: &PgPool,
|
||||||
|
account_id: &str,
|
||||||
|
device_id: &str,
|
||||||
|
app_version: &str,
|
||||||
|
entries: &[TelemetryEntry],
|
||||||
|
) -> SaasResult<TelemetryReportResponse> {
|
||||||
|
let mut accepted = 0usize;
|
||||||
|
let mut rejected = 0usize;
|
||||||
|
|
||||||
|
for entry in entries {
|
||||||
|
// 基本验证
|
||||||
|
if entry.input_tokens < 0 || entry.output_tokens < 0 {
|
||||||
|
rejected += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if entry.model_id.is_empty() {
|
||||||
|
rejected += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let id = uuid::Uuid::new_v4().to_string();
|
||||||
|
let now = chrono::Utc::now().to_rfc3339();
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
"INSERT INTO telemetry_reports
|
||||||
|
(id, account_id, device_id, app_version, model_id, input_tokens, output_tokens,
|
||||||
|
latency_ms, success, error_type, connection_mode, reported_at, created_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)"
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.bind(account_id)
|
||||||
|
.bind(device_id)
|
||||||
|
.bind(app_version)
|
||||||
|
.bind(&entry.model_id)
|
||||||
|
.bind(entry.input_tokens)
|
||||||
|
.bind(entry.output_tokens)
|
||||||
|
.bind(entry.latency_ms)
|
||||||
|
.bind(entry.success)
|
||||||
|
.bind(&entry.error_type)
|
||||||
|
.bind(&entry.connection_mode)
|
||||||
|
.bind(&entry.timestamp)
|
||||||
|
.bind(&now)
|
||||||
|
.execute(db)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => accepted += 1,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Failed to insert telemetry entry: {}", e);
|
||||||
|
rejected += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(TelemetryReportResponse { accepted, rejected })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按模型聚合用量统计
|
||||||
|
pub async fn get_model_stats(
|
||||||
|
db: &PgPool,
|
||||||
|
account_id: &str,
|
||||||
|
query: &TelemetryStatsQuery,
|
||||||
|
) -> SaasResult<Vec<ModelUsageStat>> {
|
||||||
|
let mut param_idx: i32 = 1;
|
||||||
|
let mut where_clauses = vec![format!("account_id = ${}", param_idx)];
|
||||||
|
let mut params: Vec<String> = vec![account_id.to_string()];
|
||||||
|
param_idx += 1;
|
||||||
|
|
||||||
|
if let Some(ref from) = query.from {
|
||||||
|
where_clauses.push(format!("reported_at >= ${}", param_idx));
|
||||||
|
params.push(from.clone());
|
||||||
|
param_idx += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref to) = query.to {
|
||||||
|
where_clauses.push(format!("reported_at <= ${}", param_idx));
|
||||||
|
params.push(to.clone());
|
||||||
|
param_idx += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref model) = query.model_id {
|
||||||
|
where_clauses.push(format!("model_id = ${}", param_idx));
|
||||||
|
params.push(model.clone());
|
||||||
|
param_idx += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref mode) = query.connection_mode {
|
||||||
|
where_clauses.push(format!("connection_mode = ${}", param_idx));
|
||||||
|
params.push(mode.clone());
|
||||||
|
param_idx += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let where_sql = where_clauses.join(" AND ");
|
||||||
|
|
||||||
|
let sql = format!(
|
||||||
|
"SELECT
|
||||||
|
model_id,
|
||||||
|
COUNT(*)::bigint as request_count,
|
||||||
|
COALESCE(SUM(input_tokens), 0)::bigint as input_tokens,
|
||||||
|
COALESCE(SUM(output_tokens), 0)::bigint as output_tokens,
|
||||||
|
AVG(latency_ms) as avg_latency_ms,
|
||||||
|
(COUNT(*) FILTER (WHERE success = true))::float / NULLIF(COUNT(*), 0) as success_rate
|
||||||
|
FROM telemetry_reports
|
||||||
|
WHERE {}
|
||||||
|
GROUP BY model_id
|
||||||
|
ORDER BY request_count DESC
|
||||||
|
LIMIT 50",
|
||||||
|
where_sql
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut query_builder = sqlx::query_as::<_, (String, i64, i64, i64, Option<f64>, Option<f64>)>(&sql);
|
||||||
|
for p in ¶ms {
|
||||||
|
query_builder = query_builder.bind(p);
|
||||||
|
}
|
||||||
|
|
||||||
|
let rows = query_builder.fetch_all(db).await?;
|
||||||
|
|
||||||
|
let stats: Vec<ModelUsageStat> = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|(model_id, request_count, input_tokens, output_tokens, avg_latency_ms, success_rate)| {
|
||||||
|
ModelUsageStat {
|
||||||
|
model_id,
|
||||||
|
request_count,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
avg_latency_ms,
|
||||||
|
success_rate: success_rate.unwrap_or(0.0),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(stats)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 写入审计日志摘要(批量写入 operation_logs)
|
||||||
|
pub async fn ingest_audit_summary(
|
||||||
|
db: &PgPool,
|
||||||
|
account_id: &str,
|
||||||
|
device_id: &str,
|
||||||
|
entries: &[AuditSummaryEntry],
|
||||||
|
) -> SaasResult<usize> {
|
||||||
|
let mut written = 0usize;
|
||||||
|
|
||||||
|
for entry in entries {
|
||||||
|
if entry.action.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 审计详情仅包含操作类型和目标,不包含用户内容
|
||||||
|
let details = serde_json::json!({
|
||||||
|
"source": "desktop",
|
||||||
|
"device_id": device_id,
|
||||||
|
"result": entry.result,
|
||||||
|
});
|
||||||
|
|
||||||
|
let result = sqlx::query(
|
||||||
|
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, created_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, $6)"
|
||||||
|
)
|
||||||
|
.bind(account_id)
|
||||||
|
.bind(&entry.action)
|
||||||
|
.bind("desktop_audit")
|
||||||
|
.bind(&entry.target)
|
||||||
|
.bind(&details)
|
||||||
|
.bind(&entry.timestamp)
|
||||||
|
.execute(db)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(_) => written += 1,
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Failed to insert audit summary entry: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(written)
|
||||||
|
}/// 按天聚合用量统计
|
||||||
|
pub async fn get_daily_stats(
|
||||||
|
db: &PgPool,
|
||||||
|
account_id: &str,
|
||||||
|
query: &TelemetryStatsQuery,
|
||||||
|
) -> SaasResult<Vec<DailyUsageStat>> {
|
||||||
|
let days = query.days.unwrap_or(30).min(90).max(1);
|
||||||
|
|
||||||
|
let sql = format!(
|
||||||
|
"SELECT
|
||||||
|
SUBSTRING(reported_at, 1, 10) as day,
|
||||||
|
COUNT(*)::bigint as request_count,
|
||||||
|
COALESCE(SUM(input_tokens), 0)::bigint as input_tokens,
|
||||||
|
COALESCE(SUM(output_tokens), 0)::bigint as output_tokens,
|
||||||
|
COUNT(DISTINCT device_id)::bigint as unique_devices
|
||||||
|
FROM telemetry_reports
|
||||||
|
WHERE account_id = $1
|
||||||
|
AND reported_at >= to_char(CURRENT_DATE - INTERVAL '{} days', 'YYYY-MM-DD')
|
||||||
|
GROUP BY SUBSTRING(reported_at, 1, 10)
|
||||||
|
ORDER BY day DESC",
|
||||||
|
days
|
||||||
|
);
|
||||||
|
|
||||||
|
let rows: Vec<(String, i64, i64, i64, i64)> =
|
||||||
|
sqlx::query_as(&sql).bind(account_id).fetch_all(db).await?;
|
||||||
|
|
||||||
|
let stats: Vec<DailyUsageStat> = rows
|
||||||
|
.into_iter()
|
||||||
|
.map(|(day, request_count, input_tokens, output_tokens, unique_devices)| {
|
||||||
|
DailyUsageStat {
|
||||||
|
day,
|
||||||
|
request_count,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
unique_devices,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(stats)
|
||||||
|
}
|
||||||
98
crates/zclaw-saas/src/telemetry/types.rs
Normal file
98
crates/zclaw-saas/src/telemetry/types.rs
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
//! 遥测类型定义
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// 审计日志摘要条目(桌面端上报,仅操作类型和计数,无内容)
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct AuditSummaryEntry {
|
||||||
|
/// 操作类型(如 "hand.trigger", "agent.create")
|
||||||
|
pub action: String,
|
||||||
|
/// 操作目标(如 Agent 名称或 Hand 名称)
|
||||||
|
pub target: String,
|
||||||
|
/// 操作结果: "success" | "failure" | "pending"
|
||||||
|
pub result: String,
|
||||||
|
/// 操作时间(ISO 8601)
|
||||||
|
pub timestamp: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 审计摘要上报请求
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct AuditSummaryRequest {
|
||||||
|
/// 设备 ID
|
||||||
|
pub device_id: String,
|
||||||
|
/// 审计条目列表
|
||||||
|
pub entries: Vec<AuditSummaryEntry>,
|
||||||
|
}
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct TelemetryEntry {
|
||||||
|
/// 模型标识(如 "gpt-4o", "glm-4-flash")
|
||||||
|
pub model_id: String,
|
||||||
|
/// 输入 Token 数
|
||||||
|
pub input_tokens: i64,
|
||||||
|
/// 输出 Token 数
|
||||||
|
pub output_tokens: i64,
|
||||||
|
/// 调用延迟(毫秒)
|
||||||
|
pub latency_ms: Option<i64>,
|
||||||
|
/// 调用是否成功
|
||||||
|
pub success: bool,
|
||||||
|
/// 错误类型(失败时)
|
||||||
|
pub error_type: Option<String>,
|
||||||
|
/// 调用时间(ISO 8601)
|
||||||
|
pub timestamp: String,
|
||||||
|
/// 连接模式: "tauri"(本地直连)/ "saas"(通过中转)
|
||||||
|
pub connection_mode: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 桌面端遥测上报请求
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct TelemetryReportRequest {
|
||||||
|
/// 设备 ID
|
||||||
|
pub device_id: String,
|
||||||
|
/// 桌面端版本
|
||||||
|
pub app_version: String,
|
||||||
|
/// 用量条目列表(批量上报)
|
||||||
|
pub entries: Vec<TelemetryEntry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 遥测上报响应
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct TelemetryReportResponse {
|
||||||
|
pub accepted: usize,
|
||||||
|
pub rejected: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 遥测统计查询参数
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct TelemetryStatsQuery {
|
||||||
|
/// 起始日期 (ISO 8601)
|
||||||
|
pub from: Option<String>,
|
||||||
|
/// 结束日期 (ISO 8601)
|
||||||
|
pub to: Option<String>,
|
||||||
|
/// 按模型过滤
|
||||||
|
pub model_id: Option<String>,
|
||||||
|
/// 按连接模式过滤
|
||||||
|
pub connection_mode: Option<String>,
|
||||||
|
/// 按天分组时的时间范围(天数)
|
||||||
|
pub days: Option<i64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按模型聚合的用量统计
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct ModelUsageStat {
|
||||||
|
pub model_id: String,
|
||||||
|
pub request_count: i64,
|
||||||
|
pub input_tokens: i64,
|
||||||
|
pub output_tokens: i64,
|
||||||
|
pub avg_latency_ms: Option<f64>,
|
||||||
|
pub success_rate: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// 按天的用量统计
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct DailyUsageStat {
|
||||||
|
pub day: String,
|
||||||
|
pub request_count: i64,
|
||||||
|
pub input_tokens: i64,
|
||||||
|
pub output_tokens: i64,
|
||||||
|
pub unique_devices: i64,
|
||||||
|
}
|
||||||
203
crates/zclaw-saas/tests/account_test.rs
Normal file
203
crates/zclaw-saas/tests/account_test.rs
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use common::*;
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Account listing
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn list_accounts_forbidden_for_regular_user() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "useracct").await;
|
||||||
|
let (status, _) = send(&app, get("/api/v1/accounts", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::FORBIDDEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn list_accounts_success_as_admin() {
|
||||||
|
let (app, pool) = build_test_app().await;
|
||||||
|
let admin = admin_token(&app, &pool, "adminacct").await;
|
||||||
|
let (status, body) = send(&app, get("/api/v1/accounts", &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
// Should include at least the admin + the auto-seeded testadmin
|
||||||
|
assert!(body["items"].is_array() || body.is_array());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn get_own_account() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "ownacct").await;
|
||||||
|
|
||||||
|
// First get own account info from /me
|
||||||
|
let (status, me) = send(&app, get("/api/v1/auth/me", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK, "get /me: {me}");
|
||||||
|
let account_id = me["id"].as_str().unwrap();
|
||||||
|
eprintln!("DEBUG account_id = {account_id}");
|
||||||
|
|
||||||
|
let url = format!("/api/v1/accounts/{}", account_id);
|
||||||
|
eprintln!("DEBUG url = {url}");
|
||||||
|
let (status, body) = send(&app, get(&url, &token)).await;
|
||||||
|
eprintln!("DEBUG status = {status}, body = {body}");
|
||||||
|
assert_eq!(status, StatusCode::OK, "get own account: {body}");
|
||||||
|
assert_eq!(body["username"], "ownacct");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn update_own_account_display_name() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "updateacct").await;
|
||||||
|
|
||||||
|
// Get account ID from /me
|
||||||
|
let (_, me) = send(&app, get("/api/v1/auth/me", &token)).await;
|
||||||
|
let account_id = me["id"].as_str().unwrap();
|
||||||
|
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
patch(
|
||||||
|
&format!("/api/v1/accounts/{account_id}"),
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "display_name": "New Display Name" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK, "update account: {body}");
|
||||||
|
assert_eq!(body["display_name"], "New Display Name");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// API Token lifecycle
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn api_token_create_list_revoke() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "tokenuser").await;
|
||||||
|
|
||||||
|
// Create
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/tokens",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "name": "test-token", "permissions": ["model:read", "relay:use"] }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK, "create token: {body}");
|
||||||
|
let raw_token = body["token"].as_str().unwrap();
|
||||||
|
assert!(raw_token.starts_with("zclaw_"));
|
||||||
|
let token_id = body["id"].as_str().unwrap();
|
||||||
|
|
||||||
|
// List (paginated response: {items, total, page, page_size})
|
||||||
|
let (status, list) = send(&app, get("/api/v1/tokens", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK, "list tokens: {list}");
|
||||||
|
assert!(list["items"].is_array(), "tokens list should have items field: {list}");
|
||||||
|
assert_eq!(list["items"].as_array().unwrap().len(), 1);
|
||||||
|
|
||||||
|
// Use the API token to authenticate
|
||||||
|
let (status, _) = send(&app, get("/api/v1/auth/me", raw_token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
// Revoke
|
||||||
|
let (status, _) = send(&app, delete(&format!("/api/v1/tokens/{token_id}"), &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
// After revoke, API token no longer works
|
||||||
|
let (status, _) = send(&app, get("/api/v1/auth/me", raw_token)).await;
|
||||||
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Device registration
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn device_register_and_list() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "deviceuser").await;
|
||||||
|
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/devices/register",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({
|
||||||
|
"device_id": "test-device-001",
|
||||||
|
"device_name": "Test Desktop",
|
||||||
|
"platform": "windows",
|
||||||
|
"app_version": "0.1.0"
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
let (status, body) = send(&app, get("/api/v1/devices", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK, "list devices: {body}");
|
||||||
|
let devices = body["items"].as_array().expect("devices should be paginated {items}");
|
||||||
|
assert_eq!(devices.len(), 1);
|
||||||
|
assert_eq!(devices[0]["device_id"], "test-device-001");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn device_upsert_on_reregister() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "upsertdev").await;
|
||||||
|
|
||||||
|
send(&app, post("/api/v1/devices/register", &token, serde_json::json!({
|
||||||
|
"device_id": "dev-upsert", "device_name": "Old Name"
|
||||||
|
}))).await;
|
||||||
|
|
||||||
|
send(&app, post("/api/v1/devices/register", &token, serde_json::json!({
|
||||||
|
"device_id": "dev-upsert", "device_name": "New Name"
|
||||||
|
}))).await;
|
||||||
|
|
||||||
|
let (_, body) = send(&app, get("/api/v1/devices", &token)).await;
|
||||||
|
let devs = body["items"].as_array().expect("devices should be paginated {items}");
|
||||||
|
assert_eq!(devs.len(), 1);
|
||||||
|
assert_eq!(devs[0]["device_name"], "New Name");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn device_heartbeat() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "hbuser").await;
|
||||||
|
|
||||||
|
// Register first
|
||||||
|
send(&app, post("/api/v1/devices/register", &token, serde_json::json!({
|
||||||
|
"device_id": "hb-dev"
|
||||||
|
}))).await;
|
||||||
|
|
||||||
|
// Heartbeat
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post("/api/v1/devices/heartbeat", &token, serde_json::json!({ "device_id": "hb-dev" })),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
// Heartbeat nonexistent → 404
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post("/api/v1/devices/heartbeat", &token, serde_json::json!({ "device_id": "ghost" })),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::NOT_FOUND);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Operation logs (admin only)
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn operation_logs_forbidden_for_user() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "loguser").await;
|
||||||
|
let (status, _) = send(&app, get("/api/v1/logs/operations", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::FORBIDDEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dashboard_stats_admin() {
|
||||||
|
let (app, pool) = build_test_app().await;
|
||||||
|
let admin = admin_token(&app, &pool, "statsadmin").await;
|
||||||
|
let (status, _) = send(&app, get("/api/v1/stats/dashboard", &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
}
|
||||||
97
crates/zclaw-saas/tests/agent_template_test.rs
Normal file
97
crates/zclaw-saas/tests/agent_template_test.rs
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use common::*;
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// List templates
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_template_list_empty() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "atlist").await;
|
||||||
|
let (status, body) = send(&app, get("/api/v1/agent-templates", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert!(body.is_array() || body["items"].is_array());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Full CRUD
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_template_crud() {
|
||||||
|
let (app, pool) = build_test_app().await;
|
||||||
|
let admin = admin_token(&app, &pool, "atadmin").await;
|
||||||
|
|
||||||
|
// Create
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/agent-templates",
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({
|
||||||
|
"name": "Test Agent",
|
||||||
|
"description": "A test agent template",
|
||||||
|
"category": "general",
|
||||||
|
"model": "test-model-v1",
|
||||||
|
"system_prompt": "You are a test agent.",
|
||||||
|
"tools": ["search", "browser"],
|
||||||
|
"capabilities": ["reasoning", "code"],
|
||||||
|
"temperature": 0.7,
|
||||||
|
"max_tokens": 4096
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
if status != StatusCode::OK {
|
||||||
|
eprintln!("ERROR create agent template: status={status}, body={body}");
|
||||||
|
}
|
||||||
|
assert_eq!(status, StatusCode::OK, "create agent template: {body}");
|
||||||
|
let tmpl_id = body["id"].as_str().unwrap();
|
||||||
|
assert_eq!(body["name"], "Test Agent");
|
||||||
|
|
||||||
|
// Get
|
||||||
|
let (status, body) = send(&app, get(&format!("/api/v1/agent-templates/{tmpl_id}"), &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(body["name"], "Test Agent");
|
||||||
|
assert_eq!(body["model"], "test-model-v1");
|
||||||
|
|
||||||
|
// Update (POST for update)
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
&format!("/api/v1/agent-templates/{tmpl_id}"),
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({
|
||||||
|
"description": "Updated description",
|
||||||
|
"temperature": 0.5
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(body["description"], "Updated description");
|
||||||
|
|
||||||
|
// Archive (DELETE)
|
||||||
|
let (status, _) = send(&app, delete(&format!("/api/v1/agent-templates/{tmpl_id}"), &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Permission enforcement
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_template_create_forbidden_for_user() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "atuser").await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/agent-templates",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "name": "Forbidden" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::FORBIDDEN);
|
||||||
|
}
|
||||||
385
crates/zclaw-saas/tests/auth_test.rs
Normal file
385
crates/zclaw-saas/tests/auth_test.rs
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use common::*;
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Registration
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn register_success() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let (token, refresh, json) = register(&app, "alice", "alice@test.io", DEFAULT_PASSWORD).await;
|
||||||
|
assert!(!token.is_empty());
|
||||||
|
assert!(!refresh.is_empty());
|
||||||
|
assert_eq!(json["account"]["username"], "alice");
|
||||||
|
assert_eq!(json["account"]["role"], "user");
|
||||||
|
assert_eq!(json["account"]["status"], "active");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn register_duplicate_username() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
register(&app, "dupuser", "dup@test.io", DEFAULT_PASSWORD).await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
serde_json::json!({ "username": "dupuser", "email": "other@test.io", "password": DEFAULT_PASSWORD }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::CONFLICT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn register_duplicate_email() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
register(&app, "user1", "same@test.io", DEFAULT_PASSWORD).await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
serde_json::json!({ "username": "user2", "email": "same@test.io", "password": DEFAULT_PASSWORD }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::CONFLICT);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn register_validation_short_username() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
serde_json::json!({ "username": "ab", "email": "a@b.c", "password": DEFAULT_PASSWORD }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||||
|
assert_eq!(body["error"], "INVALID_INPUT");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn register_validation_bad_email() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
serde_json::json!({ "username": "goodname", "email": "no-at-sign", "password": DEFAULT_PASSWORD }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn register_validation_short_password() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
serde_json::json!({ "username": "goodname", "email": "a@b.c", "password": "short" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Login
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn login_success() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
register(&app, "loginuser", "login@test.io", DEFAULT_PASSWORD).await;
|
||||||
|
let (token, refresh, json) = login(&app, "loginuser", DEFAULT_PASSWORD).await;
|
||||||
|
assert!(!token.is_empty());
|
||||||
|
assert!(!refresh.is_empty());
|
||||||
|
assert_eq!(json["account"]["username"], "loginuser");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn login_wrong_password() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
register(&app, "wrongpwd", "wrong@test.io", DEFAULT_PASSWORD).await;
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
serde_json::json!({ "username": "wrongpwd", "password": "incorrect_password" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
assert_eq!(body["error"], "AUTH_ERROR");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn login_nonexistent_user() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
serde_json::json!({ "username": "ghost", "password": "whatever" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Auth chain: register → login → me (P0)
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn auth_chain_register_login_me() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
|
||||||
|
// 1. Register
|
||||||
|
let (token, _refresh, json) = register(&app, "chainuser", "chain@test.io", DEFAULT_PASSWORD).await;
|
||||||
|
assert_eq!(json["account"]["username"], "chainuser");
|
||||||
|
|
||||||
|
// 2. GET /me with the registration token
|
||||||
|
let (status, me) = send(&app, get("/api/v1/auth/me", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(me["username"], "chainuser");
|
||||||
|
assert_eq!(me["role"], "user");
|
||||||
|
|
||||||
|
// 3. Login separately
|
||||||
|
let (token2, _, _) = login(&app, "chainuser", DEFAULT_PASSWORD).await;
|
||||||
|
|
||||||
|
// 4. GET /me with the login token
|
||||||
|
let (status, me2) = send(&app, get("/api/v1/auth/me", &token2)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(me2["id"], me["id"]); // same account
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn me_without_token_is_unauthorized() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let req = axum::http::Request::builder()
|
||||||
|
.method("GET")
|
||||||
|
.uri("/api/v1/auth/me")
|
||||||
|
.body(axum::body::Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
let (status, _) = send(&app, req).await;
|
||||||
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn me_with_invalid_token_is_unauthorized() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let (status, _) = send(&app, get("/api/v1/auth/me", "invalid.jwt.token")).await;
|
||||||
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Refresh token
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn refresh_token_success() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let (_, refresh, _) = register(&app, "refreshuser", "refresh@test.io", DEFAULT_PASSWORD).await;
|
||||||
|
|
||||||
|
// Use refresh token to get a new pair
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
serde_json::json!({ "refresh_token": refresh }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert!(body["token"].is_string());
|
||||||
|
assert!(body["refresh_token"].is_string());
|
||||||
|
let new_token = body["token"].as_str().unwrap();
|
||||||
|
let new_refresh = body["refresh_token"].as_str().unwrap();
|
||||||
|
assert!(!new_token.is_empty());
|
||||||
|
assert!(!new_refresh.is_empty());
|
||||||
|
|
||||||
|
// New token works for /me
|
||||||
|
let (status, _) = send(&app, get("/api/v1/auth/me", new_token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn refresh_token_one_time_use() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let (_, refresh, _) = register(&app, "onetime", "onetime@test.io", DEFAULT_PASSWORD).await;
|
||||||
|
|
||||||
|
// First refresh succeeds
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
serde_json::json!({ "refresh_token": refresh }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
// Second use of the same refresh token fails
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
serde_json::json!({ "refresh_token": refresh }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn refresh_with_invalid_token() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
serde_json::json!({ "refresh_token": "garbage" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::INTERNAL_SERVER_ERROR);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Password change
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn change_password_success() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "pwduser").await;
|
||||||
|
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
put(
|
||||||
|
"/api/v1/auth/password",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "old_password": DEFAULT_PASSWORD, "new_password": "BrandNewP@ss1" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
// Login with new password works
|
||||||
|
let (_, _, _) = login(&app, "pwduser", "BrandNewP@ss1").await;
|
||||||
|
|
||||||
|
// Login with old password fails
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post_public(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
serde_json::json!({ "username": "pwduser", "password": DEFAULT_PASSWORD }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn change_password_wrong_old() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "wrongold").await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
put(
|
||||||
|
"/api/v1/auth/password",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "old_password": "wrong_old_pass", "new_password": "BrandNewP@ss1" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn change_password_too_short() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "shortpwd").await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
put(
|
||||||
|
"/api/v1/auth/password",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "old_password": DEFAULT_PASSWORD, "new_password": "abc" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// TOTP 2FA (P0 chain test)
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn totp_setup_and_disable() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "totpuser").await;
|
||||||
|
|
||||||
|
// Setup TOTP → returns otpauth_uri + secret
|
||||||
|
let (status, body) = send(&app, post("/api/v1/auth/totp/setup", &token, serde_json::json!({}))).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert!(body["otpauth_uri"].is_string());
|
||||||
|
assert!(body["secret"].is_string());
|
||||||
|
|
||||||
|
// Disable TOTP (requires password)
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post("/api/v1/auth/totp/disable", &token, serde_json::json!({ "password": DEFAULT_PASSWORD })),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
// After disable, login without TOTP code succeeds
|
||||||
|
let (_, _, login_json) = login(&app, "totpuser", DEFAULT_PASSWORD).await;
|
||||||
|
assert_eq!(login_json["account"]["totp_enabled"], false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn totp_disable_wrong_password() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "totpwrong").await;
|
||||||
|
// Setup first
|
||||||
|
send(&app, post("/api/v1/auth/totp/setup", &token, serde_json::json!({}))).await;
|
||||||
|
// Try disable with wrong password
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/auth/totp/disable",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "password": "wrong_password_here" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::UNAUTHORIZED);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn totp_verify_wrong_code() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "totpbadcode").await;
|
||||||
|
// Setup
|
||||||
|
send(&app, post("/api/v1/auth/totp/setup", &token, serde_json::json!({}))).await;
|
||||||
|
// Verify with a definitely-wrong code
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/auth/totp/verify",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "code": "000000" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::BAD_REQUEST);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Health endpoint
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn health_check() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let req = axum::http::Request::builder()
|
||||||
|
.uri("/api/health")
|
||||||
|
.body(axum::body::Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
let (status, _) = send(&app, req).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
}
|
||||||
382
crates/zclaw-saas/tests/common/mod.rs
Normal file
382
crates/zclaw-saas/tests/common/mod.rs
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
//! Integration test harness for zclaw-saas
|
||||||
|
//!
|
||||||
|
//! Uses a **shared** PostgreSQL database (`zclaw_test_shared`) with per-test
|
||||||
|
//! TRUNCATE isolation. Only one database is created; each test truncates all
|
||||||
|
//! tables and re-seeds via `init_db`.
|
||||||
|
//!
|
||||||
|
//! # Setup
|
||||||
|
//!
|
||||||
|
//! ```bash
|
||||||
|
//! # Start PostgreSQL (e.g. via Docker Compose)
|
||||||
|
//! docker compose up -d postgres
|
||||||
|
//!
|
||||||
|
//! # Set the test database URL (point to the base DB for CREATE DATABASE)
|
||||||
|
//! export TEST_DATABASE_URL="postgres://postgres:123123@localhost:5432/zclaw"
|
||||||
|
//!
|
||||||
|
//! # Run tests
|
||||||
|
//! cargo test -p zclaw-saas
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
use axum::body::Body;
|
||||||
|
use axum::http::{Request, StatusCode};
|
||||||
|
use axum::Router;
|
||||||
|
use sqlx::PgPool;
|
||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use tower::ServiceExt;
|
||||||
|
use zclaw_saas::config::SaaSConfig;
|
||||||
|
use zclaw_saas::db::init_db;
|
||||||
|
use zclaw_saas::state::AppState;
|
||||||
|
|
||||||
|
pub const MAX_BODY: usize = 2 * 1024 * 1024;
|
||||||
|
pub const DEFAULT_PASSWORD: &str = "testpassword123";
|
||||||
|
|
||||||
|
const SHARED_DB_NAME: &str = "zclaw_test_shared";
|
||||||
|
|
||||||
|
/// Schema version counter — increment to force DROP+CREATE on next run.
|
||||||
|
const SCHEMA_VERSION: u32 = 2;
|
||||||
|
|
||||||
|
/// Whether the shared test database has been created at the current schema version.
|
||||||
|
static DB_CREATED: AtomicBool = AtomicBool::new(false);
|
||||||
|
|
||||||
|
// ── Database helpers ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Resolve the base test database URL (used to connect for CREATE DATABASE).
|
||||||
|
pub fn test_database_url() -> String {
|
||||||
|
std::env::var("TEST_DATABASE_URL")
|
||||||
|
.or_else(|_| std::env::var("DATABASE_URL"))
|
||||||
|
.unwrap_or_else(|_| "postgres://postgres:123123@localhost:5432/zclaw".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the shared test database URL by replacing the database name.
|
||||||
|
fn shared_db_url() -> String {
|
||||||
|
let mut url = test_database_url();
|
||||||
|
if let Some(pos) = url.rfind('/') {
|
||||||
|
url.truncate(pos + 1);
|
||||||
|
url.push_str(SHARED_DB_NAME);
|
||||||
|
}
|
||||||
|
url
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ensure the shared test database exists with a clean schema.
|
||||||
|
/// Runs once per process: drops the old DB and recreates it.
|
||||||
|
async fn ensure_shared_db() -> String {
|
||||||
|
if !DB_CREATED.swap(true, Ordering::SeqCst) {
|
||||||
|
let base = test_database_url();
|
||||||
|
let pool = PgPool::connect(&base)
|
||||||
|
.await
|
||||||
|
.expect("Cannot connect to PostgreSQL — is it running?");
|
||||||
|
// Drop + recreate for a clean schema
|
||||||
|
let _ = sqlx::query(&format!("DROP DATABASE IF EXISTS \"{}\"", SHARED_DB_NAME))
|
||||||
|
.execute(&pool)
|
||||||
|
.await;
|
||||||
|
sqlx::query(&format!("CREATE DATABASE \"{}\"", SHARED_DB_NAME))
|
||||||
|
.execute(&pool)
|
||||||
|
.await
|
||||||
|
.expect("Failed to create shared test database");
|
||||||
|
drop(pool);
|
||||||
|
}
|
||||||
|
shared_db_url()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Truncate all public tables in the database (CASCADE handles FK).
|
||||||
|
async fn truncate_all_tables(pool: &PgPool) {
|
||||||
|
sqlx::query(
|
||||||
|
r#"DO $$
|
||||||
|
DECLARE
|
||||||
|
r RECORD;
|
||||||
|
BEGIN
|
||||||
|
FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') LOOP
|
||||||
|
EXECUTE 'TRUNCATE TABLE ' || quote_ident(r.tablename) || ' CASCADE';
|
||||||
|
END LOOP;
|
||||||
|
END$$;"#,
|
||||||
|
)
|
||||||
|
.execute(pool)
|
||||||
|
.await
|
||||||
|
.expect("Failed to truncate tables");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── App builder ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Build a full Axum `Router` wired to the shared test database.
|
||||||
|
///
|
||||||
|
/// Flow per test:
|
||||||
|
/// 1. Ensure shared DB exists (once)
|
||||||
|
/// 2. Truncate all tables (isolation)
|
||||||
|
/// 3. Re-run `init_db` to seed fresh data
|
||||||
|
/// 4. Return `(Router, PgPool)`
|
||||||
|
pub async fn build_test_app() -> (Router, PgPool) {
|
||||||
|
let db_url = ensure_shared_db().await;
|
||||||
|
|
||||||
|
// Dev-mode env vars
|
||||||
|
std::env::set_var("ZCLAW_SAAS_DEV", "true");
|
||||||
|
std::env::set_var("ZCLAW_SAAS_JWT_SECRET", "test-jwt-secret-do-not-use-in-prod");
|
||||||
|
std::env::set_var("ZCLAW_ADMIN_USERNAME", "testadmin");
|
||||||
|
std::env::set_var("ZCLAW_ADMIN_PASSWORD", "Admin123456");
|
||||||
|
|
||||||
|
// Truncate all data for test isolation
|
||||||
|
let truncate_pool = PgPool::connect(&db_url)
|
||||||
|
.await
|
||||||
|
.expect("Cannot connect to shared test DB");
|
||||||
|
truncate_all_tables(&truncate_pool).await;
|
||||||
|
drop(truncate_pool);
|
||||||
|
|
||||||
|
// init_db: schema (IF NOT EXISTS, fast) + seed data
|
||||||
|
let pool = init_db(&db_url).await.expect("init_db failed");
|
||||||
|
|
||||||
|
let mut config = SaaSConfig::default();
|
||||||
|
config.auth.jwt_expiration_hours = 24;
|
||||||
|
config.auth.refresh_token_hours = 168;
|
||||||
|
config.rate_limit.requests_per_minute = 10_000;
|
||||||
|
config.rate_limit.burst = 1_000;
|
||||||
|
|
||||||
|
let state = AppState::new(pool.clone(), config).expect("AppState::new failed");
|
||||||
|
let router = build_router(state);
|
||||||
|
(router, pool)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_router(state: AppState) -> Router {
|
||||||
|
use axum::middleware;
|
||||||
|
use tower_http::cors::{Any, CorsLayer};
|
||||||
|
use tower_http::trace::TraceLayer;
|
||||||
|
|
||||||
|
let public_routes = zclaw_saas::auth::routes()
|
||||||
|
.route("/api/health", axum::routing::get(health_handler));
|
||||||
|
|
||||||
|
let protected_routes = zclaw_saas::auth::protected_routes()
|
||||||
|
.merge(zclaw_saas::account::routes())
|
||||||
|
.merge(zclaw_saas::model_config::routes())
|
||||||
|
.merge(zclaw_saas::relay::routes())
|
||||||
|
.merge(zclaw_saas::migration::routes())
|
||||||
|
.merge(zclaw_saas::role::routes())
|
||||||
|
.merge(zclaw_saas::prompt::routes())
|
||||||
|
.merge(zclaw_saas::agent_template::routes())
|
||||||
|
.merge(zclaw_saas::telemetry::routes())
|
||||||
|
.layer(middleware::from_fn_with_state(
|
||||||
|
state.clone(),
|
||||||
|
zclaw_saas::middleware::api_version_middleware,
|
||||||
|
))
|
||||||
|
.layer(middleware::from_fn_with_state(
|
||||||
|
state.clone(),
|
||||||
|
zclaw_saas::middleware::request_id_middleware,
|
||||||
|
))
|
||||||
|
.layer(middleware::from_fn_with_state(
|
||||||
|
state.clone(),
|
||||||
|
zclaw_saas::middleware::rate_limit_middleware,
|
||||||
|
))
|
||||||
|
.layer(middleware::from_fn_with_state(
|
||||||
|
state.clone(),
|
||||||
|
zclaw_saas::auth::auth_middleware,
|
||||||
|
));
|
||||||
|
|
||||||
|
Router::new()
|
||||||
|
.merge(public_routes)
|
||||||
|
.merge(protected_routes)
|
||||||
|
.layer(TraceLayer::new_for_http())
|
||||||
|
.layer(
|
||||||
|
CorsLayer::new()
|
||||||
|
.allow_origin(Any)
|
||||||
|
.allow_methods(Any)
|
||||||
|
.allow_headers(Any),
|
||||||
|
)
|
||||||
|
.with_state(state)
|
||||||
|
.layer(axum::middleware::from_fn(inject_connect_info))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Simple health handler for testing (mirrors main.rs health_handler).
|
||||||
|
async fn health_handler(State(state): axum::extract::State<AppState>) -> axum::Json<serde_json::Value> {
|
||||||
|
let db_healthy = sqlx::query_scalar::<_, i32>("SELECT 1")
|
||||||
|
.fetch_one(&state.db)
|
||||||
|
.await
|
||||||
|
.ok()
|
||||||
|
.map(|v| v == 1)
|
||||||
|
.unwrap_or(false);
|
||||||
|
let status = if db_healthy { "healthy" } else { "degraded" };
|
||||||
|
axum::Json(serde_json::json!({ "status": status, "database": db_healthy }))
|
||||||
|
}
|
||||||
|
|
||||||
|
use axum::extract::State;
|
||||||
|
async fn inject_connect_info(
|
||||||
|
mut req: axum::extract::Request,
|
||||||
|
next: axum::middleware::Next,
|
||||||
|
) -> axum::response::Response {
|
||||||
|
use axum::extract::ConnectInfo;
|
||||||
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
|
req.extensions_mut().insert(ConnectInfo::<SocketAddr>(
|
||||||
|
"127.0.0.1:12345".parse().unwrap(),
|
||||||
|
));
|
||||||
|
next.run(req).await
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── HTTP helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub async fn body_bytes(body: Body) -> Vec<u8> {
|
||||||
|
axum::body::to_bytes(body, MAX_BODY)
|
||||||
|
.await
|
||||||
|
.expect("body too large")
|
||||||
|
.to_vec()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn body_json(body: Body) -> serde_json::Value {
|
||||||
|
let bytes = body_bytes(body).await;
|
||||||
|
serde_json::from_slice(&bytes).unwrap_or_else(|e| {
|
||||||
|
panic!(
|
||||||
|
"Failed to parse JSON: {}\nBody: {}",
|
||||||
|
e,
|
||||||
|
String::from_utf8_lossy(&bytes)
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(uri: &str, token: &str) -> Request<Body> {
|
||||||
|
Request::builder()
|
||||||
|
.method("GET")
|
||||||
|
.uri(uri)
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn delete(uri: &str, token: &str) -> Request<Body> {
|
||||||
|
Request::builder()
|
||||||
|
.method("DELETE")
|
||||||
|
.uri(uri)
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post(uri: &str, token: &str, body: serde_json::Value) -> Request<Body> {
|
||||||
|
Request::builder()
|
||||||
|
.method("POST")
|
||||||
|
.uri(uri)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.body(Body::from(body.to_string()))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn post_public(uri: &str, body: serde_json::Value) -> Request<Body> {
|
||||||
|
Request::builder()
|
||||||
|
.method("POST")
|
||||||
|
.uri(uri)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.body(Body::from(body.to_string()))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn put(uri: &str, token: &str, body: serde_json::Value) -> Request<Body> {
|
||||||
|
Request::builder()
|
||||||
|
.method("PUT")
|
||||||
|
.uri(uri)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.body(Body::from(body.to_string()))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn patch(uri: &str, token: &str, body: serde_json::Value) -> Request<Body> {
|
||||||
|
Request::builder()
|
||||||
|
.method("PATCH")
|
||||||
|
.uri(uri)
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.body(Body::from(body.to_string()))
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send request and return (status, body_json).
|
||||||
|
/// If body is empty, returns `serde_json::Value::Null` instead of panicking.
|
||||||
|
pub async fn send(app: &Router, req: Request<Body>) -> (StatusCode, serde_json::Value) {
|
||||||
|
let resp = app.clone().oneshot(req).await.unwrap();
|
||||||
|
let status = resp.status();
|
||||||
|
let bytes = body_bytes(resp.into_body()).await;
|
||||||
|
if bytes.is_empty() {
|
||||||
|
return (status, serde_json::Value::Null);
|
||||||
|
}
|
||||||
|
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or_else(|e| {
|
||||||
|
panic!(
|
||||||
|
"Failed to parse JSON: {}\nBody: {}",
|
||||||
|
e,
|
||||||
|
String::from_utf8_lossy(&bytes)
|
||||||
|
)
|
||||||
|
});
|
||||||
|
(status, json)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Auth helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Register a new user. Returns (access_token, refresh_token, response_json).
|
||||||
|
pub async fn register(
|
||||||
|
app: &Router,
|
||||||
|
username: &str,
|
||||||
|
email: &str,
|
||||||
|
password: &str,
|
||||||
|
) -> (String, String, serde_json::Value) {
|
||||||
|
let resp = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(post_public(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
serde_json::json!({ "username": username, "email": email, "password": password }),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let status = resp.status();
|
||||||
|
let json = body_json(resp.into_body()).await;
|
||||||
|
assert_eq!(status, StatusCode::CREATED, "register failed: {json}");
|
||||||
|
let token = json["token"].as_str().unwrap().to_string();
|
||||||
|
let refresh = json["refresh_token"].as_str().unwrap().to_string();
|
||||||
|
(token, refresh, json)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Login. Returns (access_token, refresh_token, response_json).
|
||||||
|
pub async fn login(
|
||||||
|
app: &Router,
|
||||||
|
username: &str,
|
||||||
|
password: &str,
|
||||||
|
) -> (String, String, serde_json::Value) {
|
||||||
|
let resp = app
|
||||||
|
.clone()
|
||||||
|
.oneshot(post_public(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
serde_json::json!({ "username": username, "password": password }),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let status = resp.status();
|
||||||
|
let json = body_json(resp.into_body()).await;
|
||||||
|
assert_eq!(status, StatusCode::OK, "login failed: {json}");
|
||||||
|
let token = json["token"].as_str().unwrap().to_string();
|
||||||
|
let refresh = json["refresh_token"].as_str().unwrap().to_string();
|
||||||
|
(token, refresh, json)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register + return access token.
|
||||||
|
pub async fn register_token(app: &Router, username: &str) -> String {
|
||||||
|
let email = format!("{username}@test.io");
|
||||||
|
register(app, username, &email, DEFAULT_PASSWORD).await.0
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a user and promote to `admin`. Returns fresh JWT with admin permissions.
|
||||||
|
pub async fn admin_token(app: &Router, pool: &PgPool, username: &str) -> String {
|
||||||
|
let email = format!("{username}@test.io");
|
||||||
|
register(app, username, &email, DEFAULT_PASSWORD).await;
|
||||||
|
sqlx::query("UPDATE accounts SET role = 'admin' WHERE username = $1")
|
||||||
|
.bind(username)
|
||||||
|
.execute(pool)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
login(app, username, DEFAULT_PASSWORD).await.0
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a user and promote to `super_admin`. Returns fresh JWT.
|
||||||
|
pub async fn super_admin_token(app: &Router, pool: &PgPool, username: &str) -> String {
|
||||||
|
let email = format!("{username}@test.io");
|
||||||
|
register(app, username, &email, DEFAULT_PASSWORD).await;
|
||||||
|
sqlx::query("UPDATE accounts SET role = 'super_admin' WHERE username = $1")
|
||||||
|
.bind(username)
|
||||||
|
.execute(pool)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
login(app, username, DEFAULT_PASSWORD).await.0
|
||||||
|
}
|
||||||
174
crates/zclaw-saas/tests/migration_test.rs
Normal file
174
crates/zclaw-saas/tests/migration_test.rs
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use common::*;
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Config analysis
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn config_analysis_empty() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "cfganalyze").await;
|
||||||
|
let (status, body) = send(&app, get("/api/v1/config/analysis", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(body["total_items"], 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Config items CRUD
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn config_items_crud() {
|
||||||
|
let (app, pool) = build_test_app().await;
|
||||||
|
let admin = admin_token(&app, &pool, "cfgadmin").await;
|
||||||
|
|
||||||
|
// Create config item
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/config/items",
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({
|
||||||
|
"category": "server",
|
||||||
|
"key_path": "server.host",
|
||||||
|
"value_type": "string",
|
||||||
|
"current_value": "0.0.0.0",
|
||||||
|
"description": "Server bind address"
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::CREATED, "create config item: {body}");
|
||||||
|
let item_id = body["id"].as_str().unwrap();
|
||||||
|
|
||||||
|
// List
|
||||||
|
let (status, list) = send(&app, get("/api/v1/config/items", &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert!(list.is_array() || list["items"].is_array());
|
||||||
|
|
||||||
|
// Get
|
||||||
|
let (status, body) = send(&app, get(&format!("/api/v1/config/items/{item_id}"), &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(body["key_path"], "server.host");
|
||||||
|
|
||||||
|
// Update
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
put(
|
||||||
|
&format!("/api/v1/config/items/{item_id}"),
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({ "current_value": "127.0.0.1" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(body["current_value"], "127.0.0.1");
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
let (status, _) = send(&app, delete(&format!("/api/v1/config/items/{item_id}"), &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn config_items_write_forbidden_for_user() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "cfguser").await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/config/items",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "category": "x", "key_path": "y", "value_type": "string" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::FORBIDDEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Config seed
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn config_seed_admin_only() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let user_token = register_token(&app, "cfgseeduser").await;
|
||||||
|
let (status, _) = send(&app, post("/api/v1/config/seed", &user_token, serde_json::json!({}))).await;
|
||||||
|
assert_eq!(status, StatusCode::FORBIDDEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Config sync (push)
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn config_sync_push() {
|
||||||
|
let (app, pool) = build_test_app().await;
|
||||||
|
let admin = admin_token(&app, &pool, "cfgsync").await;
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/config/sync",
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({
|
||||||
|
"client_fingerprint": "test-desktop-v1",
|
||||||
|
"action": "push",
|
||||||
|
"config_keys": ["server.host", "server.port"],
|
||||||
|
"client_values": { "server.host": "192.168.1.1", "server.port": "9090" }
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK, "config sync push: {body}");
|
||||||
|
// Push mode: keys don't exist in SaaS → auto-created
|
||||||
|
assert_eq!(body["created"], 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Config diff
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn config_diff() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "cfgdiff").await;
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/config/diff",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({
|
||||||
|
"client_fingerprint": "test-client",
|
||||||
|
"action": "push",
|
||||||
|
"config_keys": ["server.host"],
|
||||||
|
"client_values": { "server.host": "0.0.0.0" }
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(body["total_keys"], 1);
|
||||||
|
assert!(body["items"].is_array());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Config sync logs
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn config_sync_logs() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "cfglogs").await;
|
||||||
|
let (status, _) = send(&app, get("/api/v1/config/sync-logs", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Config pull
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn config_pull_empty() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "cfgpull").await;
|
||||||
|
let (status, _) = send(&app, get("/api/v1/config/pull", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
}
|
||||||
234
crates/zclaw-saas/tests/model_config_test.rs
Normal file
234
crates/zclaw-saas/tests/model_config_test.rs
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
mod common;
|
||||||
|
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use common::*;
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Provider CRUD
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn provider_crud_full_lifecycle() {
|
||||||
|
let (app, pool) = build_test_app().await;
|
||||||
|
let admin = admin_token(&app, &pool, "provadmin").await;
|
||||||
|
|
||||||
|
// Create
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/providers",
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({
|
||||||
|
"name": "test-provider",
|
||||||
|
"display_name": "Test Provider",
|
||||||
|
"base_url": "https://api.example.com/v1"
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::CREATED, "create provider failed: {body}");
|
||||||
|
let provider_id = body["id"].as_str().unwrap().to_string();
|
||||||
|
|
||||||
|
// List (paginated)
|
||||||
|
let (status, body) = send(&app, get("/api/v1/providers", &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
let items = body["items"].as_array().expect("providers should be paginated {items}");
|
||||||
|
assert!(items.iter().any(|p| p["id"] == provider_id));
|
||||||
|
|
||||||
|
// Get
|
||||||
|
let (status, body) = send(&app, get(&format!("/api/v1/providers/{provider_id}"), &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(body["name"], "test-provider");
|
||||||
|
|
||||||
|
// Update
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
patch(
|
||||||
|
&format!("/api/v1/providers/{provider_id}"),
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({ "display_name": "Updated Provider" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(body["display_name"], "Updated Provider");
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
let (status, _) = send(&app, delete(&format!("/api/v1/providers/{provider_id}"), &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
// Verify deleted
|
||||||
|
let (status, _) = send(&app, get(&format!("/api/v1/providers/{provider_id}"), &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::NOT_FOUND);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn provider_create_forbidden_for_user() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "provuser").await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/providers",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "name": "x", "display_name": "X", "base_url": "https://x.com" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::FORBIDDEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn provider_list_accessible_to_all_authenticated() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "listprovuser").await;
|
||||||
|
let (status, body) = send(&app, get("/api/v1/providers", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert!(body["items"].is_array(), "providers list should be paginated: {body}");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Model CRUD
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn model_crud_with_provider() {
|
||||||
|
let (app, pool) = build_test_app().await;
|
||||||
|
let admin = admin_token(&app, &pool, "modeladmin").await;
|
||||||
|
|
||||||
|
// Create provider first
|
||||||
|
let (_, prov_body) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/providers",
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({ "name": "model-prov", "display_name": "Model Prov", "base_url": "https://api.test.com/v1" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
let provider_id = prov_body["id"].as_str().unwrap();
|
||||||
|
|
||||||
|
// Create model
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/models",
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({
|
||||||
|
"provider_id": provider_id,
|
||||||
|
"model_id": "test-model-v1",
|
||||||
|
"alias": "Test Model",
|
||||||
|
"context_window": 8192
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::CREATED, "create model: {body}");
|
||||||
|
let model_id = body["id"].as_str().unwrap();
|
||||||
|
|
||||||
|
// List models (paginated)
|
||||||
|
let (status, list) = send(&app, get("/api/v1/models", &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert!(list["items"].is_array(), "models list should be paginated: {list}");
|
||||||
|
|
||||||
|
// Get model
|
||||||
|
let (status, _) = send(&app, get(&format!("/api/v1/models/{model_id}"), &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
|
||||||
|
// Update model
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
patch(
|
||||||
|
&format!("/api/v1/models/{model_id}"),
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({ "alias": "Updated Alias" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(body["alias"], "Updated Alias");
|
||||||
|
|
||||||
|
// Delete model
|
||||||
|
let (status, _) = send(&app, delete(&format!("/api/v1/models/{model_id}"), &admin)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn model_create_forbidden_for_user() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "modeluser").await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/models",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "provider_id": "x", "model_id": "y", "alias": "Z" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::FORBIDDEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Account API Key
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn api_key_requires_existing_provider() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "keyuser").await;
|
||||||
|
let (status, _) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/keys",
|
||||||
|
&token,
|
||||||
|
serde_json::json!({ "provider_id": "nonexistent", "key_value": "sk-test", "key_label": "Test" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::NOT_FOUND);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn api_key_lifecycle_with_provider() {
|
||||||
|
let (app, pool) = build_test_app().await;
|
||||||
|
let admin = admin_token(&app, &pool, "keyadmin").await;
|
||||||
|
|
||||||
|
// Create provider
|
||||||
|
let (_, prov) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/providers",
|
||||||
|
&admin,
|
||||||
|
serde_json::json!({ "name": "key-prov", "display_name": "Key Prov", "base_url": "https://api.test.com/v1" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
let provider_id = prov["id"].as_str().unwrap();
|
||||||
|
|
||||||
|
// Create key as regular user
|
||||||
|
let user_token = register_token(&app, "keyowner").await;
|
||||||
|
let (status, body) = send(
|
||||||
|
&app,
|
||||||
|
post(
|
||||||
|
"/api/v1/keys",
|
||||||
|
&user_token,
|
||||||
|
serde_json::json!({ "provider_id": provider_id, "key_value": "sk-test-key-123", "key_label": "My Key" }),
|
||||||
|
),
|
||||||
|
).await;
|
||||||
|
assert_eq!(status, StatusCode::CREATED, "create key: {body}");
|
||||||
|
let key_id = body["id"].as_str().unwrap();
|
||||||
|
|
||||||
|
// List keys (paginated)
|
||||||
|
let (status, list) = send(&app, get("/api/v1/keys", &user_token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert!(list["items"].is_array(), "keys list should be paginated: {list}");
|
||||||
|
|
||||||
|
// Delete key
|
||||||
|
let (status, _) = send(&app, delete(&format!("/api/v1/keys/{key_id}"), &user_token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
// Usage stats
|
||||||
|
// ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn usage_stats_empty() {
|
||||||
|
let (app, _pool) = build_test_app().await;
|
||||||
|
let token = register_token(&app, "usageuser").await;
|
||||||
|
let (status, body) = send(&app, get("/api/v1/usage", &token)).await;
|
||||||
|
assert_eq!(status, StatusCode::OK);
|
||||||
|
assert_eq!(body["total_requests"], 0);
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user