Compare commits

...

15 Commits

Author SHA1 Message Date
iven
44256a511c feat: 增强SaaS后端功能与安全性
refactor: 重构数据库连接使用PostgreSQL替代SQLite
feat(auth): 增加JWT验证的audience和issuer检查
feat(crypto): 添加AES-256-GCM字段加密支持
feat(api): 集成utoipa实现OpenAPI文档
fix(admin): 修复配置项表单验证逻辑
style: 统一代码格式与类型定义
docs: 更新技术栈文档说明PostgreSQL
2026-03-31 00:12:53 +08:00
iven
4d8d560d1f feat(saas): 桌面端 P2 客户端补齐 — TOTP 2FA、Relay 任务、Config 同步
- saas-client: 添加 TOTP/Relay/Config 类型和 typed 方法,login 支持 totp_code
- saasStore: TOTP 感知登录 (检测 TOTP_ERROR → 两步登录),TOTP 管理动作
- SaaSLogin: TOTP 验证码输入步骤 (6 位数字,Enter 提交)
- TOTPSettings (新): 启用流程 (QR 码 + secret + 验证码),禁用 (密码确认)
- RelayTasksPanel (新): 状态过滤、任务列表、Admin 重试按钮
- SaaSSettings: 集成 TOTP 和 Relay 面板到设置页
2026-03-27 18:20:11 +08:00
iven
452ff45a5f feat(saas): P2 增强 — TOTP 2FA、Relay 重试、配置同步升级
- TOTP 2FA: totp-rs v5.7.1 + data-encoding Base32, setup/verify/disable 流程,
  登录时 TOTP 验证集成, SaasError::Totp 返回 400
- Relay 重试: 指数退避 (base_delay_ms * 2^attempt), 错误分类 (4xx 不重试),
  Admin POST /tasks/:id/retry 端点
- 配置同步: push (客户端覆盖) / merge (SaaS 优先) / diff (只读对比),
  实际写入 config_items 表
- 集成测试: 27 个测试全部通过 (新增 6 个 P2 测试)
- 文档: 更新 SaaS 平台总览 (模块完成度 + API 端点列表)
2026-03-27 17:58:14 +08:00
iven
bc12f6899a feat(saas): Phase 4 端到端完善 — 设备注册、离线支持、配置迁移、集成测试
- 后端: devices 表 + register/heartbeat/list 端点 (UPSERT 语义)
- 桌面端: 设备 ID 持久化 + 5 分钟心跳 + 离线状态指示
- saas-client: 重试逻辑 (2 次指数退避) + isServerReachable 跟踪
- ConfigMigrationWizard: 3 步向导 (方向选择→冲突解决→结果)
- SaaSSettings: 修改密码折叠面板 + 迁移向导入口
- 集成测试: 21 个测试全部通过 (含设备注册/UPSERT/心跳、密码修改、E2E 生命周期)
- 修复 ConfigMigrationWizard merge 分支变量遮蔽 bug
2026-03-27 15:07:03 +08:00
iven
8cce2283f7 fix(saas): P0 安全修复 + P1 功能补全 — 角色提升、Admin 引导、IP 记录、密码修改
P0 安全修复:
- 修复 account update 自角色提升漏洞: 非 admin 用户更新自己时剥离 role 字段
- 添加 Admin 引导机制: accounts 表为空时自动从环境变量创建 super_admin

P1 功能补全:
- 所有 17 个 log_operation 调用点传入真实客户端 IP (ConnectInfo + X-Forwarded-For)
- AuthContext 新增 client_ip 字段, middleware 层自动提取
- main.rs 使用 into_make_service_with_connect_info 启用 SocketAddr 注入
- 新增 PUT /api/v1/auth/password 密码修改端点 (验证旧密码 + argon2 哈希)
- 桌面端 SaaS 设置页添加密码修改 UI (折叠式表单)
- SaaSClient 添加 changePassword() 方法
- 集成测试修复: 注入模拟 ConnectInfo 适配 onshot 测试模式
2026-03-27 14:45:47 +08:00
iven
15450ca895 feat(saas): Phase 3 桌面端 SaaS 集成 — 客户端、Store、UI、LLM 适配器
- saas-client.ts: SaaS HTTP 客户端 (登录/注册/Token/模型列表/Chat Relay/配置同步)
- saasStore.ts: Zustand 状态管理 (登录态、连接模式、可用模型、localStorage 持久化)
- connectionStore.ts: 集成 SaaS 模式分支 (connect() 优先检查 SaaS 连接模式)
- llm-service.ts: SaasLLMAdapter 实现 (通过 SaaS Relay 代理 LLM 调用)
- SaaSLogin.tsx: 登录/注册表单 (服务器地址、用户名、密码、邮箱)
- SaaSStatus.tsx: 连接状态展示 (账号信息、健康检查、可用模型列表)
- SaaSSettings.tsx: SaaS 设置页面入口 (登录态切换、功能列表)
- SettingsLayout.tsx: 添加 SaaS 平台菜单项
- store/index.ts: 导出 useSaaSStore
2026-03-27 14:21:23 +08:00
iven
a66b675675 feat(saas): Phase 2 Admin Web 管理后台 — 完整 CRUD + Dashboard 统计
后端:
- 添加 GET /api/v1/stats/dashboard 聚合统计端点
  (账号数/活跃服务商/今日请求/今日Token用量等7项指标)
- 需要 account:admin 权限

Admin 前端 (Next.js 14 + shadcn/ui + Tailwind + Recharts):
- 设计系统: Dark Mode OLED (#020617 背景, #22C55E CTA)
- 登录页: 双栏布局, 品牌区 + 表单
- Dashboard 布局: Sidebar 导航 + Header + 主内容区
- 仪表盘: 4 统计卡片 + AreaChart 请求趋势 + BarChart Token用量
- 8 个 CRUD 页面:
  - 账号管理 (搜索/角色/状态筛选, 编辑/启用禁用)
  - 服务商 (CRUD + API Key masked)
  - 模型管理 (Provider筛选, CRUD)
  - API 密钥 (创建/撤销, 一次性显示token)
  - 用量统计 (LineChart + BarChart)
  - 中转任务 (状态筛选, 展开详情)
  - 系统配置 (分类Tab, 编辑)
  - 操作日志 (Action筛选, 展开详情)
- 14 个 shadcn 风格 UI 组件 (手写实现)
- 类型化 API 客户端 (SaaSClient, 20+ 方法, 401 自动跳转)
- AuthGuard 路由保护 + useAuth() hook

验证: tsc --noEmit 零 error, pnpm build 13 页面成功, cargo test 21 通过
2026-03-27 14:06:50 +08:00
iven
d760b9ca10 feat(saas): Phase 1 后端能力补强 — API Token 认证、真实 SSE 流式、速率限制
Phase 1.1: API Token 认证中间件
- auth_middleware 新增 zclaw_ 前缀 token 分支 (SHA-256 验证)
- 合并 token 自身权限与角色权限,异步更新 last_used_at
- 添加 GET /api/v1/auth/me 端点返回当前用户信息
- get_role_permissions 改为 pub(crate) 供中间件调用

Phase 1.2: 真实 SSE 流式中转
- RelayResponse::Sse 改为 axum::body::Body (bytes_stream)
- 流式请求超时提升至 300s,转发 SSE headers (Cache-Control, Connection)
- 添加 futures 依赖用于 StreamExt

Phase 1.3: 滑动窗口速率限制中间件
- 按 account_id 做 per-minute 限流 (默认 60 rpm + 10 burst)
- 超限返回 429 + Retry-After header
- RateLimitConfig 支持配置化,DashMap 存储时间戳

21 tests passed, zero warnings.
2026-03-27 13:49:45 +08:00
iven
a0d59b1947 fix(saas): 统一权限体系 — check_permission 辅助函数 + admin:full 超级权限
- 新增 check_permission() 统一权限检查,admin:full 自动通过所有检查
- 统一种子角色权限名称与 handler 检查一致 (provider:manage, model:manage, config:write)
- super_admin 拥有 admin:full + 所有模块管理权限
- 全部 handler 迁移到 check_permission(),消除手动 contains 检查
2026-03-27 13:12:09 +08:00
iven
900430d93e fix(saas): 修复安全审查发现的 Critical/High/Medium 问题
- Critical: 移除注册接口的 role 字段,固定为 "user" 防止权限提升
- High: 生产环境未配置 cors_origins 时拒绝启动而非默认全开放
- Medium: 增强 SSRF 防护 — 阻止 IPv6 映射地址、私有 IP 网段、十进制 IP 格式
2026-03-27 13:09:59 +08:00
iven
94bf387aee fix(saas): 安全修复 — IDOR防护、SSRF防护、JWT密钥强制、错误信息脱敏、CORS配置化
- account: admin 权限守卫 (list_accounts/get_account/update_status/list_logs)
- relay: SSRF 防护 (禁止内网地址、限制 http scheme、30s 超时)
- config: 生产环境强制 ZCLAW_SAAS_JWT_SECRET 环境变量
- error: 500 错误不再泄露内部细节给客户端
- main: CORS 支持配置白名单 origins
- 全部 21 个测试通过 (7 unit + 14 integration)
2026-03-27 13:07:20 +08:00
iven
00a08c9f9b feat(saas): Phase 4 — 配置迁移模块
- 配置项 CRUD (列表/详情/创建/更新/删除)
- 配置分析端点 (按类别汇总, SaaS 托管统计)
- 13 个默认配置项种子数据 (server/agent/memory/llm)
- 配置同步协议 (客户端→SaaS, SaaS 优先策略)
- 同步日志记录和查询
- 3 个新集成测试覆盖配置迁移端点
2026-03-27 12:58:02 +08:00
iven
a99a3df9dd feat(saas): Phase 3 — 模型请求中转服务
- OpenAI 兼容 API 代理 (/api/v1/relay/chat/completions)
- 中转任务管理 (创建/查询/状态跟踪)
- 可用模型列表端点 (仅 enabled providers+models)
- 任务生命周期 (queued → processing → completed/failed)
- 用量自动记录 (token 统计 + 错误追踪)
- 3 个新集成测试覆盖中转端点
2026-03-27 12:58:02 +08:00
iven
fec64af565 feat(saas): Phase 2 — 模型配置模块
- Provider CRUD (列表/详情/创建/更新/删除)
- Model CRUD (列表/详情/创建/更新/删除)
- Account API Key 管理 (创建/轮换/撤销/掩码显示)
- Usage 统计 (总量/按模型/按天, 支持时间/供应商/模型过滤)
- 权限控制 (provider:manage, model:manage)
- 3 个新集成测试覆盖 providers/models/keys
2026-03-27 12:58:02 +08:00
iven
a2f8112d69 feat(saas): Phase 1 — 基础框架与账号管理模块
- 新增 zclaw-saas crate 作为 workspace 成员
- 配置系统 (TOML + 环境变量覆盖)
- 错误类型体系 (SaasError 16 变体, IntoResponse)
- SQLite 数据库 (12 表 schema, 内存/文件双模式, 3 系统角色种子数据)
- JWT 认证 (签发/验证/刷新)
- Argon2id 密码哈希
- 认证中间件 (公开/受保护路由分层)
- 账号管理 CRUD + API Token 管理 + 操作日志
- 7 单元测试 + 5 集成测试全部通过
2026-03-27 12:58:01 +08:00
208 changed files with 24428 additions and 92 deletions

93
.dockerignore Normal file
View File

@@ -0,0 +1,93 @@
# ============================================================
# ZCLAW SaaS Backend - Docker Ignore
# ============================================================
# Build artifacts
target/
# Frontend applications (not needed for SaaS backend)
desktop/
admin/
design-system/
# Node.js
node_modules/
.pnpm-store/
bun.lock
pnpm-lock.yaml
package.json
package-lock.json
# Git
.git/
.gitignore
# IDE and editor
.vscode/
.idea/
*.swp
*.swo
*~
# OS files
.DS_Store
Thumbs.db
# Docker
.docker/
docker-compose*.yml
Dockerfile
.dockerignore
# Documentation
docs/
*.md
!saas-config.toml
CLAUDE.md
CLAUDE*.md
# Environment files (secrets)
.env
.env.*
saas-env.example
# Data files
saas-data/
saas-data.db
saas-data.db-shm
saas-data.db-wal
*.db
*.db-shm
*.db-wal
# Test artifacts
tests/
test-results/
test.rs
*.log
# Temporary files
tmp-screenshot.png
tmp/
temp/
*.tmp
# Claude worktree metadata
.claude/
plans/
pipelines/
scripts/
hands/
skills/
plugins/
config/
extract.js
extract_models.js
extract_privacy.js
start-all.ps1
start.ps1
start.sh
Makefile
PROGRESS.md
CHANGELOG.md
pencil-new.pen

495
Cargo.lock generated
View File

@@ -110,6 +110,18 @@ dependencies = [
"derive_arbitrary",
]
[[package]]
name = "argon2"
version = "0.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072"
dependencies = [
"base64ct",
"blake2",
"cpufeatures",
"password-hash",
]
[[package]]
name = "async-broadcast"
version = "0.7.2"
@@ -315,6 +327,7 @@ checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f"
dependencies = [
"async-trait",
"axum-core",
"axum-macros",
"bytes",
"futures-util",
"http 1.4.0",
@@ -335,7 +348,7 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper",
"tokio",
"tower",
"tower 0.5.3",
"tower-layer",
"tower-service",
"tracing",
@@ -362,6 +375,47 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum-extra"
version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c794b30c904f0a1c2fb7740f7df7f7972dfaa14ef6f57cb6178dc63e5dca2f04"
dependencies = [
"axum",
"axum-core",
"bytes",
"fastrand",
"futures-util",
"headers",
"http 1.4.0",
"http-body",
"http-body-util",
"mime",
"multer",
"pin-project-lite",
"serde",
"tower 0.5.3",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-macros"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57d123550fa8d071b7255cb0cc04dc302baa6c8c4a79f55701552684d8399bce"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "base32"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "022dfe9eb35f19ebbcb51e0b40a5ab759f46ad60cadf7297e0bd085afb50e076"
[[package]]
name = "base64"
version = "0.21.7"
@@ -410,6 +464,15 @@ dependencies = [
"serde_core",
]
[[package]]
name = "blake2"
version = "0.10.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
dependencies = [
"digest",
]
[[package]]
name = "block-buffer"
version = "0.10.4"
@@ -654,6 +717,12 @@ version = "0.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8"
[[package]]
name = "constant_time_eq"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
[[package]]
name = "convert_case"
version = "0.4.0"
@@ -896,6 +965,12 @@ dependencies = [
"parking_lot_core",
]
[[package]]
name = "data-encoding"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea"
[[package]]
name = "der"
version = "0.7.10"
@@ -1168,6 +1243,15 @@ version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ef6b89e5b37196644d8796de5268852ff179b44e96276cf4290264843743bb7"
[[package]]
name = "encoding_rs"
version = "0.8.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3"
dependencies = [
"cfg-if",
]
[[package]]
name = "endi"
version = "1.1.1"
@@ -1894,6 +1978,30 @@ dependencies = [
"hashbrown 0.14.5",
]
[[package]]
name = "headers"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3314d5adb5d94bcdf56771f2e50dbbc80bb4bdf88967526706205ac9eff24eb"
dependencies = [
"base64 0.22.1",
"bytes",
"headers-core",
"http 1.4.0",
"httpdate",
"mime",
"sha1",
]
[[package]]
name = "headers-core"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "54b4a22553d4242c49fddb9ba998a99962b5cc6f22cb5a3482bec22522403ce4"
dependencies = [
"http 1.4.0",
]
[[package]]
name = "heck"
version = "0.4.1"
@@ -2433,6 +2541,21 @@ dependencies = [
"serde_json",
]
[[package]]
name = "jsonwebtoken"
version = "9.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde"
dependencies = [
"base64 0.22.1",
"js-sys",
"pem",
"ring",
"serde",
"serde_json",
"simple_asn1",
]
[[package]]
name = "keyboard-types"
version = "0.7.0"
@@ -2625,6 +2748,15 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "matchers"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
dependencies = [
"regex-automata",
]
[[package]]
name = "matches"
version = "0.1.10"
@@ -2668,6 +2800,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@@ -2716,6 +2858,23 @@ dependencies = [
"windows-sys 0.60.2",
]
[[package]]
name = "multer"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83e87776546dc87511aa5ee218730c92b666d7264ab6ed41f9d215af9cd5224b"
dependencies = [
"bytes",
"encoding_rs",
"futures-util",
"http 1.4.0",
"httparse",
"memchr",
"mime",
"spin",
"version_check",
]
[[package]]
name = "native-tls"
version = "0.2.18"
@@ -2785,6 +2944,25 @@ dependencies = [
"minimal-lexical",
]
[[package]]
name = "nu-ansi-term"
version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-bigint-dig"
version = "0.8.6"
@@ -3120,6 +3298,17 @@ dependencies = [
"windows-link 0.2.1",
]
[[package]]
name = "password-hash"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
dependencies = [
"base64ct",
"rand_core 0.6.4",
"subtle",
]
[[package]]
name = "paste"
version = "1.0.15"
@@ -3132,6 +3321,16 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]]
name = "pem"
version = "3.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d30c53c26bc5b31a98cd02d20f25a7c8567146caf63ed593a9d87b2775291be"
dependencies = [
"base64 0.22.1",
"serde_core",
]
[[package]]
name = "pem-rfc7468"
version = "0.7.0"
@@ -3334,6 +3533,26 @@ dependencies = [
"siphasher 1.0.2",
]
[[package]]
name = "pin-project"
version = "1.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1749c7ed4bcaf4c3d0a3efc28538844fb29bcdd7d2b67b2be7e20ba861ff517"
dependencies = [
"pin-project-internal",
]
[[package]]
name = "pin-project-internal"
version = "1.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b20ed30f105399776b9c883e68e536ef602a16ae6f596d2c473591d6ad64c6"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "pin-project-lite"
version = "0.2.17"
@@ -3860,7 +4079,7 @@ dependencies = [
"tokio",
"tokio-rustls",
"tokio-util",
"tower",
"tower 0.5.3",
"tower-http 0.6.8",
"tower-service",
"url",
@@ -3895,7 +4114,7 @@ dependencies = [
"sync_wrapper",
"tokio",
"tokio-util",
"tower",
"tower 0.5.3",
"tower-http 0.6.8",
"tower-service",
"url",
@@ -3939,6 +4158,41 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rust-embed"
version = "8.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04113cb9355a377d83f06ef1f0a45b8ab8cd7d8b1288160717d66df5c7988d27"
dependencies = [
"rust-embed-impl",
"rust-embed-utils",
"walkdir",
]
[[package]]
name = "rust-embed-impl"
version = "8.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da0902e4c7c8e997159ab384e6d0fc91c221375f6894346ae107f47dd0f3ccaa"
dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"shellexpand",
"syn 2.0.117",
"walkdir",
]
[[package]]
name = "rust-embed-utils"
version = "8.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bcdef0be6fe7f6fa333b1073c949729274b05f123a0ad7efcb8efd878e5c3b1"
dependencies = [
"sha2",
"walkdir",
]
[[package]]
name = "rustc-hash"
version = "2.1.1"
@@ -4399,6 +4653,24 @@ dependencies = [
"digest",
]
[[package]]
name = "sharded-slab"
version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6"
dependencies = [
"lazy_static",
]
[[package]]
name = "shellexpand"
version = "3.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32824fab5e16e6c4d86dc1ba84489390419a39f97699852b66480bb87d297ed8"
dependencies = [
"dirs",
]
[[package]]
name = "shlex"
version = "1.3.0"
@@ -4431,6 +4703,18 @@ version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2"
[[package]]
name = "simple_asn1"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d585997b0ac10be3c5ee635f1bab02d512760d14b7c468801ac8a01d9ae5f1d"
dependencies = [
"num-bigint",
"num-traits",
"thiserror 2.0.18",
"time",
]
[[package]]
name = "siphasher"
version = "0.3.11"
@@ -4565,6 +4849,7 @@ dependencies = [
"atoi",
"byteorder",
"bytes",
"chrono",
"crc",
"crossbeam-queue",
"either",
@@ -4625,6 +4910,7 @@ dependencies = [
"sha2",
"sqlx-core",
"sqlx-mysql",
"sqlx-postgres",
"sqlx-sqlite",
"syn 1.0.109",
"tempfile",
@@ -4643,6 +4929,7 @@ dependencies = [
"bitflags 2.11.0",
"byteorder",
"bytes",
"chrono",
"crc",
"digest",
"dotenvy",
@@ -4684,6 +4971,7 @@ dependencies = [
"base64 0.21.7",
"bitflags 2.11.0",
"byteorder",
"chrono",
"crc",
"dotenvy",
"etcetera",
@@ -4719,6 +5007,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
dependencies = [
"atoi",
"chrono",
"flume",
"futures-channel",
"futures-core",
@@ -5261,6 +5550,15 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "thread_local"
version = "1.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185"
dependencies = [
"cfg-if",
]
[[package]]
name = "time"
version = "0.3.47"
@@ -5505,6 +5803,34 @@ version = "1.1.0+spec-1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d282ade6016312faf3e41e57ebbba0c073e4056dab1232ab1cb624199648f8ed"
[[package]]
name = "totp-rs"
version = "5.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2b36a9dd327e9f401320a2cb4572cc76ff43742bcfc3291f871691050f140ba"
dependencies = [
"base32",
"constant_time_eq",
"hmac",
"sha1",
"sha2",
]
[[package]]
name = "tower"
version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8fa9be0de6cf49e536ce1851f987bd21a43b771b09473c3549a6c853db37c1c"
dependencies = [
"futures-core",
"futures-util",
"pin-project",
"pin-project-lite",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower"
version = "0.5.3"
@@ -5535,6 +5861,7 @@ dependencies = [
"pin-project-lite",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
@@ -5550,7 +5877,7 @@ dependencies = [
"http-body",
"iri-string",
"pin-project-lite",
"tower",
"tower 0.5.3",
"tower-layer",
"tower-service",
]
@@ -5597,6 +5924,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a"
dependencies = [
"once_cell",
"valuable",
]
[[package]]
name = "tracing-log"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3"
dependencies = [
"log",
"once_cell",
"tracing-core",
]
[[package]]
name = "tracing-subscriber"
version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex-automata",
"sharded-slab",
"smallvec",
"thread_local",
"tracing",
"tracing-core",
"tracing-log",
]
[[package]]
@@ -5691,6 +6048,12 @@ dependencies = [
"unic-common",
]
[[package]]
name = "unicase"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
[[package]]
name = "unicode-bidi"
version = "0.3.18"
@@ -5801,6 +6164,70 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "utoipa"
version = "4.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23"
dependencies = [
"indexmap 2.13.0",
"serde",
"serde_json",
"utoipa-gen 4.3.1",
]
[[package]]
name = "utoipa"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fcc29c80c21c31608227e0912b2d7fddba57ad76b606890627ba8ee7964e993"
dependencies = [
"indexmap 2.13.0",
"serde",
"serde_json",
"utoipa-gen 5.4.0",
]
[[package]]
name = "utoipa-gen"
version = "4.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "utoipa-gen"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d79d08d92ab8af4c5e8a6da20c47ae3f61a0f1dabc1997cdf2d082b757ca08b"
dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.117",
]
[[package]]
name = "utoipa-swagger-ui"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f839caa8e09dddc3ff1c3112a91ef7da0601075ba5025d9f33ae99c4cb9b6e51"
dependencies = [
"axum",
"mime_guess",
"regex",
"rust-embed",
"serde",
"serde_json",
"utoipa 4.2.3",
"zip 0.6.6",
]
[[package]]
name = "uuid"
version = "1.22.0"
@@ -5814,6 +6241,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "valuable"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "vcpkg"
version = "0.2.15"
@@ -7034,7 +7467,7 @@ dependencies = [
"zclaw-runtime",
"zclaw-skills",
"zclaw-types",
"zip",
"zip 2.4.2",
]
[[package]]
@@ -7124,6 +7557,46 @@ dependencies = [
"zclaw-types",
]
[[package]]
name = "zclaw-saas"
version = "0.1.0"
dependencies = [
"aes-gcm",
"anyhow",
"argon2",
"async-stream",
"axum",
"axum-extra",
"bytes",
"chrono",
"dashmap",
"data-encoding",
"futures",
"hex",
"jsonwebtoken",
"rand 0.8.5",
"reqwest 0.12.28",
"secrecy",
"serde",
"serde_json",
"sha2",
"sqlx",
"tempfile",
"thiserror 2.0.18",
"tokio",
"toml 0.8.2",
"totp-rs",
"tower 0.4.13",
"tower-http 0.5.2",
"tracing",
"tracing-subscriber",
"url",
"urlencoding",
"utoipa 5.4.0",
"utoipa-swagger-ui",
"uuid",
]
[[package]]
name = "zclaw-skills"
version = "0.1.0"
@@ -7231,6 +7704,18 @@ dependencies = [
"syn 2.0.117",
]
[[package]]
name = "zip"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261"
dependencies = [
"byteorder",
"crc32fast",
"crossbeam-utils",
"flate2",
]
[[package]]
name = "zip"
version = "2.4.2"

View File

@@ -15,6 +15,8 @@ members = [
"crates/zclaw-growth",
# Desktop Application
"desktop/src-tauri",
# SaaS Backend
"crates/zclaw-saas",
]
[workspace.package]
@@ -55,7 +57,7 @@ chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1", features = ["v4", "v5", "serde"] }
# Database
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] }
sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "chrono"] }
libsqlite3-sys = { version = "0.27", features = ["bundled"] }
# HTTP client (for LLM drivers)
@@ -95,6 +97,16 @@ shlex = "1"
# Testing
tempfile = "3"
# SaaS dependencies
axum = { version = "0.7", features = ["macros"] }
axum-extra = { version = "0.9", features = ["typed-header"] }
tower = { version = "0.4", features = ["util"] }
tower-http = { version = "0.5", features = ["cors", "trace", "limit"] }
jsonwebtoken = "9"
argon2 = "0.5"
totp-rs = "5"
hex = "0.4"
# Internal crates
zclaw-types = { path = "crates/zclaw-types" }
zclaw-memory = { path = "crates/zclaw-memory" }
@@ -106,6 +118,7 @@ zclaw-channels = { path = "crates/zclaw-channels" }
zclaw-protocols = { path = "crates/zclaw-protocols" }
zclaw-pipeline = { path = "crates/zclaw-pipeline" }
zclaw-growth = { path = "crates/zclaw-growth" }
zclaw-saas = { path = "crates/zclaw-saas" }
[profile.release]
lto = true

83
Dockerfile Normal file
View File

@@ -0,0 +1,83 @@
# ============================================================
# ZCLAW SaaS Backend - Multi-stage Docker Build
# ============================================================
# ---- Stage 1: Builder ----
FROM rust:1.75-bookworm AS builder
# Install build dependencies for sqlx (postgres) and libsqlite3-sys (bundled)
RUN apt-get update && apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# Copy workspace manifests first to leverage Docker layer caching
COPY Cargo.toml Cargo.lock ./
# Create stub source files so cargo can resolve and cache dependencies
# This avoids rebuilding dependencies when only application code changes
RUN mkdir -p crates/zclaw-saas/src \
&& echo 'fn main() {}' > crates/zclaw-saas/src/main.rs \
&& for crate in zclaw-types zclaw-memory zclaw-runtime zclaw-kernel \
zclaw-skills zclaw-hands zclaw-channels zclaw-protocols \
zclaw-pipeline zclaw-growth; do \
mkdir -p crates/$crate/src && echo '' > crates/$crate/src/lib.rs; \
done \
&& mkdir -p desktop/src-tauri/src && echo 'fn main() {}' > desktop/src-tauri/src/main.rs
# Pre-build dependencies (release profile with caching)
RUN cargo build --release --package zclaw-saas 2>/dev/null || true
# Copy actual source code (invalidates stubs, triggers recompile of app code only)
COPY crates/ crates/
COPY desktop/ desktop/
# Touch source files to invalidate the stub timestamps
RUN touch crates/zclaw-saas/src/main.rs \
&& for crate in zclaw-types zclaw-memory zclaw-runtime zclaw-kernel \
zclaw-skills zclaw-hands zclaw-channels zclaw-protocols \
zclaw-pipeline zclaw-growth; do \
touch crates/$crate/src/lib.rs 2>/dev/null || true; \
done \
&& touch desktop/src-tauri/src/main.rs 2>/dev/null || true
# Build the actual binary
RUN cargo build --release --package zclaw-saas
# ---- Stage 2: Runtime ----
FROM debian:bookworm-slim AS runtime
# Install runtime dependencies (ca-certificates for HTTPS, libgcc for Rust runtime)
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
libgcc-s \
&& rm -rf /var/lib/apt/lists/* \
&& update-ca-certificates
# Create non-root user for security
RUN groupadd --gid 1000 zclaw \
&& useradd --uid 1000 --gid zclaw --shell /bin/false zclaw
WORKDIR /app
# Copy binary from builder
COPY --from=builder /app/target/release/zclaw-saas /app/zclaw-saas
# Copy configuration file
COPY saas-config.toml /app/saas-config.toml
# Ensure the non-root user owns the application files
RUN chown -R zclaw:zclaw /app
USER zclaw
# Expose the SaaS API port
EXPOSE 8080
# Health check endpoint (matches the saas-config.toml port)
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
CMD ["/app/zclaw-saas", "--healthcheck"] || exit 1
ENTRYPOINT ["/app/zclaw-saas"]

View File

@@ -1,7 +1,9 @@
# ZCLAW Makefile
# Cross-platform task runner
.PHONY: help start start-dev start-no-browser desktop desktop-build setup test clean
.PHONY: help start start-dev start-no-browser desktop desktop-build setup test clean \
saas-build saas-run saas-test saas-test-integration saas-clippy saas-migrate \
saas-docker-up saas-docker-down saas-docker-build
help: ## Show this help message
@echo "ZCLAW - AI Agent Desktop Client"
@@ -71,3 +73,32 @@ clean-deep: clean ## Deep clean (including pnpm cache)
@rm -rf desktop/pnpm-lock.yaml
@rm -rf pnpm-lock.yaml
@echo "Deep clean complete. Run 'pnpm install' to reinstall."
# === SaaS Backend ===
saas-build: ## Build zclaw-saas crate
@cargo build -p zclaw-saas
saas-run: ## Start SaaS backend (cargo run)
@cargo run -p zclaw-saas
saas-test: ## Run SaaS unit tests
@cargo test -p zclaw-saas -- --test-threads=1
saas-test-integration: ## Run SaaS integration tests (requires PostgreSQL)
@cargo test -p zclaw-saas -- --ignored --test-threads=1
saas-clippy: ## Run clippy on zclaw-saas
@cargo clippy -p zclaw-saas -- -D warnings
saas-migrate: ## Run database migrations
@cargo run -p zclaw-saas -- --migrate
saas-docker-up: ## Start SaaS services (PostgreSQL + backend)
@docker compose up -d
saas-docker-down: ## Stop SaaS services
@docker compose down
saas-docker-build: ## Build SaaS Docker images
@docker compose build

4
admin/.gitignore vendored Normal file
View File

@@ -0,0 +1,4 @@
.next/
node_modules/
.env.local
.env*.local

5
admin/next-env.d.ts vendored Normal file
View File

@@ -0,0 +1,5 @@
/// <reference types="next" />
/// <reference types="next/image-types/global" />
// NOTE: This file should not be edited
// see https://nextjs.org/docs/app/building-your-application/configuring/typescript for more information.

44
admin/next.config.js Normal file
View File

@@ -0,0 +1,44 @@
/** @type {import('next').NextConfig} */
const nextConfig = {
async headers() {
return [
{
source: '/(.*)',
headers: [
{
key: 'X-Frame-Options',
value: 'DENY',
},
{
key: 'X-Content-Type-Options',
value: 'nosniff',
},
{
key: 'Referrer-Policy',
value: 'strict-origin-when-cross-origin',
},
{
key: 'Content-Security-Policy',
value: [
"default-src 'self'",
"script-src 'self' 'unsafe-eval' 'unsafe-inline'",
"style-src 'self' 'unsafe-inline' https://fonts.googleapis.com",
"font-src 'self' https://fonts.gstatic.com",
"img-src 'self' data: blob:",
"connect-src 'self'",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
].join('; '),
},
{
key: 'Permissions-Policy',
value: 'camera=(), microphone=(), geolocation=()',
},
],
},
]
},
}
module.exports = nextConfig

38
admin/package.json Normal file
View File

@@ -0,0 +1,38 @@
{
"name": "zclaw-admin",
"version": "0.1.0",
"private": true,
"scripts": {
"dev": "next dev",
"build": "next build",
"start": "next start",
"lint": "next lint"
},
"dependencies": {
"@radix-ui/react-dialog": "^1.1.14",
"@radix-ui/react-select": "^2.2.5",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-switch": "^1.2.5",
"@radix-ui/react-tabs": "^1.1.12",
"@radix-ui/react-tooltip": "^1.2.7",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"lucide-react": "^0.484.0",
"next": "14.2.29",
"react": "^18.3.1",
"react-dom": "^18.3.1",
"recharts": "^2.15.3",
"sonner": "^2.0.7",
"tailwind-merge": "^3.0.2"
},
"devDependencies": {
"@types/node": "^20.17.19",
"@types/react": "^18.3.18",
"@types/react-dom": "^18.3.5",
"autoprefixer": "^10.4.20",
"postcss": "^8.5.3",
"tailwindcss": "^3.4.17",
"typescript": "^5.7.3"
},
"packageManager": "pnpm@10.30.2"
}

2185
admin/pnpm-lock.yaml generated Normal file

File diff suppressed because it is too large Load Diff

6
admin/postcss.config.js Normal file
View File

@@ -0,0 +1,6 @@
module.exports = {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
}

View File

@@ -0,0 +1,400 @@
'use client'
import { useEffect, useState, useCallback } from 'react'
import {
Search,
Plus,
Loader2,
ChevronLeft,
ChevronRight,
Pencil,
Ban,
CheckCircle2,
} from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Badge } from '@/components/ui/badge'
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from '@/components/ui/table'
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogFooter,
DialogDescription,
} from '@/components/ui/dialog'
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
import { formatDate } from '@/lib/utils'
import type { AccountPublic } from '@/lib/types'
const PAGE_SIZE = 20
const roleLabels: Record<string, string> = {
super_admin: '超级管理员',
admin: '管理员',
user: '普通用户',
}
const statusColors: Record<string, 'success' | 'destructive' | 'warning'> = {
active: 'success',
disabled: 'destructive',
suspended: 'warning',
}
const statusLabels: Record<string, string> = {
active: '正常',
disabled: '已禁用',
suspended: '已暂停',
}
export default function AccountsPage() {
const [accounts, setAccounts] = useState<AccountPublic[]>([])
const [total, setTotal] = useState(0)
const [page, setPage] = useState(1)
const [search, setSearch] = useState('')
// 搜索 debounce: 输入后 300ms 再触发请求
const [debouncedSearchState, setDebouncedSearchState] = useState('')
useEffect(() => {
const timer = setTimeout(() => setDebouncedSearchState(search), 300)
return () => clearTimeout(timer)
}, [search])
const [roleFilter, setRoleFilter] = useState<string>('all')
const [statusFilter, setStatusFilter] = useState<string>('all')
const [loading, setLoading] = useState(true)
const [error, setError] = useState('')
// 编辑 Dialog
const [editTarget, setEditTarget] = useState<AccountPublic | null>(null)
const [editForm, setEditForm] = useState({ display_name: '', email: '', role: 'user' })
const [editSaving, setEditSaving] = useState(false)
// 确认 Dialog
const [confirmTarget, setConfirmTarget] = useState<{ id: string; action: string; status: string } | null>(null)
const [confirmSaving, setConfirmSaving] = useState(false)
const fetchAccounts = useCallback(async () => {
setLoading(true)
setError('')
try {
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
if (debouncedSearchState.trim()) params.search = debouncedSearchState.trim()
if (roleFilter !== 'all') params.role = roleFilter
if (statusFilter !== 'all') params.status = statusFilter
const res = await api.accounts.list(params)
setAccounts(res.items)
setTotal(res.total)
} catch (err) {
if (err instanceof ApiRequestError) {
setError(err.body.message)
} else {
setError('加载失败')
}
} finally {
setLoading(false)
}
}, [page, debouncedSearchState, roleFilter, statusFilter])
useEffect(() => {
fetchAccounts()
}, [fetchAccounts])
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
function openEditDialog(account: AccountPublic) {
setEditTarget(account)
setEditForm({
display_name: account.display_name,
email: account.email,
role: account.role,
})
}
async function handleEditSave() {
if (!editTarget) return
setEditSaving(true)
try {
await api.accounts.update(editTarget.id, {
display_name: editForm.display_name,
email: editForm.email,
role: editForm.role as AccountPublic['role'],
})
setEditTarget(null)
fetchAccounts()
} catch (err) {
if (err instanceof ApiRequestError) {
setError(err.body.message)
}
} finally {
setEditSaving(false)
}
}
function openConfirmDialog(account: AccountPublic) {
const newStatus = account.status === 'active' ? 'disabled' : 'active'
setConfirmTarget({
id: account.id,
action: newStatus === 'disabled' ? '禁用' : '启用',
status: newStatus,
})
}
async function handleConfirmSave() {
if (!confirmTarget) return
setConfirmSaving(true)
try {
await api.accounts.updateStatus(confirmTarget.id, {
status: confirmTarget.status as AccountPublic['status'],
})
setConfirmTarget(null)
fetchAccounts()
} catch (err) {
if (err instanceof ApiRequestError) {
setError(err.body.message)
}
} finally {
setConfirmSaving(false)
}
}
return (
<div className="space-y-4">
{/* 搜索和筛选 */}
<div className="flex flex-wrap items-center gap-3">
<div className="relative flex-1 min-w-[200px] max-w-sm">
<Search className="absolute left-3 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<Input
placeholder="搜索用户名 / 邮箱 / 显示名..."
value={search}
onChange={(e) => { setSearch(e.target.value); setPage(1) }}
className="pl-10"
/>
</div>
<Select value={roleFilter} onValueChange={(v) => { setRoleFilter(v); setPage(1) }}>
<SelectTrigger className="w-[140px]">
<SelectValue placeholder="角色筛选" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all"></SelectItem>
<SelectItem value="super_admin"></SelectItem>
<SelectItem value="admin"></SelectItem>
<SelectItem value="user"></SelectItem>
</SelectContent>
</Select>
<Select value={statusFilter} onValueChange={(v) => { setStatusFilter(v); setPage(1) }}>
<SelectTrigger className="w-[140px]">
<SelectValue placeholder="状态筛选" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all"></SelectItem>
<SelectItem value="active"></SelectItem>
<SelectItem value="disabled"></SelectItem>
<SelectItem value="suspended"></SelectItem>
</SelectContent>
</Select>
</div>
{/* 错误提示 */}
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">
</button>
</div>
)}
{/* 表格 */}
{loading ? (
<div className="flex h-64 items-center justify-center">
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : accounts.length === 0 ? (
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
</div>
) : (
<>
<Table>
<TableHeader>
<TableRow>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead className="text-right"></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{accounts.map((account) => (
<TableRow key={account.id}>
<TableCell className="font-medium">{account.username}</TableCell>
<TableCell className="text-muted-foreground">{account.email}</TableCell>
<TableCell>{account.display_name || '-'}</TableCell>
<TableCell>
<Badge variant={account.role === 'super_admin' ? 'default' : account.role === 'admin' ? 'info' : 'secondary'}>
{roleLabels[account.role] || account.role}
</Badge>
</TableCell>
<TableCell>
<Badge variant={statusColors[account.status] || 'secondary'}>
<span className="mr-1 inline-block h-1.5 w-1.5 rounded-full bg-current" />
{statusLabels[account.status] || account.status}
</Badge>
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{formatDate(account.created_at)}
</TableCell>
<TableCell className="text-right">
<div className="flex items-center justify-end gap-1">
<Button
variant="ghost"
size="icon"
onClick={() => openEditDialog(account)}
title="编辑"
>
<Pencil className="h-4 w-4" />
</Button>
<Button
variant="ghost"
size="icon"
onClick={() => openConfirmDialog(account)}
title={account.status === 'active' ? '禁用' : '启用'}
>
{account.status === 'active' ? (
<Ban className="h-4 w-4 text-destructive" />
) : (
<CheckCircle2 className="h-4 w-4 text-green-400" />
)}
</Button>
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
{/* 分页 */}
<div className="flex items-center justify-between text-sm">
<p className="text-muted-foreground">
{page} / {totalPages} ({total} )
</p>
<div className="flex items-center gap-2">
<Button
variant="outline"
size="sm"
disabled={page <= 1}
onClick={() => setPage(page - 1)}
>
<ChevronLeft className="h-4 w-4 mr-1" />
</Button>
<Button
variant="outline"
size="sm"
disabled={page >= totalPages}
onClick={() => setPage(page + 1)}
>
<ChevronRight className="h-4 w-4 ml-1" />
</Button>
</div>
</div>
</>
)}
{/* 编辑 Dialog */}
<Dialog open={!!editTarget} onOpenChange={() => setEditTarget(null)}>
<DialogContent>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription></DialogDescription>
</DialogHeader>
<div className="space-y-4">
<div className="space-y-2">
<Label></Label>
<Input
value={editForm.display_name}
onChange={(e) => setEditForm({ ...editForm, display_name: e.target.value })}
/>
</div>
<div className="space-y-2">
<Label></Label>
<Input
type="email"
value={editForm.email}
onChange={(e) => setEditForm({ ...editForm, email: e.target.value })}
/>
</div>
<div className="space-y-2">
<Label></Label>
<Select value={editForm.role} onValueChange={(v) => setEditForm({ ...editForm, role: v })}>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="user"></SelectItem>
<SelectItem value="admin"></SelectItem>
<SelectItem value="super_admin"></SelectItem>
</SelectContent>
</Select>
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setEditTarget(null)}>
</Button>
<Button onClick={handleEditSave} disabled={editSaving}>
{editSaving && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* 确认 Dialog */}
<Dialog open={!!confirmTarget} onOpenChange={() => setConfirmTarget(null)}>
<DialogContent>
<DialogHeader>
<DialogTitle>{confirmTarget?.action}</DialogTitle>
<DialogDescription>
{confirmTarget?.action}
</DialogDescription>
</DialogHeader>
<DialogFooter>
<Button variant="outline" onClick={() => setConfirmTarget(null)}>
</Button>
<Button
variant={confirmTarget?.status === 'disabled' ? 'destructive' : 'default'}
onClick={handleConfirmSave}
disabled={confirmSaving}
>
{confirmSaving && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
{confirmTarget?.action}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
)
}

View File

@@ -0,0 +1,351 @@
'use client'
import { useEffect, useState, useCallback } from 'react'
import {
Plus,
Loader2,
ChevronLeft,
ChevronRight,
Trash2,
Copy,
Check,
AlertTriangle,
} from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Badge } from '@/components/ui/badge'
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from '@/components/ui/table'
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogFooter,
DialogDescription,
} from '@/components/ui/dialog'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
import { formatDate } from '@/lib/utils'
import type { TokenInfo } from '@/lib/types'
const PAGE_SIZE = 20
const allPermissions = [
{ key: 'chat', label: '对话' },
{ key: 'relay', label: '中转' },
{ key: 'admin', label: '管理' },
]
export default function ApiKeysPage() {
const [tokens, setTokens] = useState<TokenInfo[]>([])
const [total, setTotal] = useState(0)
const [page, setPage] = useState(1)
const [loading, setLoading] = useState(true)
const [error, setError] = useState('')
// 创建 Dialog
const [createOpen, setCreateOpen] = useState(false)
const [createForm, setCreateForm] = useState({ name: '', expires_days: '', permissions: ['chat'] as string[] })
const [creating, setCreating] = useState(false)
// 创建成功显示 token
const [createdToken, setCreatedToken] = useState<TokenInfo | null>(null)
const [copied, setCopied] = useState(false)
// 撤销确认
const [revokeTarget, setRevokeTarget] = useState<TokenInfo | null>(null)
const [revoking, setRevoking] = useState(false)
const fetchTokens = useCallback(async () => {
setLoading(true)
setError('')
try {
const res = await api.tokens.list({ page, page_size: PAGE_SIZE })
setTokens(res.items)
setTotal(res.total)
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
else setError('加载失败')
} finally {
setLoading(false)
}
}, [page])
useEffect(() => {
fetchTokens()
}, [fetchTokens])
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
function togglePermission(perm: string) {
setCreateForm((prev) => ({
...prev,
permissions: prev.permissions.includes(perm)
? prev.permissions.filter((p) => p !== perm)
: [...prev.permissions, perm],
}))
}
async function handleCreate() {
if (!createForm.name.trim() || createForm.permissions.length === 0) return
setCreating(true)
try {
const payload = {
name: createForm.name.trim(),
expires_days: createForm.expires_days ? parseInt(createForm.expires_days, 10) : undefined,
permissions: createForm.permissions,
}
const res = await api.tokens.create(payload)
setCreateOpen(false)
setCreatedToken(res)
setCreateForm({ name: '', expires_days: '', permissions: ['chat'] })
fetchTokens()
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
} finally {
setCreating(false)
}
}
async function handleRevoke() {
if (!revokeTarget) return
setRevoking(true)
try {
await api.tokens.revoke(revokeTarget.id)
setRevokeTarget(null)
fetchTokens()
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
} finally {
setRevoking(false)
}
}
async function copyToken() {
if (!createdToken?.token) return
try {
await navigator.clipboard.writeText(createdToken.token)
setCopied(true)
setTimeout(() => setCopied(false), 2000)
} catch {
// Fallback
const textarea = document.createElement('textarea')
textarea.value = createdToken.token
document.body.appendChild(textarea)
textarea.select()
document.execCommand('copy')
document.body.removeChild(textarea)
setCopied(true)
setTimeout(() => setCopied(false), 2000)
}
}
return (
<div className="space-y-4">
<div className="flex items-center justify-between">
<div />
<Button onClick={() => setCreateOpen(true)}>
<Plus className="h-4 w-4 mr-2" />
</Button>
</div>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer"></button>
</div>
)}
{loading ? (
<div className="flex h-64 items-center justify-center">
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : tokens.length === 0 ? (
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
</div>
) : (
<>
<Table>
<TableHeader>
<TableRow>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead>使</TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead className="text-right"></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{tokens.map((t) => (
<TableRow key={t.id}>
<TableCell className="font-medium">{t.name}</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{t.token_prefix}...
</TableCell>
<TableCell>
<div className="flex gap-1">
{t.permissions.map((p) => (
<Badge key={p} variant="outline" className="text-xs">
{p}
</Badge>
))}
</div>
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{t.last_used_at ? formatDate(t.last_used_at) : '未使用'}
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{t.expires_at ? formatDate(t.expires_at) : '永不过期'}
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{formatDate(t.created_at)}
</TableCell>
<TableCell className="text-right">
<Button variant="ghost" size="icon" onClick={() => setRevokeTarget(t)} title="撤销">
<Trash2 className="h-4 w-4 text-destructive" />
</Button>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
<div className="flex items-center justify-between text-sm">
<p className="text-muted-foreground">
{page} / {totalPages} ({total} )
</p>
<div className="flex items-center gap-2">
<Button variant="outline" size="sm" disabled={page <= 1} onClick={() => setPage(page - 1)}>
<ChevronLeft className="h-4 w-4 mr-1" />
</Button>
<Button variant="outline" size="sm" disabled={page >= totalPages} onClick={() => setPage(page + 1)}>
<ChevronRight className="h-4 w-4 ml-1" />
</Button>
</div>
</div>
</>
)}
{/* 创建 Dialog */}
<Dialog open={createOpen} onOpenChange={setCreateOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle> API </DialogTitle>
<DialogDescription> API </DialogDescription>
</DialogHeader>
<div className="space-y-4">
<div className="space-y-2">
<Label> *</Label>
<Input
value={createForm.name}
onChange={(e) => setCreateForm({ ...createForm, name: e.target.value })}
placeholder="例如: 生产环境"
/>
</div>
<div className="space-y-2">
<Label> ()</Label>
<Input
type="number"
value={createForm.expires_days}
onChange={(e) => setCreateForm({ ...createForm, expires_days: e.target.value })}
placeholder="365"
/>
</div>
<div className="space-y-2">
<Label> *</Label>
<div className="flex flex-wrap gap-3 mt-1">
{allPermissions.map((perm) => (
<label
key={perm.key}
className="flex items-center gap-2 cursor-pointer"
>
<input
type="checkbox"
checked={createForm.permissions.includes(perm.key)}
onChange={() => togglePermission(perm.key)}
className="h-4 w-4 rounded border-input bg-transparent accent-primary cursor-pointer"
/>
<span className="text-sm text-foreground">{perm.label}</span>
</label>
))}
</div>
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setCreateOpen(false)}></Button>
<Button onClick={handleCreate} disabled={creating || !createForm.name.trim() || createForm.permissions.length === 0}>
{creating && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* 创建成功 Dialog */}
<Dialog open={!!createdToken} onOpenChange={() => setCreatedToken(null)}>
<DialogContent>
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<AlertTriangle className="h-5 w-5 text-yellow-400" />
</DialogTitle>
<DialogDescription>
</DialogDescription>
</DialogHeader>
<div className="space-y-4">
<div className="rounded-md bg-muted p-4">
<p className="text-xs text-muted-foreground mb-2"></p>
<p className="font-mono text-sm break-all text-foreground">
{createdToken?.token}
</p>
</div>
<div className="rounded-md bg-yellow-500/10 border border-yellow-500/20 p-3 text-sm text-yellow-400">
</div>
</div>
<DialogFooter>
<Button onClick={copyToken} variant="outline">
{copied ? <Check className="h-4 w-4 mr-2" /> : <Copy className="h-4 w-4 mr-2" />}
{copied ? '已复制' : '复制密钥'}
</Button>
<Button onClick={() => setCreatedToken(null)}></Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* 撤销确认 */}
<Dialog open={!!revokeTarget} onOpenChange={() => setRevokeTarget(null)}>
<DialogContent>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription>
&quot;{revokeTarget?.name}&quot; 使访
</DialogDescription>
</DialogHeader>
<DialogFooter>
<Button variant="outline" onClick={() => setRevokeTarget(null)}></Button>
<Button variant="destructive" onClick={handleRevoke} disabled={revoking}>
{revoking && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
)
}

View File

@@ -0,0 +1,283 @@
'use client'
import { useEffect, useState, useCallback } from 'react'
import {
Loader2,
Pencil,
RotateCcw,
} from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Badge } from '@/components/ui/badge'
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select'
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from '@/components/ui/table'
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogFooter,
DialogDescription,
} from '@/components/ui/dialog'
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
import type { ConfigItem } from '@/lib/types'
const sourceLabels: Record<string, string> = {
default: '默认值',
env: '环境变量',
db: '数据库',
}
const sourceVariants: Record<string, 'secondary' | 'info' | 'default'> = {
default: 'secondary',
env: 'info',
db: 'default',
}
export default function ConfigPage() {
const [configs, setConfigs] = useState<ConfigItem[]>([])
const [loading, setLoading] = useState(true)
const [error, setError] = useState('')
const [activeTab, setActiveTab] = useState('all')
// 编辑 Dialog
const [editTarget, setEditTarget] = useState<ConfigItem | null>(null)
const [editValue, setEditValue] = useState('')
const [saving, setSaving] = useState(false)
const fetchConfigs = useCallback(async (category?: string) => {
setLoading(true)
setError('')
try {
const params: Record<string, unknown> = {}
if (category && category !== 'all') params.category = category
const res = await api.config.list(params)
setConfigs(res)
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
else setError('加载失败')
} finally {
setLoading(false)
}
}, [])
useEffect(() => {
fetchConfigs(activeTab)
}, [fetchConfigs, activeTab])
function openEditDialog(config: ConfigItem) {
setEditTarget(config)
setEditValue(config.current_value !== undefined ? String(config.current_value) : '')
}
async function handleSave() {
if (!editTarget) return
// 表单验证
if (editValue.trim() === '') {
setError('配置值不能为空')
return
}
if (editTarget.value_type === 'number' && isNaN(Number(editValue))) {
setError('请输入有效的数字')
return
}
if (editTarget.value_type === 'boolean' && editValue !== 'true' && editValue !== 'false') {
setError('布尔值只能为 true 或 false')
return
}
setSaving(true)
try {
let parsedValue: string | number | boolean = editValue
if (editTarget.value_type === 'number') {
parsedValue = parseFloat(editValue) || 0
} else if (editTarget.value_type === 'boolean') {
parsedValue = editValue === 'true'
}
await api.config.update(editTarget.id, { current_value: parsedValue })
setEditTarget(null)
fetchConfigs(activeTab)
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
} finally {
setSaving(false)
}
}
function formatValue(value: unknown): string {
if (value === undefined || value === null) return '-'
if (typeof value === 'boolean') return value ? 'true' : 'false'
return String(value)
}
const categories = ['all', 'auth', 'relay', 'model', 'system']
return (
<div className="space-y-4">
{/* 分类 Tabs */}
<Tabs value={activeTab} onValueChange={setActiveTab}>
<TabsList>
{categories.map((cat) => (
<TabsTrigger key={cat} value={cat}>
{cat === 'all' ? '全部' : cat}
</TabsTrigger>
))}
</TabsList>
</Tabs>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer"></button>
</div>
)}
{loading ? (
<div className="flex h-64 items-center justify-center">
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : configs.length === 0 ? (
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
</div>
) : (
<Table>
<TableHeader>
<TableRow>
<TableHead></TableHead>
<TableHead>Key</TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead className="text-right"></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{configs.map((config) => (
<TableRow key={config.id}>
<TableCell>
<Badge variant="outline">{config.category}</Badge>
</TableCell>
<TableCell className="font-mono text-sm">{config.key_path}</TableCell>
<TableCell className="font-mono text-sm max-w-[200px] truncate">
{formatValue(config.current_value)}
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground max-w-[200px] truncate">
{formatValue(config.default_value)}
</TableCell>
<TableCell>
<Badge variant={sourceVariants[config.source] || 'secondary'}>
{sourceLabels[config.source] || config.source}
</Badge>
</TableCell>
<TableCell>
{config.requires_restart ? (
<Badge variant="warning"></Badge>
) : (
<span className="text-muted-foreground"></span>
)}
</TableCell>
<TableCell className="text-sm text-muted-foreground max-w-[250px] truncate">
{config.description || '-'}
</TableCell>
<TableCell className="text-right">
<Button variant="ghost" size="icon" onClick={() => openEditDialog(config)} title="编辑">
<Pencil className="h-4 w-4" />
</Button>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
)}
{/* 编辑 Dialog */}
<Dialog open={!!editTarget} onOpenChange={() => setEditTarget(null)}>
<DialogContent>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription>
{editTarget?.key_path}
{editTarget?.requires_restart && (
<span className="block mt-1 text-yellow-400 text-xs">
注意: 修改此配置需要重启服务才能生效
</span>
)}
</DialogDescription>
</DialogHeader>
<div className="space-y-4">
<div className="space-y-2">
<Label>Key</Label>
<Input value={editTarget?.key_path || ''} disabled />
</div>
<div className="space-y-2">
<Label></Label>
<Input value={editTarget?.value_type || ''} disabled />
</div>
<div className="space-y-2">
<Label>
{editTarget?.default_value !== undefined && (
<span className="text-xs text-muted-foreground ml-2">
(: {formatValue(editTarget.default_value)})
</span>
)}
</Label>
{editTarget?.value_type === 'boolean' ? (
<Select value={editValue} onValueChange={setEditValue}>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="true">true</SelectItem>
<SelectItem value="false">false</SelectItem>
</SelectContent>
</Select>
) : (
<Input
type={editTarget?.value_type === 'number' ? 'number' : 'text'}
value={editValue}
onChange={(e) => setEditValue(e.target.value)}
/>
)}
</div>
</div>
<DialogFooter>
<Button
variant="outline"
onClick={() => {
if (editTarget?.default_value !== undefined) {
setEditValue(String(editTarget.default_value))
}
}}
>
<RotateCcw className="h-4 w-4 mr-2" />
</Button>
<Button variant="outline" onClick={() => setEditTarget(null)}></Button>
<Button onClick={handleSave} disabled={saving}>
{saving && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
)
}

View File

@@ -0,0 +1,125 @@
'use client'
import { useEffect, useState } from 'react'
import { Monitor, Loader2, RefreshCw } from 'lucide-react'
import { Badge } from '@/components/ui/badge'
import {
Table, TableBody, TableCell, TableHead, TableHeader, TableRow,
} from '@/components/ui/table'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
import type { DeviceInfo } from '@/lib/types'
function formatRelativeTime(dateStr: string): string {
const now = Date.now()
const then = new Date(dateStr).getTime()
const diffMs = now - then
const diffMin = Math.floor(diffMs / 60000)
const diffHour = Math.floor(diffMs / 3600000)
const diffDay = Math.floor(diffMs / 86400000)
if (diffMin < 1) return '刚刚'
if (diffMin < 60) return `${diffMin} 分钟前`
if (diffHour < 24) return `${diffHour} 小时前`
return `${diffDay} 天前`
}
function isOnline(lastSeen: string): boolean {
return Date.now() - new Date(lastSeen).getTime() < 5 * 60 * 1000
}
export default function DevicesPage() {
const [devices, setDevices] = useState<DeviceInfo[]>([])
const [loading, setLoading] = useState(true)
const [error, setError] = useState('')
async function fetchDevices() {
setLoading(true)
setError('')
try {
const res = await api.devices.list()
setDevices(res)
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
else setError('加载失败')
} finally {
setLoading(false)
}
}
useEffect(() => { fetchDevices() }, [])
return (
<div className="space-y-4">
<div className="flex items-center justify-between">
<h2 className="text-lg font-semibold text-foreground"></h2>
<button
onClick={fetchDevices}
disabled={loading}
className="flex items-center gap-2 rounded-md border border-border px-3 py-1.5 text-sm text-muted-foreground hover:bg-muted hover:text-foreground transition-colors cursor-pointer disabled:opacity-50"
>
<RefreshCw className={`h-4 w-4 ${loading ? 'animate-spin' : ''}`} />
</button>
</div>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
</div>
)}
{loading && !devices.length ? (
<div className="flex items-center justify-center py-12">
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : devices.length === 0 ? (
<div className="flex flex-col items-center justify-center py-12 text-muted-foreground">
<Monitor className="h-10 w-10 mb-3" />
<p className="text-sm"></p>
</div>
) : (
<div className="rounded-md border border-border">
<Table>
<TableHeader>
<TableRow>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{devices.map((d) => (
<TableRow key={d.id}>
<TableCell className="font-medium">
{d.device_name || d.device_id}
</TableCell>
<TableCell>
<Badge variant="secondary">{d.platform || 'unknown'}</Badge>
</TableCell>
<TableCell className="text-muted-foreground">
{d.app_version || '-'}
</TableCell>
<TableCell>
<Badge variant={isOnline(d.last_seen_at) ? 'success' : 'outline'}>
{isOnline(d.last_seen_at) ? '在线' : '离线'}
</Badge>
</TableCell>
<TableCell className="text-muted-foreground">
{formatRelativeTime(d.last_seen_at)}
</TableCell>
<TableCell className="text-muted-foreground text-xs">
{new Date(d.created_at).toLocaleString('zh-CN')}
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
)}
</div>
)
}

View File

@@ -0,0 +1,305 @@
'use client'
import { useState, useEffect, type ReactNode } from 'react'
import Link from 'next/link'
import { usePathname, useRouter } from 'next/navigation'
import {
LayoutDashboard,
Users,
Server,
Cpu,
Key,
BarChart3,
ArrowLeftRight,
Settings,
FileText,
LogOut,
ChevronLeft,
Menu,
Bell,
UserCog,
ShieldCheck,
Monitor,
} from 'lucide-react'
import { AuthGuard, useAuth } from '@/components/auth-guard'
import { logout } from '@/lib/auth'
import { cn } from '@/lib/utils'
const navItems = [
{ href: '/', label: '仪表盘', icon: LayoutDashboard, permission: null },
{ href: '/accounts', label: '账号管理', icon: Users, permission: 'account:admin' },
{ href: '/providers', label: '服务商', icon: Server, permission: 'model:admin' },
{ href: '/models', label: '模型管理', icon: Cpu, permission: 'model:admin' },
{ href: '/api-keys', label: 'API 密钥', icon: Key, permission: null },
{ href: '/usage', label: '用量统计', icon: BarChart3, permission: null },
{ href: '/relay', label: '中转任务', icon: ArrowLeftRight, permission: 'relay:admin' },
{ href: '/config', label: '系统配置', icon: Settings, permission: 'admin:full' },
{ href: '/logs', label: '操作日志', icon: FileText, permission: 'admin:full' },
{ href: '/profile', label: '个人设置', icon: UserCog, permission: null },
{ href: '/security', label: '安全设置', icon: ShieldCheck, permission: null },
{ href: '/devices', label: '设备管理', icon: Monitor, permission: null },
]
function Sidebar({
collapsed,
onToggle,
mobileOpen,
onMobileClose,
}: {
collapsed: boolean
onToggle: () => void
mobileOpen: boolean
onMobileClose: () => void
}) {
const pathname = usePathname()
const router = useRouter()
const { account } = useAuth()
// 路由变化时关闭移动端菜单
useEffect(() => {
onMobileClose()
}, [pathname, onMobileClose])
function handleLogout() {
logout()
router.replace('/login')
}
return (
<>
{/* 移动端 overlay */}
{mobileOpen && (
<div
className="fixed inset-0 z-40 bg-black/50 lg:hidden"
onClick={onMobileClose}
/>
)}
<aside
className={cn(
'fixed left-0 top-0 z-50 flex h-screen flex-col border-r border-border bg-card transition-all duration-300',
collapsed ? 'w-16' : 'w-64',
'lg:z-40',
mobileOpen ? 'translate-x-0' : '-translate-x-full lg:translate-x-0',
)}
>
{/* Logo */}
<div className="flex h-14 items-center border-b border-border px-4">
<Link href="/" className="flex items-center gap-2 cursor-pointer">
<div className="flex h-8 w-8 items-center justify-center rounded-md bg-primary text-primary-foreground font-bold text-sm">
Z
</div>
{!collapsed && (
<div className="flex flex-col">
<span className="text-sm font-bold text-foreground">ZCLAW</span>
<span className="text-[10px] text-muted-foreground">Admin</span>
</div>
)}
</Link>
</div>
{/* 导航 */}
<nav className="flex-1 overflow-y-auto scrollbar-thin py-2 px-2">
<ul className="space-y-1">
{navItems
.filter((item) => {
if (!item.permission) return true
if (!account) return false
// super_admin 拥有所有权限
if (account.role === 'super_admin') return true
return account.permissions?.includes(item.permission) ?? false
})
.map((item) => {
const isActive =
item.href === '/'
? pathname === '/'
: pathname.startsWith(item.href)
const Icon = item.icon
return (
<li key={item.href}>
<Link
href={item.href}
className={cn(
'flex items-center gap-3 rounded-md px-3 py-2 text-sm font-medium transition-colors duration-200 cursor-pointer',
isActive
? 'bg-muted text-green-400'
: 'text-muted-foreground hover:bg-muted hover:text-foreground',
collapsed && 'justify-center px-2',
)}
title={collapsed ? item.label : undefined}
>
<Icon className="h-4 w-4 shrink-0" />
{!collapsed && <span>{item.label}</span>}
</Link>
</li>
)
})}
</ul>
</nav>
{/* 底部折叠按钮 */}
<div className="border-t border-border p-2">
<button
onClick={onToggle}
className="flex w-full items-center justify-center rounded-md px-3 py-2 text-muted-foreground hover:bg-muted hover:text-foreground transition-colors duration-200 cursor-pointer"
>
<ChevronLeft
className={cn(
'h-4 w-4 transition-transform duration-200',
collapsed && 'rotate-180',
)}
/>
</button>
</div>
{/* 折叠时显示退出按钮 */}
{collapsed && (
<div className="border-t border-border p-2">
<button
onClick={handleLogout}
className="flex w-full items-center justify-center rounded-md px-3 py-2 text-muted-foreground hover:bg-muted hover:text-destructive transition-colors duration-200 cursor-pointer"
title="退出登录"
>
<LogOut className="h-4 w-4" />
</button>
</div>
)}
{/* 用户信息 */}
{!collapsed && (
<div className="border-t border-border p-3">
<div className="flex items-center gap-3">
<div className="flex h-8 w-8 shrink-0 items-center justify-center rounded-full bg-muted text-xs font-medium text-foreground">
{account?.display_name?.[0] || account?.username?.[0] || 'A'}
</div>
<div className="flex-1 min-w-0">
<p className="truncate text-sm font-medium text-foreground">
{account?.display_name || account?.username || 'Admin'}
</p>
<p className="truncate text-xs text-muted-foreground">
{account?.role || 'admin'}
</p>
</div>
<button
onClick={handleLogout}
className="rounded-md p-1.5 text-muted-foreground hover:bg-muted hover:text-destructive transition-colors duration-200 cursor-pointer"
title="退出登录"
>
<LogOut className="h-4 w-4" />
</button>
</div>
</div>
)}
</aside>
</>
)
}
function Header({ children }: { children?: ReactNode }) {
const pathname = usePathname()
const currentNav = navItems.find(
(item) =>
item.href === '/'
? pathname === '/'
: pathname.startsWith(item.href),
)
return (
<header className="sticky top-0 z-30 flex h-14 items-center border-b border-border bg-background/80 backdrop-blur-sm px-6">
{/* 移动端菜单按钮 */}
{children}
{/* 页面标题 */}
<h1 className="text-lg font-semibold text-foreground">
{currentNav?.label || '仪表盘'}
</h1>
<div className="ml-auto flex items-center gap-2">
{/* 通知 */}
<button
className="relative rounded-md p-2 text-muted-foreground hover:bg-muted hover:text-foreground transition-colors duration-200 cursor-pointer"
title="通知"
>
<Bell className="h-4 w-4" />
</button>
</div>
</header>
)
}
function MobileMenuButton({ onClick }: { onClick: () => void }) {
return (
<button
onClick={onClick}
className="mr-3 rounded-md p-2 text-muted-foreground hover:bg-muted hover:text-foreground transition-colors duration-200 lg:hidden cursor-pointer"
>
<Menu className="h-5 w-5" />
</button>
)
}
/** 路由级权限守卫:隐藏导航项但用户直接访问 URL 时拦截 */
function PageGuard({ children }: { children: ReactNode }) {
const pathname = usePathname()
const router = useRouter()
const { account } = useAuth()
const matchedNav = navItems.find((item) =>
item.href === '/' ? pathname === '/' : pathname.startsWith(item.href),
)
if (matchedNav?.permission && account) {
if (account.role !== 'super_admin' && !(account.permissions?.includes(matchedNav.permission) ?? false)) {
return (
<div className="flex flex-1 items-center justify-center">
<div className="text-center space-y-3">
<p className="text-lg font-medium text-muted-foreground"></p>
<p className="text-sm text-muted-foreground">访{matchedNav.label}</p>
<button
onClick={() => router.replace('/')}
className="text-sm text-primary hover:underline cursor-pointer"
>
</button>
</div>
</div>
)
}
}
return <>{children}</>
}
export default function DashboardLayout({ children }: { children: ReactNode }) {
const [sidebarCollapsed, setSidebarCollapsed] = useState(false)
const [mobileOpen, setMobileOpen] = useState(false)
return (
<AuthGuard>
<PageGuard>
<div className="flex min-h-screen">
<Sidebar
collapsed={sidebarCollapsed}
onToggle={() => setSidebarCollapsed(!sidebarCollapsed)}
mobileOpen={mobileOpen}
onMobileClose={() => setMobileOpen(false)}
/>
<div
className={cn(
'flex flex-1 flex-col transition-all duration-300',
'ml-0 lg:transition-all',
sidebarCollapsed ? 'lg:ml-16' : 'lg:ml-64',
)}
>
<Header>
<MobileMenuButton onClick={() => setMobileOpen(true)} />
</Header>
<main className="flex-1 overflow-auto p-6 scrollbar-thin">
{children}
</main>
</div>
</div>
</PageGuard>
</AuthGuard>
)
}

View File

@@ -0,0 +1,436 @@
'use client'
import { useEffect, useState, useCallback } from 'react'
import {
Plus,
Loader2,
ChevronLeft,
ChevronRight,
Pencil,
Trash2,
} from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Badge } from '@/components/ui/badge'
import { Switch } from '@/components/ui/switch'
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from '@/components/ui/table'
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogFooter,
DialogDescription,
} from '@/components/ui/dialog'
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
import { formatNumber } from '@/lib/utils'
import type { Model, Provider } from '@/lib/types'
const PAGE_SIZE = 20
interface ModelForm {
provider_id: string
model_id: string
alias: string
context_window: string
max_output_tokens: string
supports_streaming: boolean
supports_vision: boolean
enabled: boolean
pricing_input: string
pricing_output: string
}
const emptyForm: ModelForm = {
provider_id: '',
model_id: '',
alias: '',
context_window: '4096',
max_output_tokens: '4096',
supports_streaming: true,
supports_vision: false,
enabled: true,
pricing_input: '',
pricing_output: '',
}
export default function ModelsPage() {
const [models, setModels] = useState<Model[]>([])
const [providers, setProviders] = useState<Provider[]>([])
const [total, setTotal] = useState(0)
const [page, setPage] = useState(1)
const [providerFilter, setProviderFilter] = useState<string>('all')
const [loading, setLoading] = useState(true)
const [error, setError] = useState('')
// Dialog
const [dialogOpen, setDialogOpen] = useState(false)
const [editTarget, setEditTarget] = useState<Model | null>(null)
const [form, setForm] = useState<ModelForm>(emptyForm)
const [saving, setSaving] = useState(false)
// 删除
const [deleteTarget, setDeleteTarget] = useState<Model | null>(null)
const [deleting, setDeleting] = useState(false)
const fetchModels = useCallback(async () => {
setLoading(true)
setError('')
try {
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
if (providerFilter !== 'all') params.provider_id = providerFilter
const res = await api.models.list(params)
setModels(res.items)
setTotal(res.total)
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
else setError('加载失败')
} finally {
setLoading(false)
}
}, [page, providerFilter])
const fetchProviders = useCallback(async () => {
try {
const res = await api.providers.list()
setProviders(res)
} catch {
// ignore
}
}, [])
useEffect(() => {
fetchModels()
fetchProviders()
}, [fetchModels, fetchProviders])
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
const providerMap = new Map(providers.map((p) => [p.id, p.display_name || p.name]))
function openCreateDialog() {
setEditTarget(null)
setForm(emptyForm)
setDialogOpen(true)
}
function openEditDialog(model: Model) {
setEditTarget(model)
setForm({
provider_id: model.provider_id,
model_id: model.model_id,
alias: model.alias,
context_window: model.context_window.toString(),
max_output_tokens: model.max_output_tokens.toString(),
supports_streaming: model.supports_streaming,
supports_vision: model.supports_vision,
enabled: model.enabled,
pricing_input: model.pricing_input.toString(),
pricing_output: model.pricing_output.toString(),
})
setDialogOpen(true)
}
async function handleSave() {
if (!form.model_id.trim() || !form.provider_id) return
setSaving(true)
try {
const payload = {
provider_id: form.provider_id,
model_id: form.model_id.trim(),
alias: form.alias.trim(),
context_window: parseInt(form.context_window, 10) || 4096,
max_output_tokens: parseInt(form.max_output_tokens, 10) || 4096,
supports_streaming: form.supports_streaming,
supports_vision: form.supports_vision,
enabled: form.enabled,
pricing_input: parseFloat(form.pricing_input) || 0,
pricing_output: parseFloat(form.pricing_output) || 0,
}
if (editTarget) {
await api.models.update(editTarget.id, payload)
} else {
await api.models.create(payload)
}
setDialogOpen(false)
fetchModels()
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
} finally {
setSaving(false)
}
}
async function handleDelete() {
if (!deleteTarget) return
setDeleting(true)
try {
await api.models.delete(deleteTarget.id)
setDeleteTarget(null)
fetchModels()
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
} finally {
setDeleting(false)
}
}
return (
<div className="space-y-4">
<div className="flex items-center justify-between">
<Select value={providerFilter} onValueChange={(v) => { setProviderFilter(v); setPage(1) }}>
<SelectTrigger className="w-[200px]">
<SelectValue placeholder="按服务商筛选" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all"></SelectItem>
{providers.map((p) => (
<SelectItem key={p.id} value={p.id}>
{p.display_name || p.name}
</SelectItem>
))}
</SelectContent>
</Select>
<Button onClick={openCreateDialog}>
<Plus className="h-4 w-4 mr-2" />
</Button>
</div>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer"></button>
</div>
)}
{loading ? (
<div className="flex h-64 items-center justify-center">
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : models.length === 0 ? (
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
</div>
) : (
<>
<Table>
<TableHeader>
<TableRow>
<TableHead> ID</TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead className="text-right"></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{models.map((m) => (
<TableRow key={m.id}>
<TableCell className="font-mono text-sm">{m.model_id}</TableCell>
<TableCell>{m.alias || '-'}</TableCell>
<TableCell className="text-muted-foreground">
{providerMap.get(m.provider_id) || m.provider_id.slice(0, 8)}
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{formatNumber(m.context_window)}
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{formatNumber(m.max_output_tokens)}
</TableCell>
<TableCell>
<Badge variant={m.supports_streaming ? 'success' : 'secondary'}>
{m.supports_streaming ? '是' : '否'}
</Badge>
</TableCell>
<TableCell>
<Badge variant={m.supports_vision ? 'success' : 'secondary'}>
{m.supports_vision ? '是' : '否'}
</Badge>
</TableCell>
<TableCell>
<Badge variant={m.enabled ? 'success' : 'destructive'}>
{m.enabled ? '启用' : '禁用'}
</Badge>
</TableCell>
<TableCell className="text-right">
<div className="flex items-center justify-end gap-1">
<Button variant="ghost" size="icon" onClick={() => openEditDialog(m)} title="编辑">
<Pencil className="h-4 w-4" />
</Button>
<Button variant="ghost" size="icon" onClick={() => setDeleteTarget(m)} title="删除">
<Trash2 className="h-4 w-4 text-destructive" />
</Button>
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
<div className="flex items-center justify-between text-sm">
<p className="text-muted-foreground">
{page} / {totalPages} ({total} )
</p>
<div className="flex items-center gap-2">
<Button variant="outline" size="sm" disabled={page <= 1} onClick={() => setPage(page - 1)}>
<ChevronLeft className="h-4 w-4 mr-1" />
</Button>
<Button variant="outline" size="sm" disabled={page >= totalPages} onClick={() => setPage(page + 1)}>
<ChevronRight className="h-4 w-4 ml-1" />
</Button>
</div>
</div>
</>
)}
{/* 创建/编辑 Dialog */}
<Dialog open={dialogOpen} onOpenChange={setDialogOpen}>
<DialogContent className="max-w-lg">
<DialogHeader>
<DialogTitle>{editTarget ? '编辑模型' : '新建模型'}</DialogTitle>
<DialogDescription>
{editTarget ? '修改模型配置' : '添加新的 AI 模型'}
</DialogDescription>
</DialogHeader>
<div className="space-y-4 max-h-[60vh] overflow-y-auto scrollbar-thin pr-1">
<div className="space-y-2">
<Label> *</Label>
<Select value={form.provider_id} onValueChange={(v) => setForm({ ...form, provider_id: v })} disabled={!!editTarget}>
<SelectTrigger>
<SelectValue placeholder="选择服务商" />
</SelectTrigger>
<SelectContent>
{providers.map((p) => (
<SelectItem key={p.id} value={p.id}>
{p.display_name || p.name}
</SelectItem>
))}
</SelectContent>
</Select>
</div>
<div className="space-y-2">
<Label> ID *</Label>
<Input
value={form.model_id}
onChange={(e) => setForm({ ...form, model_id: e.target.value })}
placeholder="gpt-4o"
disabled={!!editTarget}
/>
</div>
<div className="space-y-2">
<Label></Label>
<Input
value={form.alias}
onChange={(e) => setForm({ ...form, alias: e.target.value })}
placeholder="GPT-4o"
/>
</div>
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
<Label></Label>
<Input
type="number"
value={form.context_window}
onChange={(e) => setForm({ ...form, context_window: e.target.value })}
/>
</div>
<div className="space-y-2">
<Label> Tokens</Label>
<Input
type="number"
value={form.max_output_tokens}
onChange={(e) => setForm({ ...form, max_output_tokens: e.target.value })}
/>
</div>
</div>
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
<Label>Input ($/1M tokens)</Label>
<Input
type="number"
step="0.01"
value={form.pricing_input}
onChange={(e) => setForm({ ...form, pricing_input: e.target.value })}
placeholder="0"
/>
</div>
<div className="space-y-2">
<Label>Output ($/1M tokens)</Label>
<Input
type="number"
step="0.01"
value={form.pricing_output}
onChange={(e) => setForm({ ...form, pricing_output: e.target.value })}
placeholder="0"
/>
</div>
</div>
<div className="flex items-center gap-6">
<div className="flex items-center gap-2">
<Switch checked={form.supports_streaming} onCheckedChange={(v) => setForm({ ...form, supports_streaming: v })} />
<Label></Label>
</div>
<div className="flex items-center gap-2">
<Switch checked={form.supports_vision} onCheckedChange={(v) => setForm({ ...form, supports_vision: v })} />
<Label></Label>
</div>
<div className="flex items-center gap-2">
<Switch checked={form.enabled} onCheckedChange={(v) => setForm({ ...form, enabled: v })} />
<Label></Label>
</div>
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setDialogOpen(false)}></Button>
<Button onClick={handleSave} disabled={saving || !form.model_id.trim() || !form.provider_id}>
{saving && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* 删除确认 */}
<Dialog open={!!deleteTarget} onOpenChange={() => setDeleteTarget(null)}>
<DialogContent>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription>
&quot;{deleteTarget?.alias || deleteTarget?.model_id}&quot;
</DialogDescription>
</DialogHeader>
<DialogFooter>
<Button variant="outline" onClick={() => setDeleteTarget(null)}></Button>
<Button variant="destructive" onClick={handleDelete} disabled={deleting}>
{deleting && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
)
}

View File

@@ -0,0 +1,338 @@
'use client'
import { useEffect, useState } from 'react'
import {
Users,
Server,
ArrowLeftRight,
Zap,
Loader2,
TrendingUp,
} from 'lucide-react'
import {
AreaChart,
Area,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
ResponsiveContainer,
BarChart,
Bar,
Legend,
} from 'recharts'
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
import { Badge } from '@/components/ui/badge'
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from '@/components/ui/table'
import { api } from '@/lib/api-client'
import { formatNumber, formatDate } from '@/lib/utils'
import type {
DashboardStats,
UsageStats,
OperationLog,
} from '@/lib/types'
interface StatCardProps {
title: string
value: string | number
icon: React.ReactNode
color: string
subtitle?: string
}
function StatCard({ title, value, icon, color, subtitle }: StatCardProps) {
return (
<Card>
<CardContent className="p-6">
<div className="flex items-center justify-between">
<div>
<p className="text-sm text-muted-foreground">{title}</p>
<p className="mt-1 text-2xl font-bold text-foreground">{value}</p>
{subtitle && (
<p className="mt-1 text-xs text-muted-foreground">{subtitle}</p>
)}
</div>
<div
className={`flex h-10 w-10 items-center justify-center rounded-lg ${color}`}
>
{icon}
</div>
</div>
</CardContent>
</Card>
)
}
function StatusBadge({ status }: { status: string }) {
const variantMap: Record<string, 'success' | 'destructive' | 'warning' | 'info' | 'secondary'> = {
active: 'success',
completed: 'success',
disabled: 'destructive',
failed: 'destructive',
processing: 'info',
queued: 'warning',
suspended: 'destructive',
}
return (
<Badge variant={variantMap[status] || 'secondary'}>{status}</Badge>
)
}
export default function DashboardPage() {
const [stats, setStats] = useState<DashboardStats | null>(null)
const [usageStats, setUsageStats] = useState<UsageStats | null>(null)
const [recentLogs, setRecentLogs] = useState<OperationLog[]>([])
const [loading, setLoading] = useState(true)
const [error, setError] = useState('')
useEffect(() => {
async function fetchData() {
try {
const [statsRes, usageRes, logsRes] = await Promise.allSettled([
api.stats.dashboard(),
api.usage.get(),
api.logs.list({ page: 1, page_size: 5 }),
])
if (statsRes.status === 'fulfilled') setStats(statsRes.value)
if (usageRes.status === 'fulfilled') setUsageStats(usageRes.value)
if (logsRes.status === 'fulfilled') setRecentLogs(logsRes.value)
if (statsRes.status === 'rejected' && usageRes.status === 'rejected' && logsRes.status === 'rejected') {
setError('加载数据失败,请检查后端服务是否启动')
}
} finally {
setLoading(false)
}
}
fetchData()
}, [])
if (loading) {
return (
<div className="flex h-[60vh] items-center justify-center">
<div className="flex flex-col items-center gap-3">
<Loader2 className="h-8 w-8 animate-spin text-primary" />
<p className="text-sm text-muted-foreground">...</p>
</div>
</div>
)
}
if (error) {
return (
<div className="flex h-[60vh] items-center justify-center">
<div className="text-center">
<p className="text-destructive">{error}</p>
<button
onClick={() => window.location.reload()}
className="mt-4 text-sm text-primary hover:underline cursor-pointer"
>
</button>
</div>
</div>
)
}
const chartData = (usageStats?.by_day ?? []).map((r) => ({
day: r.date.slice(5), // MM-DD
请求量: r.request_count,
Input: r.input_tokens,
Output: r.output_tokens,
}))
return (
<div className="space-y-6">
{/* 统计卡片 */}
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-4">
<StatCard
title="总账号数"
value={stats?.total_accounts ?? '-'}
icon={<Users className="h-5 w-5 text-blue-400" />}
color="bg-blue-500/10"
subtitle={`活跃 ${stats?.active_accounts ?? 0}`}
/>
<StatCard
title="活跃服务商"
value={stats?.active_providers ?? '-'}
icon={<Server className="h-5 w-5 text-green-400" />}
color="bg-green-500/10"
subtitle={`模型 ${stats?.active_models ?? 0}`}
/>
<StatCard
title="今日请求"
value={stats?.tasks_today ?? '-'}
icon={<ArrowLeftRight className="h-5 w-5 text-purple-400" />}
color="bg-purple-500/10"
subtitle="中转任务"
/>
<StatCard
title="今日 Token"
value={formatNumber((stats?.tokens_today_input ?? 0) + (stats?.tokens_today_output ?? 0))}
icon={<Zap className="h-5 w-5 text-orange-400" />}
color="bg-orange-500/10"
subtitle={`In: ${formatNumber(stats?.tokens_today_input ?? 0)} / Out: ${formatNumber(stats?.tokens_today_output ?? 0)}`}
/>
</div>
{/* 图表 */}
<div className="grid grid-cols-1 gap-4 lg:grid-cols-2">
{/* 请求趋势 */}
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2 text-base">
<TrendingUp className="h-4 w-4 text-primary" />
(30 )
</CardTitle>
</CardHeader>
<CardContent>
{chartData.length > 0 ? (
<ResponsiveContainer width="100%" height={280}>
<AreaChart data={chartData}>
<defs>
<linearGradient id="colorRequests" x1="0" y1="0" x2="0" y2="1">
<stop offset="5%" stopColor="#22C55E" stopOpacity={0.3} />
<stop offset="95%" stopColor="#22C55E" stopOpacity={0} />
</linearGradient>
</defs>
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
<XAxis
dataKey="day"
tick={{ fontSize: 12, fill: '#94A3B8' }}
axisLine={{ stroke: '#1E293B' }}
/>
<YAxis
tick={{ fontSize: 12, fill: '#94A3B8' }}
axisLine={{ stroke: '#1E293B' }}
/>
<Tooltip
contentStyle={{
backgroundColor: '#0F172A',
border: '1px solid #1E293B',
borderRadius: '8px',
color: '#F8FAFC',
fontSize: '12px',
}}
/>
<Area
type="monotone"
dataKey="请求量"
stroke="#22C55E"
fillOpacity={1}
fill="url(#colorRequests)"
strokeWidth={2}
/>
</AreaChart>
</ResponsiveContainer>
) : (
<div className="flex h-[280px] items-center justify-center text-muted-foreground text-sm">
</div>
)}
</CardContent>
</Card>
{/* Token 用量 */}
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2 text-base">
<Zap className="h-4 w-4 text-orange-400" />
Token (30 )
</CardTitle>
</CardHeader>
<CardContent>
{chartData.length > 0 ? (
<ResponsiveContainer width="100%" height={280}>
<BarChart data={chartData}>
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
<XAxis
dataKey="day"
tick={{ fontSize: 12, fill: '#94A3B8' }}
axisLine={{ stroke: '#1E293B' }}
/>
<YAxis
tick={{ fontSize: 12, fill: '#94A3B8' }}
axisLine={{ stroke: '#1E293B' }}
/>
<Tooltip
contentStyle={{
backgroundColor: '#0F172A',
border: '1px solid #1E293B',
borderRadius: '8px',
color: '#F8FAFC',
fontSize: '12px',
}}
/>
<Legend
wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }}
/>
<Bar dataKey="Input" fill="#3B82F6" radius={[2, 2, 0, 0]} />
<Bar dataKey="Output" fill="#F97316" radius={[2, 2, 0, 0]} />
</BarChart>
</ResponsiveContainer>
) : (
<div className="flex h-[280px] items-center justify-center text-muted-foreground text-sm">
</div>
)}
</CardContent>
</Card>
</div>
{/* 最近操作日志 */}
<Card>
<CardHeader>
<CardTitle className="text-base"></CardTitle>
</CardHeader>
<CardContent>
{recentLogs.length > 0 ? (
<Table>
<TableHeader>
<TableRow>
<TableHead></TableHead>
<TableHead> ID</TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead> ID</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{recentLogs.map((log) => (
<TableRow key={log.id}>
<TableCell className="font-mono text-xs text-muted-foreground">
{formatDate(log.created_at)}
</TableCell>
<TableCell className="font-mono text-xs">
{log.account_id.slice(0, 8)}...
</TableCell>
<TableCell>
<Badge variant="outline">{log.action}</Badge>
</TableCell>
<TableCell className="text-muted-foreground">
{log.target_type}
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{log.target_id.slice(0, 8)}...
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
) : (
<div className="flex h-32 items-center justify-center text-muted-foreground text-sm">
</div>
)}
</CardContent>
</Card>
</div>
)
}

View File

@@ -0,0 +1,154 @@
'use client'
import { useState } from 'react'
import { Lock, Loader2, Eye, EyeOff, Check } from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
export default function ProfilePage() {
const [oldPassword, setOldPassword] = useState('')
const [newPassword, setNewPassword] = useState('')
const [confirmPassword, setConfirmPassword] = useState('')
const [showOld, setShowOld] = useState(false)
const [showNew, setShowNew] = useState(false)
const [showConfirm, setShowConfirm] = useState(false)
const [saving, setSaving] = useState(false)
const [error, setError] = useState('')
const [success, setSuccess] = useState('')
async function handleSubmit(e: React.FormEvent) {
e.preventDefault()
setError('')
setSuccess('')
if (newPassword.length < 8) {
setError('新密码至少 8 个字符')
return
}
if (newPassword !== confirmPassword) {
setError('两次输入的新密码不一致')
return
}
setSaving(true)
try {
await api.auth.changePassword({ old_password: oldPassword, new_password: newPassword })
setSuccess('密码修改成功')
setOldPassword('')
setNewPassword('')
setConfirmPassword('')
} catch (err) {
if (err instanceof ApiRequestError) {
setError(err.body.message || '修改失败')
} else {
setError('网络错误,请稍后重试')
}
} finally {
setSaving(false)
}
}
return (
<div className="max-w-lg">
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<Lock className="h-5 w-5" />
</CardTitle>
<CardDescription></CardDescription>
</CardHeader>
<CardContent>
<form onSubmit={handleSubmit} className="space-y-4">
<div className="space-y-2">
<Label htmlFor="old-password"></Label>
<div className="relative">
<Input
id="old-password"
type={showOld ? 'text' : 'password'}
value={oldPassword}
onChange={(e) => setOldPassword(e.target.value)}
placeholder="请输入当前密码"
required
/>
<button
type="button"
onClick={() => setShowOld(!showOld)}
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
>
{showOld ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
</button>
</div>
</div>
<div className="space-y-2">
<Label htmlFor="new-password"></Label>
<div className="relative">
<Input
id="new-password"
type={showNew ? 'text' : 'password'}
value={newPassword}
onChange={(e) => setNewPassword(e.target.value)}
placeholder="至少 8 个字符"
required
minLength={8}
/>
<button
type="button"
onClick={() => setShowNew(!showNew)}
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
>
{showNew ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
</button>
</div>
</div>
<div className="space-y-2">
<Label htmlFor="confirm-password"></Label>
<div className="relative">
<Input
id="confirm-password"
type={showConfirm ? 'text' : 'password'}
value={confirmPassword}
onChange={(e) => setConfirmPassword(e.target.value)}
placeholder="再次输入新密码"
required
minLength={8}
/>
<button
type="button"
onClick={() => setShowConfirm(!showConfirm)}
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
>
{showConfirm ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
</button>
</div>
</div>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
</div>
)}
{success && (
<div className="rounded-md bg-emerald-500/10 border border-emerald-500/20 px-4 py-3 text-sm text-emerald-500 flex items-center gap-2">
<Check className="h-4 w-4" />
{success}
</div>
)}
<Button type="submit" disabled={saving || !oldPassword || !newPassword || !confirmPassword}>
{saving && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</form>
</CardContent>
</Card>
</div>
)
}

View File

@@ -0,0 +1,369 @@
'use client'
import { useEffect, useState, useCallback } from 'react'
import {
Plus,
Loader2,
ChevronLeft,
ChevronRight,
Pencil,
Trash2,
} from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Badge } from '@/components/ui/badge'
import { Switch } from '@/components/ui/switch'
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from '@/components/ui/table'
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogFooter,
DialogDescription,
} from '@/components/ui/dialog'
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
import { formatDate } from '@/lib/utils'
import type { Provider } from '@/lib/types'
const PAGE_SIZE = 20
interface ProviderForm {
name: string
display_name: string
base_url: string
api_protocol: 'openai' | 'anthropic'
enabled: boolean
rate_limit_rpm: string
rate_limit_tpm: string
}
const emptyForm: ProviderForm = {
name: '',
display_name: '',
base_url: '',
api_protocol: 'openai',
enabled: true,
rate_limit_rpm: '',
rate_limit_tpm: '',
}
export default function ProvidersPage() {
const [providers, setProviders] = useState<Provider[]>([])
const [total, setTotal] = useState(0)
const [page, setPage] = useState(1)
const [loading, setLoading] = useState(true)
const [error, setError] = useState('')
// 创建/编辑 Dialog
const [dialogOpen, setDialogOpen] = useState(false)
const [editTarget, setEditTarget] = useState<Provider | null>(null)
const [form, setForm] = useState<ProviderForm>(emptyForm)
const [saving, setSaving] = useState(false)
// 删除确认 Dialog
const [deleteTarget, setDeleteTarget] = useState<Provider | null>(null)
const [deleting, setDeleting] = useState(false)
const fetchProviders = useCallback(async () => {
setLoading(true)
setError('')
try {
const res = await api.providers.list({ page, page_size: PAGE_SIZE })
setProviders(res.items)
setTotal(res.total)
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
else setError('加载失败')
} finally {
setLoading(false)
}
}, [page])
useEffect(() => {
fetchProviders()
}, [fetchProviders])
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
function openCreateDialog() {
setEditTarget(null)
setForm(emptyForm)
setDialogOpen(true)
}
function openEditDialog(provider: Provider) {
setEditTarget(provider)
setForm({
name: provider.name,
display_name: provider.display_name,
base_url: provider.base_url,
api_protocol: provider.api_protocol,
enabled: provider.enabled,
rate_limit_rpm: provider.rate_limit_rpm?.toString() || '',
rate_limit_tpm: provider.rate_limit_tpm?.toString() || '',
})
setDialogOpen(true)
}
async function handleSave() {
if (!form.name.trim() || !form.base_url.trim()) return
setSaving(true)
try {
const payload = {
name: form.name.trim(),
display_name: form.display_name.trim(),
base_url: form.base_url.trim(),
api_protocol: form.api_protocol,
enabled: form.enabled,
rate_limit_rpm: form.rate_limit_rpm ? parseInt(form.rate_limit_rpm, 10) : undefined,
rate_limit_tpm: form.rate_limit_tpm ? parseInt(form.rate_limit_tpm, 10) : undefined,
}
if (editTarget) {
await api.providers.update(editTarget.id, payload)
} else {
await api.providers.create(payload)
}
setDialogOpen(false)
fetchProviders()
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
} finally {
setSaving(false)
}
}
async function handleDelete() {
if (!deleteTarget) return
setDeleting(true)
try {
await api.providers.delete(deleteTarget.id)
setDeleteTarget(null)
fetchProviders()
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
} finally {
setDeleting(false)
}
}
return (
<div className="space-y-4">
{/* 工具栏 */}
<div className="flex items-center justify-between">
<div />
<Button onClick={openCreateDialog}>
<Plus className="h-4 w-4 mr-2" />
</Button>
</div>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer"></button>
</div>
)}
{loading ? (
<div className="flex h-64 items-center justify-center">
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : providers.length === 0 ? (
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
</div>
) : (
<>
<Table>
<TableHeader>
<TableRow>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead>Base URL</TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead>RPM </TableHead>
<TableHead></TableHead>
<TableHead className="text-right"></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{providers.map((p) => (
<TableRow key={p.id}>
<TableCell className="font-medium">{p.name}</TableCell>
<TableCell>{p.display_name || '-'}</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground max-w-[200px] truncate">
{p.base_url}
</TableCell>
<TableCell>
<Badge variant={p.api_protocol === 'openai' ? 'default' : 'info'}>
{p.api_protocol}
</Badge>
</TableCell>
<TableCell>
<Badge variant={p.enabled ? 'success' : 'secondary'}>
{p.enabled ? '是' : '否'}
</Badge>
</TableCell>
<TableCell className="text-muted-foreground">
{p.rate_limit_rpm ?? '-'}
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{formatDate(p.created_at)}
</TableCell>
<TableCell className="text-right">
<div className="flex items-center justify-end gap-1">
<Button variant="ghost" size="icon" onClick={() => openEditDialog(p)} title="编辑">
<Pencil className="h-4 w-4" />
</Button>
<Button variant="ghost" size="icon" onClick={() => setDeleteTarget(p)} title="删除">
<Trash2 className="h-4 w-4 text-destructive" />
</Button>
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
<div className="flex items-center justify-between text-sm">
<p className="text-muted-foreground">
{page} / {totalPages} ({total} )
</p>
<div className="flex items-center gap-2">
<Button variant="outline" size="sm" disabled={page <= 1} onClick={() => setPage(page - 1)}>
<ChevronLeft className="h-4 w-4 mr-1" />
</Button>
<Button variant="outline" size="sm" disabled={page >= totalPages} onClick={() => setPage(page + 1)}>
<ChevronRight className="h-4 w-4 ml-1" />
</Button>
</div>
</div>
</>
)}
{/* 创建/编辑 Dialog */}
<Dialog open={dialogOpen} onOpenChange={setDialogOpen}>
<DialogContent>
<DialogHeader>
<DialogTitle>{editTarget ? '编辑服务商' : '新建服务商'}</DialogTitle>
<DialogDescription>
{editTarget ? '修改服务商配置' : '添加新的 AI 服务商'}
</DialogDescription>
</DialogHeader>
<div className="space-y-4 max-h-[60vh] overflow-y-auto scrollbar-thin pr-1">
<div className="space-y-2">
<Label> *</Label>
<Input
value={form.name}
onChange={(e) => setForm({ ...form, name: e.target.value })}
placeholder="例如: openai"
disabled={!!editTarget}
/>
</div>
<div className="space-y-2">
<Label></Label>
<Input
value={form.display_name}
onChange={(e) => setForm({ ...form, display_name: e.target.value })}
placeholder="例如: OpenAI"
/>
</div>
<div className="space-y-2">
<Label>Base URL *</Label>
<Input
value={form.base_url}
onChange={(e) => setForm({ ...form, base_url: e.target.value })}
placeholder="https://api.openai.com/v1"
/>
</div>
<div className="space-y-2">
<Label>API </Label>
<Select value={form.api_protocol} onValueChange={(v) => setForm({ ...form, api_protocol: v as 'openai' | 'anthropic' })}>
<SelectTrigger>
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="openai">OpenAI</SelectItem>
<SelectItem value="anthropic">Anthropic</SelectItem>
</SelectContent>
</Select>
</div>
<div className="flex items-center gap-3">
<Switch
checked={form.enabled}
onCheckedChange={(v) => setForm({ ...form, enabled: v })}
/>
<Label></Label>
</div>
<div className="grid grid-cols-2 gap-4">
<div className="space-y-2">
<Label>RPM </Label>
<Input
type="number"
value={form.rate_limit_rpm}
onChange={(e) => setForm({ ...form, rate_limit_rpm: e.target.value })}
placeholder="不限"
/>
</div>
<div className="space-y-2">
<Label>TPM </Label>
<Input
type="number"
value={form.rate_limit_tpm}
onChange={(e) => setForm({ ...form, rate_limit_tpm: e.target.value })}
placeholder="不限"
/>
</div>
</div>
</div>
<DialogFooter>
<Button variant="outline" onClick={() => setDialogOpen(false)}></Button>
<Button onClick={handleSave} disabled={saving || !form.name.trim() || !form.base_url.trim()}>
{saving && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
{/* 删除确认 Dialog */}
<Dialog open={!!deleteTarget} onOpenChange={() => setDeleteTarget(null)}>
<DialogContent>
<DialogHeader>
<DialogTitle></DialogTitle>
<DialogDescription>
&quot;{deleteTarget?.display_name || deleteTarget?.name}&quot;
</DialogDescription>
</DialogHeader>
<DialogFooter>
<Button variant="outline" onClick={() => setDeleteTarget(null)}></Button>
<Button variant="destructive" onClick={handleDelete} disabled={deleting}>
{deleting && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</DialogFooter>
</DialogContent>
</Dialog>
</div>
)
}

View File

@@ -0,0 +1,278 @@
'use client'
import { useEffect, useState, useCallback } from 'react'
import {
Loader2,
ChevronLeft,
ChevronRight,
ChevronDown,
ChevronUp,
RotateCcw,
} from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Badge } from '@/components/ui/badge'
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select'
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from '@/components/ui/table'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
import { formatDate, formatNumber } from '@/lib/utils'
import type { RelayTask } from '@/lib/types'
const PAGE_SIZE = 20
const statusVariants: Record<string, 'success' | 'info' | 'warning' | 'destructive' | 'secondary'> = {
queued: 'warning',
processing: 'info',
completed: 'success',
failed: 'destructive',
}
const statusLabels: Record<string, string> = {
queued: '排队中',
processing: '处理中',
completed: '已完成',
failed: '失败',
}
export default function RelayPage() {
const [tasks, setTasks] = useState<RelayTask[]>([])
const [total, setTotal] = useState(0)
const [page, setPage] = useState(1)
const [statusFilter, setStatusFilter] = useState<string>('all')
const [loading, setLoading] = useState(true)
const [error, setError] = useState('')
const [expandedId, setExpandedId] = useState<string | null>(null)
const [retryingId, setRetryingId] = useState<string | null>(null)
const fetchTasks = useCallback(async () => {
setLoading(true)
setError('')
try {
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
if (statusFilter !== 'all') params.status = statusFilter
const res = await api.relay.list(params)
setTasks(res.items)
setTotal(res.total)
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
else setError('加载失败')
} finally {
setLoading(false)
}
}, [page, statusFilter])
useEffect(() => {
fetchTasks()
}, [fetchTasks])
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
function toggleExpand(id: string) {
setExpandedId((prev) => (prev === id ? null : id))
}
async function handleRetry(taskId: string, e: React.MouseEvent) {
e.stopPropagation()
setRetryingId(taskId)
try {
await api.relay.retry(taskId)
fetchTasks()
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
else setError('重试失败')
} finally {
setRetryingId(null)
}
}
return (
<div className="space-y-4">
{/* 筛选 */}
<div className="flex items-center gap-3">
<Select value={statusFilter} onValueChange={(v) => { setStatusFilter(v); setPage(1) }}>
<SelectTrigger className="w-[140px]">
<SelectValue placeholder="状态筛选" />
</SelectTrigger>
<SelectContent>
<SelectItem value="all"></SelectItem>
<SelectItem value="queued"></SelectItem>
<SelectItem value="processing"></SelectItem>
<SelectItem value="completed"></SelectItem>
<SelectItem value="failed"></SelectItem>
</SelectContent>
</Select>
</div>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer"></button>
</div>
)}
{loading ? (
<div className="flex h-64 items-center justify-center">
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : tasks.length === 0 ? (
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
</div>
) : (
<>
<Table>
<TableHeader>
<TableRow>
<TableHead className="w-8" />
<TableHead> ID</TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead>Input Tokens</TableHead>
<TableHead>Output Tokens</TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead className="text-right"></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{tasks.map((task) => (
<>
<TableRow key={task.id} className="cursor-pointer" onClick={() => toggleExpand(task.id)}>
<TableCell>
{expandedId === task.id ? (
<ChevronUp className="h-4 w-4 text-muted-foreground" />
) : (
<ChevronDown className="h-4 w-4 text-muted-foreground" />
)}
</TableCell>
<TableCell className="font-mono text-xs">
{task.id.slice(0, 8)}...
</TableCell>
<TableCell className="font-mono text-xs">
{task.model_id}
</TableCell>
<TableCell>
<Badge variant={statusVariants[task.status] || 'secondary'}>
{statusLabels[task.status] || task.status}
</Badge>
</TableCell>
<TableCell className="text-muted-foreground">{task.priority}</TableCell>
<TableCell className="text-muted-foreground">{task.attempt_count}</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{formatNumber(task.input_tokens)}
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{formatNumber(task.output_tokens)}
</TableCell>
<TableCell className="max-w-[200px] truncate text-xs text-destructive">
{task.error_message || '-'}
</TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{formatDate(task.created_at)}
</TableCell>
<TableCell className="text-right">
{task.status === 'failed' && (
<Button
variant="ghost"
size="icon"
onClick={(e) => handleRetry(task.id, e)}
disabled={retryingId === task.id}
title="重试"
>
{retryingId === task.id ? (
<Loader2 className="h-4 w-4 animate-spin" />
) : (
<RotateCcw className="h-4 w-4" />
)}
</Button>
)}
</TableCell>
</TableRow>
{expandedId === task.id && (
<TableRow key={`${task.id}-detail`}>
<TableCell colSpan={11} className="bg-muted/20 px-8 py-4">
<div className="grid grid-cols-2 gap-4 text-sm">
<div>
<p className="text-muted-foreground"> ID</p>
<p className="font-mono text-xs">{task.id}</p>
</div>
<div>
<p className="text-muted-foreground"> ID</p>
<p className="font-mono text-xs">{task.account_id}</p>
</div>
<div>
<p className="text-muted-foreground"> ID</p>
<p className="font-mono text-xs">{task.provider_id}</p>
</div>
<div>
<p className="text-muted-foreground"> ID</p>
<p className="font-mono text-xs">{task.model_id}</p>
</div>
{task.queued_at && (
<div>
<p className="text-muted-foreground"></p>
<p className="font-mono text-xs">{formatDate(task.queued_at)}</p>
</div>
)}
{task.started_at && (
<div>
<p className="text-muted-foreground"></p>
<p className="font-mono text-xs">{formatDate(task.started_at)}</p>
</div>
)}
{task.completed_at && (
<div>
<p className="text-muted-foreground"></p>
<p className="font-mono text-xs">{formatDate(task.completed_at)}</p>
</div>
)}
{task.error_message && (
<div className="col-span-2">
<p className="text-muted-foreground"></p>
<p className="text-xs text-destructive mt-1">{task.error_message}</p>
</div>
)}
</div>
</TableCell>
</TableRow>
)}
</>
))}
</TableBody>
</Table>
<div className="flex items-center justify-between text-sm">
<p className="text-muted-foreground">
{page} / {totalPages} ({total} )
</p>
<div className="flex items-center gap-2">
<Button variant="outline" size="sm" disabled={page <= 1} onClick={() => setPage(page - 1)}>
<ChevronLeft className="h-4 w-4 mr-1" />
</Button>
<Button variant="outline" size="sm" disabled={page >= totalPages} onClick={() => setPage(page + 1)}>
<ChevronRight className="h-4 w-4 ml-1" />
</Button>
</div>
</div>
</>
)}
</div>
)
}

View File

@@ -0,0 +1,203 @@
'use client'
import { useState } from 'react'
import { ShieldCheck, Loader2, Eye, EyeOff, QrCode, Key, AlertTriangle } from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card'
import { Badge } from '@/components/ui/badge'
import { api } from '@/lib/api-client'
import { useAuth } from '@/components/auth-guard'
import { ApiRequestError } from '@/lib/api-client'
export default function SecurityPage() {
const { account } = useAuth()
const totpEnabled = account?.totp_enabled ?? false
// Setup state
const [step, setStep] = useState<'idle' | 'verify' | 'done'>('idle')
const [otpauthUri, setOtpauthUri] = useState('')
const [secret, setSecret] = useState('')
const [verifyCode, setVerifyCode] = useState('')
const [loading, setLoading] = useState(false)
const [error, setError] = useState('')
// Disable state
const [disablePassword, setDisablePassword] = useState('')
const [showDisablePassword, setShowDisablePassword] = useState(false)
const [disabling, setDisabling] = useState(false)
async function handleSetup() {
setLoading(true)
setError('')
try {
const res = await api.auth.totpSetup()
setOtpauthUri(res.otpauth_uri)
setSecret(res.secret)
setStep('verify')
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message || '获取密钥失败')
else setError('网络错误')
} finally {
setLoading(false)
}
}
async function handleVerify() {
if (verifyCode.length !== 6) {
setError('请输入 6 位验证码')
return
}
setLoading(true)
setError('')
try {
await api.auth.totpVerify({ code: verifyCode })
setStep('done')
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message || '验证失败')
else setError('网络错误')
} finally {
setLoading(false)
}
}
async function handleDisable() {
if (!disablePassword) {
setError('请输入密码以确认禁用')
return
}
setDisabling(true)
setError('')
try {
await api.auth.totpDisable({ password: disablePassword })
setDisablePassword('')
window.location.reload()
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message || '禁用失败')
else setError('网络错误')
} finally {
setDisabling(false)
}
}
return (
<div className="max-w-lg space-y-6">
{/* TOTP 状态 */}
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<ShieldCheck className="h-5 w-5" />
(TOTP)
</CardTitle>
<CardDescription>
使 Google Authenticator
</CardDescription>
</CardHeader>
<CardContent>
<div className="flex items-center gap-3 mb-4">
<span className="text-sm text-muted-foreground">:</span>
<Badge variant={totpEnabled ? 'success' : 'secondary'}>
{totpEnabled ? '已启用' : '未启用'}
</Badge>
</div>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive mb-4">
{error}
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer"></button>
</div>
)}
{/* 未启用: 设置流程 */}
{!totpEnabled && step === 'idle' && (
<Button onClick={handleSetup} disabled={loading}>
{loading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
<Key className="mr-2 h-4 w-4" />
</Button>
)}
{!totpEnabled && step === 'verify' && (
<div className="space-y-4">
<div className="rounded-md border border-border p-4 space-y-3">
<div className="flex items-center gap-2 text-sm font-medium">
<QrCode className="h-4 w-4" />
1: 扫描二维码或手动输入密钥
</div>
<div className="bg-muted rounded-md p-3 font-mono text-xs break-all">
{otpauthUri}
</div>
<div className="space-y-1">
<p className="text-xs text-muted-foreground">:</p>
<p className="font-mono text-sm font-medium select-all">{secret}</p>
</div>
</div>
<div className="space-y-2">
<Label>
2: 输入 6
</Label>
<Input
value={verifyCode}
onChange={(e) => setVerifyCode(e.target.value.replace(/\D/g, '').slice(0, 6))}
placeholder="请输入应用中显示的 6 位数字"
maxLength={6}
className="font-mono tracking-widest text-center"
/>
</div>
<div className="flex gap-2">
<Button variant="outline" onClick={() => { setStep('idle'); setVerifyCode('') }}>
</Button>
<Button onClick={handleVerify} disabled={loading || verifyCode.length !== 6}>
{loading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</div>
</div>
)}
{!totpEnabled && step === 'done' && (
<div className="rounded-md bg-emerald-500/10 border border-emerald-500/20 p-4 text-sm text-emerald-500">
</div>
)}
{/* 已启用: 禁用流程 */}
{totpEnabled && (
<div className="space-y-4">
<div className="rounded-md bg-amber-500/10 border border-amber-500/20 p-3 flex items-start gap-2 text-sm text-amber-600">
<AlertTriangle className="h-4 w-4 mt-0.5 shrink-0" />
<span></span>
</div>
<div className="space-y-2">
<Label></Label>
<div className="relative">
<Input
type={showDisablePassword ? 'text' : 'password'}
value={disablePassword}
onChange={(e) => setDisablePassword(e.target.value)}
placeholder="请输入当前密码"
/>
<button
type="button"
onClick={() => setShowDisablePassword(!showDisablePassword)}
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
>
{showDisablePassword ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
</button>
</div>
</div>
<Button variant="destructive" onClick={handleDisable} disabled={disabling || !disablePassword}>
{disabling && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</div>
)}
</CardContent>
</Card>
</div>
)
}

View File

@@ -0,0 +1,234 @@
'use client'
import { useEffect, useState, useCallback } from 'react'
import { Loader2, Zap } from 'lucide-react'
import {
LineChart,
Line,
XAxis,
YAxis,
CartesianGrid,
Tooltip,
ResponsiveContainer,
BarChart,
Bar,
Legend,
} from 'recharts'
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
import {
Select,
SelectContent,
SelectItem,
SelectTrigger,
SelectValue,
} from '@/components/ui/select'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
import { formatNumber } from '@/lib/utils'
import type { UsageStats } from '@/lib/types'
export default function UsagePage() {
const [days, setDays] = useState(7)
const [usageStats, setUsageStats] = useState<UsageStats | null>(null)
const [loading, setLoading] = useState(true)
const [error, setError] = useState('')
const fetchData = useCallback(async () => {
setLoading(true)
setError('')
try {
const from = new Date()
from.setDate(from.getDate() - days)
const fromStr = from.toISOString().slice(0, 10)
const res = await api.usage.get({ from: fromStr })
setUsageStats(res)
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
else setError('加载数据失败')
} finally {
setLoading(false)
}
}, [days])
useEffect(() => {
fetchData()
}, [fetchData])
const byDay = usageStats?.by_day ?? []
const lineChartData = byDay.map((r) => ({
day: r.date.slice(5),
Input: r.input_tokens,
Output: r.output_tokens,
}))
const barChartData = (usageStats?.by_model ?? []).map((r) => ({
model: r.model_id,
请求量: r.request_count,
Input: r.input_tokens,
Output: r.output_tokens,
}))
const totalInput = byDay.reduce((s, r) => s + r.input_tokens, 0)
const totalOutput = byDay.reduce((s, r) => s + r.output_tokens, 0)
const totalRequests = byDay.reduce((s, r) => s + r.request_count, 0)
if (loading) {
return (
<div className="flex h-[60vh] items-center justify-center">
<div className="flex flex-col items-center gap-3">
<Loader2 className="h-8 w-8 animate-spin text-primary" />
<p className="text-sm text-muted-foreground">...</p>
</div>
</div>
)
}
if (error) {
return (
<div className="flex h-[60vh] items-center justify-center">
<div className="text-center">
<p className="text-destructive">{error}</p>
<button onClick={() => fetchData()} className="mt-4 text-sm text-primary hover:underline cursor-pointer">
</button>
</div>
</div>
)
}
return (
<div className="space-y-6">
{/* 时间范围 */}
<div className="flex items-center gap-3">
<span className="text-sm text-muted-foreground">:</span>
<Select value={String(days)} onValueChange={(v) => setDays(Number(v))}>
<SelectTrigger className="w-[140px]">
<SelectValue />
</SelectTrigger>
<SelectContent>
<SelectItem value="7"> 7 </SelectItem>
<SelectItem value="30"> 30 </SelectItem>
<SelectItem value="90"> 90 </SelectItem>
</SelectContent>
</Select>
</div>
{/* 汇总统计 */}
<div className="grid grid-cols-1 gap-4 sm:grid-cols-3">
<Card>
<CardContent className="p-6">
<p className="text-sm text-muted-foreground"></p>
<p className="mt-1 text-2xl font-bold text-foreground">
{formatNumber(totalRequests)}
</p>
</CardContent>
</Card>
<Card>
<CardContent className="p-6">
<p className="text-sm text-muted-foreground">Input Tokens</p>
<p className="mt-1 text-2xl font-bold text-blue-400">
{formatNumber(totalInput)}
</p>
</CardContent>
</Card>
<Card>
<CardContent className="p-6">
<p className="text-sm text-muted-foreground">Output Tokens</p>
<p className="mt-1 text-2xl font-bold text-orange-400">
{formatNumber(totalOutput)}
</p>
</CardContent>
</Card>
</div>
{/* Token 用量趋势 */}
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2 text-base">
<Zap className="h-4 w-4 text-primary" />
Token
</CardTitle>
</CardHeader>
<CardContent>
{lineChartData.length > 0 ? (
<ResponsiveContainer width="100%" height={320}>
<LineChart data={lineChartData}>
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
<XAxis
dataKey="day"
tick={{ fontSize: 12, fill: '#94A3B8' }}
axisLine={{ stroke: '#1E293B' }}
/>
<YAxis
tick={{ fontSize: 12, fill: '#94A3B8' }}
axisLine={{ stroke: '#1E293B' }}
/>
<Tooltip
contentStyle={{
backgroundColor: '#0F172A',
border: '1px solid #1E293B',
borderRadius: '8px',
color: '#F8FAFC',
fontSize: '12px',
}}
/>
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
<Line type="monotone" dataKey="Input" stroke="#3B82F6" strokeWidth={2} dot={false} />
<Line type="monotone" dataKey="Output" stroke="#F97316" strokeWidth={2} dot={false} />
</LineChart>
</ResponsiveContainer>
) : (
<div className="flex h-[320px] items-center justify-center text-muted-foreground text-sm">
</div>
)}
</CardContent>
</Card>
{/* 按模型分布 */}
<Card>
<CardHeader>
<CardTitle className="text-base"></CardTitle>
</CardHeader>
<CardContent>
{barChartData.length > 0 ? (
<ResponsiveContainer width="100%" height={320}>
<BarChart data={barChartData} layout="vertical">
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
<XAxis
type="number"
tick={{ fontSize: 12, fill: '#94A3B8' }}
axisLine={{ stroke: '#1E293B' }}
/>
<YAxis
type="category"
dataKey="model"
tick={{ fontSize: 12, fill: '#94A3B8' }}
axisLine={{ stroke: '#1E293B' }}
width={120}
/>
<Tooltip
contentStyle={{
backgroundColor: '#0F172A',
border: '1px solid #1E293B',
borderRadius: '8px',
color: '#F8FAFC',
fontSize: '12px',
}}
/>
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
<Bar dataKey="Input" fill="#3B82F6" radius={[0, 2, 2, 0]} />
<Bar dataKey="Output" fill="#F97316" radius={[0, 2, 2, 0]} />
</BarChart>
</ResponsiveContainer>
) : (
<div className="flex h-[320px] items-center justify-center text-muted-foreground text-sm">
</div>
)}
</CardContent>
</Card>
</div>
)
}

66
admin/src/app/globals.css Normal file
View File

@@ -0,0 +1,66 @@
@tailwind base;
@tailwind components;
@tailwind utilities;
@layer base {
:root {
--background: 222 47% 5%;
--foreground: 210 40% 98%;
--card: 222 47% 8%;
--card-foreground: 210 40% 98%;
--primary: 142 71% 45%;
--primary-foreground: 222 47% 5%;
--muted: 217 33% 17%;
--muted-foreground: 215 20% 65%;
--accent: 215 28% 23%;
--accent-foreground: 210 40% 98%;
--destructive: 0 84% 60%;
--destructive-foreground: 210 40% 98%;
--border: 217 33% 17%;
--input: 217 33% 17%;
--ring: 142 71% 45%;
}
* {
border-color: hsl(var(--border));
}
body {
background-color: hsl(var(--background));
color: hsl(var(--foreground));
font-family: 'Inter', system-ui, -apple-system, sans-serif;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
}
@layer utilities {
.scrollbar-thin {
scrollbar-width: thin;
scrollbar-color: hsl(var(--muted)) transparent;
}
.scrollbar-thin::-webkit-scrollbar {
width: 6px;
height: 6px;
}
.scrollbar-thin::-webkit-scrollbar-track {
background: transparent;
}
.scrollbar-thin::-webkit-scrollbar-thumb {
background-color: hsl(var(--muted));
border-radius: 3px;
}
.scrollbar-thin::-webkit-scrollbar-thumb:hover {
background-color: hsl(var(--accent));
}
}
@layer components {
.glass-card {
@apply bg-card/80 backdrop-blur-sm border border-border rounded-lg;
}
}

29
admin/src/app/layout.tsx Normal file
View File

@@ -0,0 +1,29 @@
import type { Metadata } from 'next'
import { Toaster } from 'sonner'
import './globals.css'
export const metadata: Metadata = {
title: 'ZCLAW Admin',
description: 'ZCLAW AI Agent 管理平台',
}
export default function RootLayout({
children,
}: {
children: React.ReactNode
}) {
return (
<html lang="zh-CN" className="dark">
<head>
<link
href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap"
rel="stylesheet"
/>
</head>
<body className="min-h-screen bg-background font-sans antialiased">
{children}
<Toaster richColors position="top-right" />
</body>
</html>
)
}

View File

@@ -0,0 +1,218 @@
'use client'
import { useState, type FormEvent } from 'react'
import { useRouter } from 'next/navigation'
import { Lock, User, Loader2, Eye, EyeOff, ShieldCheck } from 'lucide-react'
import { api } from '@/lib/api-client'
import { login } from '@/lib/auth'
import { ApiRequestError } from '@/lib/api-client'
export default function LoginPage() {
const router = useRouter()
const [username, setUsername] = useState('')
const [password, setPassword] = useState('')
const [showPassword, setShowPassword] = useState(false)
const [totpCode, setTotpCode] = useState('')
const [showTotp, setShowTotp] = useState(false)
const [loading, setLoading] = useState(false)
const [error, setError] = useState('')
async function handleSubmit(e: FormEvent) {
e.preventDefault()
setError('')
if (!username.trim()) {
setError('请输入用户名')
return
}
if (!password.trim()) {
setError('请输入密码')
return
}
setLoading(true)
try {
const res = await api.auth.login({
username: username.trim(),
password,
totp_code: showTotp ? totpCode.trim() || undefined : undefined,
})
login(res.token, res.account)
router.replace('/')
} catch (err) {
if (err instanceof ApiRequestError) {
// 检测 TOTP 错误码,自动显示验证码输入框
if (err.body.error === 'totp_required' || err.body.message?.includes('双因素认证') || err.body.message?.includes('TOTP')) {
setShowTotp(true)
setError(err.body.message || '此账号已启用双因素认证,请输入验证码')
} else {
setError(err.body.message || '登录失败,请检查用户名和密码')
}
} else {
setError('网络错误,请稍后重试')
}
} finally {
setLoading(false)
}
}
return (
<div className="flex min-h-screen">
{/* 左侧品牌区域 */}
<div className="hidden lg:flex lg:w-1/2 relative overflow-hidden bg-gradient-to-br from-slate-900 via-slate-800 to-slate-900">
{/* 装饰性背景 */}
<div className="absolute inset-0">
<div className="absolute top-1/4 left-1/4 w-96 h-96 bg-green-500/5 rounded-full blur-3xl" />
<div className="absolute bottom-1/4 right-1/4 w-64 h-64 bg-green-500/8 rounded-full blur-3xl" />
<div className="absolute top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[600px] h-[600px] border border-green-500/10 rounded-full" />
<div className="absolute top-1/2 left-1/2 -translate-x-1/2 -translate-y-1/2 w-[400px] h-[400px] border border-green-500/10 rounded-full" />
</div>
{/* 品牌内容 */}
<div className="relative z-10 flex flex-col items-center justify-center w-full p-12">
<div className="text-center">
<h1 className="text-6xl font-bold tracking-tight text-foreground mb-4">
ZCLAW
</h1>
<p className="text-xl text-muted-foreground font-light">
AI Agent
</p>
<div className="mt-8 flex items-center justify-center gap-2">
<div className="h-px w-12 bg-green-500/50" />
<div className="w-2 h-2 rounded-full bg-green-500" />
<div className="h-px w-12 bg-green-500/50" />
</div>
<p className="mt-6 text-sm text-muted-foreground/60 max-w-sm">
AI API
</p>
</div>
</div>
</div>
{/* 右侧登录表单 */}
<div className="flex w-full lg:w-1/2 items-center justify-center p-8">
<div className="w-full max-w-sm space-y-8">
{/* 移动端 Logo */}
<div className="lg:hidden text-center">
<h1 className="text-4xl font-bold tracking-tight text-foreground mb-2">
ZCLAW
</h1>
<p className="text-sm text-muted-foreground">AI Agent </p>
</div>
<div>
<h2 className="text-2xl font-semibold text-foreground"></h2>
<p className="mt-2 text-sm text-muted-foreground">
</p>
</div>
<form onSubmit={handleSubmit} className="space-y-4">
{/* 用户名 */}
<div className="space-y-2">
<label
htmlFor="username"
className="text-sm font-medium text-foreground"
>
</label>
<div className="relative">
<User className="absolute left-3 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<input
id="username"
type="text"
placeholder="请输入用户名"
value={username}
onChange={(e) => setUsername(e.target.value)}
className="flex h-10 w-full rounded-md border border-input bg-transparent pl-10 pr-3 py-2 text-sm shadow-sm transition-colors duration-200 placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
autoComplete="username"
/>
</div>
</div>
{/* 密码 */}
<div className="space-y-2">
<label
htmlFor="password"
className="text-sm font-medium text-foreground"
>
</label>
<div className="relative">
<Lock className="absolute left-3 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
<input
id="password"
type={showPassword ? 'text' : 'password'}
placeholder="请输入密码"
value={password}
onChange={(e) => setPassword(e.target.value)}
className="flex h-10 w-full rounded-md border border-input bg-transparent pl-10 pr-10 py-2 text-sm shadow-sm transition-colors duration-200 placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
autoComplete="current-password"
/>
<button
type="button"
onClick={() => setShowPassword(!showPassword)}
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground transition-colors duration-200 cursor-pointer"
>
{showPassword ? (
<EyeOff className="h-4 w-4" />
) : (
<Eye className="h-4 w-4" />
)}
</button>
</div>
</div>
{/* TOTP 验证码 (仅账号启用 2FA 时显示) */}
{showTotp && (
<div className="space-y-2">
<label
htmlFor="totp_code"
className="text-sm font-medium text-foreground"
>
<span className="inline-flex items-center gap-1">
<ShieldCheck className="h-3.5 w-3.5" />
</span>
</label>
<input
id="totp_code"
type="text"
placeholder="请输入 6 位验证码"
value={totpCode}
onChange={(e) => setTotpCode(e.target.value.replace(/\D/g, '').slice(0, 6))}
className="flex h-10 w-full rounded-md border border-input bg-transparent px-3 py-2 text-sm tracking-widest text-center font-mono shadow-sm transition-colors duration-200 placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
maxLength={6}
autoFocus
/>
</div>
)}
{/* 错误信息 */}
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
</div>
)}
{/* 登录按钮 */}
<button
type="submit"
disabled={loading}
className="flex h-10 w-full items-center justify-center rounded-md bg-primary text-primary-foreground font-medium text-sm shadow-sm transition-colors duration-200 hover:bg-primary-hover focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 focus-visible:ring-offset-background disabled:pointer-events-none disabled:opacity-50 cursor-pointer"
>
{loading ? (
<>
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
...
</>
) : (
'登录'
)}
</button>
</form>
</div>
</div>
</div>
)
}

View File

@@ -0,0 +1,85 @@
'use client'
import { createContext, useContext, useEffect, useState, useCallback, type ReactNode } from 'react'
import { useRouter } from 'next/navigation'
import { isAuthenticated, getAccount, logout as clearCredentials, scheduleTokenRefresh, cancelTokenRefresh, setOnSessionExpired } from '@/lib/auth'
import { api } from '@/lib/api-client'
import type { AccountPublic } from '@/lib/types'
interface AuthContextValue {
account: AccountPublic | null
loading: boolean
refresh: () => Promise<void>
}
const AuthContext = createContext<AuthContextValue>({
account: null,
loading: true,
refresh: async () => {},
})
export function useAuth() {
return useContext(AuthContext)
}
interface AuthGuardProps {
children: ReactNode
}
export function AuthGuard({ children }: AuthGuardProps) {
const router = useRouter()
const [account, setAccount] = useState<AccountPublic | null>(null)
const [loading, setLoading] = useState(true)
const refresh = useCallback(async () => {
try {
const me = await api.auth.me()
setAccount(me)
} catch {
clearCredentials()
router.replace('/login')
}
}, [router])
useEffect(() => {
if (!isAuthenticated()) {
router.replace('/login')
return
}
// 验证 token 有效性并获取最新账号信息
refresh().finally(() => setLoading(false))
}, [router, refresh])
// Set up proactive token refresh with session-expired handler
useEffect(() => {
const handleSessionExpired = () => {
clearCredentials()
router.replace('/login')
}
setOnSessionExpired(handleSessionExpired)
scheduleTokenRefresh()
return () => {
cancelTokenRefresh()
setOnSessionExpired(null)
}
}, [router])
if (loading) {
return (
<div className="flex h-screen w-screen items-center justify-center bg-background">
<div className="h-8 w-8 animate-spin rounded-full border-2 border-primary border-t-transparent" />
</div>
)
}
if (!account) {
return null
}
return (
<AuthContext.Provider value={{ account, loading, refresh }}>
{children}
</AuthContext.Provider>
)
}

View File

@@ -0,0 +1,42 @@
import * as React from 'react'
import { cva, type VariantProps } from 'class-variance-authority'
import { cn } from '@/lib/utils'
const badgeVariants = cva(
'inline-flex items-center rounded-full border px-2.5 py-0.5 text-xs font-semibold transition-colors duration-200 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2',
{
variants: {
variant: {
default:
'border-transparent bg-primary/15 text-primary',
secondary:
'border-transparent bg-muted text-muted-foreground',
destructive:
'border-transparent bg-destructive/15 text-destructive',
outline:
'text-foreground border-border',
success:
'border-transparent bg-green-500/15 text-green-400',
warning:
'border-transparent bg-yellow-500/15 text-yellow-400',
info:
'border-transparent bg-blue-500/15 text-blue-400',
},
},
defaultVariants: {
variant: 'default',
},
},
)
export interface BadgeProps
extends React.HTMLAttributes<HTMLDivElement>,
VariantProps<typeof badgeVariants> {}
function Badge({ className, variant, ...props }: BadgeProps) {
return (
<div className={cn(badgeVariants({ variant }), className)} {...props} />
)
}
export { Badge, badgeVariants }

View File

@@ -0,0 +1,56 @@
'use client'
import * as React from 'react'
import { cva, type VariantProps } from 'class-variance-authority'
import { cn } from '@/lib/utils'
const buttonVariants = cva(
'inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium transition-colors duration-200 focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 focus-visible:ring-offset-background disabled:pointer-events-none disabled:opacity-50',
{
variants: {
variant: {
default:
'bg-primary text-primary-foreground hover:bg-primary-hover shadow-sm',
secondary:
'bg-muted text-muted-foreground hover:bg-accent hover:text-accent-foreground',
destructive:
'bg-destructive text-destructive-foreground hover:bg-red-600 shadow-sm',
outline:
'border border-border bg-transparent hover:bg-accent hover:text-accent-foreground',
ghost:
'hover:bg-accent hover:text-accent-foreground',
link:
'text-primary underline-offset-4 hover:underline',
},
size: {
default: 'h-9 px-4 py-2',
sm: 'h-8 rounded-md px-3 text-xs',
lg: 'h-10 rounded-md px-8',
icon: 'h-9 w-9',
},
},
defaultVariants: {
variant: 'default',
size: 'default',
},
},
)
export interface ButtonProps
extends React.ButtonHTMLAttributes<HTMLButtonElement>,
VariantProps<typeof buttonVariants> {}
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
({ className, variant, size, ...props }, ref) => {
return (
<button
className={cn(buttonVariants({ variant, size, className }))}
ref={ref}
{...props}
/>
)
},
)
Button.displayName = 'Button'
export { Button, buttonVariants }

View File

@@ -0,0 +1,75 @@
import * as React from 'react'
import { cn } from '@/lib/utils'
const Card = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn(
'rounded-lg border border-border bg-card text-card-foreground shadow-sm',
className,
)}
{...props}
/>
))
Card.displayName = 'Card'
const CardHeader = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn('flex flex-col space-y-1.5 p-6', className)}
{...props}
/>
))
CardHeader.displayName = 'CardHeader'
const CardTitle = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLHeadingElement>
>(({ className, ...props }, ref) => (
<h3
ref={ref}
className={cn('font-semibold leading-none tracking-tight', className)}
{...props}
/>
))
CardTitle.displayName = 'CardTitle'
const CardDescription = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLParagraphElement>
>(({ className, ...props }, ref) => (
<p
ref={ref}
className={cn('text-sm text-muted-foreground', className)}
{...props}
/>
))
CardDescription.displayName = 'CardDescription'
const CardContent = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div ref={ref} className={cn('p-6 pt-0', className)} {...props} />
))
CardContent.displayName = 'CardContent'
const CardFooter = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn('flex items-center p-6 pt-0', className)}
{...props}
/>
))
CardFooter.displayName = 'CardFooter'
export { Card, CardHeader, CardFooter, CardTitle, CardDescription, CardContent }

View File

@@ -0,0 +1,118 @@
'use client'
import * as React from 'react'
import * as DialogPrimitive from '@radix-ui/react-dialog'
import { X } from 'lucide-react'
import { cn } from '@/lib/utils'
const Dialog = DialogPrimitive.Root
const DialogTrigger = DialogPrimitive.Trigger
const DialogPortal = DialogPrimitive.Portal
const DialogClose = DialogPrimitive.Close
const DialogOverlay = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Overlay>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Overlay>
>(({ className, ...props }, ref) => (
<DialogPrimitive.Overlay
ref={ref}
className={cn(
'fixed inset-0 z-50 bg-black/60 backdrop-blur-sm',
'data-[state=open]:animate-in data-[state=closed]:animate-out',
'data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0',
className,
)}
{...props}
/>
))
DialogOverlay.displayName = DialogPrimitive.Overlay.displayName
const DialogContent = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Content>
>(({ className, children, ...props }, ref) => (
<DialogPortal>
<DialogOverlay />
<DialogPrimitive.Content
ref={ref}
className={cn(
'fixed left-[50%] top-[50%] z-50 grid w-full max-w-lg translate-x-[-50%] translate-y-[-50%]',
'gap-4 border border-border bg-card p-6 shadow-lg duration-200',
'data-[state=open]:animate-in data-[state=closed]:animate-out',
'data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0',
'data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95',
'data-[state=closed]:slide-out-to-left-1/2 data-[state=closed]:slide-out-to-top-[48%]',
'data-[state=open]:slide-in-from-left-1/2 data-[state=open]:slide-in-from-top-[48%]',
'rounded-lg',
className,
)}
{...props}
>
{children}
<DialogPrimitive.Close className="absolute right-4 top-4 rounded-sm opacity-70 ring-offset-background transition-opacity hover:opacity-100 focus:outline-none focus:ring-2 focus:ring-ring focus:ring-offset-2 disabled:pointer-events-none data-[state=open]:bg-accent data-[state=open]:text-muted-foreground">
<X className="h-4 w-4" />
<span className="sr-only">Close</span>
</DialogPrimitive.Close>
</DialogPrimitive.Content>
</DialogPortal>
))
DialogContent.displayName = DialogPrimitive.Content.displayName
const DialogHeader = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn('flex flex-col space-y-1.5 text-center sm:text-left', className)}
{...props}
/>
)
DialogHeader.displayName = 'DialogHeader'
const DialogFooter = ({
className,
...props
}: React.HTMLAttributes<HTMLDivElement>) => (
<div
className={cn('flex flex-col-reverse sm:flex-row sm:justify-end sm:space-x-2', className)}
{...props}
/>
)
DialogFooter.displayName = 'DialogFooter'
const DialogTitle = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Title>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Title>
>(({ className, ...props }, ref) => (
<DialogPrimitive.Title
ref={ref}
className={cn('text-lg font-semibold leading-none tracking-tight', className)}
{...props}
/>
))
DialogTitle.displayName = DialogPrimitive.Title.displayName
const DialogDescription = React.forwardRef<
React.ElementRef<typeof DialogPrimitive.Description>,
React.ComponentPropsWithoutRef<typeof DialogPrimitive.Description>
>(({ className, ...props }, ref) => (
<DialogPrimitive.Description
ref={ref}
className={cn('text-sm text-muted-foreground', className)}
{...props}
/>
))
DialogDescription.displayName = DialogPrimitive.Description.displayName
export {
Dialog,
DialogPortal,
DialogOverlay,
DialogClose,
DialogTrigger,
DialogContent,
DialogHeader,
DialogFooter,
DialogTitle,
DialogDescription,
}

View File

@@ -0,0 +1,28 @@
import * as React from 'react'
import { cn } from '@/lib/utils'
export interface InputProps
extends React.InputHTMLAttributes<HTMLInputElement> {}
const Input = React.forwardRef<HTMLInputElement, InputProps>(
({ className, type, ...props }, ref) => {
return (
<input
type={type}
className={cn(
'flex h-9 w-full rounded-md border border-input bg-transparent px-3 py-1 text-sm shadow-sm transition-colors duration-200',
'file:border-0 file:bg-transparent file:text-sm file:font-medium',
'placeholder:text-muted-foreground',
'focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring',
'disabled:cursor-not-allowed disabled:opacity-50',
className,
)}
ref={ref}
{...props}
/>
)
},
)
Input.displayName = 'Input'
export { Input }

View File

@@ -0,0 +1,23 @@
import * as React from 'react'
import { cn } from '@/lib/utils'
export interface LabelProps
extends React.LabelHTMLAttributes<HTMLLabelElement> {}
const Label = React.forwardRef<HTMLLabelElement, LabelProps>(
({ className, ...props }, ref) => {
return (
<label
ref={ref}
className={cn(
'text-sm font-medium leading-none peer-disabled:cursor-not-allowed peer-disabled:opacity-70',
className,
)}
{...props}
/>
)
},
)
Label.displayName = 'Label'
export { Label }

View File

@@ -0,0 +1,100 @@
'use client'
import * as React from 'react'
import * as SelectPrimitive from '@radix-ui/react-select'
import { Check, ChevronDown } from 'lucide-react'
import { cn } from '@/lib/utils'
const Select = SelectPrimitive.Root
const SelectGroup = SelectPrimitive.Group
const SelectValue = SelectPrimitive.Value
const SelectTrigger = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Trigger>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Trigger>
>(({ className, children, ...props }, ref) => (
<SelectPrimitive.Trigger
ref={ref}
className={cn(
'flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm ring-offset-background',
'placeholder:text-muted-foreground',
'focus:outline-none focus:ring-1 focus:ring-ring',
'disabled:cursor-not-allowed disabled:opacity-50',
'[&>span]:line-clamp-1',
className,
)}
{...props}
>
{children}
<SelectPrimitive.Icon asChild>
<ChevronDown className="h-4 w-4 opacity-50" />
</SelectPrimitive.Icon>
</SelectPrimitive.Trigger>
))
SelectTrigger.displayName = SelectPrimitive.Trigger.displayName
const SelectContent = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Content>
>(({ className, children, position = 'popper', ...props }, ref) => (
<SelectPrimitive.Portal>
<SelectPrimitive.Content
ref={ref}
className={cn(
'relative z-50 max-h-96 min-w-[8rem] overflow-hidden rounded-md border border-border bg-card text-foreground shadow-md',
'data-[state=open]:animate-in data-[state=closed]:animate-out',
'data-[state=closed]:fade-out-0 data-[state=open]:fade-in-0',
'data-[state=closed]:zoom-out-95 data-[state=open]:zoom-in-95',
position === 'popper' &&
'data-[side=bottom]:translate-y-1 data-[side=left]:-translate-x-1 data-[side=right]:translate-x-1 data-[side=top]:-translate-y-1',
className,
)}
position={position}
{...props}
>
<SelectPrimitive.Viewport
className={cn(
'p-1',
position === 'popper' &&
'h-[var(--radix-select-trigger-height)] w-full min-w-[var(--radix-select-trigger-width)]',
)}
>
{children}
</SelectPrimitive.Viewport>
</SelectPrimitive.Content>
</SelectPrimitive.Portal>
))
SelectContent.displayName = SelectPrimitive.Content.displayName
const SelectItem = React.forwardRef<
React.ElementRef<typeof SelectPrimitive.Item>,
React.ComponentPropsWithoutRef<typeof SelectPrimitive.Item>
>(({ className, children, ...props }, ref) => (
<SelectPrimitive.Item
ref={ref}
className={cn(
'relative flex w-full cursor-pointer select-none items-center rounded-sm py-1.5 pl-8 pr-2 text-sm outline-none',
'focus:bg-accent focus:text-accent-foreground',
'data-[disabled]:pointer-events-none data-[disabled]:opacity-50',
className,
)}
{...props}
>
<span className="absolute left-2 flex h-3.5 w-3.5 items-center justify-center">
<SelectPrimitive.ItemIndicator>
<Check className="h-4 w-4" />
</SelectPrimitive.ItemIndicator>
</span>
<SelectPrimitive.ItemText>{children}</SelectPrimitive.ItemText>
</SelectPrimitive.Item>
))
SelectItem.displayName = SelectPrimitive.Item.displayName
export {
Select,
SelectGroup,
SelectValue,
SelectTrigger,
SelectContent,
SelectItem,
}

View File

@@ -0,0 +1,30 @@
'use client'
import * as React from 'react'
import * as SeparatorPrimitive from '@radix-ui/react-separator'
import { cn } from '@/lib/utils'
const Separator = React.forwardRef<
React.ElementRef<typeof SeparatorPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof SeparatorPrimitive.Root>
>(
(
{ className, orientation = 'horizontal', decorative = true, ...props },
ref,
) => (
<SeparatorPrimitive.Root
ref={ref}
decorative={decorative}
orientation={orientation}
className={cn(
'shrink-0 bg-border',
orientation === 'horizontal' ? 'h-[1px] w-full' : 'h-full w-[1px]',
className,
)}
{...props}
/>
),
)
Separator.displayName = SeparatorPrimitive.Root.displayName
export { Separator }

View File

@@ -0,0 +1,32 @@
'use client'
import * as React from 'react'
import * as SwitchPrimitive from '@radix-ui/react-switch'
import { cn } from '@/lib/utils'
const Switch = React.forwardRef<
React.ElementRef<typeof SwitchPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof SwitchPrimitive.Root>
>(({ className, ...props }, ref) => (
<SwitchPrimitive.Root
className={cn(
'peer inline-flex h-5 w-9 shrink-0 cursor-pointer items-center rounded-full border-2 border-transparent shadow-sm transition-colors duration-200',
'focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 focus-visible:ring-offset-background',
'disabled:cursor-not-allowed disabled:opacity-50',
'data-[state=checked]:bg-primary data-[state=unchecked]:bg-input',
className,
)}
{...props}
ref={ref}
>
<SwitchPrimitive.Thumb
className={cn(
'pointer-events-none block h-4 w-4 rounded-full bg-background shadow-lg ring-0 transition-transform duration-200',
'data-[state=checked]:translate-x-4 data-[state=unchecked]:translate-x-0',
)}
/>
</SwitchPrimitive.Root>
))
Switch.displayName = SwitchPrimitive.Root.displayName
export { Switch }

View File

@@ -0,0 +1,119 @@
import * as React from 'react'
import { cn } from '@/lib/utils'
const Table = React.forwardRef<
HTMLTableElement,
React.HTMLAttributes<HTMLTableElement>
>(({ className, ...props }, ref) => (
<div className="relative w-full overflow-auto scrollbar-thin">
<table
ref={ref}
className={cn('w-full caption-bottom text-sm', className)}
{...props}
/>
</div>
))
Table.displayName = 'Table'
const TableHeader = React.forwardRef<
HTMLTableSectionElement,
React.HTMLAttributes<HTMLTableSectionElement>
>(({ className, ...props }, ref) => (
<thead ref={ref} className={cn('[&_tr]:border-b', className)} {...props} />
))
TableHeader.displayName = 'TableHeader'
const TableBody = React.forwardRef<
HTMLTableSectionElement,
React.HTMLAttributes<HTMLTableSectionElement>
>(({ className, ...props }, ref) => (
<tbody
ref={ref}
className={cn('[&_tr:last-child]:border-0', className)}
{...props}
/>
))
TableBody.displayName = 'TableBody'
const TableFooter = React.forwardRef<
HTMLTableSectionElement,
React.HTMLAttributes<HTMLTableSectionElement>
>(({ className, ...props }, ref) => (
<tfoot
ref={ref}
className={cn(
'border-t bg-muted/50 font-medium [&>tr]:last:border-b-0',
className,
)}
{...props}
/>
))
TableFooter.displayName = 'TableFooter'
const TableRow = React.forwardRef<
HTMLTableRowElement,
React.HTMLAttributes<HTMLTableRowElement>
>(({ className, ...props }, ref) => (
<tr
ref={ref}
className={cn(
'border-b border-border transition-colors duration-200 hover:bg-muted/50',
className,
)}
{...props}
/>
))
TableRow.displayName = 'TableRow'
const TableHead = React.forwardRef<
HTMLTableCellElement,
React.ThHTMLAttributes<HTMLTableCellElement>
>(({ className, ...props }, ref) => (
<th
ref={ref}
className={cn(
'h-10 px-4 text-left align-middle font-medium text-muted-foreground [&:has([role=checkbox])]:pr-0',
className,
)}
{...props}
/>
))
TableHead.displayName = 'TableHead'
const TableCell = React.forwardRef<
HTMLTableCellElement,
React.TdHTMLAttributes<HTMLTableCellElement>
>(({ className, ...props }, ref) => (
<td
ref={ref}
className={cn(
'p-4 align-middle [&:has([role=checkbox])]:pr-0',
className,
)}
{...props}
/>
))
TableCell.displayName = 'TableCell'
const TableCaption = React.forwardRef<
HTMLTableCaptionElement,
React.HTMLAttributes<HTMLTableCaptionElement>
>(({ className, ...props }, ref) => (
<caption
ref={ref}
className={cn('mt-4 text-sm text-muted-foreground', className)}
{...props}
/>
))
TableCaption.displayName = 'TableCaption'
export {
Table,
TableHeader,
TableBody,
TableFooter,
TableHead,
TableRow,
TableCell,
TableCaption,
}

View File

@@ -0,0 +1,57 @@
'use client'
import * as React from 'react'
import * as TabsPrimitive from '@radix-ui/react-tabs'
import { cn } from '@/lib/utils'
const Tabs = TabsPrimitive.Root
const TabsList = React.forwardRef<
React.ElementRef<typeof TabsPrimitive.List>,
React.ComponentPropsWithoutRef<typeof TabsPrimitive.List>
>(({ className, ...props }, ref) => (
<TabsPrimitive.List
ref={ref}
className={cn(
'inline-flex h-9 items-center justify-center rounded-lg bg-muted p-1 text-muted-foreground',
className,
)}
{...props}
/>
))
TabsList.displayName = TabsPrimitive.List.displayName
const TabsTrigger = React.forwardRef<
React.ElementRef<typeof TabsPrimitive.Trigger>,
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Trigger>
>(({ className, ...props }, ref) => (
<TabsPrimitive.Trigger
ref={ref}
className={cn(
'inline-flex items-center justify-center whitespace-nowrap rounded-md px-3 py-1 text-sm font-medium ring-offset-background transition-all duration-200',
'focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2',
'disabled:pointer-events-none disabled:opacity-50',
'data-[state=active]:bg-card data-[state=active]:text-foreground data-[state=active]:shadow',
className,
)}
{...props}
/>
))
TabsTrigger.displayName = TabsPrimitive.Trigger.displayName
const TabsContent = React.forwardRef<
React.ElementRef<typeof TabsPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof TabsPrimitive.Content>
>(({ className, ...props }, ref) => (
<TabsPrimitive.Content
ref={ref}
className={cn(
'mt-2 ring-offset-background focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2',
className,
)}
{...props}
/>
))
TabsContent.displayName = TabsPrimitive.Content.displayName
export { Tabs, TabsList, TabsTrigger, TabsContent }

View File

@@ -0,0 +1,31 @@
'use client'
import * as React from 'react'
import * as TooltipPrimitive from '@radix-ui/react-tooltip'
import { cn } from '@/lib/utils'
const TooltipProvider = TooltipPrimitive.Provider
const Tooltip = TooltipPrimitive.Root
const TooltipTrigger = TooltipPrimitive.Trigger
const TooltipContent = React.forwardRef<
React.ElementRef<typeof TooltipPrimitive.Content>,
React.ComponentPropsWithoutRef<typeof TooltipPrimitive.Content>
>(({ className, sideOffset = 4, ...props }, ref) => (
<TooltipPrimitive.Content
ref={ref}
sideOffset={sideOffset}
className={cn(
'z-50 overflow-hidden rounded-md bg-card border border-border px-3 py-1.5 text-sm text-foreground shadow-md',
'animate-in fade-in-0 zoom-in-95',
'data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95',
'data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2',
'data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2',
className,
)}
{...props}
/>
))
TooltipContent.displayName = TooltipPrimitive.Content.displayName
export { Tooltip, TooltipTrigger, TooltipContent, TooltipProvider }

347
admin/src/lib/api-client.ts Normal file
View File

@@ -0,0 +1,347 @@
// ============================================================
// ZCLAW SaaS Admin — 类型化 HTTP 客户端
// ============================================================
import { getToken, logout, refreshToken } from './auth'
import { toast } from 'sonner'
import type {
AccountPublic,
ApiError,
ConfigItem,
CreateTokenRequest,
DashboardStats,
DeviceInfo,
LoginRequest,
LoginResponse,
Model,
OperationLog,
PaginatedResponse,
Provider,
RelayTask,
TokenInfo,
UsageByModel,
UsageStats,
} from './types'
// ── 错误类 ────────────────────────────────────────────────
export class ApiRequestError extends Error {
constructor(
public status: number,
public body: ApiError,
) {
super(body.message || `Request failed with status ${status}`)
this.name = 'ApiRequestError'
}
}
// ── 基础请求 ──────────────────────────────────────────────
const BASE_URL = process.env.NEXT_PUBLIC_SAAS_API_URL || 'http://localhost:8080'
const API_PREFIX = '/api/v1'
async function request<T>(
method: string,
path: string,
body?: unknown,
): Promise<T> {
const token = getToken()
const headers: Record<string, string> = {
'Content-Type': 'application/json',
}
if (token) {
headers['Authorization'] = `Bearer ${token}`
}
const res = await fetch(`${BASE_URL}${API_PREFIX}${path}`, {
method,
headers,
body: body ? JSON.stringify(body) : undefined,
})
if (res.status === 401) {
// 尝试刷新 token 后重试
try {
const newToken = await refreshToken()
headers['Authorization'] = `Bearer ${newToken}`
const retryRes = await fetch(`${BASE_URL}${API_PREFIX}${path}`, {
method,
headers,
body: body ? JSON.stringify(body) : undefined,
})
if (retryRes.ok || retryRes.status === 204) {
return retryRes.status === 204 ? (undefined as T) : retryRes.json()
}
// 刷新成功但重试仍失败,走正常错误处理
if (!retryRes.ok) {
let errorBody: ApiError
try { errorBody = await retryRes.json() } catch { errorBody = { error: 'unknown', message: `请求失败 (${retryRes.status})` } }
throw new ApiRequestError(retryRes.status, errorBody)
}
} catch {
// 刷新失败,执行登出
}
logout()
if (typeof window !== 'undefined') {
window.location.href = '/login'
}
throw new ApiRequestError(401, { error: 'unauthorized', message: '登录已过期,请重新登录' })
}
if (!res.ok) {
let errorBody: ApiError
try {
errorBody = await res.json()
} catch {
errorBody = { error: 'unknown', message: `请求失败 (${res.status})` }
}
if (typeof window !== 'undefined') {
toast.error(errorBody.message || `请求失败 (${res.status})`)
}
throw new ApiRequestError(res.status, errorBody)
}
// 204 No Content
if (res.status === 204) {
return undefined as T
}
return res.json() as Promise<T>
}
// ── API 客户端 ────────────────────────────────────────────
export const api = {
// ── 认证 ──────────────────────────────────────────────
auth: {
async login(data: LoginRequest): Promise<LoginResponse> {
return request<LoginResponse>('POST', '/auth/login', data)
},
async register(data: {
username: string
password: string
email: string
display_name?: string
}): Promise<LoginResponse> {
return request<LoginResponse>('POST', '/auth/register', data)
},
async me(): Promise<AccountPublic> {
return request<AccountPublic>('GET', '/auth/me')
},
async changePassword(data: { old_password: string; new_password: string }): Promise<void> {
return request<void>('PUT', '/auth/password', data)
},
async totpSetup(): Promise<{ otpauth_uri: string; secret: string; issuer: string }> {
return request<{ otpauth_uri: string; secret: string; issuer: string }>('POST', '/auth/totp/setup')
},
async totpVerify(data: { code: string }): Promise<void> {
return request<void>('POST', '/auth/totp/verify', data)
},
async totpDisable(data: { password: string }): Promise<void> {
return request<void>('POST', '/auth/totp/disable', data)
},
},
// ── 账号管理 ──────────────────────────────────────────
accounts: {
async list(params?: {
page?: number
page_size?: number
search?: string
role?: string
status?: string
}): Promise<PaginatedResponse<AccountPublic>> {
const qs = buildQueryString(params)
return request<PaginatedResponse<AccountPublic>>('GET', `/accounts${qs}`)
},
async get(id: string): Promise<AccountPublic> {
return request<AccountPublic>('GET', `/accounts/${id}`)
},
async update(
id: string,
data: Partial<Pick<AccountPublic, 'display_name' | 'email' | 'role'>>,
): Promise<AccountPublic> {
return request<AccountPublic>('PUT', `/accounts/${id}`, data)
},
async updateStatus(
id: string,
data: { status: AccountPublic['status'] },
): Promise<void> {
return request<void>('PATCH', `/accounts/${id}/status`, data)
},
},
// ── 服务商管理 ────────────────────────────────────────
providers: {
async list(params?: {
page?: number
page_size?: number
}): Promise<PaginatedResponse<Provider>> {
const qs = buildQueryString(params)
return request<PaginatedResponse<Provider>>('GET', `/providers${qs}`)
},
async get(id: string): Promise<Provider> {
return request<Provider>('GET', `/providers/${id}`)
},
async create(data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>): Promise<Provider> {
return request<Provider>('POST', '/providers', data)
},
async update(
id: string,
data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>,
): Promise<Provider> {
return request<Provider>('PUT', `/providers/${id}`, data)
},
async delete(id: string): Promise<void> {
return request<void>('DELETE', `/providers/${id}`)
},
},
// ── 模型管理 ──────────────────────────────────────────
models: {
async list(params?: {
page?: number
page_size?: number
provider_id?: string
}): Promise<PaginatedResponse<Model>> {
const qs = buildQueryString(params)
return request<PaginatedResponse<Model>>('GET', `/models${qs}`)
},
async get(id: string): Promise<Model> {
return request<Model>('GET', `/models/${id}`)
},
async create(data: Partial<Omit<Model, 'id'>>): Promise<Model> {
return request<Model>('POST', '/models', data)
},
async update(id: string, data: Partial<Omit<Model, 'id'>>): Promise<Model> {
return request<Model>('PUT', `/models/${id}`, data)
},
async delete(id: string): Promise<void> {
return request<void>('DELETE', `/models/${id}`)
},
},
// ── API 密钥 ──────────────────────────────────────────
tokens: {
async list(params?: {
page?: number
page_size?: number
}): Promise<PaginatedResponse<TokenInfo>> {
const qs = buildQueryString(params)
return request<PaginatedResponse<TokenInfo>>('GET', `/tokens${qs}`)
},
async create(data: CreateTokenRequest): Promise<TokenInfo> {
return request<TokenInfo>('POST', '/tokens', data)
},
async revoke(id: string): Promise<void> {
return request<void>('DELETE', `/tokens/${id}`)
},
},
// ── 用量统计 ──────────────────────────────────────────
usage: {
async get(params?: { from?: string; to?: string; provider_id?: string; model_id?: string }): Promise<UsageStats> {
const qs = buildQueryString(params)
return request<UsageStats>('GET', `/usage${qs}`)
},
},
// ── 中转任务 ──────────────────────────────────────────
relay: {
async list(params?: {
page?: number
page_size?: number
status?: string
}): Promise<PaginatedResponse<RelayTask>> {
const qs = buildQueryString(params)
return request<PaginatedResponse<RelayTask>>('GET', `/relay/tasks${qs}`)
},
async get(id: string): Promise<RelayTask> {
return request<RelayTask>('GET', `/relay/tasks/${id}`)
},
async retry(id: string): Promise<void> {
return request<void>('POST', `/relay/tasks/${id}/retry`)
},
},
// ── 系统配置 ──────────────────────────────────────────
config: {
async list(params?: {
category?: string
}): Promise<ConfigItem[]> {
const qs = buildQueryString(params)
return request<ConfigItem[]>('GET', `/config/items${qs}`)
},
async update(id: string, data: { current_value: string | number | boolean }): Promise<ConfigItem> {
return request<ConfigItem>('PUT', `/config/items/${id}`, data)
},
},
// ── 操作日志 ──────────────────────────────────────────
logs: {
async list(params?: {
page?: number
page_size?: number
action?: string
}): Promise<OperationLog[]> {
const qs = buildQueryString(params)
return request<OperationLog[]>('GET', `/logs/operations${qs}`)
},
},
// ── 仪表盘 ────────────────────────────────────────────
stats: {
async dashboard(): Promise<DashboardStats> {
return request<DashboardStats>('GET', '/stats/dashboard')
},
},
// ── 设备管理 ──────────────────────────────────────────
devices: {
async list(): Promise<DeviceInfo[]> {
return request<DeviceInfo[]>('GET', '/devices')
},
async register(data: { device_id: string; device_name?: string; platform?: string; app_version?: string }) {
return request<{ ok: boolean; device_id: string }>('POST', '/devices/register', data)
},
async heartbeat(data: { device_id: string }) {
return request<{ ok: boolean }>('POST', '/devices/heartbeat', data)
},
},
}
// ── 工具函数 ──────────────────────────────────────────────
function buildQueryString(params?: Record<string, unknown>): string {
if (!params) return ''
const entries = Object.entries(params).filter(
([, v]) => v !== undefined && v !== null && v !== '',
)
if (entries.length === 0) return ''
const qs = entries
.map(([k, v]) => `${encodeURIComponent(k)}=${encodeURIComponent(String(v))}`)
.join('&')
return `?${qs}`
}

216
admin/src/lib/auth.ts Normal file
View File

@@ -0,0 +1,216 @@
// ============================================================
// ZCLAW SaaS Admin — JWT Token 管理
// ============================================================
import type { AccountPublic, LoginResponse } from './types'
const TOKEN_KEY = 'zclaw_admin_token'
const ACCOUNT_KEY = 'zclaw_admin_account'
// ── JWT 辅助函数 ────────────────────────────────────────────
interface JwtPayload {
exp?: number
iat?: number
sub?: string
}
/**
* Decode a JWT payload without verifying the signature.
* Returns the parsed JSON payload, or null if the token is malformed.
*/
function decodeJwtPayload<T = Record<string, unknown>>(token: string): T | null {
try {
const parts = token.split('.')
if (parts.length !== 3) return null
const base64 = parts[1].replace(/-/g, '+').replace(/_/g, '/')
const json = decodeURIComponent(
atob(base64)
.split('')
.map((c) => '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2))
.join(''),
)
return JSON.parse(json) as T
} catch {
return null
}
}
/**
* Calculate the delay (ms) until 80% of the token's remaining lifetime
* has elapsed. Returns null if the token is already past that point.
*/
function getRefreshDelay(exp: number): number | null {
const now = Math.floor(Date.now() / 1000)
const totalLifetime = exp - now
if (totalLifetime <= 0) return null
const refreshAt = now + Math.floor(totalLifetime * 0.8)
const delayMs = (refreshAt - now) * 1000
return delayMs > 5000 ? delayMs : 5000
}
// ── 定时刷新状态 ────────────────────────────────────────────
let refreshTimerId: ReturnType<typeof setTimeout> | null = null
let visibilityHandler: (() => void) | null = null
let sessionExpiredCallback: (() => void) | null = null
// ── 凭证操作 ────────────────────────────────────────────────
/** 保存登录凭证并启动自动刷新 */
export function login(token: string, account: AccountPublic): void {
if (typeof window === 'undefined') return
localStorage.setItem(TOKEN_KEY, token)
localStorage.setItem(ACCOUNT_KEY, JSON.stringify(account))
scheduleTokenRefresh()
}
/** 清除登录凭证并停止自动刷新 */
export function logout(): void {
if (typeof window === 'undefined') return
cancelTokenRefresh()
localStorage.removeItem(TOKEN_KEY)
localStorage.removeItem(ACCOUNT_KEY)
}
/** 获取 JWT token */
export function getToken(): string | null {
if (typeof window === 'undefined') return null
return localStorage.getItem(TOKEN_KEY)
}
/** 获取当前登录用户信息 */
export function getAccount(): AccountPublic | null {
if (typeof window === 'undefined') return null
const raw = localStorage.getItem(ACCOUNT_KEY)
if (!raw) return null
try {
return JSON.parse(raw) as AccountPublic
} catch {
return null
}
}
/** 是否已认证 */
export function isAuthenticated(): boolean {
return !!getToken()
}
/** 尝试刷新 token成功则更新 localStorage 并返回新 token */
export async function refreshToken(): Promise<string> {
const res = await fetch(
`${process.env.NEXT_PUBLIC_SAAS_API_URL || 'http://localhost:8080'}/api/v1/auth/refresh`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${getToken()}`,
},
},
)
if (!res.ok) {
throw new Error('Token 刷新失败')
}
const data: LoginResponse = await res.json()
login(data.token, data.account)
return data.token
}
// ── 自动刷新调度 ────────────────────────────────────────────
/**
* Register a callback invoked when the proactive token refresh fails.
* The caller should use this to trigger a logout/redirect flow.
*/
export function setOnSessionExpired(handler: (() => void) | null): void {
sessionExpiredCallback = handler
}
/**
* Schedule a proactive token refresh at 80% of the token's remaining lifetime.
* Also registers a visibilitychange listener to re-check when the tab regains focus.
*/
export function scheduleTokenRefresh(): void {
cancelTokenRefresh()
const token = getToken()
if (!token) return
const payload = decodeJwtPayload<JwtPayload>(token)
if (!payload?.exp) return
const delay = getRefreshDelay(payload.exp)
if (delay === null) {
attemptTokenRefresh()
return
}
refreshTimerId = setTimeout(() => {
attemptTokenRefresh()
}, delay)
if (typeof document !== 'undefined' && !visibilityHandler) {
visibilityHandler = () => {
if (document.visibilityState === 'visible') {
checkAndRefreshToken()
}
}
document.addEventListener('visibilitychange', visibilityHandler)
}
}
/**
* Cancel any pending token refresh timer and remove the visibility listener.
*/
export function cancelTokenRefresh(): void {
if (refreshTimerId !== null) {
clearTimeout(refreshTimerId)
refreshTimerId = null
}
if (visibilityHandler !== null && typeof document !== 'undefined') {
document.removeEventListener('visibilitychange', visibilityHandler)
visibilityHandler = null
}
}
/**
* Check if the current token is close to expiry and refresh if needed.
* Called on visibility change to handle clock skew / long background tabs.
*/
function checkAndRefreshToken(): void {
const token = getToken()
if (!token) return
const payload = decodeJwtPayload<JwtPayload>(token)
if (!payload?.exp) return
const now = Math.floor(Date.now() / 1000)
const remaining = payload.exp - now
if (remaining <= 0) {
attemptTokenRefresh()
return
}
const delay = getRefreshDelay(payload.exp)
if (delay !== null && delay < 60_000) {
attemptTokenRefresh()
}
}
/**
* Attempt to refresh the token. On success, the new token is persisted via
* login() which also reschedules the next refresh. On failure, invoke the
* session-expired callback.
*/
async function attemptTokenRefresh(): Promise<void> {
try {
await refreshToken()
} catch {
cancelTokenRefresh()
if (sessionExpiredCallback) {
sessionExpiredCallback()
}
}
}

193
admin/src/lib/types.ts Normal file
View File

@@ -0,0 +1,193 @@
// ============================================================
// ZCLAW SaaS Admin — 全局类型定义
// ============================================================
/** 公共账号信息 */
export interface AccountPublic {
id: string
username: string
email: string
display_name: string
role: 'super_admin' | 'admin' | 'user'
permissions: string[]
status: 'active' | 'disabled' | 'suspended'
totp_enabled: boolean
created_at: string
}
/** 登录请求 */
export interface LoginRequest {
username: string
password: string
totp_code?: string
}
/** 登录响应 */
export interface LoginResponse {
token: string
account: AccountPublic
}
/** 注册请求 */
export interface RegisterRequest {
username: string
password: string
email: string
display_name?: string
}
/** 分页响应 */
export interface PaginatedResponse<T> {
items: T[]
total: number
page: number
page_size: number
}
/** 服务商 (Provider) */
export interface Provider {
id: string
name: string
display_name: string
base_url: string
api_protocol: 'openai' | 'anthropic'
enabled: boolean
rate_limit_rpm?: number
rate_limit_tpm?: number
created_at: string
updated_at: string
}
/** 模型 */
export interface Model {
id: string
provider_id: string
model_id: string
alias: string
context_window: number
max_output_tokens: number
supports_streaming: boolean
supports_vision: boolean
enabled: boolean
pricing_input: number
pricing_output: number
}
/** API 密钥信息 */
export interface TokenInfo {
id: string
name: string
token_prefix: string
permissions: string[]
last_used_at?: string
expires_at?: string
created_at: string
token?: string
}
/** 创建 Token 请求 */
export interface CreateTokenRequest {
name: string
expires_days?: number
permissions: string[]
}
/** 中转任务 */
export interface RelayTask {
id: string
account_id: string
provider_id: string
model_id: string
status: 'queued' | 'processing' | 'completed' | 'failed'
priority: number
attempt_count: number
input_tokens: number
output_tokens: number
error_message?: string
queued_at?: string
started_at?: string
completed_at?: string
created_at: string
}
/** 用量统计 — 后端返回的完整结构 */
export interface UsageStats {
total_requests: number
total_input_tokens: number
total_output_tokens: number
by_model: UsageByModel[]
by_day: DailyUsage[]
}
/** 每日用量 */
export interface DailyUsage {
date: string
request_count: number
input_tokens: number
output_tokens: number
}
/** 按模型用量 */
export interface UsageByModel {
provider_id: string
model_id: string
request_count: number
input_tokens: number
output_tokens: number
}
/** 系统配置项 */
export interface ConfigItem {
id: string
category: string
key_path: string
value_type: 'string' | 'number' | 'boolean'
current_value?: string
default_value?: string
source: 'default' | 'env' | 'db'
description?: string
requires_restart: boolean
created_at: string
updated_at: string
}
/** 操作日志 */
export interface OperationLog {
id: number
account_id: string
action: string
target_type: string
target_id: string
details?: Record<string, unknown>
ip_address?: string
created_at: string
}
/** 仪表盘统计 */
export interface DashboardStats {
total_accounts: number
active_accounts: number
tasks_today: number
active_providers: number
active_models: number
tokens_today_input: number
tokens_today_output: number
}
/** 设备信息 */
export interface DeviceInfo {
id: string
device_id: string
device_name?: string
platform?: string
app_version?: string
last_seen_at: string
created_at: string
}
/** API 错误响应 */
export interface ApiError {
error: string
message: string
status?: number
}

34
admin/src/lib/utils.ts Normal file
View File

@@ -0,0 +1,34 @@
import { type ClassValue, clsx } from 'clsx'
import { twMerge } from 'tailwind-merge'
export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs))
}
export function formatDate(date: string | Date): string {
const d = new Date(date)
return d.toLocaleString('zh-CN', {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
})
}
export function formatNumber(n: number): string {
if (n >= 1_000_000) return `${(n / 1_000_000).toFixed(1)}M`
if (n >= 1_000) return `${(n / 1_000).toFixed(1)}K`
return n.toLocaleString()
}
export function maskApiKey(key?: string): string {
if (!key) return '-'
if (key.length <= 8) return '****'
return `${key.slice(0, 4)}${'*'.repeat(key.length - 8)}${key.slice(-4)}`
}
export function sleep(ms: number): Promise<void> {
return new Promise(resolve => setTimeout(resolve, ms))
}

62
admin/tailwind.config.ts Normal file
View File

@@ -0,0 +1,62 @@
import type { Config } from 'tailwindcss'
const config: Config = {
darkMode: 'class',
content: [
'./src/pages/**/*.{js,ts,jsx,tsx,mdx}',
'./src/components/**/*.{js,ts,jsx,tsx,mdx}',
'./src/app/**/*.{js,ts,jsx,tsx,mdx}',
],
theme: {
extend: {
colors: {
background: '#020617',
foreground: '#F8FAFC',
card: {
DEFAULT: '#0F172A',
foreground: '#F8FAFC',
},
primary: {
DEFAULT: '#22C55E',
foreground: '#020617',
hover: '#16A34A',
},
muted: {
DEFAULT: '#1E293B',
foreground: '#94A3B8',
},
accent: {
DEFAULT: '#334155',
foreground: '#F8FAFC',
},
destructive: {
DEFAULT: '#EF4444',
foreground: '#F8FAFC',
},
border: '#1E293B',
input: '#1E293B',
ring: '#22C55E',
},
fontFamily: {
sans: ['Inter', 'system-ui', '-apple-system', 'sans-serif'],
mono: ['JetBrains Mono', 'Fira Code', 'monospace'],
},
keyframes: {
'fade-in': {
'0%': { opacity: '0', transform: 'translateY(4px)' },
'100%': { opacity: '1', transform: 'translateY(0)' },
},
'slide-in': {
'0%': { opacity: '0', transform: 'translateX(-8px)' },
'100%': { opacity: '1', transform: 'translateX(0)' },
},
},
animation: {
'fade-in': 'fade-in 0.2s ease-out',
'slide-in': 'slide-in 0.2s ease-out',
},
},
},
plugins: [],
}
export default config

21
admin/tsconfig.json Normal file
View File

@@ -0,0 +1,21 @@
{
"compilerOptions": {
"target": "es2017",
"lib": ["dom", "dom.iterable", "esnext"],
"allowJs": true,
"skipLibCheck": true,
"strict": true,
"noEmit": true,
"esModuleInterop": true,
"module": "esnext",
"moduleResolution": "bundler",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
"incremental": true,
"plugins": [{ "name": "next" }],
"paths": { "@/*": ["./src/*"] }
},
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
"exclude": ["node_modules"]
}

View File

@@ -0,0 +1,48 @@
[package]
name = "zclaw-saas"
version.workspace = true
edition.workspace = true
description = "ZCLAW SaaS backend - account, API config, relay, migration"
[[bin]]
name = "zclaw-saas"
path = "src/main.rs"
[dependencies]
tokio = { workspace = true }
futures = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
toml = { workspace = true }
thiserror = { workspace = true }
anyhow = { workspace = true }
uuid = { workspace = true }
chrono = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
sqlx = { workspace = true }
reqwest = { workspace = true }
secrecy = { workspace = true }
sha2 = { workspace = true }
rand = { workspace = true }
dashmap = { workspace = true }
hex = { workspace = true }
url = "2"
axum = { workspace = true }
axum-extra = { workspace = true }
bytes = { workspace = true }
async-stream = { workspace = true }
tower = { workspace = true }
tower-http = { workspace = true }
jsonwebtoken = { workspace = true }
argon2 = { workspace = true }
totp-rs = { workspace = true }
urlencoding = "2"
data-encoding = "2"
aes-gcm = { workspace = true }
utoipa = { version = "5", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "5", features = ["axum"] }
[dev-dependencies]
tempfile = { workspace = true }

View File

@@ -0,0 +1,279 @@
//! 账号管理 HTTP 处理器
use axum::{
extract::{Extension, Path, Query, State},
Json,
};
use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use crate::auth::types::AuthContext;
use crate::auth::handlers::{log_operation, check_permission};
use super::{types::*, service};
fn require_admin(ctx: &AuthContext) -> SaasResult<()> {
check_permission(ctx, "account:admin")
}
/// GET /api/v1/accounts (admin only)
pub async fn list_accounts(
State(state): State<AppState>,
Query(query): Query<ListAccountsQuery>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<PaginatedResponse<serde_json::Value>>> {
require_admin(&ctx)?;
service::list_accounts(&state.db, &query).await.map(Json)
}
/// GET /api/v1/accounts/:id
pub async fn get_account(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
// 只能查看自己,或 admin 查看任何人
if id != ctx.account_id {
require_admin(&ctx)?;
}
service::get_account(&state.db, &id).await.map(Json)
}
/// PUT /api/v1/accounts/:id (admin or self for limited fields)
pub async fn update_account(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<UpdateAccountRequest>,
) -> SaasResult<Json<serde_json::Value>> {
let is_self_update = id == ctx.account_id;
// 非管理员只能修改自己的资料
if !is_self_update {
require_admin(&ctx)?;
}
// 安全限制: 非管理员修改自己时,剥离 role 字段防止自角色提升
let safe_req = if is_self_update && !ctx.permissions.contains(&"admin:full".to_string()) {
UpdateAccountRequest {
role: None,
..req
}
} else {
req
};
let result = service::update_account(&state.db, &id, &safe_req).await?;
log_operation(&state.db, &ctx.account_id, "account.update", "account", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(result))
}
/// PATCH /api/v1/accounts/:id/status (admin only)
pub async fn update_status(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<UpdateStatusRequest>,
) -> SaasResult<Json<serde_json::Value>> {
require_admin(&ctx)?;
service::update_account_status(&state.db, &id, &req.status).await?;
log_operation(&state.db, &ctx.account_id, "account.update_status", "account", &id,
Some(serde_json::json!({"status": &req.status})), ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
/// GET /api/v1/tokens
pub async fn list_tokens(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<Vec<TokenInfo>>> {
service::list_api_tokens(&state.db, &ctx.account_id).await.map(Json)
}
/// POST /api/v1/tokens
pub async fn create_token(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateTokenRequest>,
) -> SaasResult<Json<TokenInfo>> {
let token = service::create_api_token(&state.db, &ctx.account_id, &req).await?;
log_operation(&state.db, &ctx.account_id, "token.create", "api_token", &token.id,
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
Ok(Json(token))
}
/// DELETE /api/v1/tokens/:id
pub async fn revoke_token(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
service::revoke_api_token(&state.db, &id, &ctx.account_id).await?;
log_operation(&state.db, &ctx.account_id, "token.revoke", "api_token", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
/// GET /api/v1/logs/operations (admin only)
pub async fn list_operation_logs(
State(state): State<AppState>,
Query(params): Query<std::collections::HashMap<String, String>>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<Vec<serde_json::Value>>> {
require_admin(&ctx)?;
let page: i64 = params.get("page").and_then(|v| v.parse().ok()).unwrap_or(1);
let page_size: i64 = params.get("page_size").and_then(|v| v.parse().ok()).unwrap_or(50);
let offset = (page - 1) * page_size;
let action_filter = params.get("action").map(|s| s.as_str());
let target_type_filter = params.get("target_type").map(|s| s.as_str());
let mut sql = String::from(
"SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at
FROM operation_logs"
);
let mut param_idx: usize = 1;
if action_filter.is_some() || target_type_filter.is_some() {
sql.push_str(" WHERE 1=1");
if action_filter.is_some() {
sql.push_str(&format!(" AND action = ${}", param_idx));
param_idx += 1;
}
if target_type_filter.is_some() {
sql.push_str(&format!(" AND target_type = ${}", param_idx));
param_idx += 1;
}
}
sql.push_str(&format!(" ORDER BY created_at DESC LIMIT ${} OFFSET ${}", param_idx, param_idx + 1));
let mut query = sqlx::query_as::<_, (i64, Option<String>, String, Option<String>, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>)>(&sql);
if let Some(action) = action_filter {
query = query.bind(action);
}
if let Some(target_type) = target_type_filter {
query = query.bind(target_type);
}
query = query.bind(page_size).bind(offset);
let rows = query.fetch_all(&state.db).await?;
let items: Vec<serde_json::Value> = rows.into_iter().map(|(id, account_id, action, target_type, target_id, details, ip_address, created_at)| {
serde_json::json!({
"id": id, "account_id": account_id, "action": action,
"target_type": target_type, "target_id": target_id,
"details": details.and_then(|d| serde_json::from_str::<serde_json::Value>(&d).ok()),
"ip_address": ip_address, "created_at": created_at.to_rfc3339(),
})
}).collect();
Ok(Json(items))
}
/// GET /api/v1/stats/dashboard — 仪表盘聚合统计 (需要 admin 权限)
pub async fn dashboard_stats(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
require_admin(&ctx)?;
let row: (i64, i64, i64, i64, i64, i64, i64) = sqlx::query_as(
"SELECT
(SELECT COUNT(*) FROM accounts),
(SELECT COUNT(*) FROM accounts WHERE status = 'active'),
(SELECT COUNT(*) FROM relay_tasks WHERE DATE(created_at) = CURRENT_DATE),
(SELECT COUNT(*) FROM providers WHERE enabled = true),
(SELECT COUNT(*) FROM models WHERE enabled = true),
(SELECT COALESCE(SUM(input_tokens), 0) FROM usage_records WHERE DATE(created_at) = CURRENT_DATE),
(SELECT COALESCE(SUM(output_tokens), 0) FROM usage_records WHERE DATE(created_at) = CURRENT_DATE)"
)
.fetch_one(&state.db)
.await?;
Ok(Json(serde_json::json!({
"total_accounts": row.0,
"active_accounts": row.1,
"tasks_today": row.2,
"active_providers": row.3,
"active_models": row.4,
"tokens_today_input": row.5,
"tokens_today_output": row.6,
})))
}
// ============ Devices ============
/// POST /api/v1/devices/register — 注册或更新设备
pub async fn register_device(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<super::types::RegisterDeviceRequest>,
) -> SaasResult<Json<serde_json::Value>> {
let now = chrono::Utc::now();
let device_uuid = uuid::Uuid::new_v4().to_string();
// UPSERT: 已存在则更新 last_seen_at不存在则插入
sqlx::query(
"INSERT INTO devices (id, account_id, device_id, device_name, platform, app_version, last_seen_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $7)
ON CONFLICT(account_id, device_id) DO UPDATE SET
device_name = EXCLUDED.device_name, platform = EXCLUDED.platform, app_version = EXCLUDED.app_version, last_seen_at = EXCLUDED.last_seen_at"
)
.bind(&device_uuid)
.bind(&ctx.account_id)
.bind(&req.device_id)
.bind(&req.device_name)
.bind(&req.platform)
.bind(&req.app_version)
.bind(&now)
.execute(&state.db)
.await?;
log_operation(&state.db, &ctx.account_id, "device.register", "device", &req.device_id,
Some(serde_json::json!({"device_name": req.device_name, "platform": req.platform})),
ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true, "device_id": req.device_id})))
}
/// POST /api/v1/devices/heartbeat — 设备心跳
pub async fn device_heartbeat(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<super::types::DeviceHeartbeatRequest>,
) -> SaasResult<Json<serde_json::Value>> {
let now = chrono::Utc::now();
let result = sqlx::query(
"UPDATE devices SET last_seen_at = $1 WHERE account_id = $2 AND device_id = $3"
)
.bind(&now)
.bind(&ctx.account_id)
.bind(&req.device_id)
.execute(&state.db)
.await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound("设备未注册".into()));
}
Ok(Json(serde_json::json!({"ok": true})))
}
/// GET /api/v1/devices — 列出当前用户的设备
pub async fn list_devices(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<Vec<super::types::DeviceInfo>>> {
let rows: Vec<(String, String, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, device_id, device_name, platform, app_version, last_seen_at, created_at
FROM devices WHERE account_id = $1 ORDER BY last_seen_at DESC"
)
.bind(&ctx.account_id)
.fetch_all(&state.db)
.await?;
let items: Vec<super::types::DeviceInfo> = rows.into_iter().map(|r| {
super::types::DeviceInfo {
id: r.0, device_id: r.1,
device_name: r.2, platform: r.3, app_version: r.4,
last_seen_at: r.5.to_rfc3339(), created_at: r.6.to_rfc3339(),
}
}).collect();
Ok(Json(items))
}

View File

@@ -0,0 +1,23 @@
//! 账号管理模块
pub mod types;
pub mod service;
pub mod handlers;
use axum::routing::{delete, get, patch, post, put};
pub fn routes() -> axum::Router<crate::state::AppState> {
axum::Router::new()
.route("/api/v1/accounts", get(handlers::list_accounts))
.route("/api/v1/accounts/{id}", get(handlers::get_account))
.route("/api/v1/accounts/{id}", put(handlers::update_account))
.route("/api/v1/accounts/{id}/status", patch(handlers::update_status))
.route("/api/v1/tokens", get(handlers::list_tokens))
.route("/api/v1/tokens", post(handlers::create_token))
.route("/api/v1/tokens/{id}", delete(handlers::revoke_token))
.route("/api/v1/logs/operations", get(handlers::list_operation_logs))
.route("/api/v1/stats/dashboard", get(handlers::dashboard_stats))
.route("/api/v1/devices", get(handlers::list_devices))
.route("/api/v1/devices/register", post(handlers::register_device))
.route("/api/v1/devices/heartbeat", post(handlers::device_heartbeat))
}

View File

@@ -0,0 +1,230 @@
//! 账号管理业务逻辑
use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use super::types::*;
pub async fn list_accounts(
db: &PgPool,
query: &ListAccountsQuery,
) -> SaasResult<PaginatedResponse<serde_json::Value>> {
let page = query.page.unwrap_or(1).max(1);
let page_size = query.page_size.unwrap_or(20).min(100);
let offset = (page - 1) * page_size;
let mut where_clauses = Vec::new();
let mut params: Vec<String> = Vec::new();
let mut param_idx: usize = 1;
if let Some(role) = &query.role {
where_clauses.push(format!("role = ${}", param_idx));
params.push(role.clone());
param_idx += 1;
}
if let Some(status) = &query.status {
where_clauses.push(format!("status = ${}", param_idx));
params.push(status.clone());
param_idx += 1;
}
if let Some(search) = &query.search {
where_clauses.push(format!("(username LIKE ${} OR email LIKE ${} OR display_name LIKE ${})", param_idx, param_idx + 1, param_idx + 2));
let pattern = format!("%{}%", search);
params.push(pattern.clone());
params.push(pattern.clone());
params.push(pattern);
param_idx += 3;
}
let where_sql = if where_clauses.is_empty() {
String::new()
} else {
format!("WHERE {}", where_clauses.join(" AND "))
};
let count_sql = format!("SELECT COUNT(*) as count FROM accounts {}", where_sql);
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
for p in &params {
count_query = count_query.bind(p);
}
let total: i64 = count_query.fetch_one(db).await?;
let data_sql = format!(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
where_sql, param_idx, param_idx + 1
);
let mut data_query = sqlx::query_as::<_, (String, String, String, String, String, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)>(&data_sql);
for p in &params {
data_query = data_query.bind(p);
}
let rows = data_query.bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
let items: Vec<serde_json::Value> = rows
.into_iter()
.map(|(id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at)| {
serde_json::json!({
"id": id, "username": username, "email": email, "display_name": display_name,
"role": role, "status": status, "totp_enabled": totp_enabled,
"last_login_at": last_login_at.map(|t| t.to_rfc3339()), "created_at": created_at.to_rfc3339(),
})
})
.collect();
Ok(PaginatedResponse { items, total, page, page_size })
}
pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json::Value> {
let row: Option<(String, String, String, String, String, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts WHERE id = $1"
)
.bind(account_id)
.fetch_optional(db)
.await?;
let (id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at) =
row.ok_or_else(|| SaasError::NotFound(format!("账号 {} 不存在", account_id)))?;
Ok(serde_json::json!({
"id": id, "username": username, "email": email, "display_name": display_name,
"role": role, "status": status, "totp_enabled": totp_enabled,
"last_login_at": last_login_at.map(|t| t.to_rfc3339()), "created_at": created_at.to_rfc3339(),
}))
}
pub async fn update_account(
db: &PgPool,
account_id: &str,
req: &UpdateAccountRequest,
) -> SaasResult<serde_json::Value> {
let now = chrono::Utc::now();
let mut updates = Vec::new();
let mut params: Vec<String> = Vec::new();
let mut param_idx: usize = 1;
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.email { updates.push(format!("email = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.role { updates.push(format!("role = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.avatar_url { updates.push(format!("avatar_url = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if updates.is_empty() {
return get_account(db, account_id).await;
}
updates.push(format!("updated_at = ${}", param_idx));
param_idx += 1;
params.push(account_id.to_string());
let sql = format!("UPDATE accounts SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(p);
}
query = query.bind(now);
query.execute(db).await?;
get_account(db, account_id).await
}
pub async fn update_account_status(
db: &PgPool,
account_id: &str,
status: &str,
) -> SaasResult<()> {
let valid = ["active", "disabled", "suspended"];
if !valid.contains(&status) {
return Err(SaasError::InvalidInput(format!("无效状态: {},有效值: {:?}", status, valid)));
}
let now = chrono::Utc::now();
let result = sqlx::query("UPDATE accounts SET status = $1, updated_at = $2 WHERE id = $3")
.bind(status).bind(&now).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("账号 {} 不存在", account_id)));
}
Ok(())
}
pub async fn create_api_token(
db: &PgPool,
account_id: &str,
req: &CreateTokenRequest,
) -> SaasResult<TokenInfo> {
use sha2::{Sha256, Digest};
let mut bytes = [0u8; 48];
use rand::RngCore;
rand::thread_rng().fill_bytes(&mut bytes);
let raw_token = format!("zclaw_{}", hex::encode(bytes));
let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
let token_prefix = raw_token[..8].to_string();
let now = chrono::Utc::now();
let now_str = now.to_rfc3339();
let expires_at = req.expires_days.map(|d| {
chrono::Utc::now() + chrono::Duration::days(d)
});
let expires_at_str = expires_at.as_ref().map(|t| t.to_rfc3339());
let permissions = serde_json::to_string(&req.permissions)?;
let token_id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO api_tokens (id, account_id, name, token_hash, token_prefix, permissions, created_at, expires_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
)
.bind(&token_id)
.bind(account_id)
.bind(&req.name)
.bind(&token_hash)
.bind(&token_prefix)
.bind(&permissions)
.bind(&now)
.bind(&expires_at)
.execute(db)
.await?;
Ok(TokenInfo {
id: token_id,
name: req.name.clone(),
token_prefix,
permissions: req.permissions.clone(),
last_used_at: None,
expires_at: expires_at_str,
created_at: now_str,
token: Some(raw_token),
})
}
pub async fn list_api_tokens(
db: &PgPool,
account_id: &str,
) -> SaasResult<Vec<TokenInfo>> {
let rows: Vec<(String, String, String, String, Option<chrono::DateTime<chrono::Utc>>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, name, token_prefix, permissions, last_used_at, expires_at, created_at
FROM api_tokens WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC"
)
.bind(account_id)
.fetch_all(db)
.await?;
Ok(rows.into_iter().map(|(id, name, token_prefix, perms, last_used, expires, created)| {
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
TokenInfo { id, name, token_prefix, permissions, last_used_at: last_used.map(|t| t.to_rfc3339()), expires_at: expires.map(|t| t.to_rfc3339()), created_at: created.to_rfc3339(), token: None, }
}).collect())
}
pub async fn revoke_api_token(db: &PgPool, token_id: &str, account_id: &str) -> SaasResult<()> {
let now = chrono::Utc::now();
let result = sqlx::query(
"UPDATE api_tokens SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
)
.bind(&now).bind(token_id).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound("Token 不存在或已撤销".into()));
}
Ok(())
}

View File

@@ -0,0 +1,99 @@
//! 账号管理类型
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateAccountRequest {
pub display_name: Option<String>,
pub email: Option<String>,
pub role: Option<String>,
pub avatar_url: Option<String>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateStatusRequest {
pub status: String,
}
#[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
pub struct ListAccountsQuery {
pub page: Option<u32>,
pub page_size: Option<u32>,
pub role: Option<String>,
pub status: Option<String>,
pub search: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct PaginatedResponse<T: Serialize> {
pub items: Vec<T>,
pub total: i64,
pub page: u32,
pub page_size: u32,
}
/// Concrete type alias for OpenAPI schema generation.
///
/// NOTE: This is intentionally a concrete (non-generic) type because utoipa
/// requires concrete types for schema generation. It is functionally
/// identical to `Paginated<AccountPublic>`.
#[derive(Debug, Serialize, utoipa::ToSchema)]
#[allow(clippy::manual_non_exhaustive)] // kept for OpenAPI schema
pub struct AccountPublicPaginatedResponse {
pub items: Vec<crate::auth::types::AccountPublic>,
pub total: i64,
pub page: u32,
pub page_size: u32,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateTokenRequest {
pub name: String,
pub permissions: Vec<String>,
pub expires_days: Option<i64>,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct TokenInfo {
pub id: String,
pub name: String,
pub token_prefix: String,
pub permissions: Vec<String>,
pub last_used_at: Option<String>,
pub expires_at: Option<String>,
pub created_at: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub token: Option<String>,
}
// ============ Device Types ============
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct RegisterDeviceRequest {
pub device_id: String,
#[serde(default = "default_device_name")]
pub device_name: String,
#[serde(default = "default_platform")]
pub platform: String,
#[serde(default)]
pub app_version: String,
}
fn default_device_name() -> String { "Unknown".into() }
fn default_platform() -> String { "unknown".into() }
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct DeviceHeartbeatRequest {
pub device_id: String,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct DeviceInfo {
pub id: String,
pub device_id: String,
pub device_name: Option<String>,
pub platform: Option<String>,
pub app_version: Option<String>,
pub last_seen_at: String,
pub created_at: String,
}

View File

@@ -0,0 +1,364 @@
//! 认证 HTTP 处理器
use axum::{extract::{State, ConnectInfo}, http::StatusCode, Json};
use std::net::SocketAddr;
use secrecy::ExposeSecret;
use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use super::{
jwt::create_token,
password::{hash_password, verify_password},
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic},
};
/// POST /api/v1/auth/register
pub async fn register(
State(state): State<AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(req): Json<RegisterRequest>,
) -> SaasResult<(StatusCode, Json<LoginResponse>)> {
// 4.6: 用户名格式验证 — 3-32 字符,仅允许字母数字下划线
if req.username.len() < 3 || req.username.len() > 32 {
return Err(SaasError::InvalidInput("用户名长度需在 3-32 个字符之间".into()));
}
if !req.username.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(SaasError::InvalidInput("用户名仅允许字母、数字和下划线".into()));
}
// 4.7: 邮箱格式验证
if !req.email.contains('@') || !req.email.split('@').nth(1).map_or(false, |d| d.contains('.')) {
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
}
if req.password.len() < 8 {
return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
}
let existing: Vec<(String,)> = sqlx::query_as(
"SELECT id FROM accounts WHERE username = $1 OR email = $2"
)
.bind(&req.username)
.bind(&req.email)
.fetch_all(&state.db)
.await?;
if !existing.is_empty() {
return Err(SaasError::AlreadyExists("用户名或邮箱已存在".into()));
}
let password_hash = hash_password(&req.password)?;
let account_id = uuid::Uuid::new_v4().to_string();
let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配
let display_name = req.display_name.unwrap_or_default();
let now = chrono::Utc::now();
sqlx::query(
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, 'active', $7, $7)"
)
.bind(&account_id)
.bind(&req.username)
.bind(&req.email)
.bind(&password_hash)
.bind(&display_name)
.bind(&role)
.bind(now)
.execute(&state.db)
.await?;
let client_ip = addr.ip().to_string();
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
// Generate JWT token for auto-login after registration
let config = state.config.read().await;
let token = create_token(
&account_id, &role, vec![],
state.jwt_secret.expose_secret(), config.auth.jwt_expiration_hours,
)?;
Ok((StatusCode::CREATED, Json(LoginResponse {
token,
account: AccountPublic {
id: account_id,
username: req.username,
email: req.email,
display_name,
role,
permissions: vec![],
status: "active".into(),
totp_enabled: false,
created_at: now.to_rfc3339(),
},
})))
}
/// POST /api/v1/auth/login
pub async fn login(
State(state): State<AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(req): Json<LoginRequest>,
) -> SaasResult<Json<LoginResponse>> {
let row: Option<(String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
FROM accounts WHERE username = $1 OR email = $1"
)
.bind(&req.username)
.fetch_optional(&state.db)
.await?;
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
let created_at = created_at.to_rfc3339();
if status != "active" {
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", status)));
}
let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = $1"
)
.bind(&id)
.fetch_one(&state.db)
.await?;
if !verify_password(&req.password, &password_hash)? {
return Err(SaasError::AuthError("用户名或密码错误".into()));
}
// TOTP 验证: 如果用户已启用 2FA必须提供有效 TOTP 码
if totp_enabled {
let code = req.totp_code.as_deref()
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
let (totp_secret,): (Option<String>,) = sqlx::query_as(
"SELECT totp_secret FROM accounts WHERE id = $1"
)
.bind(&id)
.fetch_one(&state.db)
.await?;
let secret = totp_secret.ok_or_else(|| {
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
})?;
// 解密 TOTP 密钥(兼容迁移期间的明文数据)
let decrypted_secret = state.field_encryption.decrypt_or_plaintext(&secret);
if !super::totp::verify_totp_code(&decrypted_secret, code) {
return Err(SaasError::Totp("TOTP 码错误或已过期".into()));
}
}
let permissions = get_role_permissions(&state.db, &role).await?;
let config = state.config.read().await;
let token = create_token(
&id, &role, permissions.clone(),
state.jwt_secret.expose_secret(),
config.auth.jwt_expiration_hours,
)?;
let now = chrono::Utc::now();
sqlx::query("UPDATE accounts SET last_login_at = $1 WHERE id = $2")
.bind(now).bind(&id)
.execute(&state.db).await?;
let client_ip = addr.ip().to_string();
log_operation(&state.db, &id, "account.login", "account", &id, None, Some(&client_ip)).await?;
Ok(Json(LoginResponse {
token,
account: AccountPublic {
id, username, email, display_name, role, permissions, status, totp_enabled, created_at,
},
}))
}
/// POST /api/v1/auth/refresh
pub async fn refresh(
State(state): State<AppState>,
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
) -> SaasResult<Json<LoginResponse>> {
let config = state.config.read().await;
let token = create_token(
&ctx.account_id, &ctx.role, ctx.permissions.clone(),
state.jwt_secret.expose_secret(),
config.auth.jwt_expiration_hours,
)?;
// 查询账号信息以返回完整 LoginResponse
let row = sqlx::query_as::<_, (String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)>(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_optional(&state.db)
.await?
.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
let (id, username, email, display_name, role, status, totp_enabled, created_at) = row;
let created_at = created_at.to_rfc3339();
Ok(Json(LoginResponse {
token,
account: AccountPublic { id, username, email, display_name, role, permissions: ctx.permissions, status, totp_enabled, created_at },
}))
}
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
pub async fn me(
State(state): State<AppState>,
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
) -> SaasResult<Json<AccountPublic>> {
let row: Option<(String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_optional(&state.db)
.await?;
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
let created_at = created_at.to_rfc3339();
Ok(Json(AccountPublic {
id, username, email, display_name, role, permissions: ctx.permissions, status, totp_enabled, created_at,
}))
}
/// PUT /api/v1/auth/password — 修改密码
pub async fn change_password(
State(state): State<AppState>,
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
Json(req): Json<ChangePasswordRequest>,
) -> SaasResult<Json<serde_json::Value>> {
if req.new_password.len() < 8 {
return Err(SaasError::InvalidInput("新密码至少 8 个字符".into()));
}
// 获取当前密码哈希
let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
.await?;
// 验证旧密码
if !verify_password(&req.old_password, &password_hash)? {
return Err(SaasError::AuthError("旧密码错误".into()));
}
// 更新密码
let new_hash = hash_password(&req.new_password)?;
let now = chrono::Utc::now();
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2 WHERE id = $3")
.bind(&new_hash)
.bind(now)
.bind(&ctx.account_id)
.execute(&state.db)
.await?;
log_operation(&state.db, &ctx.account_id, "account.change_password", "account", &ctx.account_id,
None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
}
pub(crate) async fn get_role_permissions(db: &sqlx::PgPool, role: &str) -> SaasResult<Vec<String>> {
let row: Option<(String,)> = sqlx::query_as(
"SELECT permissions FROM roles WHERE id = $1"
)
.bind(role)
.fetch_optional(db)
.await?;
let permissions_str = row
.ok_or_else(|| SaasError::Forbidden(format!("角色 {} 不存在或无权限", role)))?
.0;
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
Ok(permissions)
}
/// 检查权限 (admin:full 自动通过所有检查)
pub fn check_permission(ctx: &AuthContext, permission: &str) -> SaasResult<()> {
if ctx.permissions.contains(&"admin:full".to_string()) {
return Ok(());
}
if !ctx.permissions.contains(&permission.to_string()) {
return Err(SaasError::Forbidden(format!("需要 {} 权限", permission)));
}
Ok(())
}
/// 记录操作日志
pub async fn log_operation(
db: &sqlx::PgPool,
account_id: &str,
action: &str,
target_type: &str,
target_id: &str,
details: Option<serde_json::Value>,
ip_address: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now();
sqlx::query(
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7)"
)
.bind(account_id)
.bind(action)
.bind(target_type)
.bind(target_id)
.bind(details.map(|d| d.to_string()))
.bind(ip_address)
.bind(now)
.execute(db)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::types::AuthContext;
fn ctx(permissions: Vec<&str>) -> AuthContext {
AuthContext {
account_id: "test-id".into(),
role: "user".into(),
permissions: permissions.into_iter().map(String::from).collect(),
client_ip: None,
}
}
#[test]
fn test_check_permission_admin_full() {
let c = ctx(vec!["admin:full"]);
assert!(check_permission(&c, "config:write").is_ok());
assert!(check_permission(&c, "account:admin").is_ok());
assert!(check_permission(&c, "any:permission").is_ok());
}
#[test]
fn test_check_permission_has_permission() {
let c = ctx(vec!["config:write", "model:read"]);
assert!(check_permission(&c, "config:write").is_ok());
assert!(check_permission(&c, "model:read").is_ok());
}
#[test]
fn test_check_permission_missing() {
let c = ctx(vec!["model:read"]);
let result = check_permission(&c, "config:write");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("config:write"));
}
#[test]
fn test_check_permission_empty_list() {
let c = ctx(vec![]);
assert!(check_permission(&c, "config:write").is_err());
assert!(check_permission(&c, "admin:full").is_err());
}
}

View File

@@ -0,0 +1,102 @@
//! JWT Token 创建与验证
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use crate::error::SaasResult;
/// JWT Claims
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub aud: String,
pub iss: String,
pub role: String,
pub permissions: Vec<String>,
pub iat: i64,
pub exp: i64,
}
const JWT_AUDIENCE: &str = "zclaw-saas";
const JWT_ISSUER: &str = "zclaw-saas";
impl Claims {
pub fn new(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self {
let now = Utc::now();
Self {
sub: account_id.to_string(),
aud: JWT_AUDIENCE.to_string(),
iss: JWT_ISSUER.to_string(),
role: role.to_string(),
permissions,
iat: now.timestamp(),
exp: (now + Duration::hours(expiration_hours)).timestamp(),
}
}
}
/// 创建 JWT Token
pub fn create_token(
account_id: &str,
role: &str,
permissions: Vec<String>,
secret: &str,
expiration_hours: i64,
) -> SaasResult<String> {
let claims = Claims::new(account_id, role, permissions, expiration_hours);
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)?;
Ok(token)
}
/// 验证 JWT Token
pub fn verify_token(token: &str, secret: &str) -> SaasResult<Claims> {
let mut validation = Validation::default();
validation.set_audience(&[JWT_AUDIENCE]);
validation.set_issuer(&[JWT_ISSUER]);
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(secret.as_bytes()),
&validation,
)?;
Ok(token_data.claims)
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_SECRET: &str = "test-secret-key";
#[test]
fn test_create_and_verify_token() {
let token = create_token(
"account-123", "admin",
vec!["model:read".to_string()],
TEST_SECRET, 24,
).unwrap();
let claims = verify_token(&token, TEST_SECRET).unwrap();
assert_eq!(claims.sub, "account-123");
assert_eq!(claims.role, "admin");
assert_eq!(claims.permissions, vec!["model:read"]);
}
#[test]
fn test_invalid_token() {
let result = verify_token("invalid.token.here", TEST_SECRET);
assert!(result.is_err());
}
#[test]
fn test_wrong_secret() {
let token = create_token("account-123", "admin", vec![], TEST_SECRET, 24).unwrap();
let result = verify_token(&token, "wrong-secret");
assert!(result.is_err());
}
}

View File

@@ -0,0 +1,157 @@
//! 认证模块
pub mod jwt;
pub mod password;
pub mod types;
pub mod handlers;
pub mod totp;
use axum::{
extract::{Request, State},
http::header,
middleware::Next,
response::{IntoResponse, Response},
extract::ConnectInfo,
};
use secrecy::ExposeSecret;
use crate::error::SaasError;
use crate::state::AppState;
use types::AuthContext;
use std::net::SocketAddr;
/// 通过 API Token 验证身份
///
/// 流程: SHA-256 哈希 → 查 api_tokens 表 → 检查有效期 → 获取关联账号角色权限 → 更新 last_used_at
async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<String>) -> Result<AuthContext, SaasError> {
use sha2::{Sha256, Digest};
let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
let row: Option<(String, Option<String>, String)> = sqlx::query_as(
"SELECT account_id, expires_at, permissions FROM api_tokens
WHERE token_hash = $1 AND revoked_at IS NULL"
)
.bind(&token_hash)
.fetch_optional(&state.db)
.await?;
let (account_id, expires_at, permissions_json) = row
.ok_or(SaasError::Unauthorized)?;
// 检查是否过期
if let Some(ref exp) = expires_at {
let now = chrono::Utc::now();
if let Ok(exp_time) = chrono::DateTime::parse_from_rfc3339(exp) {
if now >= exp_time.with_timezone(&chrono::Utc) {
return Err(SaasError::Unauthorized);
}
}
}
// 查询关联账号的角色
let (role,): (String,) = sqlx::query_as(
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
)
.bind(&account_id)
.fetch_optional(&state.db)
.await?
.ok_or(SaasError::Unauthorized)?;
// 合并 token 权限与角色权限(去重)
let role_permissions = handlers::get_role_permissions(&state.db, &role).await?;
let token_permissions: Vec<String> = serde_json::from_str(&permissions_json).unwrap_or_default();
let mut permissions = role_permissions;
for p in token_permissions {
if !permissions.contains(&p) {
permissions.push(p);
}
}
// 异步更新 last_used_at不阻塞请求
let db = state.db.clone();
tokio::spawn(async move {
let now = chrono::Utc::now();
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
.bind(now).bind(&token_hash)
.execute(&db).await;
});
Ok(AuthContext {
account_id,
role,
permissions,
client_ip,
})
}
/// 从请求中提取客户端 IP仅信任直连 IP不信任可伪造的 proxy header
fn extract_client_ip(req: &Request) -> Option<String> {
req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|addr| addr.ip().to_string())
}
/// 认证中间件: 从 JWT 或 API Token 提取身份
pub async fn auth_middleware(
State(state): State<AppState>,
mut req: Request,
next: Next,
) -> Response {
let client_ip = extract_client_ip(&req);
let auth_header = req.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());
let result = if let Some(auth) = auth_header {
if let Some(token) = auth.strip_prefix("Bearer ") {
if token.starts_with("zclaw_") {
// API Token 路径
verify_api_token(&state, token, client_ip.clone()).await
} else {
// JWT 路径
jwt::verify_token(token, state.jwt_secret.expose_secret())
.map(|claims| AuthContext {
account_id: claims.sub,
role: claims.role,
permissions: claims.permissions,
client_ip,
})
.map_err(|_| SaasError::Unauthorized)
}
} else {
Err(SaasError::Unauthorized)
}
} else {
Err(SaasError::Unauthorized)
};
match result {
Ok(ctx) => {
req.extensions_mut().insert(ctx);
next.run(req).await
}
Err(e) => e.into_response(),
}
}
/// 路由 (无需认证的端点)
pub fn routes() -> axum::Router<AppState> {
use axum::routing::post;
axum::Router::new()
.route("/api/v1/auth/register", post(handlers::register))
.route("/api/v1/auth/login", post(handlers::login))
}
/// 需要认证的路由
pub fn protected_routes() -> axum::Router<AppState> {
use axum::routing::{get, post, put};
axum::Router::new()
.route("/api/v1/auth/refresh", post(handlers::refresh))
.route("/api/v1/auth/me", get(handlers::me))
.route("/api/v1/auth/password", put(handlers::change_password))
.route("/api/v1/auth/totp/setup", post(totp::setup_totp))
.route("/api/v1/auth/totp/verify", post(totp::verify_totp))
.route("/api/v1/auth/totp/disable", post(totp::disable_totp))
}

View File

@@ -0,0 +1,48 @@
//! 密码哈希 (Argon2id)
use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Argon2,
};
use crate::error::{SaasError, SaasResult};
/// 哈希密码
pub fn hash_password(password: &str) -> SaasResult<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| SaasError::PasswordHash(e.to_string()))?;
Ok(hash.to_string())
}
/// 验证密码
pub fn verify_password(password: &str, hash: &str) -> SaasResult<bool> {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| SaasError::PasswordHash(e.to_string()))?;
Ok(Argon2::default()
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_and_verify() {
let hash = hash_password("correct_password").unwrap();
assert!(verify_password("correct_password", &hash).unwrap());
assert!(!verify_password("wrong_password", &hash).unwrap());
}
#[test]
fn test_different_hashes_for_same_password() {
let hash1 = hash_password("same_password").unwrap();
let hash2 = hash_password("same_password").unwrap();
assert_ne!(hash1, hash2);
assert!(verify_password("same_password", &hash1).unwrap());
assert!(verify_password("same_password", &hash2).unwrap());
}
}

View File

@@ -0,0 +1,258 @@
//! TOTP 双因素认证
use axum::{
extract::{Extension, State},
Json,
};
use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use crate::auth::types::AuthContext;
use crate::auth::handlers::log_operation;
use serde::{Deserialize, Serialize};
/// TOTP 设置响应
#[derive(Debug, Serialize)]
pub struct TotpSetupResponse {
/// otpauth:// URI用于扫码绑定
pub otpauth_uri: String,
/// Base32 编码的密钥(备用手动输入)
pub secret: String,
/// issuer 名称
pub issuer: String,
}
/// TOTP 验证请求
#[derive(Debug, Deserialize)]
pub struct TotpVerifyRequest {
pub code: String,
}
/// TOTP 禁用请求
#[derive(Debug, Deserialize)]
pub struct TotpDisableRequest {
pub password: String,
}
/// 生成随机 Base32 密钥 (20 字节 = 32 字符 Base32)
fn generate_random_secret() -> String {
use rand::Rng;
let mut bytes = [0u8; 20];
rand::thread_rng().fill(&mut bytes);
data_encoding::BASE32.encode(&bytes)
}
/// Base32 解码
fn base32_decode(data: &str) -> Option<Vec<u8>> {
data_encoding::BASE32.decode(data.as_bytes()).ok()
}
/// 生成 TOTP 密钥并返回 otpauth URI
pub fn generate_totp_secret(issuer: &str, account_name: &str) -> TotpSetupResponse {
let secret = generate_random_secret();
let otpauth_uri = format!(
"otpauth://totp/{}:{}?secret={}&issuer={}&algorithm=SHA1&digits=6&period=30",
urlencoding::encode(issuer),
urlencoding::encode(account_name),
secret,
urlencoding::encode(issuer),
);
TotpSetupResponse {
otpauth_uri,
secret,
issuer: issuer.to_string(),
}
}
/// 验证 TOTP 6 位码
pub fn verify_totp_code(secret: &str, code: &str) -> bool {
let secret_bytes = match base32_decode(secret) {
Some(b) => b,
None => return false,
};
let totp = match totp_rs::TOTP::new(
totp_rs::Algorithm::SHA1,
6, // digits
1, // skew (允许 1 个周期偏差)
30, // step (秒)
secret_bytes,
) {
Ok(t) => t,
Err(_) => return false,
};
totp.check_current(code).unwrap_or(false)
}
/// POST /api/v1/auth/totp/setup
/// 生成 TOTP 密钥并返回 otpauth URI
/// 用户扫码后需要调用 /verify 验证一个码才能激活
pub async fn setup_totp(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<TotpSetupResponse>> {
// 如果已启用 TOTP先清除旧密钥
let (username,): (String,) = sqlx::query_as(
"SELECT username FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
.await?;
let config = state.config.read().await;
let setup = generate_totp_secret(&config.auth.totp_issuer, &username);
// 加密 TOTP 密钥后存储 (但不启用,需要 /verify 确认)
let encrypted_secret = state.field_encryption.encrypt(&setup.secret)?;
sqlx::query("UPDATE accounts SET totp_secret = $1 WHERE id = $2")
.bind(&encrypted_secret)
.bind(&ctx.account_id)
.execute(&state.db)
.await?;
log_operation(&state.db, &ctx.account_id, "totp.setup", "account", &ctx.account_id,
None, ctx.client_ip.as_deref()).await?;
Ok(Json(setup))
}
/// POST /api/v1/auth/totp/verify
/// 验证 TOTP 码并启用 2FA
pub async fn verify_totp(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<TotpVerifyRequest>,
) -> SaasResult<Json<serde_json::Value>> {
let code = req.code.trim();
if code.len() != 6 || !code.chars().all(|c| c.is_ascii_digit()) {
return Err(SaasError::InvalidInput("TOTP 码必须是 6 位数字".into()));
}
// 获取存储的密钥
let (totp_secret,): (Option<String>,) = sqlx::query_as(
"SELECT totp_secret FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
.await?;
let secret = totp_secret.ok_or_else(|| {
SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into())
})?;
// 解密 TOTP 密钥(兼容迁移期间的明文数据)
let decrypted_secret = state.field_encryption.decrypt_or_plaintext(&secret);
if !verify_totp_code(&decrypted_secret, code) {
return Err(SaasError::Totp("TOTP 码验证失败".into()));
}
// 验证成功 → 启用 TOTP
let now = chrono::Utc::now();
sqlx::query("UPDATE accounts SET totp_enabled = true, updated_at = $1 WHERE id = $2")
.bind(now)
.bind(&ctx.account_id)
.execute(&state.db)
.await?;
log_operation(&state.db, &ctx.account_id, "totp.verify", "account", &ctx.account_id,
None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": true, "message": "TOTP 已启用"})))
}
/// POST /api/v1/auth/totp/disable
/// 禁用 TOTP (需要密码确认)
pub async fn disable_totp(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<TotpDisableRequest>,
) -> SaasResult<Json<serde_json::Value>> {
// 验证密码
let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
.await?;
if !crate::auth::password::verify_password(&req.password, &password_hash)? {
return Err(SaasError::AuthError("密码错误".into()));
}
// 清除 TOTP
let now = chrono::Utc::now();
sqlx::query("UPDATE accounts SET totp_enabled = false, totp_secret = NULL, updated_at = $1 WHERE id = $2")
.bind(now)
.bind(&ctx.account_id)
.execute(&state.db)
.await?;
log_operation(&state.db, &ctx.account_id, "totp.disable", "account", &ctx.account_id,
None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"})))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_totp_secret_format() {
let result = generate_totp_secret("TestIssuer", "user@example.com");
assert!(result.otpauth_uri.starts_with("otpauth://totp/"));
assert!(result.otpauth_uri.contains("secret="));
assert!(result.otpauth_uri.contains("issuer=TestIssuer"));
assert!(result.otpauth_uri.contains("algorithm=SHA1"));
assert!(result.otpauth_uri.contains("digits=6"));
assert!(result.otpauth_uri.contains("period=30"));
// Base32 编码的 20 字节 = 32 字符
assert_eq!(result.secret.len(), 32);
assert_eq!(result.issuer, "TestIssuer");
}
#[test]
fn test_generate_totp_secret_special_chars() {
let result = generate_totp_secret("My App", "user@domain:8080");
// 特殊字符应被 URL 编码
assert!(!result.otpauth_uri.contains("user@domain:8080"));
assert!(result.otpauth_uri.contains("user%40domain"));
}
#[test]
fn test_verify_totp_code_valid() {
// 使用 generate_random_secret 创建合法 secret然后生成并验证码
let secret = generate_random_secret();
let secret_bytes = data_encoding::BASE32.decode(secret.as_bytes()).unwrap();
let totp = totp_rs::TOTP::new(
totp_rs::Algorithm::SHA1, 6, 1, 30, secret_bytes,
).unwrap();
let valid_code = totp.generate(chrono::Utc::now().timestamp() as u64);
assert!(verify_totp_code(&secret, &valid_code));
}
#[test]
fn test_verify_totp_code_invalid() {
let secret = generate_random_secret();
assert!(!verify_totp_code(&secret, "000000"));
assert!(!verify_totp_code(&secret, "999999"));
assert!(!verify_totp_code(&secret, "abcdef"));
}
#[test]
fn test_verify_totp_code_invalid_secret() {
assert!(!verify_totp_code("not-valid-base32!!!", "123456"));
assert!(!verify_totp_code("", "123456"));
assert!(!verify_totp_code("", "123456"));
}
#[test]
fn test_verify_totp_code_empty() {
let secret = "JBSWY3DPEHPK3PXP";
assert!(!verify_totp_code(secret, ""));
assert!(!verify_totp_code(secret, "12345"));
assert!(!verify_totp_code(secret, "1234567"));
}
}

View File

@@ -0,0 +1,57 @@
//! 认证相关类型
use serde::{Deserialize, Serialize};
/// 登录请求
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct LoginRequest {
pub username: String,
pub password: String,
pub totp_code: Option<String>,
}
/// 登录响应
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct LoginResponse {
pub token: String,
pub account: AccountPublic,
}
/// 注册请求
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct RegisterRequest {
pub username: String,
pub email: String,
pub password: String,
pub display_name: Option<String>,
}
/// 修改密码请求
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct ChangePasswordRequest {
pub old_password: String,
pub new_password: String,
}
/// 公开账号信息 (无敏感数据)
#[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
pub struct AccountPublic {
pub id: String,
pub username: String,
pub email: String,
pub display_name: String,
pub role: String,
pub permissions: Vec<String>,
pub status: String,
pub totp_enabled: bool,
pub created_at: String,
}
/// 认证上下文 (注入到 request extensions)
#[derive(Debug, Clone)]
pub struct AuthContext {
pub account_id: String,
pub role: String,
pub permissions: Vec<String>,
pub client_ip: Option<String>,
}

View File

@@ -0,0 +1,303 @@
//! SaaS 服务器配置
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use secrecy::SecretString;
/// SaaS 服务器完整配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SaaSConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub auth: AuthConfig,
pub relay: RelayConfig,
#[serde(default)]
pub rate_limit: RateLimitConfig,
}
/// 服务器配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerConfig {
#[serde(default = "default_host")]
pub host: String,
#[serde(default = "default_port")]
pub port: u16,
#[serde(default)]
pub cors_origins: Vec<String>,
}
/// 数据库配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
#[serde(default = "default_db_url")]
pub url: String,
}
/// 认证配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
#[serde(default = "default_jwt_hours")]
pub jwt_expiration_hours: i64,
#[serde(default = "default_totp_issuer")]
pub totp_issuer: String,
}
/// 中转服务配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelayConfig {
#[doc(hidden)]
#[serde(default = "default_max_queue")]
pub max_queue_size: usize,
#[doc(hidden)]
#[serde(default = "default_max_concurrent")]
pub max_concurrent_per_provider: usize,
#[doc(hidden)]
#[serde(default = "default_batch_window")]
pub batch_window_ms: u64,
#[serde(default = "default_retry_delay")]
pub retry_delay_ms: u64,
#[serde(default = "default_max_attempts")]
pub max_attempts: u32,
}
fn default_host() -> String { "0.0.0.0".into() }
fn default_port() -> u16 { 8080 }
fn default_db_url() -> String {
// 无默认值:生产环境必须通过 DATABASE_URL 或配置文件设置
// 开发环境可设置 ZCLAW_SAAS_DEV=true 使用 postgres://localhost:5432/zclaw
std::env::var("DATABASE_URL")
.unwrap_or_else(|_| {
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if is_dev {
"postgres://localhost:5432/zclaw".into()
} else {
tracing::error!("DATABASE_URL 未设置且非开发环境");
String::new()
}
})
}
fn default_jwt_hours() -> i64 { 24 }
fn default_totp_issuer() -> String { "ZCLAW SaaS".into() }
fn default_max_queue() -> usize { 1000 }
fn default_max_concurrent() -> usize { 5 }
fn default_batch_window() -> u64 { 50 }
fn default_retry_delay() -> u64 { 1000 }
fn default_max_attempts() -> u32 { 3 }
/// 速率限制配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
/// 每分钟最大请求数 (滑动窗口)
#[serde(default = "default_rpm")]
pub requests_per_minute: u32,
/// 突发允许的额外请求数
#[serde(default = "default_burst")]
pub burst: u32,
}
fn default_rpm() -> u32 { 60 }
fn default_burst() -> u32 { 10 }
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
requests_per_minute: default_rpm(),
burst: default_burst(),
}
}
}
impl Default for SaaSConfig {
fn default() -> Self {
Self {
server: ServerConfig::default(),
database: DatabaseConfig::default(),
auth: AuthConfig::default(),
relay: RelayConfig::default(),
rate_limit: RateLimitConfig::default(),
}
}
}
impl Default for ServerConfig {
fn default() -> Self {
Self {
host: default_host(),
port: default_port(),
cors_origins: Vec::new(),
}
}
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self { url: default_db_url() }
}
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
jwt_expiration_hours: default_jwt_hours(),
totp_issuer: default_totp_issuer(),
}
}
}
impl Default for RelayConfig {
fn default() -> Self {
Self {
max_queue_size: default_max_queue(),
max_concurrent_per_provider: default_max_concurrent(),
batch_window_ms: default_batch_window(),
retry_delay_ms: default_retry_delay(),
max_attempts: default_max_attempts(),
}
}
}
impl SaaSConfig {
/// 加载配置文件,优先级: 环境变量 > ZCLAW_SAAS_CONFIG > ./saas-config.toml
pub fn load() -> anyhow::Result<Self> {
let config_path = std::env::var("ZCLAW_SAAS_CONFIG")
.map(PathBuf::from)
.unwrap_or_else(|_| PathBuf::from("saas-config.toml"));
let config = if config_path.exists() {
let content = std::fs::read_to_string(&config_path)?;
toml::from_str(&content)?
} else {
tracing::warn!("Config file {:?} not found, using defaults", config_path);
SaaSConfig::default()
};
// 验证数据库 URL 已配置
if config.database.url.is_empty() {
anyhow::bail!(
"数据库 URL 未配置。请通过以下方式之一设置:\n\
1. 在配置文件中设置 [database].url\n\
2. 设置 DATABASE_URL 环境变量\n\
开发环境可设置 ZCLAW_SAAS_DEV=true 使用默认值。"
);
}
Ok(config)
}
/// 获取 JWT 密钥 (从环境变量或生成临时值)
/// 生产环境必须设置 ZCLAW_SAAS_JWT_SECRET
pub fn jwt_secret(&self) -> anyhow::Result<SecretString> {
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
match std::env::var("ZCLAW_SAAS_JWT_SECRET") {
Ok(secret) => Ok(SecretString::from(secret)),
Err(_) => {
if is_dev {
tracing::warn!("ZCLAW_SAAS_JWT_SECRET not set, using development default (INSECURE)");
Ok(SecretString::from("zclaw-dev-only-secret-do-not-use-in-prod".to_string()))
} else {
anyhow::bail!(
"ZCLAW_SAAS_JWT_SECRET 环境变量未设置。\
请设置一个强随机密钥 (至少 32 字符)。\
开发环境可设置 ZCLAW_SAAS_DEV=true 使用默认值。"
)
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_has_expected_values() {
let config = SaaSConfig::default();
assert_eq!(config.server.host, "0.0.0.0");
assert_eq!(config.server.port, 8080);
assert!(config.server.cors_origins.is_empty());
assert_eq!(config.auth.jwt_expiration_hours, 24);
assert_eq!(config.auth.totp_issuer, "ZCLAW SaaS");
assert_eq!(config.rate_limit.requests_per_minute, 60);
assert_eq!(config.rate_limit.burst, 10);
assert_eq!(config.relay.max_queue_size, 1000);
assert_eq!(config.relay.max_concurrent_per_provider, 5);
assert_eq!(config.relay.max_attempts, 3);
}
#[test]
fn rate_limit_default_matches_manual() {
let config = SaaSConfig::default();
assert_eq!(config.rate_limit.requests_per_minute, 60);
assert_eq!(config.rate_limit.burst, 10);
}
#[test]
fn parse_minimal_config_toml() {
let toml_str = r#"
[server]
host = "127.0.0.1"
port = 9090
[database]
url = "postgres://localhost/zclaw"
[auth]
jwt_expiration_hours = 48
[relay]
max_queue_size = 500
"#;
let config: SaaSConfig = toml::from_str(toml_str).expect("parse should succeed");
assert_eq!(config.server.host, "127.0.0.1");
assert_eq!(config.server.port, 9090);
assert_eq!(config.database.url, "postgres://localhost/zclaw");
assert_eq!(config.auth.jwt_expiration_hours, 48);
assert_eq!(config.relay.max_queue_size, 500);
// defaults should fill in
assert_eq!(config.rate_limit.requests_per_minute, 60);
assert_eq!(config.relay.max_attempts, 3);
}
#[test]
fn parse_full_config_with_rate_limit() {
let toml_str = r#"
[server]
host = "0.0.0.0"
port = 8080
cors_origins = ["http://localhost:3000", "http://admin.example.com"]
[database]
url = "postgres://db:5432/zclaw"
[auth]
jwt_expiration_hours = 12
totp_issuer = "MyCorp"
[relay]
max_queue_size = 2000
max_concurrent_per_provider = 10
batch_window_ms = 100
retry_delay_ms = 2000
max_attempts = 5
[rate_limit]
requests_per_minute = 120
burst = 20
"#;
let config: SaaSConfig = toml::from_str(toml_str).expect("parse should succeed");
assert_eq!(config.server.cors_origins.len(), 2);
assert_eq!(config.auth.jwt_expiration_hours, 12);
assert_eq!(config.auth.totp_issuer, "MyCorp");
assert_eq!(config.relay.max_concurrent_per_provider, 10);
assert_eq!(config.relay.retry_delay_ms, 2000);
assert_eq!(config.relay.max_attempts, 5);
assert_eq!(config.rate_limit.requests_per_minute, 120);
assert_eq!(config.rate_limit.burst, 20);
}
}

View File

@@ -0,0 +1,277 @@
//! AES-256-GCM 字段级加密
//!
//! 用于加密数据库中存储的敏感字段(如 API Key
//! 每次加密生成随机 12 字节 nonce密文格式: `base64(nonce || ciphertext || tag)`。
use aes_gcm::aead::{AeadInPlace, KeyInit, OsRng};
use aes_gcm::{Aes256Gcm, AeadCore, Nonce};
use data_encoding::BASE64;
use std::fmt;
use crate::error::{SaasError, SaasResult};
/// AES-256-GCM 密钥字节长度
const KEY_LEN: usize = 32;
/// GCM nonce 字节长度 (96-bit推荐值)
const NONCE_LEN: usize = 12;
/// 字段加密器,持有 AES-256-GCM 密钥
///
/// 线程安全,可通过 `Arc` 在多任务间共享。
pub struct FieldEncryption {
cipher: Aes256Gcm,
}
impl fmt::Debug for FieldEncryption {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FieldEncryption")
.field("cipher", &"<redacted>")
.finish()
}
}
impl FieldEncryption {
/// 从环境变量加载或生成加密密钥
///
/// - **生产环境**: 必须设置 `ZCLAW_SAAS_FIELD_ENCRYPTION_KEY`32 字节 hex 编码)
/// - **开发环境** (`ZCLAW_SAAS_DEV=true`): 自动生成随机密钥并输出警告
pub fn new() -> anyhow::Result<Self> {
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
let key_bytes = match std::env::var("ZCLAW_SAAS_FIELD_ENCRYPTION_KEY") {
Ok(hex_key) => {
let bytes = hex::decode(&hex_key).map_err(|e| {
anyhow::anyhow!(
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 格式无效 (期望 64 字符 hex): {e}"
)
})?;
if bytes.len() != KEY_LEN {
anyhow::bail!(
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 长度错误: 期望 {KEY_LEN} 字节, 实际 {} 字节",
bytes.len()
);
}
tracing::info!("Field encryption key loaded from environment");
bytes
}
Err(_) => {
if is_dev {
let random_key: [u8; KEY_LEN] = rand::random();
let hex_key = hex::encode(random_key);
tracing::warn!(
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 未设置,已生成随机密钥 (仅限开发环境):\n {hex_key}\n\
生产环境必须设置此环境变量!"
);
random_key.to_vec()
} else {
anyhow::bail!(
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 环境变量未设置。\n\
请设置一个 32 字节 hex 编码密钥 (64 字符)。\n\
生成方式: openssl rand -hex 32\n\
开发环境可设置 ZCLAW_SAAS_DEV=true 自动生成。"
);
}
}
};
let key = aes_gcm::Key::<Aes256Gcm>::from_slice(&key_bytes);
let cipher = Aes256Gcm::new(key);
Ok(Self { cipher })
}
/// 加密明文,返回 base64 编码密文
///
/// 密文格式: `base64(nonce_12bytes || ciphertext || gcm_tag_16bytes)`
pub fn encrypt(&self, plaintext: &str) -> SaasResult<String> {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let payload = plaintext.as_bytes();
// AeadInPlace::encrypt_in_place_append_tag 会在 payload 后面追加 16 字节 tag
let mut buffer = payload.to_vec();
self.cipher
.encrypt_in_place(&nonce, &[], &mut buffer)
.map_err(|e| SaasError::Encryption(format!("加密失败: {e}")))?;
// 构造输出: nonce (12) || ciphertext + tag
let mut output = Vec::with_capacity(NONCE_LEN + buffer.len());
output.extend_from_slice(&nonce);
output.extend_from_slice(&buffer);
Ok(BASE64.encode(&output))
}
/// 解密 base64 编码密文,返回原始明文
///
/// 输入格式: `base64(nonce_12bytes || ciphertext || gcm_tag_16bytes)`
pub fn decrypt(&self, ciphertext: &str) -> SaasResult<String> {
let raw = BASE64
.decode(ciphertext.as_bytes())
.map_err(|e| SaasError::Encryption(format!("Base64 解码失败: {e}")))?;
if raw.len() < NONCE_LEN {
return Err(SaasError::Encryption(
"密文长度不足: 无法提取 nonce".to_string(),
));
}
let (nonce_bytes, encrypted) = raw.split_at(NONCE_LEN);
let nonce = Nonce::from_slice(nonce_bytes);
let mut buffer = encrypted.to_vec();
self.cipher
.decrypt_in_place(nonce, &[], &mut buffer)
.map_err(|e| SaasError::Encryption(format!("解密失败 (密文可能已损坏或密钥不匹配): {e}")))?;
String::from_utf8(buffer)
.map_err(|e| SaasError::Encryption(format!("解密结果非有效 UTF-8: {e}")))
}
/// 尝试解密,失败时返回原始明文(用于迁移期间兼容未加密的旧数据)
///
/// 在字段加密上线前,数据库中可能已存在未加密的明文数据。
/// 此方法先尝试解密若解密失败Base64 解码失败、GCM 认证失败等),
/// 则假设数据是旧版明文,直接返回原值。
pub fn decrypt_or_plaintext(&self, value: &str) -> String {
self.decrypt(value).unwrap_or_else(|_| value.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
/// 辅助: 用固定密钥创建 FieldEncryption测试专用
fn test_encryption() -> FieldEncryption {
// 固定 32 字节密钥,仅用于测试
let key_bytes: [u8; KEY_LEN] = [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
];
let key = aes_gcm::Key::<Aes256Gcm>::from_slice(&key_bytes);
let cipher = Aes256Gcm::new(key);
FieldEncryption { cipher }
}
#[test]
fn encrypt_produces_base64_output() {
let enc = test_encryption();
let result = enc.encrypt("hello world");
assert!(result.is_ok());
let ciphertext = result.unwrap();
// base64 输出应该能被 BASE64 解码
assert!(BASE64.decode(ciphertext.as_bytes()).is_ok());
}
#[test]
fn encrypt_decrypt_roundtrip() {
let enc = test_encryption();
let plaintext = "sk-proj-abc123SECRET_API_KEY_!@#$%";
let ciphertext = enc.encrypt(plaintext).expect("encrypt should succeed");
let decrypted = enc.decrypt(&ciphertext).expect("decrypt should succeed");
assert_eq!(decrypted, plaintext);
}
#[test]
fn encrypt_decrypt_roundtrip_chinese() {
let enc = test_encryption();
let plaintext = "这是一个包含中文的敏感字段测试";
let ciphertext = enc.encrypt(plaintext).expect("encrypt should succeed");
let decrypted = enc.decrypt(&ciphertext).expect("decrypt should succeed");
assert_eq!(decrypted, plaintext);
}
#[test]
fn different_encryptions_produce_different_ciphertexts() {
let enc = test_encryption();
let plaintext = "same-plaintext";
let ct1 = enc.encrypt(plaintext).unwrap();
let ct2 = enc.encrypt(plaintext).unwrap();
// 由于随机 nonce相同明文的密文应该不同
assert_ne!(ct1, ct2);
}
#[test]
fn decrypt_wrong_key_fails() {
let enc1 = test_encryption();
// 用不同密钥创建另一个加密器
let key_bytes2: [u8; KEY_LEN] = [
0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xf8,
0xf7, 0xf6, 0xf5, 0xf4, 0xf3, 0xf2, 0xf1, 0xf0,
0xef, 0xee, 0xed, 0xec, 0xeb, 0xea, 0xe9, 0xe8,
0xe7, 0xe6, 0xe5, 0xe4, 0xe3, 0xe2, 0xe1, 0xe0,
];
let key2 = aes_gcm::Key::<Aes256Gcm>::from_slice(&key_bytes2);
let cipher2 = Aes256Gcm::new(key2);
let enc2 = FieldEncryption { cipher: cipher2 };
let ciphertext = enc1.encrypt("secret").unwrap();
let result = enc2.decrypt(&ciphertext);
assert!(result.is_err());
}
#[test]
fn decrypt_invalid_base64_fails() {
let enc = test_encryption();
let result = enc.decrypt("not-valid-base64!!!");
assert!(result.is_err());
}
#[test]
fn decrypt_too_short_ciphertext_fails() {
let enc = test_encryption();
// 构造一个短于 12 字节 nonce 的有效 base64 字符串
let short = BASE64.encode(&[0x01, 0x02, 0x03]);
let result = enc.decrypt(&short);
assert!(result.is_err());
}
#[test]
fn decrypt_tampered_ciphertext_fails() {
let enc = test_encryption();
let ciphertext = enc.encrypt("sensitive-data").unwrap();
// 解码、篡改、重新编码
let mut raw = BASE64.decode(ciphertext.as_bytes()).unwrap();
// 翻转 nonce 后的一个字节
let tamper_pos = NONCE_LEN + 2;
if tamper_pos < raw.len() {
raw[tamper_pos] ^= 0xff;
}
let tampered = BASE64.encode(&raw);
let result = enc.decrypt(&tampered);
assert!(result.is_err());
}
#[test]
fn encrypt_empty_string_roundtrip() {
let enc = test_encryption();
let ciphertext = enc.encrypt("").unwrap();
let decrypted = enc.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "");
}
#[test]
fn ciphertext_format_has_nonce_prefix() {
let enc = test_encryption();
let ciphertext = enc.encrypt("test").unwrap();
let raw = BASE64.decode(ciphertext.as_bytes()).unwrap();
// raw 应该 = nonce(12) + ciphertext + tag(16)
// 至少 12 + 16 = 28 字节(明文 4 字节加密后 4 字节 + 16 字节 tag
assert!(raw.len() >= NONCE_LEN + 16);
}
}

View File

@@ -0,0 +1,243 @@
//! CSRF 防护: Origin 校验中间件
//!
//! 对所有状态变更请求 (POST/PUT/PATCH/DELETE) 校验 `Origin` 请求头,
//! 确保其与 `server.cors_origins` 白名单中的某项匹配。
//!
//! - GET / HEAD / OPTIONS 请求跳过校验 (安全方法)
//! - 缺少 Origin 头时拒绝 (403)
//! - Origin 不匹配白名单时拒绝 (403)
//! - `ZCLAW_SAAS_DEV=true` 时跳过校验
//!
//! 这是 Bearer Token API 最合适的 CSRF 防护方案。
//! 如果未来迁移到 Cookie 认证,需要升级为 CSRF Token 方案。
use axum::{
extract::{Request, State},
http::{header, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use tracing::warn;
use crate::state::AppState;
/// 需要进行 Origin 校验的 HTTP 方法
const CSRF_UNSAFE_METHODS: &[&str] = &["POST", "PUT", "PATCH", "DELETE"];
/// Origin 校验中间件
///
/// 在 auth_middleware 之后、rate_limit_middleware 之前执行。
/// 已认证的请求若缺少或不匹配 Origin 头,返回 403 Forbidden。
pub async fn origin_check_middleware(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Response {
// 开发模式跳过校验
if is_dev_mode() {
return next.run(req).await;
}
// 安全方法跳过校验
let method = req.method().as_str().to_uppercase();
if !CSRF_UNSAFE_METHODS.contains(&method.as_str()) {
return next.run(req).await;
}
// 获取 Origin 头
let origin_header = match req.headers().get(header::ORIGIN) {
Some(value) => match value.to_str() {
Ok(origin) => origin,
Err(_) => {
warn!("CSRF: Origin header contains invalid UTF-8");
return csrf_reject("ORIGIN_INVALID", "Origin 请求头格式无效");
}
},
None => {
warn!("CSRF: Missing Origin header on {} {}", method, req.uri());
return csrf_reject("ORIGIN_MISSING", "缺少 Origin 请求头");
}
};
// 从配置读取白名单
let allowed_origins = {
let config = state.config.read().await;
config.server.cors_origins.clone()
};
// 白名单为空时不校验 (生产环境已在 main.rs 中强制要求配置)
if allowed_origins.is_empty() {
return next.run(req).await;
}
// 校验 Origin 是否在白名单中
if !origin_matches_whitelist(origin_header, &allowed_origins) {
warn!(
"CSRF: Origin '{}' not in whitelist for {} {}",
origin_header,
method,
req.uri()
);
return csrf_reject("ORIGIN_NOT_ALLOWED", "Origin 不在允许列表中");
}
next.run(req).await
}
/// 判断是否为开发模式
fn is_dev_mode() -> bool {
std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false)
}
/// 校验 Origin 是否匹配白名单中的某项
///
/// 匹配规则: 精确匹配 (scheme + host + port)。
/// 例如白名单 `https://admin.zclaw.com` 只匹配该 Origin
/// 不匹配 `https://evil.zclaw.com`。
fn origin_matches_whitelist(origin: &str, whitelist: &[String]) -> bool {
// 使用 url::Url 进行规范化比较,避免字符串拼接攻击
let parsed_origin = match url::Url::parse(origin) {
Ok(url) => url,
Err(_) => return false,
};
for allowed in whitelist {
if let Ok(allowed_url) = url::Url::parse(allowed) {
if origins_equal(&parsed_origin, &allowed_url) {
return true;
}
} else {
// 白名单条目本身无法解析,降级为字符串比较
if origin == allowed {
return true;
}
}
}
false
}
/// 比较两个 Origin URL 是否相等 (scheme + host + port)
///
/// 同时拒绝包含路径的 URL: 真实的 Origin 头永远不会包含路径。
/// 如果传入的 origin 字符串包含路径,视为不合法的 Origin。
fn origins_equal(a: &url::Url, b: &url::Url) -> bool {
// scheme 必须完全一致
if a.scheme() != b.scheme() {
return false;
}
// host 必须完全一致
if a.host_str() != b.host_str() {
return false;
}
// port 必须完全一致 (url::Url 会规范化默认端口: 80/HTTP, 443/HTTPS)
if a.port() != b.port() {
return false;
}
// 防御性检查: 合法的 Origin 不应包含路径、query string 或 fragment
// 如果任一 URL 的 path 不是 "/" 或有 query/fragment视为可疑请求
if a.path() != "/" || b.path() != "/" {
return false;
}
if a.query().is_some() || b.query().is_some() {
return false;
}
if a.fragment().is_some() || b.fragment().is_some() {
return false;
}
true
}
/// 返回 403 拒绝响应
fn csrf_reject(error_code: &str, message: &str) -> Response {
(
StatusCode::FORBIDDEN,
[("Content-Type", "application/json")],
axum::Json(serde_json::json!({
"error": error_code,
"message": message,
})),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_origin_matches_whitelist_exact() {
let whitelist = vec![
"https://admin.zclaw.com".to_string(),
"http://localhost:3000".to_string(),
];
assert!(origin_matches_whitelist("https://admin.zclaw.com", &whitelist));
assert!(origin_matches_whitelist("http://localhost:3000", &whitelist));
assert!(!origin_matches_whitelist("https://evil.zclaw.com", &whitelist));
// url::Url normalizes port 443 for HTTPS to None, so these match
assert!(origin_matches_whitelist("https://admin.zclaw.com:443", &whitelist));
assert!(!origin_matches_whitelist("http://localhost:3001", &whitelist));
}
#[test]
fn test_origin_matches_whitelist_empty() {
let whitelist: Vec<String> = vec![];
assert!(!origin_matches_whitelist("https://example.com", &whitelist));
}
#[test]
fn test_origin_matches_whitelist_with_path() {
let whitelist = vec!["https://admin.zclaw.com".to_string()];
// 标准 Origin 不包含路径,应该匹配
assert!(origin_matches_whitelist("https://admin.zclaw.com", &whitelist));
// 包含路径的 Origin 不合法 (浏览器永远不会发送带路径的 Origin)
assert!(!origin_matches_whitelist("https://admin.zclaw.com/evil", &whitelist));
// 带查询字符串的 Origin 也不合法
assert!(!origin_matches_whitelist("https://admin.zclaw.com/?evil=1", &whitelist));
}
#[test]
fn test_origin_matches_whitelist_invalid_origin() {
let whitelist = vec!["https://admin.zclaw.com".to_string()];
assert!(!origin_matches_whitelist("not-a-url", &whitelist));
assert!(!origin_matches_whitelist("", &whitelist));
}
#[test]
fn test_origins_equal() {
let a = url::Url::parse("https://admin.zclaw.com").unwrap();
let b = url::Url::parse("https://admin.zclaw.com").unwrap();
assert!(origins_equal(&a, &b));
// Different scheme
let c = url::Url::parse("http://admin.zclaw.com").unwrap();
assert!(!origins_equal(&a, &c));
// Different host
let d = url::Url::parse("https://evil.zclaw.com").unwrap();
assert!(!origins_equal(&a, &d));
// Different port
let e = url::Url::parse("https://admin.zclaw.com:8443").unwrap();
assert!(!origins_equal(&a, &e));
// Explicit default port vs implicit
let f = url::Url::parse("https://admin.zclaw.com:443").unwrap();
// url::Url normalizes 443 for HTTPS, so both have None port
assert!(origins_equal(&a, &f));
}
#[test]
fn test_is_dev_mode() {
// Don't modify env in tests; just verify the function signature works
// Actual env-var-based behavior tested in integration tests
let _ = is_dev_mode();
}
}

387
crates/zclaw-saas/src/db.rs Normal file
View File

@@ -0,0 +1,387 @@
//! 数据库初始化与 Schema (PostgreSQL)
use sqlx::PgPool;
use crate::error::SaasResult;
const SCHEMA_VERSION: i32 = 2;
const SCHEMA_SQL: &str = r#"
CREATE TABLE IF NOT EXISTS saas_schema_version (
version INTEGER PRIMARY KEY
);
CREATE TABLE IF NOT EXISTS accounts (
id TEXT PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
email TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
display_name TEXT NOT NULL DEFAULT '',
avatar_url TEXT,
role TEXT NOT NULL DEFAULT 'user',
status TEXT NOT NULL DEFAULT 'active',
totp_secret TEXT,
totp_enabled BOOLEAN NOT NULL DEFAULT false,
last_login_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_accounts_email ON accounts(email);
CREATE INDEX IF NOT EXISTS idx_accounts_role ON accounts(role);
CREATE TABLE IF NOT EXISTS api_tokens (
id TEXT PRIMARY KEY,
account_id TEXT NOT NULL,
name TEXT NOT NULL,
token_hash TEXT NOT NULL,
token_prefix TEXT NOT NULL,
permissions TEXT NOT NULL DEFAULT '[]',
last_used_at TIMESTAMPTZ,
expires_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TIMESTAMPTZ,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_api_tokens_account ON api_tokens(account_id);
CREATE INDEX IF NOT EXISTS idx_api_tokens_hash ON api_tokens(token_hash);
CREATE TABLE IF NOT EXISTS roles (
id TEXT PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
description TEXT,
permissions TEXT NOT NULL DEFAULT '[]',
is_system BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS operation_logs (
id BIGSERIAL PRIMARY KEY,
account_id TEXT,
action TEXT NOT NULL,
target_type TEXT,
target_id TEXT,
details TEXT,
ip_address TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_op_logs_account ON operation_logs(account_id);
CREATE INDEX IF NOT EXISTS idx_op_logs_action ON operation_logs(action);
CREATE INDEX IF NOT EXISTS idx_op_logs_time ON operation_logs(created_at);
CREATE TABLE IF NOT EXISTS providers (
id TEXT PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
display_name TEXT NOT NULL,
api_key TEXT,
base_url TEXT NOT NULL,
api_protocol TEXT NOT NULL DEFAULT 'openai',
enabled BOOLEAN NOT NULL DEFAULT true,
rate_limit_rpm INTEGER,
rate_limit_tpm INTEGER,
config_json TEXT DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS models (
id TEXT PRIMARY KEY,
provider_id TEXT NOT NULL,
model_id TEXT NOT NULL,
alias TEXT NOT NULL,
context_window INTEGER NOT NULL DEFAULT 8192,
max_output_tokens INTEGER NOT NULL DEFAULT 4096,
supports_streaming BOOLEAN NOT NULL DEFAULT true,
supports_vision BOOLEAN NOT NULL DEFAULT false,
enabled BOOLEAN NOT NULL DEFAULT true,
pricing_input DOUBLE PRECISION DEFAULT 0,
pricing_output DOUBLE PRECISION DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(provider_id, model_id),
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_models_provider ON models(provider_id);
CREATE TABLE IF NOT EXISTS account_api_keys (
id TEXT PRIMARY KEY,
account_id TEXT NOT NULL,
provider_id TEXT NOT NULL,
key_value TEXT NOT NULL,
key_label TEXT,
permissions TEXT NOT NULL DEFAULT '[]',
enabled BOOLEAN NOT NULL DEFAULT true,
last_used_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TIMESTAMPTZ,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_account_api_keys_account ON account_api_keys(account_id);
CREATE TABLE IF NOT EXISTS usage_records (
id BIGSERIAL PRIMARY KEY,
account_id TEXT NOT NULL,
provider_id TEXT NOT NULL,
model_id TEXT NOT NULL,
input_tokens INTEGER NOT NULL DEFAULT 0,
output_tokens INTEGER NOT NULL DEFAULT 0,
latency_ms INTEGER,
status TEXT NOT NULL DEFAULT 'success',
error_message TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_usage_account ON usage_records(account_id);
CREATE INDEX IF NOT EXISTS idx_usage_time ON usage_records(created_at);
CREATE INDEX IF NOT EXISTS idx_usage_provider ON usage_records(provider_id);
CREATE INDEX IF NOT EXISTS idx_usage_model ON usage_records(model_id);
CREATE TABLE IF NOT EXISTS relay_tasks (
id TEXT PRIMARY KEY,
account_id TEXT NOT NULL,
provider_id TEXT NOT NULL,
model_id TEXT NOT NULL,
request_hash TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
priority INTEGER NOT NULL DEFAULT 0,
attempt_count INTEGER NOT NULL DEFAULT 0,
max_attempts INTEGER NOT NULL DEFAULT 3,
request_body TEXT NOT NULL,
response_body TEXT,
input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0,
error_message TEXT,
queued_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
started_at TIMESTAMPTZ,
completed_at TIMESTAMPTZ,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_relay_status ON relay_tasks(status);
CREATE INDEX IF NOT EXISTS idx_relay_account ON relay_tasks(account_id);
CREATE INDEX IF NOT EXISTS idx_relay_provider ON relay_tasks(provider_id);
CREATE INDEX IF NOT EXISTS idx_relay_account_status ON relay_tasks(account_id, status);
CREATE TABLE IF NOT EXISTS config_items (
id TEXT PRIMARY KEY,
category TEXT NOT NULL,
key_path TEXT NOT NULL,
value_type TEXT NOT NULL,
current_value TEXT,
default_value TEXT,
source TEXT NOT NULL DEFAULT 'local',
description TEXT,
requires_restart BOOLEAN NOT NULL DEFAULT false,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(category, key_path)
);
CREATE INDEX IF NOT EXISTS idx_config_category ON config_items(category);
CREATE TABLE IF NOT EXISTS config_sync_log (
id BIGSERIAL PRIMARY KEY,
account_id TEXT NOT NULL,
client_fingerprint TEXT NOT NULL,
action TEXT NOT NULL,
config_keys TEXT NOT NULL,
client_values TEXT,
saas_values TEXT,
resolution TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_sync_account ON config_sync_log(account_id);
CREATE TABLE IF NOT EXISTS devices (
id TEXT PRIMARY KEY,
account_id TEXT NOT NULL,
device_id TEXT NOT NULL,
device_name TEXT,
platform TEXT,
app_version TEXT,
last_seen_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_devices_account ON devices(account_id);
CREATE INDEX IF NOT EXISTS idx_devices_device_id ON devices(device_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_devices_unique ON devices(account_id, device_id);
"#;
const SEED_ROLES: &str = r#"
INSERT INTO roles (id, name, description, permissions, is_system, created_at, updated_at)
VALUES
('super_admin', '超级管理员', '拥有所有权限', '["admin:full","account:admin","provider:manage","model:manage","relay:admin","config:write"]', true, NOW(), NOW()),
('admin', '管理员', '管理账号和配置', '["account:read","account:admin","provider:manage","model:read","model:manage","relay:use","relay:admin","config:read","config:write"]', true, NOW(), NOW()),
('user', '普通用户', '基础使用权限', '["model:read","relay:use","config:read"]', true, NOW(), NOW())
ON CONFLICT (id) DO NOTHING;
"#;
/// PostgreSQL 不支持在单条 prepared statement 中执行多条 SQL 命令,
/// 因此需要拆分后逐条执行。
async fn execute_multi_statements(pool: &PgPool, sql: &str) -> SaasResult<()> {
for stmt in sql.split(';') {
let trimmed = stmt.trim();
if trimmed.is_empty() {
continue;
}
if let Err(e) = sqlx::query(trimmed).execute(pool).await {
let err_str = e.to_string();
// 忽略 "已存在" 类错误 (并发初始化或重复调用)
let is_already_exists = err_str.contains("already exists")
|| err_str.contains("已经存在")
|| err_str.contains("重复键");
if !is_already_exists {
return Err(e.into());
}
}
}
Ok(())
}
/// 初始化数据库
pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
tracing::info!("Connecting to database: {}", database_url);
let pool = PgPool::connect(database_url).await?;
execute_multi_statements(&pool, SCHEMA_SQL).await?;
execute_multi_statements(&pool, SEED_ROLES).await?;
seed_admin_account(&pool).await?;
tracing::info!("Database initialized (schema v{})", SCHEMA_VERSION);
Ok(pool)
}
/// 创建测试数据库 (连接到真实 PG 实例)
/// 测试前清空所有数据,确保每次从干净状态开始
pub async fn init_test_db() -> SaasResult<PgPool> {
let url = std::env::var("ZCLAW_TEST_DATABASE_URL")
.unwrap_or_else(|_| "postgres://localhost:5432/zclaw_test".to_string());
let pool = PgPool::connect(&url).await?;
execute_multi_statements(&pool, SCHEMA_SQL).await?;
clean_test_data(&pool).await?;
execute_multi_statements(&pool, SEED_ROLES).await?;
Ok(pool)
}
/// 清空所有表数据 (按外键依赖顺序,使用 DELETE 而非 TRUNCATE)
/// DELETE 不获取 ACCESS EXCLUSIVE 锁,对并发更友好
pub async fn clean_test_data(pool: &PgPool) -> SaasResult<()> {
let tables_to_clean = [
"config_sync_log", "config_items", "usage_records", "relay_tasks",
"account_api_keys", "models", "providers", "operation_logs",
"api_tokens", "devices", "roles", "accounts",
];
for table in &tables_to_clean {
let _ = sqlx::query(&format!("DELETE FROM {}", table))
.execute(pool).await;
}
Ok(())
}
/// 如果 accounts 表为空且环境变量已设置,自动创建 super_admin 账号
async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
let has_accounts: (bool,) = sqlx::query_as(
"SELECT EXISTS(SELECT 1 FROM accounts LIMIT 1) as has"
)
.fetch_one(pool)
.await?;
if has_accounts.0 {
return Ok(());
}
let admin_username = std::env::var("ZCLAW_ADMIN_USERNAME")
.unwrap_or_else(|_| "admin".to_string());
let admin_password = match std::env::var("ZCLAW_ADMIN_PASSWORD") {
Ok(pwd) => pwd,
Err(_) => {
tracing::warn!(
"accounts 表为空但未设置 ZCLAW_ADMIN_PASSWORD 环境变量。\
请通过 POST /api/v1/auth/register 注册首个用户,然后手动将其 role 改为 super_admin。\
或设置 ZCLAW_ADMIN_USERNAME 和 ZCLAW_ADMIN_PASSWORD 环境变量后重启服务。"
);
return Ok(());
}
};
use crate::auth::password::hash_password;
let password_hash = hash_password(&admin_password)?;
let account_id = uuid::Uuid::new_v4().to_string();
let email = format!("{}@zclaw.local", admin_username);
sqlx::query(
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, 'super_admin', 'active', NOW(), NOW())"
)
.bind(&account_id)
.bind(&admin_username)
.bind(&email)
.bind(&password_hash)
.bind(&admin_username)
.execute(pool)
.await?;
tracing::info!(
"自动创建 super_admin 账号: username={}, email={}", admin_username, email
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
/// 全局 Mutex 用于序列化所有数据库测试,避免并行测试之间的数据竞争
static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
/// 共享测试连接池,避免每次测试都创建新连接
static TEST_POOL: tokio::sync::OnceCell<PgPool> = tokio::sync::OnceCell::const_new();
/// 获取测试连接池(异步初始化,避免嵌套 runtime 问题)
async fn get_test_pool() -> &'static PgPool {
TEST_POOL.get_or_init(|| async {
init_test_db().await.expect("init_test_db failed")
}).await
}
/// 每个测试前清理数据,确保隔离
async fn clean_before_test(pool: &PgPool) {
clean_test_data(pool).await.expect("clean_test_data failed");
execute_multi_statements(pool, SEED_ROLES).await.expect("seed roles failed");
}
#[tokio::test]
async fn test_init_test_db() {
// 获取全局锁,确保测试串行执行
let _guard = TEST_LOCK.lock().unwrap();
let pool = get_test_pool().await;
clean_before_test(pool).await;
let roles: Vec<(String,)> = sqlx::query_as(
"SELECT id FROM roles WHERE is_system = true"
)
.fetch_all(pool)
.await
.unwrap();
assert_eq!(roles.len(), 3);
}
#[tokio::test]
async fn test_schema_tables_exist() {
let _guard = TEST_LOCK.lock().unwrap();
let pool = get_test_pool().await;
clean_before_test(pool).await;
let tables = [
"accounts", "api_tokens", "roles",
"operation_logs", "providers", "models", "account_api_keys",
"usage_records", "relay_tasks", "config_items", "config_sync_log", "devices",
];
for table in tables {
let count: (i64,) = sqlx::query_as(&format!(
"SELECT COUNT(*) FROM information_schema.tables WHERE table_schema='public' AND table_name='{}'", table
))
.fetch_one(pool)
.await
.unwrap();
assert_eq!(count.0, 1, "Table {} should exist", table);
}
}
}

View File

@@ -0,0 +1,188 @@
//! SaaS 错误类型
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use serde_json::json;
/// SaaS 服务错误类型
#[derive(Debug, thiserror::Error)]
pub enum SaasError {
#[error("未找到: {0}")]
NotFound(String),
#[error("权限不足: {0}")]
Forbidden(String),
#[error("未认证")]
Unauthorized,
#[error("无效输入: {0}")]
InvalidInput(String),
#[error("认证失败: {0}")]
AuthError(String),
#[error("用户已存在: {0}")]
AlreadyExists(String),
#[error("序列化错误: {0}")]
Serialization(#[from] serde_json::Error),
#[error("IO 错误: {0}")]
Io(#[from] std::io::Error),
#[error("数据库错误: {0}")]
Database(#[from] sqlx::Error),
#[error("配置错误: {0}")]
Config(#[from] toml::de::Error),
#[error("JWT 错误: {0}")]
Jwt(#[from] jsonwebtoken::errors::Error),
#[error("密码哈希错误: {0}")]
PasswordHash(String),
#[error("TOTP 错误: {0}")]
Totp(String),
#[error("加密错误: {0}")]
Encryption(String),
#[error("中转错误: {0}")]
Relay(String),
#[error("速率限制: {0}")]
RateLimited(String),
#[error("内部错误: {0}")]
Internal(String),
}
impl SaasError {
/// 获取 HTTP 状态码
pub fn status_code(&self) -> StatusCode {
match self {
Self::NotFound(_) => StatusCode::NOT_FOUND,
Self::Forbidden(_) => StatusCode::FORBIDDEN,
Self::Unauthorized => StatusCode::UNAUTHORIZED,
Self::InvalidInput(_) => StatusCode::BAD_REQUEST,
Self::AlreadyExists(_) => StatusCode::CONFLICT,
Self::RateLimited(_) => StatusCode::TOO_MANY_REQUESTS,
Self::Database(_) | Self::Internal(_) | Self::Io(_) | Self::Serialization(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::AuthError(_) => StatusCode::UNAUTHORIZED,
Self::Jwt(_) | Self::PasswordHash(_) | Self::Encryption(_) => {
StatusCode::INTERNAL_SERVER_ERROR
}
Self::Totp(_) => StatusCode::BAD_REQUEST,
Self::Config(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::Relay(_) => StatusCode::BAD_GATEWAY,
}
}
/// 获取错误代码
pub fn error_code(&self) -> &str {
match self {
Self::NotFound(_) => "NOT_FOUND",
Self::Forbidden(_) => "FORBIDDEN",
Self::Unauthorized => "UNAUTHORIZED",
Self::InvalidInput(_) => "INVALID_INPUT",
Self::AlreadyExists(_) => "ALREADY_EXISTS",
Self::RateLimited(_) => "RATE_LIMITED",
Self::Database(_) => "DATABASE_ERROR",
Self::Io(_) => "IO_ERROR",
Self::Serialization(_) => "SERIALIZATION_ERROR",
Self::Internal(_) => "INTERNAL_ERROR",
Self::AuthError(_) => "AUTH_ERROR",
Self::Jwt(_) => "JWT_ERROR",
Self::PasswordHash(_) => "PASSWORD_HASH_ERROR",
Self::Totp(_) => "TOTP_ERROR",
Self::Encryption(_) => "ENCRYPTION_ERROR",
Self::Config(_) => "CONFIG_ERROR",
Self::Relay(_) => "RELAY_ERROR",
}
}
}
/// 实现 Axum 响应
impl IntoResponse for SaasError {
fn into_response(self) -> Response {
let status = self.status_code();
let (error_code, message) = match &self {
// 500 错误不泄露内部细节给客户端
Self::Database(_) | Self::Internal(_) | Self::Io(_)
| Self::Jwt(_) | Self::Config(_) => {
tracing::error!("内部错误 [{}]: {}", self.error_code(), self);
(self.error_code().to_string(), "服务内部错误".to_string())
}
_ => (self.error_code().to_string(), self.to_string()),
};
let body = json!({
"error": error_code,
"message": message,
});
(status, axum::Json(body)).into_response()
}
}
/// Result 类型别名
pub type SaasResult<T> = std::result::Result<T, SaasError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn status_code_maps_correctly() {
assert_eq!(SaasError::NotFound("x".into()).status_code(), StatusCode::NOT_FOUND);
assert_eq!(SaasError::Forbidden("x".into()).status_code(), StatusCode::FORBIDDEN);
assert_eq!(SaasError::Unauthorized.status_code(), StatusCode::UNAUTHORIZED);
assert_eq!(SaasError::InvalidInput("x".into()).status_code(), StatusCode::BAD_REQUEST);
assert_eq!(SaasError::AlreadyExists("x".into()).status_code(), StatusCode::CONFLICT);
assert_eq!(SaasError::RateLimited("x".into()).status_code(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(SaasError::Relay("x".into()).status_code(), StatusCode::BAD_GATEWAY);
assert_eq!(SaasError::Totp("x".into()).status_code(), StatusCode::BAD_REQUEST);
assert_eq!(SaasError::Internal("x".into()).status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(SaasError::AuthError("x".into()).status_code(), StatusCode::UNAUTHORIZED);
}
#[test]
fn error_code_returns_expected_strings() {
assert_eq!(SaasError::NotFound("x".into()).error_code(), "NOT_FOUND");
assert_eq!(SaasError::RateLimited("x".into()).error_code(), "RATE_LIMITED");
assert_eq!(SaasError::Unauthorized.error_code(), "UNAUTHORIZED");
assert_eq!(SaasError::Encryption("x".into()).error_code(), "ENCRYPTION_ERROR");
}
#[tokio::test]
async fn into_response_hides_internal_errors() {
// 内部错误不应泄露细节
let err = SaasError::Internal("secret database password exposed".into());
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body_bytes = axum::body::to_bytes(resp.into_body(), 1024)
.await
.expect("body should be readable");
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["error"], "INTERNAL_ERROR");
assert_eq!(body["message"], "服务内部错误");
assert!(!body["message"].as_str().unwrap().contains("secret"));
}
#[tokio::test]
async fn into_response_shows_user_facing_errors() {
let err = SaasError::InvalidInput("用户名不能为空".into());
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body_bytes = axum::body::to_bytes(resp.into_body(), 1024)
.await
.expect("body should be readable");
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["error"], "INVALID_INPUT");
// InvalidInput includes the "无效输入: " prefix from Display impl
let msg = body["message"].as_str().unwrap();
assert!(msg.contains("用户名不能为空"));
}
}

View File

@@ -0,0 +1,18 @@
//! ZCLAW SaaS Backend
//!
//! 独立的 SaaS 后端服务,提供账号权限管理、模型配置、请求中转和配置迁移。
pub mod config;
pub mod crypto;
pub mod csrf;
pub mod db;
pub mod error;
pub mod middleware;
pub mod openapi;
pub mod state;
pub mod auth;
pub mod account;
pub mod model_config;
pub mod relay;
pub mod migration;

View File

@@ -0,0 +1,149 @@
//! ZCLAW SaaS 服务入口
use std::time::{Duration, Instant};
use tracing::info;
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
use axum::{extract::State, Json};
async fn health_handler(State(_state): State<AppState>) -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "ok",
"service": "zclaw-saas",
}))
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "zclaw_saas=debug,tower_http=debug".into()),
)
.init();
let config = SaaSConfig::load()?;
info!("SaaS config loaded: {}:{}", config.server.host, config.server.port);
let db = init_db(&config.database.url).await?;
info!("Database initialized");
let state = AppState::new(db, config.clone())?;
// SEC-14: 后台清理 rate_limit_entries DashMap防止不活跃账号条目无限增长。
// 中间件仅在被请求命中时清理对应 entry不活跃的 account 永远不会被回收。
// 此任务每 5 分钟扫描一次,移除所有时间戳均已超过 2 分钟的 entry
// (滑动窗口为 1 分钟2 分钟是安全的 2x 余量)。
{
let rate_limit_entries = state.rate_limit_entries.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(5 * 60)).await;
let cutoff = Instant::now() - Duration::from_secs(2 * 60);
let mut removed = 0usize;
rate_limit_entries.retain(|_account_id, timestamps| {
timestamps.retain(|&ts| ts > cutoff);
let keep = !timestamps.is_empty();
if !keep {
removed += 1;
}
keep
});
if removed > 0 {
info!(
removed,
remaining = rate_limit_entries.len(),
"rate limiter cleanup: removed stale entries"
);
}
}
});
}
// CORS 安全检查:生产环境必须配置 cors_origins
if config.server.cors_origins.is_empty() {
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if !is_dev {
anyhow::bail!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。");
}
}
let app = build_router(state, &config);
// Swagger UI / OpenAPI 文档
// TODO: 启用 Swagger UI 后取消注释 (需要 utoipa / utoipa-swagger-ui 版本对齐)
// let app = {
// use utoipa_swagger_ui::SwaggerUi;
// use utoipa::OpenApi;
// let openapi = zclaw_saas::openapi::ApiDoc::openapi();
// app.merge(
// SwaggerUi::new("/api-docs/openapi.json")
// .url("/api-docs/openapi.json", openapi),
// )
// };
let listener = tokio::net::TcpListener::bind(format!("{}:{}", config.server.host, config.server.port))
.await?;
info!("SaaS server listening on {}:{}", config.server.host, config.server.port);
axum::serve(listener, app.into_make_service_with_connect_info::<std::net::SocketAddr>()).await?;
Ok(())
}
fn build_router(state: AppState, config: &SaaSConfig) -> axum::Router {
use axum::middleware;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use axum::http::HeaderValue;
let cors = {
if config.server.cors_origins.is_empty() {
// 开发环境允许任意 origin生产环境已在 main 中拦截)
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
} else {
let origins: Vec<HeaderValue> = config.server.cors_origins.iter()
.filter_map(|o: &String| o.parse::<HeaderValue>().ok())
.collect();
CorsLayer::new()
.allow_origin(origins)
.allow_methods(Any)
.allow_headers(Any)
}
};
let public_routes = zclaw_saas::auth::routes();
let protected_routes = zclaw_saas::auth::protected_routes()
.merge(zclaw_saas::account::routes())
.merge(zclaw_saas::model_config::routes())
.merge(zclaw_saas::relay::routes())
.merge(zclaw_saas::migration::routes())
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::rate_limit_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::csrf::origin_check_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::auth::auth_middleware,
));
axum::Router::new()
.route("/api/health", axum::routing::get(health_handler))
.merge(public_routes)
.merge(protected_routes)
.layer(axum::extract::DefaultBodyLimit::max(10 * 1024 * 1024)) // 10MB 请求体限制,防止 DoS
.layer(TraceLayer::new_for_http())
.layer(cors)
.with_state(state)
}

View File

@@ -0,0 +1,274 @@
//! 通用中间件
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use std::time::Instant;
use crate::state::AppState;
/// 速率限制检查结果
#[derive(Debug, PartialEq)]
pub(crate) enum RateLimitResult {
/// 允许通过
Allowed,
/// 被限制,附带 Retry-After 秒数
Limited { retry_after_secs: u64 },
}
/// 滑动窗口速率限制核心逻辑(纯函数,便于测试)
///
/// 返回 `RateLimitResult::Allowed` 表示未超限(已记录本次请求),
/// `RateLimitResult::Limited` 表示超限。
pub(crate) fn check_rate_limit(
entries: &mut Vec<Instant>,
now: Instant,
window_duration: std::time::Duration,
max_requests: u64,
) -> RateLimitResult {
let window_start = now - window_duration;
// 清理过期条目
entries.retain(|&ts| ts > window_start);
let count = entries.len() as u64;
if count < max_requests {
entries.push(now);
RateLimitResult::Allowed
} else {
// 计算最早条目的过期时间作为 Retry-After
entries.sort();
let earliest = *entries.first().unwrap_or(&now);
let elapsed = now.duration_since(earliest).as_secs();
let retry_after = window_duration.as_secs().saturating_sub(elapsed);
RateLimitResult::Limited {
retry_after_secs: retry_after,
}
}
}
#[cfg(test)]
/// 清理过期条目并移除空 entry
fn cleanup_stale_entries(
map: &dashmap::DashMap<String, Vec<Instant>>,
cutoff: Instant,
) {
map.retain(|_, entries| {
entries.retain(|&ts| ts > cutoff);
!entries.is_empty()
});
}
/// 滑动窗口速率限制中间件
///
/// 按 account_id (从 AuthContext 提取) 做 per-minute 限流。
/// 超限时返回 429 Too Many Requests + Retry-After header。
pub async fn rate_limit_middleware(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Response {
// 从 AuthContext 提取 account_id由 auth_middleware 在此之前注入)
let account_id = req
.extensions()
.get::<crate::auth::types::AuthContext>()
.map(|ctx| ctx.account_id.clone());
let account_id = match account_id {
Some(id) => id,
None => return next.run(req).await,
};
let config = state.config.read().await;
let rpm = config.rate_limit.requests_per_minute as u64;
let burst = config.rate_limit.burst as u64;
let max_requests = rpm + burst;
drop(config);
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
let current_count = {
let mut entries = state.rate_limit_entries.entry(account_id.clone()).or_default();
let result = check_rate_limit(&mut entries, now, window, max_requests);
if let RateLimitResult::Limited { retry_after_secs } = result {
if let Some(entries) = state.rate_limit_entries.get_mut(&account_id) {
if entries.is_empty() {
drop(entries);
state.rate_limit_entries.remove(&account_id);
}
}
return (
StatusCode::TOO_MANY_REQUESTS,
[
("Retry-After", retry_after_secs.to_string()),
("Content-Type", "application/json".to_string()),
],
axum::Json(serde_json::json!({
"error": "RATE_LIMITED",
"message": format!("请求过于频繁,请在 {} 秒后重试", retry_after_secs),
})),
)
.into_response();
}
entries.len() as u64
};
// 清理空 entry (不再活跃的用户)
if current_count == 0 {
if let Some(entries) = state.rate_limit_entries.get_mut(&account_id) {
if entries.is_empty() {
drop(entries);
state.rate_limit_entries.remove(&account_id);
}
}
}
next.run(req).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_under_limit() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
for i in 0..5 {
let result = check_rate_limit(&mut entries, now, window, 10);
assert_eq!(result, RateLimitResult::Allowed, "request {} should be allowed", i);
}
assert_eq!(entries.len() as u64, 5);
}
#[test]
fn blocks_at_limit() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
let limit: u64 = 3;
// 填到限额
for _ in 0..limit {
let result = check_rate_limit(&mut entries, now, window, limit);
assert_eq!(result, RateLimitResult::Allowed);
}
assert_eq!(entries.len() as u64, limit);
// 下一个应该被限流
let result = check_rate_limit(&mut entries, now, window, limit);
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 60 });
// 不应该增加新条目
assert_eq!(entries.len() as u64, limit);
}
#[test]
fn expired_entries_are_cleaned() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
// 插入一个 61 秒前的旧条目
entries.push(now - std::time::Duration::from_secs(61));
assert_eq!(entries.len(), 1);
// 旧条目应该被清理,然后允许新请求
let result = check_rate_limit(&mut entries, now, window, 1);
assert_eq!(result, RateLimitResult::Allowed);
}
#[test]
fn retry_after_reflects_earliest_entry() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
let limit: u64 = 2;
// 第一个请求在 10 秒前
let first_time = now - std::time::Duration::from_secs(10);
entries.push(first_time);
// 第二个请求现在
entries.push(now);
assert_eq!(entries.len() as u64, limit);
let result = check_rate_limit(&mut entries, now, window, limit);
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 50 });
}
#[test]
fn burst_allows_extra_requests() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
let rpm: u64 = 5;
let burst: u64 = 3;
let max = rpm + burst; // 8
// 前 8 个请求应该全部通过
for _ in 0..max {
let result = check_rate_limit(&mut entries, now, window, max);
assert_eq!(result, RateLimitResult::Allowed);
}
// 第 9 个被限流
let result = check_rate_limit(&mut entries, now, window, max);
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 60 });
}
#[test]
fn cleanup_removes_expired_and_empty() {
let map: dashmap::DashMap<String, Vec<Instant>> = dashmap::DashMap::new();
let now = Instant::now();
let cutoff = now - std::time::Duration::from_secs(120);
// 活跃用户
map.insert("active".to_string(), vec![now]);
// 过期用户
map.insert(
"expired".to_string(),
vec![now - std::time::Duration::from_secs(200)],
);
// 空用户
map.insert("empty".to_string(), vec![]);
cleanup_stale_entries(&map, cutoff);
assert!(map.contains_key("active"));
assert!(!map.contains_key("expired"));
assert!(!map.contains_key("empty"));
}
#[test]
fn empty_entries_allowed() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
let result = check_rate_limit(&mut entries, now, window, 0);
// limit=0 means always limited
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 60 });
}
#[test]
fn single_request_with_large_window() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(3600);
let limit: u64 = 100;
for _ in 0..limit {
let result = check_rate_limit(&mut entries, now, window, limit);
assert_eq!(result, RateLimitResult::Allowed);
}
assert_eq!(entries.len() as u64, limit);
let result = check_rate_limit(&mut entries, now, window, limit);
assert!(matches!(result, RateLimitResult::Limited { .. }));
}
}

View File

@@ -0,0 +1,123 @@
//! 配置迁移 HTTP 处理器
use axum::{
extract::{Extension, Path, Query, State},
http::StatusCode, Json,
};
use crate::state::AppState;
use crate::error::SaasResult;
use crate::auth::types::AuthContext;
use crate::auth::handlers::{check_permission, log_operation};
use super::{types::*, service};
/// GET /api/v1/config/items?category=xxx&source=xxx
pub async fn list_config_items(
State(state): State<AppState>,
Query(query): Query<ConfigQuery>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ConfigItemInfo>>> {
service::list_config_items(&state.db, &query).await.map(Json)
}
/// GET /api/v1/config/items/:id
pub async fn get_config_item(
State(state): State<AppState>,
Path(id): Path<String>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<ConfigItemInfo>> {
service::get_config_item(&state.db, &id).await.map(Json)
}
/// POST /api/v1/config/items (admin only)
pub async fn create_config_item(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateConfigItemRequest>,
) -> SaasResult<(StatusCode, Json<ConfigItemInfo>)> {
check_permission(&ctx, "config:write")?;
let item = service::create_config_item(&state.db, &req).await?;
log_operation(&state.db, &ctx.account_id, "config.create", "config_item", &item.id,
Some(serde_json::json!({"category": req.category, "key_path": req.key_path})),
ctx.client_ip.as_deref()).await?;
Ok((StatusCode::CREATED, Json(item)))
}
/// PUT /api/v1/config/items/:id (admin only)
pub async fn update_config_item(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<UpdateConfigItemRequest>,
) -> SaasResult<Json<ConfigItemInfo>> {
check_permission(&ctx, "config:write")?;
let item = service::update_config_item(&state.db, &id, &req).await?;
log_operation(&state.db, &ctx.account_id, "config.update", "config_item", &id, None,
ctx.client_ip.as_deref()).await?;
Ok(Json(item))
}
/// DELETE /api/v1/config/items/:id (admin only)
pub async fn delete_config_item(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "config:write")?;
service::delete_config_item(&state.db, &id).await?;
log_operation(&state.db, &ctx.account_id, "config.delete", "config_item", &id, None,
ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
/// GET /api/v1/config/analysis
pub async fn analyze_config(
State(state): State<AppState>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<ConfigAnalysis>> {
service::analyze_config(&state.db).await.map(Json)
}
/// POST /api/v1/config/seed (admin only)
pub async fn seed_config(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "config:write")?;
let count = service::seed_default_config_items(&state.db).await?;
log_operation(&state.db, &ctx.account_id, "config.seed", "config_items", "batch",
Some(serde_json::json!({"created": count})),
ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"created": count})))
}
/// POST /api/v1/config/sync (admin only)
pub async fn sync_config(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<SyncConfigRequest>,
) -> SaasResult<Json<super::service::ConfigSyncResult>> {
check_permission(&ctx, "config:write")?;
let result = super::service::sync_config(&state.db, &ctx.account_id, &req).await?;
log_operation(&state.db, &ctx.account_id, "config.sync", "config_sync", &ctx.account_id,
Some(serde_json::json!({"action": req.action, "updated": result.updated, "created": result.created, "skipped": result.skipped})),
ctx.client_ip.as_deref()).await?;
Ok(Json(result))
}
/// POST /api/v1/config/diff
/// 计算客户端与 SaaS 端的配置差异 (不修改数据)
pub async fn config_diff(
State(state): State<AppState>,
Extension(_ctx): Extension<AuthContext>,
Json(req): Json<SyncConfigRequest>,
) -> SaasResult<Json<ConfigDiffResponse>> {
service::compute_config_diff(&state.db, &req).await.map(Json)
}
/// GET /api/v1/config/sync-logs
pub async fn list_sync_logs(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<Vec<ConfigSyncLogInfo>>> {
service::list_sync_logs(&state.db, &ctx.account_id).await.map(Json)
}

View File

@@ -0,0 +1,20 @@
//! 配置迁移模块
pub mod types;
pub mod service;
pub mod handlers;
use axum::routing::{get, post};
use crate::state::AppState;
/// 配置迁移路由 (需要认证)
pub fn routes() -> axum::Router<AppState> {
axum::Router::new()
.route("/api/v1/config/items", get(handlers::list_config_items).post(handlers::create_config_item))
.route("/api/v1/config/items/{id}", get(handlers::get_config_item).put(handlers::update_config_item).delete(handlers::delete_config_item))
.route("/api/v1/config/analysis", get(handlers::analyze_config))
.route("/api/v1/config/seed", post(handlers::seed_config))
.route("/api/v1/config/sync", post(handlers::sync_config))
.route("/api/v1/config/diff", post(handlers::config_diff))
.route("/api/v1/config/sync-logs", get(handlers::list_sync_logs))
}

View File

@@ -0,0 +1,495 @@
//! 配置迁移业务逻辑
use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use super::types::*;
use serde::Serialize;
// ============ Config Items ============
pub async fn list_config_items(
db: &PgPool, query: &ConfigQuery,
) -> SaasResult<Vec<ConfigItemInfo>> {
let sql = match (&query.category, &query.source) {
(Some(_), Some(_)) => {
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE category = $1 AND source = $2 ORDER BY category, key_path"
}
(Some(_), None) => {
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE category = $1 ORDER BY key_path"
}
(None, Some(_)) => {
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE source = $1 ORDER BY category, key_path"
}
(None, None) => {
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items ORDER BY category, key_path"
}
};
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>(sql);
if let Some(cat) = &query.category {
query_builder = query_builder.bind(cat);
}
if let Some(src) = &query.source {
query_builder = query_builder.bind(src);
}
let rows = query_builder.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)| {
ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() }
}).collect())
}
pub async fn get_config_item(db: &PgPool, item_id: &str) -> SaasResult<ConfigItemInfo> {
let row: Option<(String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE id = $1"
)
.bind(item_id)
.fetch_optional(db)
.await?;
let (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("配置项 {} 不存在", item_id)))?;
Ok(ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() })
}
pub async fn create_config_item(
db: &PgPool, req: &CreateConfigItemRequest,
) -> SaasResult<ConfigItemInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
let source = req.source.as_deref().unwrap_or("local");
let requires_restart = req.requires_restart.unwrap_or(false);
// 检查唯一性
let existing: Option<(String,)> = sqlx::query_as(
"SELECT id FROM config_items WHERE category = $1 AND key_path = $2"
)
.bind(&req.category).bind(&req.key_path)
.fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!(
"配置项 {}:{} 已存在", req.category, req.key_path
)));
}
sqlx::query(
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $10)"
)
.bind(&id).bind(&req.category).bind(&req.key_path).bind(&req.value_type)
.bind(&req.current_value).bind(&req.default_value).bind(source)
.bind(&req.description).bind(requires_restart).bind(&now)
.execute(db).await?;
get_config_item(db, &id).await
}
pub async fn update_config_item(
db: &PgPool, item_id: &str, req: &UpdateConfigItemRequest,
) -> SaasResult<ConfigItemInfo> {
let now = chrono::Utc::now();
let mut updates = Vec::new();
let mut params: Vec<String> = Vec::new();
let mut param_idx: i32 = 1;
if let Some(ref v) = req.current_value { updates.push(format!("current_value = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.source { updates.push(format!("source = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.description { updates.push(format!("description = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if updates.is_empty() {
return get_config_item(db, item_id).await;
}
updates.push(format!("updated_at = ${}", param_idx));
param_idx += 1;
params.push(item_id.to_string());
let sql = format!("UPDATE config_items SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(p);
}
query = query.bind(now);
query.execute(db).await?;
get_config_item(db, item_id).await
}
pub async fn delete_config_item(db: &PgPool, item_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM config_items WHERE id = $1")
.bind(item_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("配置项 {} 不存在", item_id)));
}
Ok(())
}
// ============ Config Analysis ============
pub async fn analyze_config(db: &PgPool) -> SaasResult<ConfigAnalysis> {
let items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
let mut categories: std::collections::HashMap<String, (i64, i64)> = std::collections::HashMap::new();
for item in &items {
let entry = categories.entry(item.category.clone()).or_insert((0, 0));
entry.0 += 1;
if item.source == "saas" {
entry.1 += 1;
}
}
let category_summaries: Vec<CategorySummary> = categories.into_iter()
.map(|(category, (count, saas_managed))| CategorySummary { category, count, saas_managed })
.collect();
Ok(ConfigAnalysis {
total_items: items.len() as i64,
categories: category_summaries,
items,
})
}
/// 种子默认配置项
pub async fn seed_default_config_items(db: &PgPool) -> SaasResult<usize> {
let defaults = [
("server", "server.host", "string", Some("127.0.0.1"), Some("127.0.0.1"), "服务器监听地址"),
("server", "server.port", "integer", Some("4200"), Some("4200"), "服务器端口"),
("server", "server.cors_origins", "array", None, None, "CORS 允许的源"),
("agent", "agent.defaults.default_model", "string", Some("zhipu/glm-4-plus"), Some("zhipu/glm-4-plus"), "默认模型"),
("agent", "agent.defaults.fallback_models", "array", None, None, "回退模型列表"),
("agent", "agent.defaults.max_sessions", "integer", Some("10"), Some("10"), "最大并发会话数"),
("agent", "agent.defaults.heartbeat_interval", "duration", Some("1h"), Some("1h"), "心跳间隔"),
("agent", "agent.defaults.session_timeout", "duration", Some("24h"), Some("24h"), "会话超时"),
("memory", "agent.defaults.memory.max_history_length", "integer", Some("100"), Some("100"), "最大历史长度"),
("memory", "agent.defaults.memory.summarize_threshold", "integer", Some("50"), Some("50"), "摘要阈值"),
("llm", "llm.default_provider", "string", Some("zhipu"), Some("zhipu"), "默认 LLM Provider"),
("llm", "llm.temperature", "float", Some("0.7"), Some("0.7"), "默认温度"),
("llm", "llm.max_tokens", "integer", Some("4096"), Some("4096"), "默认最大 token 数"),
];
let mut created = 0;
let now = chrono::Utc::now();
for (category, key_path, value_type, default_value, current_value, description) in defaults {
let existing: Option<(String,)> = sqlx::query_as(
"SELECT id FROM config_items WHERE category = $1 AND key_path = $2"
)
.bind(category).bind(key_path)
.fetch_optional(db)
.await?;
if existing.is_none() {
let id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, 'local', $7, false, $8, $8)"
)
.bind(&id).bind(category).bind(key_path).bind(value_type)
.bind(current_value).bind(default_value).bind(description).bind(&now)
.execute(db)
.await?;
created += 1;
}
}
Ok(created)
}
// ============ Config Sync ============
/// 纯函数:计算客户端与 SaaS 配置项的差异(不依赖数据库)
pub fn compute_diff_items(
config_keys: &[String],
client_values: &serde_json::Value,
saas_items: &[ConfigItemInfo],
) -> (Vec<ConfigDiffItem>, usize) {
let mut items = Vec::new();
let mut conflicts = 0usize;
for key in config_keys {
let client_val = client_values.get(key)
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let saas_item = saas_items.iter().find(|item| item.key_path == *key);
let saas_val = saas_item.and_then(|item| item.current_value.clone());
let conflict = match (&client_val, &saas_val) {
(Some(a), Some(b)) => a != b,
_ => false,
};
if conflict {
conflicts += 1;
}
items.push(ConfigDiffItem {
key_path: key.clone(),
client_value: client_val,
saas_value: saas_val,
conflict,
});
}
(items, conflicts)
}
/// 计算客户端与 SaaS 端的配置差异
pub async fn compute_config_diff(
db: &PgPool, req: &SyncConfigRequest,
) -> SaasResult<ConfigDiffResponse> {
let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
let (items, conflicts) = compute_diff_items(&req.config_keys, &req.client_values, &saas_items);
Ok(ConfigDiffResponse {
total_keys: items.len(),
conflicts,
items,
})
}
/// 执行配置同步 (实际写入 config_items)
pub async fn sync_config(
db: &PgPool, account_id: &str, req: &SyncConfigRequest,
) -> SaasResult<ConfigSyncResult> {
let now = chrono::Utc::now();
let config_keys_str = serde_json::to_string(&req.config_keys)?;
let client_values_str = Some(serde_json::to_string(&req.client_values)?);
// 获取 SaaS 端的配置值
let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
let mut updated = 0i64;
let mut created = 0i64;
let mut skipped = 0i64;
for key in &req.config_keys {
let client_val = req.client_values.get(key)
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let saas_item = saas_items.iter().find(|item| item.key_path == *key);
match req.action.as_str() {
"push" => {
// 客户端推送 → 覆盖 SaaS 值
if let Some(val) = &client_val {
if let Some(item) = saas_item {
// 更新已有配置项
sqlx::query("UPDATE config_items SET current_value = $1, source = 'local', updated_at = $2 WHERE id = $3")
.bind(val).bind(&now).bind(&item.id)
.execute(db).await?;
updated += 1;
} else {
// SaaS 不存在该 key → 自动创建
let new_id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO config_items (id, category, key_path, value_type, current_value, source, requires_restart, created_at, updated_at)
VALUES ($1, 'imported', $2, 'string', $3, 'local', false, $4, $4)"
)
.bind(&new_id).bind(key).bind(val).bind(&now)
.execute(db).await?;
created += 1;
}
}
}
"merge" => {
// 合并: 客户端有值且 SaaS 无值 → 创建; 都有值 → SaaS 优先保留
if let Some(val) = &client_val {
if let Some(item) = saas_item {
if item.current_value.is_none() || item.current_value.as_deref() == Some("") {
sqlx::query("UPDATE config_items SET current_value = $1, source = 'local', updated_at = $2 WHERE id = $3")
.bind(val).bind(&now).bind(&item.id)
.execute(db).await?;
updated += 1;
} else {
// 冲突: SaaS 有值 → 保留 SaaS 值
skipped += 1;
}
} else {
// SaaS 完全没有该 key → 创建
let new_id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO config_items (id, category, key_path, value_type, current_value, source, requires_restart, created_at, updated_at)
VALUES ($1, 'imported', $2, 'string', $3, 'local', false, $4, $4)"
)
.bind(&new_id).bind(key).bind(val).bind(&now)
.execute(db).await?;
created += 1;
}
}
}
_ => {
// 默认: 记录日志但不修改 (向后兼容旧行为)
}
}
}
// 记录同步日志
let saas_values: serde_json::Value = saas_items.iter()
.filter(|item| req.config_keys.contains(&item.key_path))
.map(|item| {
serde_json::json!({
"value": item.current_value,
"source": item.source,
})
})
.collect();
let saas_values_str = Some(serde_json::to_string(&saas_values)?);
let resolution = req.action.clone();
sqlx::query(
"INSERT INTO config_sync_log (account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
)
.bind(account_id).bind(&req.client_fingerprint)
.bind(&req.action).bind(&config_keys_str).bind(&client_values_str)
.bind(&saas_values_str).bind(&resolution).bind(&now)
.execute(db)
.await?;
Ok(ConfigSyncResult { updated, created, skipped })
}
/// 同步结果
#[derive(Debug, Serialize)]
pub struct ConfigSyncResult {
pub updated: i64,
pub created: i64,
pub skipped: i64,
}
pub async fn list_sync_logs(
db: &PgPool, account_id: &str,
) -> SaasResult<Vec<ConfigSyncLogInfo>> {
let rows: Vec<(i64, String, String, String, String, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at
FROM config_sync_log WHERE account_id = $1 ORDER BY created_at DESC LIMIT 50"
)
.bind(account_id)
.fetch_all(db)
.await?;
Ok(rows.into_iter().map(|(id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)| {
ConfigSyncLogInfo { id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at: created_at.to_rfc3339() }
}).collect())
}
#[cfg(test)]
mod tests {
use super::*;
fn make_saas_item(key: &str, value: Option<&str>) -> ConfigItemInfo {
ConfigItemInfo {
id: "test-id".into(),
category: "test".into(),
key_path: key.into(),
value_type: "string".into(),
current_value: value.map(String::from),
default_value: None,
source: "local".into(),
description: None,
requires_restart: false,
created_at: "2026-01-01T00:00:00Z".into(),
updated_at: "2026-01-01T00:00:00Z".into(),
}
}
#[test]
fn test_diff_identical_values() {
let keys = vec!["server.host".into(), "server.port".into()];
let client = serde_json::json!({"server.host": "127.0.0.1", "server.port": "8080"});
let saas = vec![
make_saas_item("server.host", Some("127.0.0.1")),
make_saas_item("server.port", Some("8080")),
];
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert_eq!(conflicts, 0);
assert_eq!(items.len(), 2);
assert!(!items[0].conflict);
assert!(!items[1].conflict);
}
#[test]
fn test_diff_conflict() {
let keys = vec!["server.host".into()];
let client = serde_json::json!({"server.host": "0.0.0.0"});
let saas = vec![make_saas_item("server.host", Some("127.0.0.1"))];
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert_eq!(conflicts, 1);
assert!(items[0].conflict);
assert_eq!(items[0].client_value.as_deref(), Some("0.0.0.0"));
assert_eq!(items[0].saas_value.as_deref(), Some("127.0.0.1"));
}
#[test]
fn test_diff_client_only_key() {
let keys = vec!["new.key".into()];
let client = serde_json::json!({"new.key": "value1"});
let saas = vec![]; // SaaS 没有这个 key
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert_eq!(conflicts, 0);
assert_eq!(items[0].client_value.as_deref(), Some("value1"));
assert!(items[0].saas_value.is_none());
}
#[test]
fn test_diff_missing_client_value() {
let keys = vec!["server.host".into()];
let client = serde_json::json!({}); // 客户端没有这个 key
let saas = vec![make_saas_item("server.host", Some("127.0.0.1"))];
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert_eq!(conflicts, 0); // 一方为 null 不算冲突
assert!(items[0].client_value.is_none());
assert_eq!(items[0].saas_value.as_deref(), Some("127.0.0.1"));
}
#[test]
fn test_diff_empty_keys() {
let keys: Vec<String> = vec![];
let client = serde_json::json!({"server.host": "127.0.0.1"});
let saas = vec![make_saas_item("server.host", Some("127.0.0.1"))];
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert!(items.is_empty());
assert_eq!(conflicts, 0);
}
#[test]
fn test_diff_mixed() {
let keys = vec!["same".into(), "conflict".into(), "client_only".into(), "saas_only".into()];
let client = serde_json::json!({
"same": "val1",
"conflict": "client-val",
"client_only": "new-val",
});
let saas = vec![
make_saas_item("same", Some("val1")),
make_saas_item("conflict", Some("saas-val")),
make_saas_item("saas_only", Some("only-here")),
];
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert_eq!(items.len(), 4);
assert_eq!(conflicts, 1);
// same: no conflict
assert!(!items[0].conflict);
// conflict: has conflict
assert!(items[1].conflict);
// client_only: SaaS has no such key
assert!(items[2].saas_value.is_none());
assert_eq!(items[2].client_value.as_deref(), Some("new-val"));
// saas_only: client has no such key
assert!(items[3].client_value.is_none());
assert_eq!(items[3].saas_value.as_deref(), Some("only-here"));
}
}

View File

@@ -0,0 +1,106 @@
//! 配置迁移类型定义
use serde::{Deserialize, Serialize};
/// 配置项信息
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ConfigItemInfo {
pub id: String,
pub category: String,
pub key_path: String,
pub value_type: String,
pub current_value: Option<String>,
pub default_value: Option<String>,
pub source: String,
pub description: Option<String>,
pub requires_restart: bool,
pub created_at: String,
pub updated_at: String,
}
/// 创建配置项请求
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateConfigItemRequest {
pub category: String,
pub key_path: String,
pub value_type: String,
pub current_value: Option<String>,
pub default_value: Option<String>,
pub source: Option<String>,
pub description: Option<String>,
pub requires_restart: Option<bool>,
}
/// 更新配置项请求
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateConfigItemRequest {
pub current_value: Option<String>,
pub source: Option<String>,
pub description: Option<String>,
}
/// 配置同步日志
#[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
pub struct ConfigSyncLogInfo {
pub id: i64,
pub account_id: String,
pub client_fingerprint: String,
pub action: String,
pub config_keys: String,
pub client_values: Option<String>,
pub saas_values: Option<String>,
pub resolution: Option<String>,
pub created_at: String,
}
/// 配置分析结果
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ConfigAnalysis {
pub total_items: i64,
pub categories: Vec<CategorySummary>,
pub items: Vec<ConfigItemInfo>,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct CategorySummary {
pub category: String,
pub count: i64,
pub saas_managed: i64,
}
/// 配置同步请求
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct SyncConfigRequest {
pub client_fingerprint: String,
/// 同步方向: "push", "merge"
#[serde(default = "default_sync_action")]
pub action: String,
pub config_keys: Vec<String>,
pub client_values: serde_json::Value,
}
fn default_sync_action() -> String { "push".to_string() }
/// 配置差异项
#[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
pub struct ConfigDiffItem {
pub key_path: String,
pub client_value: Option<String>,
pub saas_value: Option<String>,
pub conflict: bool,
}
/// 配置差异响应
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ConfigDiffResponse {
pub items: Vec<ConfigDiffItem>,
pub total_keys: usize,
pub conflicts: usize,
}
/// 配置查询参数
#[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
pub struct ConfigQuery {
pub category: Option<String>,
pub source: Option<String>,
}

View File

@@ -0,0 +1,194 @@
//! 模型配置 HTTP 处理器
use axum::{
extract::{Extension, Path, Query, State},
http::StatusCode, Json,
};
use crate::state::AppState;
use crate::error::SaasResult;
use crate::auth::types::AuthContext;
use crate::auth::handlers::{log_operation, check_permission};
use super::{types::*, service};
// ============ Providers ============
/// GET /api/v1/providers
pub async fn list_providers(
State(state): State<AppState>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ProviderInfo>>> {
service::list_providers(&state.db).await.map(Json)
}
/// GET /api/v1/providers/:id
pub async fn get_provider(
State(state): State<AppState>,
Path(id): Path<String>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<ProviderInfo>> {
service::get_provider(&state.db, &id).await.map(Json)
}
/// POST /api/v1/providers (admin only)
pub async fn create_provider(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateProviderRequest>,
) -> SaasResult<(StatusCode, Json<ProviderInfo>)> {
check_permission(&ctx, "provider:manage")?;
let provider = service::create_provider(&state.db, &state.field_encryption, &req).await?;
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
Ok((StatusCode::CREATED, Json(provider)))
}
/// PUT /api/v1/providers/:id (admin only)
pub async fn update_provider(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<UpdateProviderRequest>,
) -> SaasResult<Json<ProviderInfo>> {
check_permission(&ctx, "provider:manage")?;
let provider = service::update_provider(&state.db, &state.field_encryption, &id, &req).await?;
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(provider))
}
/// DELETE /api/v1/providers/:id (admin only)
pub async fn delete_provider(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "provider:manage")?;
service::delete_provider(&state.db, &id).await?;
log_operation(&state.db, &ctx.account_id, "provider.delete", "provider", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
// ============ Models ============
/// GET /api/v1/models?provider_id=xxx
pub async fn list_models(
State(state): State<AppState>,
Query(params): Query<std::collections::HashMap<String, String>>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ModelInfo>>> {
let provider_id = params.get("provider_id").map(|s| s.as_str());
service::list_models(&state.db, provider_id).await.map(Json)
}
/// GET /api/v1/models/:id
pub async fn get_model(
State(state): State<AppState>,
Path(id): Path<String>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<ModelInfo>> {
service::get_model(&state.db, &id).await.map(Json)
}
/// POST /api/v1/models (admin only)
pub async fn create_model(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateModelRequest>,
) -> SaasResult<(StatusCode, Json<ModelInfo>)> {
check_permission(&ctx, "model:manage")?;
let model = service::create_model(&state.db, &req).await?;
log_operation(&state.db, &ctx.account_id, "model.create", "model", &model.id,
Some(serde_json::json!({"model_id": &req.model_id, "provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?;
Ok((StatusCode::CREATED, Json(model)))
}
/// PUT /api/v1/models/:id (admin only)
pub async fn update_model(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<UpdateModelRequest>,
) -> SaasResult<Json<ModelInfo>> {
check_permission(&ctx, "model:manage")?;
let model = service::update_model(&state.db, &id, &req).await?;
log_operation(&state.db, &ctx.account_id, "model.update", "model", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(model))
}
/// DELETE /api/v1/models/:id (admin only)
pub async fn delete_model(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "model:manage")?;
service::delete_model(&state.db, &id).await?;
log_operation(&state.db, &ctx.account_id, "model.delete", "model", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
// ============ Account API Keys ============
/// GET /api/v1/keys?provider_id=xxx
pub async fn list_api_keys(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Query(params): Query<std::collections::HashMap<String, String>>,
) -> SaasResult<Json<Vec<AccountApiKeyInfo>>> {
let provider_id = params.get("provider_id").map(|s| s.as_str());
service::list_account_api_keys(&state.db, &state.field_encryption, &ctx.account_id, provider_id).await.map(Json)
}
/// POST /api/v1/keys
pub async fn create_api_key(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateAccountApiKeyRequest>,
) -> SaasResult<(StatusCode, Json<AccountApiKeyInfo>)> {
let key = service::create_account_api_key(&state.db, &state.field_encryption, &ctx.account_id, &req).await?;
log_operation(&state.db, &ctx.account_id, "api_key.create", "api_key", &key.id,
Some(serde_json::json!({"provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?;
Ok((StatusCode::CREATED, Json(key)))
}
/// POST /api/v1/keys/:id/rotate
pub async fn rotate_api_key(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<RotateApiKeyRequest>,
) -> SaasResult<Json<serde_json::Value>> {
service::rotate_account_api_key(&state.db, &state.field_encryption, &id, &ctx.account_id, &req.new_key_value).await?;
log_operation(&state.db, &ctx.account_id, "api_key.rotate", "api_key", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
/// DELETE /api/v1/keys/:id
pub async fn revoke_api_key(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
service::revoke_account_api_key(&state.db, &id, &ctx.account_id).await?;
log_operation(&state.db, &ctx.account_id, "api_key.revoke", "api_key", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true})))
}
// ============ Usage ============
/// GET /api/v1/usage?from=...&to=...&provider_id=...&model_id=...
pub async fn get_usage(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Query(params): Query<UsageQuery>,
) -> SaasResult<Json<UsageStats>> {
service::get_usage_stats(&state.db, &ctx.account_id, &params).await.map(Json)
}
/// GET /api/v1/providers/:id/models (便捷路由)
pub async fn list_provider_models(
State(state): State<AppState>,
Path(provider_id): Path<String>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<ModelInfo>>> {
service::list_models(&state.db, Some(&provider_id)).await.map(Json)
}

View File

@@ -0,0 +1,26 @@
//! 模型配置模块
pub mod types;
pub mod service;
pub mod handlers;
use axum::routing::{delete, get, post};
use crate::state::AppState;
/// 模型配置路由 (需要认证)
pub fn routes() -> axum::Router<AppState> {
axum::Router::new()
// Providers
.route("/api/v1/providers", get(handlers::list_providers).post(handlers::create_provider))
.route("/api/v1/providers/{id}", get(handlers::get_provider).put(handlers::update_provider).delete(handlers::delete_provider))
.route("/api/v1/providers/{id}/models", get(handlers::list_provider_models))
// Models
.route("/api/v1/models", get(handlers::list_models).post(handlers::create_model))
.route("/api/v1/models/{id}", get(handlers::get_model).put(handlers::update_model).delete(handlers::delete_model))
// Account API Keys
.route("/api/v1/keys", get(handlers::list_api_keys).post(handlers::create_api_key))
.route("/api/v1/keys/{id}", delete(handlers::revoke_api_key))
.route("/api/v1/keys/{id}/rotate", post(handlers::rotate_api_key))
// Usage
.route("/api/v1/usage", get(handlers::get_usage))
}

View File

@@ -0,0 +1,514 @@
//! 模型配置业务逻辑
use sqlx::PgPool;
use std::sync::Arc;
use crate::crypto::FieldEncryption;
use crate::error::{SaasError, SaasResult};
use super::types::*;
// ============ Providers ============
pub async fn list_providers(db: &PgPool) -> SaasResult<Vec<ProviderInfo>> {
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers ORDER BY name"
)
.fetch_all(db)
.await?;
Ok(rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() }
}).collect())
}
pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers WHERE id = $1"
)
.bind(provider_id)
.fetch_optional(db)
.await?;
let (id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() })
}
pub async fn create_provider(
db: &PgPool, encryption: &Arc<FieldEncryption>, req: &CreateProviderRequest,
) -> SaasResult<ProviderInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
// 检查名称唯一性
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1")
.bind(&req.name).fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
}
// 加密 API Key 后存储
let encrypted_api_key: Option<String> = match &req.api_key {
Some(key) => Some(encryption.encrypt(key)?),
None => None,
};
sqlx::query(
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)"
)
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key)
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
.execute(db).await?;
get_provider(db, &id).await
}
pub async fn update_provider(
db: &PgPool, encryption: &Arc<FieldEncryption>, provider_id: &str, req: &UpdateProviderRequest,
) -> SaasResult<ProviderInfo> {
let now = chrono::Utc::now();
let mut updates = Vec::new();
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
let mut param_idx: i32 = 1;
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.base_url { updates.push(format!("base_url = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.api_protocol { updates.push(format!("api_protocol = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.api_key {
// 加密 API Key 后存储
let encrypted = encryption.encrypt(v)?;
updates.push(format!("api_key = ${}", param_idx));
params.push(Box::new(encrypted));
param_idx += 1;
}
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.rate_limit_rpm { updates.push(format!("rate_limit_rpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.rate_limit_tpm { updates.push(format!("rate_limit_tpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if updates.is_empty() {
return get_provider(db, provider_id).await;
}
updates.push(format!("updated_at = ${}", param_idx));
param_idx += 1;
params.push(Box::new(provider_id.to_string()));
let sql = format!("UPDATE providers SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(format!("{}", p));
}
query = query.bind(now);
query.execute(db).await?;
get_provider(db, provider_id).await
}
pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM providers WHERE id = $1")
.bind(provider_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("Provider {} 不存在", provider_id)));
}
Ok(())
}
// ============ Models ============
pub async fn list_models(db: &PgPool, provider_id: Option<&str>) -> SaasResult<Vec<ModelInfo>> {
let sql = if provider_id.is_some() {
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE provider_id = $1 ORDER BY alias"
} else {
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models ORDER BY provider_id, alias"
};
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>(sql);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() }
}).collect())
}
pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
// 验证 provider 存在
let provider = get_provider(db, &req.provider_id).await?;
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
// 检查 model 唯一性
let existing: Option<(String,)> = sqlx::query_as(
"SELECT id FROM models WHERE provider_id = $1 AND model_id = $2"
)
.bind(&req.provider_id).bind(&req.model_id)
.fetch_optional(db).await?;
if existing.is_some() {
return Err(SaasError::AlreadyExists(format!(
"模型 '{}' 已存在于 provider '{}'", req.model_id, provider.name
)));
}
let ctx = req.context_window.unwrap_or(8192);
let max_out = req.max_output_tokens.unwrap_or(4096);
let streaming = req.supports_streaming.unwrap_or(true);
let vision = req.supports_vision.unwrap_or(false);
let pi = req.pricing_input.unwrap_or(0.0);
let po = req.pricing_output.unwrap_or(0.0);
sqlx::query(
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)"
)
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
.execute(db).await?;
get_model(db, &id).await
}
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE id = $1"
)
.bind(model_id)
.fetch_optional(db)
.await?;
let (id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
Ok(ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() })
}
pub async fn update_model(
db: &PgPool, model_id: &str, req: &UpdateModelRequest,
) -> SaasResult<ModelInfo> {
let now = chrono::Utc::now();
let mut updates = Vec::new();
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
let mut param_idx: i32 = 1;
if let Some(ref v) = req.alias { updates.push(format!("alias = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(v) = req.context_window { updates.push(format!("context_window = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.max_output_tokens { updates.push(format!("max_output_tokens = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.supports_streaming { updates.push(format!("supports_streaming = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.supports_vision { updates.push(format!("supports_vision = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.pricing_input { updates.push(format!("pricing_input = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.pricing_output { updates.push(format!("pricing_output = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if updates.is_empty() {
return get_model(db, model_id).await;
}
updates.push(format!("updated_at = ${}", param_idx));
param_idx += 1;
params.push(Box::new(model_id.to_string()));
let sql = format!("UPDATE models SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql);
for p in &params {
query = query.bind(format!("{}", p));
}
query = query.bind(now);
query.execute(db).await?;
get_model(db, model_id).await
}
pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM models WHERE id = $1")
.bind(model_id).execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("模型 {} 不存在", model_id)));
}
Ok(())
}
// ============ Account API Keys ============
pub async fn list_account_api_keys(
db: &PgPool, encryption: &Arc<FieldEncryption>, account_id: &str, provider_id: Option<&str>,
) -> SaasResult<Vec<AccountApiKeyInfo>> {
let sql = if provider_id.is_some() {
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC"
} else {
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC"
};
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>, String)>(sql)
.bind(account_id);
if let Some(pid) = provider_id {
query = query.bind(pid);
}
let rows = query.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
// 解密 key_value 后再做掩码处理(兼容迁移期间的明文数据)
let decrypted = encryption.decrypt_or_plaintext(&key_value);
let masked = mask_api_key(&decrypted);
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used.map(|t| t.to_rfc3339()), created_at: created_at.to_rfc3339(), masked_key: masked }
}).collect())
}
pub async fn create_account_api_key(
db: &PgPool, encryption: &Arc<FieldEncryption>, account_id: &str, req: &CreateAccountApiKeyRequest,
) -> SaasResult<AccountApiKeyInfo> {
// 验证 provider 存在
get_provider(db, &req.provider_id).await?;
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
let now_str = now.to_rfc3339();
let permissions = serde_json::to_string(&req.permissions)?;
// 加密 key_value 后存储
let encrypted_key_value = encryption.encrypt(&req.key_value)?;
sqlx::query(
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)"
)
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value)
.bind(&req.key_label).bind(&permissions).bind(&now)
.execute(db).await?;
let masked = mask_api_key(&req.key_value);
Ok(AccountApiKeyInfo {
id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(),
permissions: req.permissions.clone(), enabled: true, last_used_at: None,
created_at: now_str, masked_key: masked,
})
}
pub async fn rotate_account_api_key(
db: &PgPool, encryption: &Arc<FieldEncryption>, key_id: &str, account_id: &str, new_key_value: &str,
) -> SaasResult<()> {
let now = chrono::Utc::now();
// 加密新 key_value 后存储
let encrypted_key = encryption.encrypt(new_key_value)?;
let result = sqlx::query(
"UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL"
)
.bind(&encrypted_key).bind(&now).bind(key_id).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
}
Ok(())
}
pub async fn revoke_account_api_key(
db: &PgPool, key_id: &str, account_id: &str,
) -> SaasResult<()> {
let now = chrono::Utc::now();
let result = sqlx::query(
"UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
)
.bind(&now).bind(key_id).bind(account_id)
.execute(db).await?;
if result.rows_affected() == 0 {
return Err(SaasError::NotFound("API Key 不存在或已撤销".into()));
}
Ok(())
}
// ============ Usage Statistics ============
pub async fn get_usage_stats(
db: &PgPool, account_id: &str, query: &UsageQuery,
) -> SaasResult<UsageStats> {
let mut param_idx: i32 = 1;
let mut where_clauses = vec![format!("account_id = ${}", param_idx)];
param_idx += 1;
let mut params: Vec<String> = vec![account_id.to_string()];
if let Some(ref from) = query.from {
where_clauses.push(format!("created_at >= ${}", param_idx));
param_idx += 1;
params.push(from.clone());
}
if let Some(ref to) = query.to {
where_clauses.push(format!("created_at <= ${}", param_idx));
param_idx += 1;
params.push(to.clone());
}
if let Some(ref pid) = query.provider_id {
where_clauses.push(format!("provider_id = ${}", param_idx));
param_idx += 1;
params.push(pid.clone());
}
if let Some(ref mid) = query.model_id {
where_clauses.push(format!("model_id = ${}", param_idx));
params.push(mid.clone());
}
let where_sql = where_clauses.join(" AND ");
// 总量统计
let total_sql = format!(
"SELECT COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE {}", where_sql
);
let mut total_query = sqlx::query_as::<_, (i64, i64, i64)>(&total_sql);
for p in &params {
total_query = total_query.bind(p);
}
let (total_requests, total_input, total_output) = total_query.fetch_one(db).await?;
// 按模型统计
let by_model_sql = format!(
"SELECT provider_id, model_id, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE {} GROUP BY provider_id, model_id ORDER BY COUNT(*) DESC LIMIT 20",
where_sql
);
let mut by_model_query = sqlx::query_as::<_, (String, String, i64, i64, i64)>(&by_model_sql);
for p in &params {
by_model_query = by_model_query.bind(p);
}
let by_model_rows = by_model_query.fetch_all(db).await?;
let by_model: Vec<ModelUsage> = by_model_rows.into_iter()
.map(|(provider_id, model_id, count, input, output)| {
ModelUsage { provider_id, model_id, request_count: count, input_tokens: input, output_tokens: output }
}).collect();
// 按天统计 (最近 30 天)
let from_30d = chrono::Utc::now() - chrono::Duration::days(30);
let daily_sql = format!(
"SELECT DATE(created_at) as day, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE account_id = $1 AND created_at >= $2
GROUP BY DATE(created_at) ORDER BY day DESC LIMIT 30"
);
let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql)
.bind(account_id).bind(&from_30d)
.fetch_all(db).await?;
let by_day: Vec<DailyUsage> = daily_rows.into_iter()
.map(|(date, count, input, output)| {
DailyUsage { date, request_count: count, input_tokens: input, output_tokens: output }
}).collect();
Ok(UsageStats {
total_requests,
total_input_tokens: total_input,
total_output_tokens: total_output,
by_model,
by_day,
})
}
pub async fn record_usage(
db: &PgPool, account_id: &str, provider_id: &str, model_id: &str,
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
status: &str, error_message: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now();
sqlx::query(
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
)
.bind(account_id).bind(provider_id).bind(model_id)
.bind(input_tokens).bind(output_tokens).bind(latency_ms)
.bind(status).bind(error_message).bind(&now)
.execute(db).await?;
Ok(())
}
// ============ Helpers ============
fn mask_api_key(key: &str) -> String {
if key.len() <= 8 {
return "*".repeat(key.len());
}
format!("{}...{}", &key[..4], &key[key.len()-4..])
}
#[cfg(test)]
mod tests {
use super::*;
// ---- mask_api_key ----
#[test]
fn mask_key_long_key() {
let key = "sk-abcdefghijklmnopqrstuvwxyz123456";
let masked = mask_api_key(key);
assert_eq!(masked, "sk-a...3456");
}
#[test]
fn mask_key_exactly_8_chars() {
// keys <= 8 chars are fully masked
let key = "12345678";
let masked = mask_api_key(key);
assert_eq!(masked, "********");
}
#[test]
fn mask_key_7_chars() {
let key = "abcdefg";
let masked = mask_api_key(key);
assert_eq!(masked, "*******");
}
#[test]
fn mask_key_1_char() {
let key = "a";
let masked = mask_api_key(key);
assert_eq!(masked, "*");
}
#[test]
fn mask_key_empty() {
let key = "";
let masked = mask_api_key(key);
assert_eq!(masked, "");
}
#[test]
fn mask_key_9_chars_boundary() {
// 9 chars is the first that uses prefix...suffix format
let key = "abcdefghi";
let masked = mask_api_key(key);
assert_eq!(masked, "abcd...fghi");
}
#[test]
fn mask_key_standard_openai_format() {
let key = "sk-proj-abcdefghijklmnopqrstuvwx";
let masked = mask_api_key(key);
assert_eq!(masked, "sk-p...uvwx");
}
#[test]
fn mask_key_no_ellipsis_for_short() {
let masked = mask_api_key("short");
assert!(!masked.contains("..."));
}
#[test]
fn mask_key_has_ellipsis_for_long() {
let masked = mask_api_key("this_is_a_very_long_key_value");
assert!(masked.contains("..."));
}
}

View File

@@ -0,0 +1,153 @@
//! 模型配置类型定义
use serde::{Deserialize, Serialize};
// --- Provider ---
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ProviderInfo {
pub id: String,
pub name: String,
pub display_name: String,
pub base_url: String,
pub api_protocol: String,
pub enabled: bool,
pub rate_limit_rpm: Option<i64>,
pub rate_limit_tpm: Option<i64>,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateProviderRequest {
pub name: String,
pub display_name: String,
pub base_url: String,
#[serde(default = "default_protocol")]
pub api_protocol: String,
pub api_key: Option<String>,
pub rate_limit_rpm: Option<i64>,
pub rate_limit_tpm: Option<i64>,
}
fn default_protocol() -> String { "openai".into() }
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateProviderRequest {
pub display_name: Option<String>,
pub base_url: Option<String>,
pub api_protocol: Option<String>,
pub api_key: Option<String>,
pub enabled: Option<bool>,
pub rate_limit_rpm: Option<i64>,
pub rate_limit_tpm: Option<i64>,
}
// --- Model ---
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ModelInfo {
pub id: String,
pub provider_id: String,
pub model_id: String,
pub alias: String,
pub context_window: i64,
pub max_output_tokens: i64,
pub supports_streaming: bool,
pub supports_vision: bool,
pub enabled: bool,
pub pricing_input: f64,
pub pricing_output: f64,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateModelRequest {
pub provider_id: String,
pub model_id: String,
pub alias: String,
pub context_window: Option<i64>,
pub max_output_tokens: Option<i64>,
pub supports_streaming: Option<bool>,
pub supports_vision: Option<bool>,
pub pricing_input: Option<f64>,
pub pricing_output: Option<f64>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateModelRequest {
pub alias: Option<String>,
pub context_window: Option<i64>,
pub max_output_tokens: Option<i64>,
pub supports_streaming: Option<bool>,
pub supports_vision: Option<bool>,
pub enabled: Option<bool>,
pub pricing_input: Option<f64>,
pub pricing_output: Option<f64>,
}
// --- Account API Key ---
#[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct AccountApiKeyInfo {
pub id: String,
pub provider_id: String,
pub key_label: Option<String>,
pub permissions: Vec<String>,
pub enabled: bool,
pub last_used_at: Option<String>,
pub created_at: String,
pub masked_key: String,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateAccountApiKeyRequest {
pub provider_id: String,
pub key_value: String,
pub key_label: Option<String>,
#[serde(default)]
pub permissions: Vec<String>,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct RotateApiKeyRequest {
pub new_key_value: String,
}
// --- Usage ---
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct UsageStats {
pub total_requests: i64,
pub total_input_tokens: i64,
pub total_output_tokens: i64,
pub by_model: Vec<ModelUsage>,
pub by_day: Vec<DailyUsage>,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ModelUsage {
pub provider_id: String,
pub model_id: String,
pub request_count: i64,
pub input_tokens: i64,
pub output_tokens: i64,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct DailyUsage {
pub date: String,
pub request_count: i64,
pub input_tokens: i64,
pub output_tokens: i64,
}
#[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
pub struct UsageQuery {
pub from: Option<String>,
pub to: Option<String>,
pub provider_id: Option<String>,
pub model_id: Option<String>,
}

View File

@@ -0,0 +1,790 @@
//! OpenAPI / Swagger 文档定义
//!
//! 聚合所有模块的 schema并在 build_router 中通过 utoipa-swagger-ui 暴露文档。
use utoipa::OpenApi;
/// ZCLAW SaaS API 根 OpenApi 定义
#[derive(OpenApi)]
#[openapi(
info(
title = "ZCLAW SaaS API",
version = "0.1.0",
description = "ZCLAW SaaS 后端服务 API -- 账号权限管理、模型配置、请求中转和配置迁移",
license(name = "Apache-2.0 OR MIT")
),
tags(
(name = "auth", description = "认证 (登录 / 注册 / TOTP)"),
(name = "accounts", description = "账号管理"),
(name = "providers", description = "模型供应商"),
(name = "models", description = "模型配置"),
(name = "keys", description = "API Key 管理"),
(name = "usage", description = "用量统计"),
(name = "relay", description = "请求中转"),
(name = "config", description = "配置迁移"),
),
paths(
crate::openapi::paths::auth::register,
crate::openapi::paths::auth::login,
crate::openapi::paths::auth::refresh,
crate::openapi::paths::auth::me,
crate::openapi::paths::auth::change_password,
crate::openapi::paths::auth::totp_setup,
crate::openapi::paths::auth::totp_verify,
crate::openapi::paths::auth::totp_disable,
crate::openapi::paths::accounts::list_accounts,
crate::openapi::paths::accounts::get_account,
crate::openapi::paths::accounts::update_account,
crate::openapi::paths::accounts::update_status,
crate::openapi::paths::accounts::list_tokens,
crate::openapi::paths::accounts::create_token,
crate::openapi::paths::accounts::revoke_token,
crate::openapi::paths::accounts::list_devices,
crate::openapi::paths::accounts::register_device,
crate::openapi::paths::accounts::device_heartbeat,
crate::openapi::paths::accounts::list_operation_logs,
crate::openapi::paths::accounts::dashboard_stats,
crate::openapi::paths::providers::list_providers,
crate::openapi::paths::providers::get_provider,
crate::openapi::paths::providers::create_provider,
crate::openapi::paths::providers::update_provider,
crate::openapi::paths::providers::delete_provider,
crate::openapi::paths::providers::list_provider_models,
crate::openapi::paths::models::list_models,
crate::openapi::paths::models::get_model,
crate::openapi::paths::models::create_model,
crate::openapi::paths::models::update_model,
crate::openapi::paths::models::delete_model,
crate::openapi::paths::keys::list_api_keys,
crate::openapi::paths::keys::create_api_key,
crate::openapi::paths::keys::revoke_api_key,
crate::openapi::paths::keys::rotate_api_key,
crate::openapi::paths::usage::get_usage,
crate::openapi::paths::relay::chat_completions,
crate::openapi::paths::relay::list_tasks,
crate::openapi::paths::relay::get_task,
crate::openapi::paths::relay::retry_task,
crate::openapi::paths::relay::list_available_models,
crate::openapi::paths::config::list_config_items,
crate::openapi::paths::config::get_config_item,
crate::openapi::paths::config::create_config_item,
crate::openapi::paths::config::update_config_item,
crate::openapi::paths::config::delete_config_item,
crate::openapi::paths::config::analyze_config,
crate::openapi::paths::config::seed_config,
crate::openapi::paths::config::sync_config,
crate::openapi::paths::config::config_diff,
crate::openapi::paths::config::list_sync_logs,
),
components(schemas(
crate::auth::types::LoginRequest,
crate::auth::types::LoginResponse,
crate::auth::types::RegisterRequest,
crate::auth::types::ChangePasswordRequest,
crate::auth::types::AccountPublic,
crate::account::types::UpdateAccountRequest,
crate::account::types::UpdateStatusRequest,
crate::account::types::ListAccountsQuery,
crate::account::types::AccountPublicPaginatedResponse,
crate::account::types::CreateTokenRequest,
crate::account::types::TokenInfo,
crate::account::types::RegisterDeviceRequest,
crate::account::types::DeviceHeartbeatRequest,
crate::account::types::DeviceInfo,
crate::model_config::types::ProviderInfo,
crate::model_config::types::CreateProviderRequest,
crate::model_config::types::UpdateProviderRequest,
crate::model_config::types::ModelInfo,
crate::model_config::types::CreateModelRequest,
crate::model_config::types::UpdateModelRequest,
crate::model_config::types::AccountApiKeyInfo,
crate::model_config::types::CreateAccountApiKeyRequest,
crate::model_config::types::RotateApiKeyRequest,
crate::model_config::types::UsageStats,
crate::model_config::types::ModelUsage,
crate::model_config::types::DailyUsage,
crate::model_config::types::UsageQuery,
crate::relay::types::RelayTaskInfo,
crate::relay::types::RelayTaskQuery,
crate::migration::types::ConfigItemInfo,
crate::migration::types::CreateConfigItemRequest,
crate::migration::types::UpdateConfigItemRequest,
crate::migration::types::ConfigSyncLogInfo,
crate::migration::types::ConfigAnalysis,
crate::migration::types::CategorySummary,
crate::migration::types::SyncConfigRequest,
crate::migration::types::ConfigDiffItem,
crate::migration::types::ConfigDiffResponse,
crate::migration::types::ConfigQuery,
)),
modifiers(&SecurityAddon)
)]
pub struct ApiDoc;
struct SecurityAddon;
impl utoipa::Modify for SecurityAddon {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
if let Some(components) = openapi.components.as_mut() {
components.add_security_scheme(
"bearer_auth",
utoipa::openapi::security::SecurityScheme::Http(
utoipa::openapi::security::Http::new(
utoipa::openapi::security::HttpAuthScheme::Bearer,
),
),
);
}
}
}
/// Path stubs for OpenAPI documentation generation.
/// These functions are never called at runtime -- they exist solely so that
/// `utoipa::path` can produce the correct OpenAPI spec entries.
pub mod paths {
pub mod auth {
#[utoipa::path(
post,
path = "/api/v1/auth/register",
tag = "auth",
request_body = crate::auth::types::RegisterRequest,
responses(
(status = 201, description = "注册成功", body = crate::auth::types::LoginResponse),
(status = 409, description = "用户已存在"),
)
)]
pub async fn register() {}
#[utoipa::path(
post,
path = "/api/v1/auth/login",
tag = "auth",
request_body = crate::auth::types::LoginRequest,
responses(
(status = 200, description = "登录成功", body = crate::auth::types::LoginResponse),
(status = 401, description = "认证失败"),
)
)]
pub async fn login() {}
#[utoipa::path(
post,
path = "/api/v1/auth/refresh",
tag = "auth",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "刷新 token 成功", body = crate::auth::types::LoginResponse),
(status = 401, description = "认证失败"),
)
)]
pub async fn refresh() {}
#[utoipa::path(
get,
path = "/api/v1/auth/me",
tag = "auth",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "当前用户信息", body = crate::auth::types::AccountPublic),
(status = 401, description = "未认证"),
)
)]
pub async fn me() {}
#[utoipa::path(
put,
path = "/api/v1/auth/password",
tag = "auth",
security(("bearer_auth" = [])),
request_body = crate::auth::types::ChangePasswordRequest,
responses(
(status = 200, description = "密码修改成功"),
(status = 400, description = "旧密码不正确"),
(status = 401, description = "未认证"),
)
)]
pub async fn change_password() {}
#[utoipa::path(
post,
path = "/api/v1/auth/totp/setup",
tag = "auth",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "TOTP 设置信息(含 secret 和 QR URI"),
(status = 401, description = "未认证"),
(status = 409, description = "TOTP 已启用"),
)
)]
pub async fn totp_setup() {}
#[utoipa::path(
post,
path = "/api/v1/auth/totp/verify",
tag = "auth",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "验证成功TOTP 已启用"),
(status = 401, description = "验证码错误"),
)
)]
pub async fn totp_verify() {}
#[utoipa::path(
post,
path = "/api/v1/auth/totp/disable",
tag = "auth",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "TOTP 已禁用"),
(status = 401, description = "密码错误"),
)
)]
pub async fn totp_disable() {}
}
pub mod accounts {
#[utoipa::path(
get,
path = "/api/v1/accounts",
tag = "accounts",
security(("bearer_auth" = [])),
params(crate::account::types::ListAccountsQuery),
responses(
(status = 200, description = "账号列表", body = crate::account::types::AccountPublicPaginatedResponse),
)
)]
pub async fn list_accounts() {}
#[utoipa::path(
get,
path = "/api/v1/accounts/{id}",
tag = "accounts",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "账号 ID")),
responses(
(status = 200, description = "账号详情", body = crate::auth::types::AccountPublic),
(status = 404, description = "账号不存在"),
)
)]
pub async fn get_account() {}
#[utoipa::path(
put,
path = "/api/v1/accounts/{id}",
tag = "accounts",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "账号 ID")),
request_body = crate::account::types::UpdateAccountRequest,
responses(
(status = 200, description = "更新成功", body = crate::auth::types::AccountPublic),
(status = 404, description = "账号不存在"),
)
)]
pub async fn update_account() {}
#[utoipa::path(
patch,
path = "/api/v1/accounts/{id}/status",
tag = "accounts",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "账号 ID")),
request_body = crate::account::types::UpdateStatusRequest,
responses(
(status = 200, description = "状态更新成功"),
)
)]
pub async fn update_status() {}
#[utoipa::path(
get,
path = "/api/v1/tokens",
tag = "accounts",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "Token 列表", body = Vec<crate::account::types::TokenInfo>),
)
)]
pub async fn list_tokens() {}
#[utoipa::path(
post,
path = "/api/v1/tokens",
tag = "accounts",
security(("bearer_auth" = [])),
request_body = crate::account::types::CreateTokenRequest,
responses(
(status = 201, description = "创建成功", body = crate::account::types::TokenInfo),
)
)]
pub async fn create_token() {}
#[utoipa::path(
delete,
path = "/api/v1/tokens/{id}",
tag = "accounts",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "Token ID")),
responses(
(status = 204, description = "撤销成功"),
)
)]
pub async fn revoke_token() {}
#[utoipa::path(
get,
path = "/api/v1/devices",
tag = "accounts",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "设备列表", body = Vec<crate::account::types::DeviceInfo>),
)
)]
pub async fn list_devices() {}
#[utoipa::path(
post,
path = "/api/v1/devices/register",
tag = "accounts",
security(("bearer_auth" = [])),
request_body = crate::account::types::RegisterDeviceRequest,
responses(
(status = 201, description = "注册成功", body = crate::account::types::DeviceInfo),
)
)]
pub async fn register_device() {}
#[utoipa::path(
post,
path = "/api/v1/devices/heartbeat",
tag = "accounts",
security(("bearer_auth" = [])),
request_body = crate::account::types::DeviceHeartbeatRequest,
responses(
(status = 200, description = "心跳更新成功"),
)
)]
pub async fn device_heartbeat() {}
#[utoipa::path(
get,
path = "/api/v1/logs/operations",
tag = "accounts",
security(("bearer_auth" = [])),
params(
("page" = Option<i32>, Query, description = "页码"),
("page_size" = Option<i32>, Query, description = "每页数量"),
("action" = Option<String>, Query, description = "操作类型过滤"),
("account_id" = Option<String>, Query, description = "账号 ID 过滤"),
),
responses(
(status = 200, description = "操作日志列表"),
)
)]
pub async fn list_operation_logs() {}
#[utoipa::path(
get,
path = "/api/v1/stats/dashboard",
tag = "accounts",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "仪表盘统计数据"),
)
)]
pub async fn dashboard_stats() {}
}
pub mod providers {
#[utoipa::path(
get,
path = "/api/v1/providers",
tag = "providers",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "供应商列表", body = Vec<crate::model_config::types::ProviderInfo>),
)
)]
pub async fn list_providers() {}
#[utoipa::path(
get,
path = "/api/v1/providers/{id}",
tag = "providers",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "供应商 ID")),
responses(
(status = 200, description = "供应商详情", body = crate::model_config::types::ProviderInfo),
(status = 404, description = "供应商不存在"),
)
)]
pub async fn get_provider() {}
#[utoipa::path(
post,
path = "/api/v1/providers",
tag = "providers",
security(("bearer_auth" = [])),
request_body = crate::model_config::types::CreateProviderRequest,
responses(
(status = 201, description = "创建成功", body = crate::model_config::types::ProviderInfo),
)
)]
pub async fn create_provider() {}
#[utoipa::path(
put,
path = "/api/v1/providers/{id}",
tag = "providers",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "供应商 ID")),
request_body = crate::model_config::types::UpdateProviderRequest,
responses(
(status = 200, description = "更新成功", body = crate::model_config::types::ProviderInfo),
(status = 404, description = "供应商不存在"),
)
)]
pub async fn update_provider() {}
#[utoipa::path(
delete,
path = "/api/v1/providers/{id}",
tag = "providers",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "供应商 ID")),
responses(
(status = 204, description = "删除成功"),
(status = 404, description = "供应商不存在"),
)
)]
pub async fn delete_provider() {}
#[utoipa::path(
get,
path = "/api/v1/providers/{id}/models",
tag = "providers",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "供应商 ID")),
responses(
(status = 200, description = "供应商下的模型列表", body = Vec<crate::model_config::types::ModelInfo>),
)
)]
pub async fn list_provider_models() {}
}
pub mod models {
#[utoipa::path(
get,
path = "/api/v1/models",
tag = "models",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "模型列表", body = Vec<crate::model_config::types::ModelInfo>),
)
)]
pub async fn list_models() {}
#[utoipa::path(
get,
path = "/api/v1/models/{id}",
tag = "models",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "模型 ID")),
responses(
(status = 200, description = "模型详情", body = crate::model_config::types::ModelInfo),
(status = 404, description = "模型不存在"),
)
)]
pub async fn get_model() {}
#[utoipa::path(
post,
path = "/api/v1/models",
tag = "models",
security(("bearer_auth" = [])),
request_body = crate::model_config::types::CreateModelRequest,
responses(
(status = 201, description = "创建成功", body = crate::model_config::types::ModelInfo),
)
)]
pub async fn create_model() {}
#[utoipa::path(
put,
path = "/api/v1/models/{id}",
tag = "models",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "模型 ID")),
request_body = crate::model_config::types::UpdateModelRequest,
responses(
(status = 200, description = "更新成功", body = crate::model_config::types::ModelInfo),
(status = 404, description = "模型不存在"),
)
)]
pub async fn update_model() {}
#[utoipa::path(
delete,
path = "/api/v1/models/{id}",
tag = "models",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "模型 ID")),
responses(
(status = 204, description = "删除成功"),
(status = 404, description = "模型不存在"),
)
)]
pub async fn delete_model() {}
}
pub mod keys {
#[utoipa::path(
get,
path = "/api/v1/keys",
tag = "keys",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "API Key 列表", body = Vec<crate::model_config::types::AccountApiKeyInfo>),
)
)]
pub async fn list_api_keys() {}
#[utoipa::path(
post,
path = "/api/v1/keys",
tag = "keys",
security(("bearer_auth" = [])),
request_body = crate::model_config::types::CreateAccountApiKeyRequest,
responses(
(status = 201, description = "创建成功", body = crate::model_config::types::AccountApiKeyInfo),
)
)]
pub async fn create_api_key() {}
#[utoipa::path(
delete,
path = "/api/v1/keys/{id}",
tag = "keys",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "Key ID")),
responses(
(status = 204, description = "撤销成功"),
)
)]
pub async fn revoke_api_key() {}
#[utoipa::path(
post,
path = "/api/v1/keys/{id}/rotate",
tag = "keys",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "Key ID")),
request_body = crate::model_config::types::RotateApiKeyRequest,
responses(
(status = 200, description = "轮换成功", body = crate::model_config::types::AccountApiKeyInfo),
)
)]
pub async fn rotate_api_key() {}
}
pub mod usage {
#[utoipa::path(
get,
path = "/api/v1/usage",
tag = "usage",
security(("bearer_auth" = [])),
params(crate::model_config::types::UsageQuery),
responses(
(status = 200, description = "用量统计", body = crate::model_config::types::UsageStats),
)
)]
pub async fn get_usage() {}
}
pub mod relay {
#[utoipa::path(
post,
path = "/api/v1/relay/chat/completions",
tag = "relay",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "聊天补全响应JSON 或 SSE 流)"),
(status = 402, description = "上游服务错误"),
(status = 404, description = "模型不存在或未启用"),
)
)]
pub async fn chat_completions() {}
#[utoipa::path(
get,
path = "/api/v1/relay/tasks",
tag = "relay",
security(("bearer_auth" = [])),
params(crate::relay::types::RelayTaskQuery),
responses(
(status = 200, description = "中转任务列表", body = Vec<crate::relay::types::RelayTaskInfo>),
)
)]
pub async fn list_tasks() {}
#[utoipa::path(
get,
path = "/api/v1/relay/tasks/{id}",
tag = "relay",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "任务 ID")),
responses(
(status = 200, description = "任务详情", body = crate::relay::types::RelayTaskInfo),
(status = 404, description = "任务不存在"),
)
)]
pub async fn get_task() {}
#[utoipa::path(
post,
path = "/api/v1/relay/tasks/{id}/retry",
tag = "relay",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "任务 ID")),
responses(
(status = 200, description = "重试成功", body = crate::relay::types::RelayTaskInfo),
(status = 404, description = "任务不存在"),
)
)]
pub async fn retry_task() {}
#[utoipa::path(
get,
path = "/api/v1/relay/models",
tag = "relay",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "可用模型列表", body = Vec<crate::model_config::types::ModelInfo>),
)
)]
pub async fn list_available_models() {}
}
pub mod config {
#[utoipa::path(
get,
path = "/api/v1/config/items",
tag = "config",
security(("bearer_auth" = [])),
params(crate::migration::types::ConfigQuery),
responses(
(status = 200, description = "配置项列表", body = Vec<crate::migration::types::ConfigItemInfo>),
)
)]
pub async fn list_config_items() {}
#[utoipa::path(
get,
path = "/api/v1/config/items/{id}",
tag = "config",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "配置项 ID")),
responses(
(status = 200, description = "配置项详情", body = crate::migration::types::ConfigItemInfo),
(status = 404, description = "配置项不存在"),
)
)]
pub async fn get_config_item() {}
#[utoipa::path(
post,
path = "/api/v1/config/items",
tag = "config",
security(("bearer_auth" = [])),
request_body = crate::migration::types::CreateConfigItemRequest,
responses(
(status = 201, description = "创建成功", body = crate::migration::types::ConfigItemInfo),
)
)]
pub async fn create_config_item() {}
#[utoipa::path(
put,
path = "/api/v1/config/items/{id}",
tag = "config",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "配置项 ID")),
request_body = crate::migration::types::UpdateConfigItemRequest,
responses(
(status = 200, description = "更新成功", body = crate::migration::types::ConfigItemInfo),
(status = 404, description = "配置项不存在"),
)
)]
pub async fn update_config_item() {}
#[utoipa::path(
delete,
path = "/api/v1/config/items/{id}",
tag = "config",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "配置项 ID")),
responses(
(status = 204, description = "删除成功"),
(status = 404, description = "配置项不存在"),
)
)]
pub async fn delete_config_item() {}
#[utoipa::path(
get,
path = "/api/v1/config/analysis",
tag = "config",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "配置分析结果", body = crate::migration::types::ConfigAnalysis),
)
)]
pub async fn analyze_config() {}
#[utoipa::path(
post,
path = "/api/v1/config/seed",
tag = "config",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "种子数据初始化成功"),
)
)]
pub async fn seed_config() {}
#[utoipa::path(
post,
path = "/api/v1/config/sync",
tag = "config",
security(("bearer_auth" = [])),
request_body = crate::migration::types::SyncConfigRequest,
responses(
(status = 200, description = "同步成功"),
)
)]
pub async fn sync_config() {}
#[utoipa::path(
post,
path = "/api/v1/config/diff",
tag = "config",
security(("bearer_auth" = [])),
request_body = crate::migration::types::SyncConfigRequest,
responses(
(status = 200, description = "配置差异", body = crate::migration::types::ConfigDiffResponse),
)
)]
pub async fn config_diff() {}
#[utoipa::path(
get,
path = "/api/v1/config/sync-logs",
tag = "config",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "同步日志列表", body = Vec<crate::migration::types::ConfigSyncLogInfo>),
)
)]
pub async fn list_sync_logs() {}
}
}

View File

@@ -0,0 +1,410 @@
//! 中转服务 HTTP 处理器
use std::sync::Arc;
use tokio::sync::Mutex;
use axum::body::Bytes;
use axum::{
extract::{Extension, Path, Query, State},
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
Json,
};
use crate::state::AppState;
use crate::error::{SaasError, SaasResult};
use crate::auth::types::AuthContext;
use crate::auth::handlers::{log_operation, check_permission};
use crate::model_config::service as model_service;
use super::{types::*, service};
/// POST /api/v1/relay/chat/completions
/// OpenAI 兼容的聊天补全端点
pub async fn chat_completions(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
_headers: HeaderMap,
Json(req): Json<serde_json::Value>,
) -> SaasResult<Response> {
check_permission(&ctx, "relay:use")?;
let model_name = req.get("model")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("缺少 model 字段".into()))?;
let stream = req.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// 查找 model 对应的 provider (直接 SQL 查询,避免全量加载)
let target_model = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool)>(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled
FROM models WHERE model_id = $1 AND enabled = true"
)
.bind(model_name)
.fetch_optional(&state.db)
.await?
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
let (_model_id, provider_id, model_name_db, _, _, _, _, _, _) = target_model;
// 获取 provider 信息
let provider = model_service::get_provider(&state.db, &provider_id).await?;
if !provider.enabled {
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
}
// 优先使用用户级 account_api_key回退到 provider 级 key
let account_key_encrypted: Option<String> = sqlx::query_scalar(
"SELECT key_value FROM account_api_keys
WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL AND enabled = true
ORDER BY created_at DESC LIMIT 1"
)
.bind(&ctx.account_id)
.bind(&provider_id)
.fetch_optional(&state.db)
.await?
.flatten();
let api_key: Option<String> = if let Some(encrypted) = account_key_encrypted {
// 更新 last_used_at
let _ = sqlx::query(
"UPDATE account_api_keys SET last_used_at = NOW() WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL AND enabled = true"
)
.bind(&ctx.account_id)
.bind(&provider_id)
.execute(&state.db)
.await;
Some(state.field_encryption.decrypt_or_plaintext(&encrypted))
} else {
// 回退到 provider 级 key
let provider_key_encrypted: Option<String> = sqlx::query_scalar(
"SELECT api_key FROM providers WHERE id = $1"
)
.bind(&provider_id)
.fetch_optional(&state.db)
.await?
.flatten();
provider_key_encrypted.map(|k| state.field_encryption.decrypt_or_plaintext(&k))
};
if api_key.is_none() {
return Err(SaasError::Internal(format!(
"Provider {} 没有可用的 API Key", provider.name
)));
}
let request_body = serde_json::to_string(&req)?;
// 创建中转任务
let config = state.config.read().await;
let task = service::create_relay_task(
&state.db, &ctx.account_id, &provider_id,
&model_name_db, &request_body, 0,
config.relay.max_attempts,
).await?;
log_operation(&state.db, &ctx.account_id, "relay.request", "relay_task", &task.id,
Some(serde_json::json!({"model": model_name, "stream": stream})), ctx.client_ip.as_deref()).await?;
// 执行中转 (带重试)
let response = service::execute_relay(
&state.db, &task.id, &ctx.account_id, &provider_id, &model_name_db,
&provider.base_url,
api_key.as_deref(), &request_body, stream,
config.relay.max_attempts,
config.relay.retry_delay_ms,
).await;
match response {
Ok(service::RelayResponse::Json(body)) => {
// 记录用量
let parsed: serde_json::Value = serde_json::from_str(&body).unwrap_or_default();
let input_tokens = parsed.get("usage")
.and_then(|u| u.get("prompt_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
let output_tokens = parsed.get("usage")
.and_then(|u| u.get("completion_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
model_service::record_usage(
&state.db, &ctx.account_id, &provider_id,
&model_name_db, input_tokens, output_tokens,
None, "success", None,
).await?;
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
}
Ok(service::RelayResponse::SseWithUsage { body, task_id: relay_task_id, account_id: relay_account_id, provider_id: relay_provider_id, model_id: relay_model_id }) => {
// 流式响应: 使用 async_stream 包装器提取 SSE 末尾的 usage
let wrapped = sse_usage_wrapper(
state.db.clone(),
relay_task_id, relay_account_id, relay_provider_id, relay_model_id,
body,
);
let wrapped_body = axum::body::Body::from_stream(wrapped);
let response = axum::response::Response::builder()
.status(StatusCode::OK)
.header(axum::http::header::CONTENT_TYPE, "text/event-stream")
.header("Cache-Control", "no-cache")
.header("Connection", "keep-alive")
.body(wrapped_body)
.map_err(|e| SaasError::Internal(format!("SSE 响应构建失败: {}", e)))?;
Ok(response)
}
Err(e) => {
model_service::record_usage(
&state.db, &ctx.account_id, &provider_id,
&model_name_db, 0, 0,
None, "failed", Some(&e.to_string()),
).await?;
Err(e)
}
}
}
/// GET /api/v1/relay/tasks
pub async fn list_tasks(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Query(query): Query<RelayTaskQuery>,
) -> SaasResult<Json<Vec<RelayTaskInfo>>> {
service::list_relay_tasks(&state.db, &ctx.account_id, &query).await.map(Json)
}
/// GET /api/v1/relay/tasks/:id
pub async fn get_task(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<RelayTaskInfo>> {
let task = service::get_relay_task(&state.db, &id).await?;
// 只允许查看自己的任务 (admin 可查看全部)
if task.account_id != ctx.account_id {
check_permission(&ctx, "relay:admin")?;
}
Ok(Json(task))
}
/// GET /api/v1/relay/models
/// 列出可用的中转模型 (enabled providers + enabled models)
pub async fn list_available_models(
State(state): State<AppState>,
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<serde_json::Value>>> {
let providers = model_service::list_providers(&state.db).await?;
let enabled_provider_ids: std::collections::HashSet<String> =
providers.iter().filter(|p| p.enabled).map(|p| p.id.clone()).collect();
let models = model_service::list_models(&state.db, None).await?;
let available: Vec<serde_json::Value> = models.into_iter()
.filter(|m| m.enabled && enabled_provider_ids.contains(&m.provider_id))
.map(|m| {
serde_json::json!({
"id": m.model_id,
"provider_id": m.provider_id,
"alias": m.alias,
"context_window": m.context_window,
"max_output_tokens": m.max_output_tokens,
"supports_streaming": m.supports_streaming,
"supports_vision": m.supports_vision,
})
})
.collect();
Ok(Json(available))
}
/// POST /api/v1/relay/tasks/:id/retry (admin only)
/// 重试失败的中转任务
pub async fn retry_task(
State(state): State<AppState>,
Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<(StatusCode, Json<serde_json::Value>)> {
check_permission(&ctx, "relay:admin")?;
let task = service::get_relay_task(&state.db, &id).await?;
if task.status != "failed" {
return Err(SaasError::InvalidInput(format!(
"只能重试失败的任务,当前状态: {}", task.status
)));
}
// 获取 provider 信息
let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
// 重试时使用原始任务所属用户的 account key回退到 provider key
let account_key_encrypted: Option<String> = sqlx::query_scalar(
"SELECT key_value FROM account_api_keys
WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL AND enabled = true
ORDER BY created_at DESC LIMIT 1"
)
.bind(&task.account_id)
.bind(&task.provider_id)
.fetch_optional(&state.db)
.await?
.flatten();
let api_key: Option<String> = if let Some(encrypted) = account_key_encrypted {
Some(state.field_encryption.decrypt_or_plaintext(&encrypted))
} else {
let provider_key_encrypted: Option<String> = sqlx::query_scalar(
"SELECT api_key FROM providers WHERE id = $1"
)
.bind(&task.provider_id)
.fetch_optional(&state.db)
.await?
.flatten();
provider_key_encrypted.map(|k| state.field_encryption.decrypt_or_plaintext(&k))
};
// 读取原始请求体
let request_body: Option<String> = sqlx::query_scalar(
"SELECT request_body FROM relay_tasks WHERE id = $1"
)
.bind(&id)
.fetch_optional(&state.db)
.await?
.flatten();
let body = request_body.ok_or_else(|| SaasError::Internal("任务请求体丢失".into()))?;
// 从 request body 解析 stream 标志
let stream: bool = serde_json::from_str::<serde_json::Value>(&body)
.ok()
.and_then(|v| v.get("stream").and_then(|s| s.as_bool()))
.unwrap_or(false);
let max_attempts = task.max_attempts as u32;
let config = state.config.read().await;
let base_delay_ms = config.relay.retry_delay_ms;
// 重置任务状态为 queued 以允许新的 processing
sqlx::query(
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = $1"
)
.bind(&id)
.execute(&state.db)
.await?;
// 异步执行重试
let db = state.db.clone();
let task_id = id.clone();
let retry_account_id = ctx.account_id.clone();
let retry_provider_id = task.provider_id.clone();
let retry_model_id = task.model_id.clone();
tokio::spawn(async move {
match service::execute_relay(
&db, &task_id, &retry_account_id, &retry_provider_id, &retry_model_id,
&provider.base_url,
api_key.as_deref(), &body, stream,
max_attempts, base_delay_ms,
).await {
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
Err(e) => tracing::warn!("Relay task {} 重试失败: {}", task_id, e),
}
});
log_operation(&state.db, &ctx.account_id, "relay.retry", "relay_task", &id,
None, ctx.client_ip.as_deref()).await?;
Ok((StatusCode::ACCEPTED, Json(serde_json::json!({"ok": true, "task_id": id}))))
}
/// 包装 SSE 流,提取末尾的 usage 数据并异步记录
///
/// 支持客户端断连检测:当 body stream 返回错误(通常表示客户端提前断开连接),
/// 记录日志并将任务标记为 "cancelled" 而非 "completed"。
fn sse_usage_wrapper(
db: sqlx::PgPool,
task_id: String,
account_id: String,
provider_id: String,
model_id: String,
body: axum::body::Body,
) -> impl futures::Stream<Item = Result<Bytes, std::io::Error>> + Send {
use futures::StreamExt;
let last_usage: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let mut saw_done = false;
async_stream::stream! {
let mut data_stream = std::pin::pin!(body.into_data_stream().map(|r| r.map_err(std::io::Error::other)));
loop {
match StreamExt::next(&mut data_stream).await {
Some(Ok(chunk)) => {
let text = String::from_utf8_lossy(&chunk);
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
let trimmed = data.trim();
if trimmed == "[DONE]" {
saw_done = true;
let usage_str = last_usage.lock().await.take();
if let Some(s) = usage_str {
let (input, output) = service::extract_token_usage(&s);
if input > 0 || output > 0 {
let db2 = db.clone();
let tid = task_id.clone();
let aid = account_id.clone();
let pid = provider_id.clone();
let mid = model_id.clone();
tokio::spawn(async move {
let _ = service::update_task_status(
&db2, &tid, "completed",
Some(input), Some(output), None
).await;
let _ = model_service::record_usage(
&db2, &aid, &pid, &mid,
input, output, None, "success", None,
).await;
});
}
}
} else if serde_json::from_str::<serde_json::Value>(trimmed)
.ok()
.and_then(|v| if v.get("usage").is_some() { Some(trimmed.to_string()) } else { None })
.is_some()
{
*last_usage.lock().await = Some(trimmed.to_owned());
}
}
}
yield Ok(chunk);
}
Some(Err(e)) => {
// 客户端断连或上游连接中断
if !saw_done {
tracing::warn!(
"SSE stream error for task {} (client disconnected): {}",
task_id, e
);
// 将任务标记为 cancelled区别于 completed 和 failed
let db2 = db.clone();
let tid = task_id.clone();
tokio::spawn(async move {
let _ = service::update_task_status(
&db2, &tid, "cancelled",
None, None, Some("客户端断开连接"),
).await;
});
}
break;
}
None => {
// Stream 正常结束(上游发送完毕)
if !saw_done {
// 上游关闭但未发送 [DONE],仍记录完成
tracing::info!(
"SSE stream ended without [DONE] for task {}",
task_id,
);
}
break;
}
}
}
}
}

View File

@@ -0,0 +1,18 @@
//! 中转服务模块
pub mod types;
pub mod service;
pub mod handlers;
use axum::routing::{get, post};
use crate::state::AppState;
/// 中转服务路由 (需要认证)
pub fn routes() -> axum::Router<AppState> {
axum::Router::new()
.route("/api/v1/relay/chat/completions", post(handlers::chat_completions))
.route("/api/v1/relay/tasks", get(handlers::list_tasks))
.route("/api/v1/relay/tasks/{id}", get(handlers::get_task))
.route("/api/v1/relay/tasks/{id}/retry", post(handlers::retry_task))
.route("/api/v1/relay/models", get(handlers::list_available_models))
}

View File

@@ -0,0 +1,680 @@
//! 中转服务核心逻辑
use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use super::types::*;
use futures::StreamExt;
/// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429)
fn is_retryable_status(status: u16) -> bool {
status == 429 || (500..600).contains(&status)
}
/// 判断 reqwest 错误是否为可重试的网络错误
fn is_retryable_error(e: &reqwest::Error) -> bool {
e.is_timeout() || e.is_connect() || e.is_request()
}
// ============ Relay Task Management ============
pub async fn create_relay_task(
db: &PgPool,
account_id: &str,
provider_id: &str,
model_id: &str,
request_body: &str,
_priority: i64,
max_attempts: u32,
) -> SaasResult<RelayTaskInfo> {
let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
let max_attempts = max_attempts.max(1).min(5);
sqlx::query(
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
VALUES ($1, $2, $3, $4, '', $5, 'queued', 0, 0, $6, $7, $7)"
)
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
.bind(request_body).bind(max_attempts as i64).bind(&now)
.execute(db).await?;
get_relay_task(db, &id).await
}
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, chrono::DateTime<chrono::Utc>, Option<chrono::DateTime<chrono::Utc>>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as(
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE id = $1"
)
.bind(task_id)
.fetch_optional(db)
.await?;
let (id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at) =
row.ok_or_else(|| SaasError::NotFound(format!("中转任务 {} 不存在", task_id)))?;
Ok(RelayTaskInfo {
id, account_id, provider_id, model_id, status, priority,
attempt_count, max_attempts, input_tokens, output_tokens,
error_message, queued_at: queued_at.to_rfc3339(), started_at: started_at.map(|t| t.to_rfc3339()), completed_at: completed_at.map(|t| t.to_rfc3339()), created_at: created_at.to_rfc3339(),
})
}
pub async fn list_relay_tasks(
db: &PgPool, account_id: &str, query: &RelayTaskQuery,
) -> SaasResult<Vec<RelayTaskInfo>> {
let page = query.page.unwrap_or(1).max(1);
let page_size = query.page_size.unwrap_or(20).min(100);
let offset = (page - 1) * page_size;
let sql = if query.status.is_some() {
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT $3 OFFSET $4"
} else {
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
};
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, chrono::DateTime<chrono::Utc>, Option<chrono::DateTime<chrono::Utc>>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)>(sql)
.bind(account_id);
if let Some(ref status) = query.status {
query_builder = query_builder.bind(status);
}
query_builder = query_builder.bind(page_size).bind(offset);
let rows = query_builder.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at: queued_at.to_rfc3339(), started_at: started_at.map(|t| t.to_rfc3339()), completed_at: completed_at.map(|t| t.to_rfc3339()), created_at: created_at.to_rfc3339() }
}).collect())
}
pub async fn update_task_status(
db: &PgPool, task_id: &str, status: &str,
input_tokens: Option<i64>, output_tokens: Option<i64>,
error_message: Option<&str>,
) -> SaasResult<()> {
let now = chrono::Utc::now();
let update_sql = match status {
"processing" => "started_at = $1, status = 'processing', attempt_count = attempt_count + 1",
"completed" => "completed_at = $1, status = 'completed', input_tokens = COALESCE($2, input_tokens), output_tokens = COALESCE($3, output_tokens)",
"failed" => "completed_at = $1, status = 'failed', error_message = $2",
"cancelled" => "completed_at = $1, status = 'cancelled', error_message = $2",
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
};
let sql = format!("UPDATE relay_tasks SET {} WHERE id = $4", update_sql);
let mut query = sqlx::query(&sql).bind(&now);
if status == "completed" {
query = query.bind(input_tokens).bind(output_tokens);
}
if status == "failed" || status == "cancelled" {
query = query.bind(error_message);
}
query = query.bind(task_id);
query.execute(db).await?;
Ok(())
}
// ============ Relay Execution ============
pub async fn execute_relay(
db: &PgPool,
task_id: &str,
account_id: &str,
provider_id: &str,
model_id: &str,
provider_base_url: &str,
provider_api_key: Option<&str>,
request_body: &str,
stream: bool,
max_attempts: u32,
base_delay_ms: u64,
) -> SaasResult<RelayResponse> {
validate_provider_url(provider_base_url)?;
// DNS Rebinding 防护: 解析 host 并验证所有 resolved IP 非私有
let parsed_url: url::Url = provider_base_url.trim_end_matches('/').parse()
.map_err(|_| SaasError::InvalidInput(format!("无效的 provider URL: {}", provider_base_url)))?;
let host = parsed_url.host_str()
.ok_or_else(|| SaasError::InvalidInput("provider URL 缺少 host".into()))?;
// 仅对非 IP 的 host 做 DNS 解析(纯 IP 已在 validate_provider_url 中检查)
if host.parse::<std::net::IpAddr>().is_err() {
let port = parsed_url.port_or_known_default().unwrap_or(443);
let addr_str = format!("{}:{}", host, port);
let addrs: Vec<std::net::SocketAddr> = std::net::ToSocketAddrs::to_socket_addrs(&addr_str)
.map_err(|e| SaasError::InvalidInput(format!("DNS 解析失败: {}", e)))?
.collect();
if addrs.is_empty() {
return Err(SaasError::InvalidInput(format!("DNS 解析无结果: {}", host)));
}
for addr in &addrs {
if is_private_ip(&addr.ip()) {
return Err(SaasError::InvalidInput(format!(
"provider URL {} 解析到私有 IP: {}", host, addr.ip()
)));
}
}
}
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(if stream { 300 } else { 30 }))
.build()
.map_err(|e| SaasError::Internal(format!("HTTP 客户端构建失败: {}", e)))?;
let max_attempts = max_attempts.max(1).min(5);
for attempt in 0..max_attempts {
let is_first = attempt == 0;
if is_first {
update_task_status(db, task_id, "processing", None, None, None).await?;
}
let mut req_builder = client.post(&url)
.header("Content-Type", "application/json")
.body(request_body.to_string());
if let Some(key) = provider_api_key {
req_builder = req_builder.header("Authorization", format!("Bearer {}", key));
}
let result = req_builder.send().await;
match result {
Ok(resp) if resp.status().is_success() => {
// 成功
if stream {
let byte_stream = resp.bytes_stream()
.map(|result| result.map_err(std::io::Error::other));
let body = axum::body::Body::from_stream(byte_stream);
update_task_status(db, task_id, "completed", None, None, None).await?;
return Ok(RelayResponse::SseWithUsage {
body,
task_id: task_id.to_string(),
account_id: account_id.to_string(),
provider_id: provider_id.to_string(),
model_id: model_id.to_string(),
});
} else {
let body = resp.text().await.unwrap_or_default();
let (input_tokens, output_tokens) = extract_token_usage(&body);
update_task_status(db, task_id, "completed",
Some(input_tokens), Some(output_tokens), None).await?;
return Ok(RelayResponse::Json(body));
}
}
Ok(resp) => {
let status = resp.status().as_u16();
if !is_retryable_status(status) || attempt + 1 >= max_attempts {
// 4xx 客户端错误或已达最大重试次数 → 立即失败
let body = resp.text().await.unwrap_or_default();
// 仅记录日志,不将上游错误体暴露给客户端(可能含敏感信息如 API key
tracing::warn!(
"Relay task {} 上游返回 HTTP {} (body truncated): {}",
task_id, status, &body[..body.len().min(200)]
);
let err_msg = format!("上游服务返回错误 (HTTP {})", status);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
return Err(SaasError::Relay(err_msg));
}
// 可重试的服务端错误 → 继续循环
tracing::warn!(
"Relay task {} 可重试错误 HTTP {} (attempt {}/{})",
task_id, status, attempt + 1, max_attempts
);
}
Err(e) => {
if !is_retryable_error(&e) || attempt + 1 >= max_attempts {
let err_msg = format!("请求上游失败: {}", e);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
return Err(SaasError::Relay(err_msg));
}
tracing::warn!(
"Relay task {} 网络错误 (attempt {}/{}): {}",
task_id, attempt + 1, max_attempts, e
);
}
}
// 指数退避: base_delay * 2^attempt
let delay_ms = base_delay_ms * (1 << attempt);
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
}
// 理论上不会到达 (循环内已处理),但满足编译器
Err(SaasError::Relay("重试次数已耗尽".into()))
}
/// 中转响应类型
#[derive(Debug)]
pub enum RelayResponse {
Json(String),
/// SSE 流式响应 + 上下文信息
SseWithUsage {
body: axum::body::Body,
task_id: String,
account_id: String,
provider_id: String,
model_id: String,
},
}
// ============ Helpers ============
pub fn extract_token_usage(body: &str) -> (i64, i64) {
let parsed: serde_json::Value = match serde_json::from_str(body) {
Ok(v) => v,
Err(_) => return (0, 0),
};
let usage = parsed.get("usage");
let input = usage
.and_then(|u| u.get("prompt_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
let output = usage
.and_then(|u| u.get("completion_tokens"))
.and_then(|v| v.as_i64())
.unwrap_or(0);
(input, output)
}
/// SSRF 防护: 验证 provider URL 不指向内网
fn validate_provider_url(url: &str) -> SaasResult<()> {
let parsed: url::Url = url.parse().map_err(|_| {
SaasError::InvalidInput(format!("无效的 provider URL: {}", url))
})?;
// 只允许 https
match parsed.scheme() {
"https" => {}
"http" => {
// 开发环境允许 http
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if !is_dev {
return Err(SaasError::InvalidInput("生产环境禁止 http scheme请使用 https".into()));
}
}
_ => return Err(SaasError::InvalidInput(format!("不允许的 URL scheme: {}", parsed.scheme()))),
}
// 禁止内网地址
let host = match parsed.host_str() {
Some(h) => h,
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
};
// url crate 的 host_str() 对 IPv6 地址保留方括号 (如 "[::1]")
// 需要去掉方括号才能与阻止列表匹配和解析为 IpAddr
let host = host.trim_start_matches('[').trim_end_matches(']');
// 精确匹配的阻止列表
let blocked_exact = [
"127.0.0.1", "0.0.0.0", "localhost", "::1", "::ffff:127.0.0.1",
"0:0:0:0:0:ffff:7f00:1", "169.254.169.254", "metadata.google.internal",
"10.0.0.1", "172.16.0.1", "192.168.0.1",
];
if blocked_exact.contains(&host) {
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
}
// 后缀匹配 (阻止子域名)
let blocked_suffixes = ["localhost", "internal", "local", "localhost.localdomain"];
for suffix in &blocked_suffixes {
if host.ends_with(&format!(".{}", suffix)) {
return Err(SaasError::InvalidInput(format!("provider URL 指向禁止的内网地址: {}", host)));
}
}
// 阻止 IPv4 私有网段 (通过解析 IP)
if let Ok(ip) = host.parse::<std::net::IpAddr>() {
if is_private_ip(&ip) {
return Err(SaasError::InvalidInput(format!("provider URL 指向私有 IP 地址: {}", host)));
}
}
// 阻止纯数字 host (可能是十进制 IP 表示法,如 2130706433 = 127.0.0.1)
if host.parse::<u64>().is_ok() {
return Err(SaasError::InvalidInput(format!("provider URL 使用了不允许的 IP 格式: {}", host)));
}
Ok(())
}
/// 检查 IP 是否属于私有/内网地址范围
fn is_private_ip(ip: &std::net::IpAddr) -> bool {
match ip {
std::net::IpAddr::V4(v4) => {
let octets = v4.octets();
// 10.0.0.0/8
octets[0] == 10
// 172.16.0.0/12
|| (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31)
// 192.168.0.0/16
|| (octets[0] == 192 && octets[1] == 168)
// 127.0.0.0/8 (loopback)
|| octets[0] == 127
// 169.254.0.0/16 (link-local)
|| (octets[0] == 169 && octets[1] == 254)
// 0.0.0.0/8
|| octets[0] == 0
}
std::net::IpAddr::V6(v6) => {
// ::1 (loopback)
v6.is_loopback()
// ::ffff:x.x.x.x (IPv6-mapped IPv4)
|| v6.to_ipv4_mapped().map_or(false, |v4| is_private_ip(&std::net::IpAddr::V4(v4)))
// fe80::/10 (link-local)
|| (v6.segments()[0] & 0xffc0) == 0xfe80
}
}
}
#[cfg(test)]
mod tests {
use super::*;
// ---- is_retryable_status ----
#[test]
fn retryable_status_429() {
assert!(is_retryable_status(429));
}
#[test]
fn retryable_status_5xx_range() {
for code in 500u16..600 {
assert!(is_retryable_status(code), "expected {code} to be retryable");
}
}
#[test]
fn not_retryable_status_200() {
assert!(!is_retryable_status(200));
}
#[test]
fn not_retryable_status_400() {
assert!(!is_retryable_status(400));
}
#[test]
fn not_retryable_status_404() {
assert!(!is_retryable_status(404));
}
#[test]
fn not_retryable_status_422() {
assert!(!is_retryable_status(422));
}
// ---- extract_token_usage ----
#[test]
fn extract_usage_normal() {
let body = r#"{"usage":{"prompt_tokens":100,"completion_tokens":50}}"#;
assert_eq!(extract_token_usage(body), (100, 50));
}
#[test]
fn extract_usage_no_usage_field() {
let body = r#"{"id":"chatcmpl-abc","object":"chat.completion"}"#;
assert_eq!(extract_token_usage(body), (0, 0));
}
#[test]
fn extract_usage_invalid_json() {
assert_eq!(extract_token_usage("not json at all"), (0, 0));
}
#[test]
fn extract_usage_empty_body() {
assert_eq!(extract_token_usage(""), (0, 0));
}
#[test]
fn extract_usage_partial_tokens() {
// only prompt_tokens present, completion_tokens missing
let body = r#"{"usage":{"prompt_tokens":200}}"#;
assert_eq!(extract_token_usage(body), (200, 0));
}
#[test]
fn extract_usage_completion_only() {
let body = r#"{"usage":{"completion_tokens":75}}"#;
assert_eq!(extract_token_usage(body), (0, 75));
}
#[test]
fn extract_usage_zero_tokens() {
let body = r#"{"usage":{"prompt_tokens":0,"completion_tokens":0}}"#;
assert_eq!(extract_token_usage(body), (0, 0));
}
#[test]
fn extract_usage_string_instead_of_int() {
// non-integer token values should fall back to 0
let body = r#"{"usage":{"prompt_tokens":"abc","completion_tokens":null}}"#;
assert_eq!(extract_token_usage(body), (0, 0));
}
// ---- is_private_ip ----
#[test]
fn private_ip_10_range() {
let ip: std::net::IpAddr = "10.0.0.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_172_16_range() {
let ip: std::net::IpAddr = "172.16.0.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_172_31_range() {
let ip: std::net::IpAddr = "172.31.255.255".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_172_15_not_private() {
// 172.15.x.x is NOT in the private range (starts at 172.16)
let ip: std::net::IpAddr = "172.15.255.255".parse().unwrap();
assert!(!is_private_ip(&ip));
}
#[test]
fn private_ip_172_32_not_private() {
// 172.32.x.x is NOT in the private range (ends at 172.31)
let ip: std::net::IpAddr = "172.32.0.0".parse().unwrap();
assert!(!is_private_ip(&ip));
}
#[test]
fn private_ip_192_168_range() {
let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_127_loopback() {
let ip: std::net::IpAddr = "127.0.0.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_127_any() {
let ip: std::net::IpAddr = "127.255.255.255".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_169_254_link_local() {
let ip: std::net::IpAddr = "169.254.1.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_0_0_0_0() {
let ip: std::net::IpAddr = "0.0.0.0".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_v6_loopback() {
let ip: std::net::IpAddr = "::1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_v6_link_local() {
let ip: std::net::IpAddr = "fe80::1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_v6_mapped_ipv4_loopback() {
let ip: std::net::IpAddr = "::ffff:127.0.0.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_v6_mapped_ipv4_private() {
let ip: std::net::IpAddr = "::ffff:192.168.1.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn public_ip_8_8_8_8() {
let ip: std::net::IpAddr = "8.8.8.8".parse().unwrap();
assert!(!is_private_ip(&ip));
}
#[test]
fn public_ip_1_1_1_1() {
let ip: std::net::IpAddr = "1.1.1.1".parse().unwrap();
assert!(!is_private_ip(&ip));
}
#[test]
fn public_ip_v6_google() {
let ip: std::net::IpAddr = "2001:4860:4860::8888".parse().unwrap();
assert!(!is_private_ip(&ip));
}
// ---- validate_provider_url ----
#[test]
fn validate_url_https_valid() {
assert!(validate_provider_url("https://api.openai.com").is_ok());
}
#[test]
fn validate_url_https_with_path() {
assert!(validate_provider_url("https://api.openai.com/v1").is_ok());
}
#[test]
fn validate_url_https_with_port() {
assert!(validate_provider_url("https://api.openai.com:443").is_ok());
}
#[test]
fn validate_url_blocks_localhost() {
assert!(validate_provider_url("https://localhost").is_err());
}
#[test]
fn validate_url_blocks_127_0_0_1() {
assert!(validate_provider_url("https://127.0.0.1").is_err());
}
#[test]
fn validate_url_blocks_0_0_0_0() {
assert!(validate_provider_url("https://0.0.0.0").is_err());
}
#[test]
fn validate_url_blocks_169_254_169_254() {
assert!(validate_provider_url("https://169.254.169.254").is_err());
}
#[test]
fn validate_url_blocks_metadata_google_internal() {
assert!(validate_provider_url("https://metadata.google.internal").is_err());
}
#[test]
fn validate_url_blocks_private_ip_10() {
assert!(validate_provider_url("https://10.0.0.1").is_err());
}
#[test]
fn validate_url_blocks_private_ip_172_16() {
assert!(validate_provider_url("https://172.16.0.1").is_err());
}
#[test]
fn validate_url_blocks_private_ip_192_168() {
assert!(validate_provider_url("https://192.168.0.1").is_err());
}
#[test]
fn validate_url_blocks_numeric_host() {
// decimal IP representation (e.g. 2130706433 = 127.0.0.1)
assert!(validate_provider_url("https://2130706433").is_err());
}
#[test]
fn validate_url_blocks_subdomain_localhost() {
assert!(validate_provider_url("https://evil.localhost").is_err());
}
#[test]
fn validate_url_blocks_subdomain_internal() {
assert!(validate_provider_url("https://app.internal").is_err());
}
#[test]
fn validate_url_blocks_subdomain_local() {
assert!(validate_provider_url("https://myapp.local").is_err());
}
#[test]
fn validate_url_blocks_ipv6_loopback() {
assert!(validate_provider_url("https://[::1]").is_err());
}
#[test]
fn validate_url_invalid_format() {
assert!(validate_provider_url("not a url").is_err());
}
#[test]
fn validate_url_missing_host() {
assert!(validate_provider_url("https://").is_err());
}
#[test]
fn validate_url_blocks_ftp_scheme() {
assert!(validate_provider_url("ftp://api.openai.com").is_err());
}
#[test]
fn validate_url_blocks_http_in_production() {
// In CI / default env, ZCLAW_SAAS_DEV is not set, so http is blocked
assert!(validate_provider_url("http://api.openai.com").is_err());
}
}

View File

@@ -0,0 +1,32 @@
//! 中转服务类型定义
use serde::{Deserialize, Serialize};
/// 中转任务信息
#[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
pub struct RelayTaskInfo {
pub id: String,
pub account_id: String,
pub provider_id: String,
pub model_id: String,
pub status: String,
pub priority: i64,
pub attempt_count: i64,
pub max_attempts: i64,
pub input_tokens: i64,
pub output_tokens: i64,
pub error_message: Option<String>,
pub queued_at: String,
pub started_at: Option<String>,
pub completed_at: Option<String>,
pub created_at: String,
}
/// 中转任务查询
#[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
pub struct RelayTaskQuery {
pub status: Option<String>,
pub page: Option<i64>,
pub page_size: Option<i64>,
}

View File

@@ -0,0 +1,37 @@
//! 应用状态
use sqlx::PgPool;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use crate::config::SaaSConfig;
use crate::crypto::FieldEncryption;
/// 全局应用状态,通过 Axum State 共享
#[derive(Clone)]
pub struct AppState {
/// 数据库连接池
pub db: PgPool,
/// 服务器配置 (可热更新)
pub config: Arc<RwLock<SaaSConfig>>,
/// JWT 密钥
pub jwt_secret: secrecy::SecretString,
/// 字段级加密器 (AES-256-GCM)
pub field_encryption: Arc<FieldEncryption>,
/// 速率限制: account_id → 请求时间戳列表
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
}
impl AppState {
pub fn new(db: PgPool, config: SaaSConfig) -> anyhow::Result<Self> {
let jwt_secret = config.jwt_secret()?;
let field_encryption = Arc::new(FieldEncryption::new()?);
Ok(Self {
db,
config: Arc::new(RwLock::new(config)),
jwt_secret,
field_encryption,
rate_limit_entries: Arc::new(dashmap::DashMap::new()),
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,206 @@
# Design System Master File
> **LOGIC:** When building a specific page, first check `design-system/pages/[page-name].md`.
> If that file exists, its rules **override** this Master file.
> If not, strictly follow the rules below.
---
**Project:** ZCLAW Admin
**Generated:** 2026-03-27 13:52:31
**Category:** Financial Dashboard
---
## Global Rules
### Color Palette
| Role | Hex | CSS Variable |
|------|-----|--------------|
| Primary | `#0F172A` | `--color-primary` |
| Secondary | `#1E293B` | `--color-secondary` |
| CTA/Accent | `#22C55E` | `--color-cta` |
| Background | `#020617` | `--color-background` |
| Text | `#F8FAFC` | `--color-text` |
**Color Notes:** Dark bg + green positive indicators
### Typography
- **Heading Font:** Fira Code
- **Body Font:** Fira Sans
- **Mood:** dashboard, data, analytics, code, technical, precise
- **Google Fonts:** [Fira Code + Fira Sans](https://fonts.google.com/share?selection.family=Fira+Code:wght@400;500;600;700|Fira+Sans:wght@300;400;500;600;700)
**CSS Import:**
```css
@import url('https://fonts.googleapis.com/css2?family=Fira+Code:wght@400;500;600;700&family=Fira+Sans:wght@300;400;500;600;700&display=swap');
```
### Spacing Variables
| Token | Value | Usage |
|-------|-------|-------|
| `--space-xs` | `4px` / `0.25rem` | Tight gaps |
| `--space-sm` | `8px` / `0.5rem` | Icon gaps, inline spacing |
| `--space-md` | `16px` / `1rem` | Standard padding |
| `--space-lg` | `24px` / `1.5rem` | Section padding |
| `--space-xl` | `32px` / `2rem` | Large gaps |
| `--space-2xl` | `48px` / `3rem` | Section margins |
| `--space-3xl` | `64px` / `4rem` | Hero padding |
### Shadow Depths
| Level | Value | Usage |
|-------|-------|-------|
| `--shadow-sm` | `0 1px 2px rgba(0,0,0,0.05)` | Subtle lift |
| `--shadow-md` | `0 4px 6px rgba(0,0,0,0.1)` | Cards, buttons |
| `--shadow-lg` | `0 10px 15px rgba(0,0,0,0.1)` | Modals, dropdowns |
| `--shadow-xl` | `0 20px 25px rgba(0,0,0,0.15)` | Hero images, featured cards |
---
## Component Specs
### Buttons
```css
/* Primary Button */
.btn-primary {
background: #22C55E;
color: white;
padding: 12px 24px;
border-radius: 8px;
font-weight: 600;
transition: all 200ms ease;
cursor: pointer;
}
.btn-primary:hover {
opacity: 0.9;
transform: translateY(-1px);
}
/* Secondary Button */
.btn-secondary {
background: transparent;
color: #0F172A;
border: 2px solid #0F172A;
padding: 12px 24px;
border-radius: 8px;
font-weight: 600;
transition: all 200ms ease;
cursor: pointer;
}
```
### Cards
```css
.card {
background: #020617;
border-radius: 12px;
padding: 24px;
box-shadow: var(--shadow-md);
transition: all 200ms ease;
cursor: pointer;
}
.card:hover {
box-shadow: var(--shadow-lg);
transform: translateY(-2px);
}
```
### Inputs
```css
.input {
padding: 12px 16px;
border: 1px solid #E2E8F0;
border-radius: 8px;
font-size: 16px;
transition: border-color 200ms ease;
}
.input:focus {
border-color: #0F172A;
outline: none;
box-shadow: 0 0 0 3px #0F172A20;
}
```
### Modals
```css
.modal-overlay {
background: rgba(0, 0, 0, 0.5);
backdrop-filter: blur(4px);
}
.modal {
background: white;
border-radius: 16px;
padding: 32px;
box-shadow: var(--shadow-xl);
max-width: 500px;
width: 90%;
}
```
---
## Style Guidelines
**Style:** Dark Mode (OLED)
**Keywords:** Dark theme, low light, high contrast, deep black, midnight blue, eye-friendly, OLED, night mode, power efficient
**Best For:** Night-mode apps, coding platforms, entertainment, eye-strain prevention, OLED devices, low-light
**Key Effects:** Minimal glow (text-shadow: 0 0 10px), dark-to-light transitions, low white emission, high readability, visible focus
### Page Pattern
**Pattern Name:** Horizontal Scroll Journey
- **Conversion Strategy:** Immersive product discovery. High engagement. Keep navigation visible.
28,Bento Grid Showcase,bento, grid, features, modular, apple-style, showcase", 1. Hero, 2. Bento Grid (Key Features), 3. Detail Cards, 4. Tech Specs, 5. CTA, Floating Action Button or Bottom of Grid, Card backgrounds: #F5F5F7 or Glass. Icons: Vibrant brand colors. Text: Dark., Hover card scale (1.02), video inside cards, tilt effect, staggered reveal, Scannable value props. High information density without clutter. Mobile stack.
29,Interactive 3D Configurator,3d, configurator, customizer, interactive, product", 1. Hero (Configurator), 2. Feature Highlight (synced), 3. Price/Specs, 4. Purchase, Inside Configurator UI + Sticky Bottom Bar, Neutral studio background. Product: Realistic materials. UI: Minimal overlay., Real-time rendering, material swap animation, camera rotate/zoom, light reflection, Increases ownership feeling. 360 view reduces return rates. Direct add-to-cart.
30,AI-Driven Dynamic Landing,ai, dynamic, personalized, adaptive, generative", 1. Prompt/Input Hero, 2. Generated Result Preview, 3. How it Works, 4. Value Prop, Input Field (Hero) + 'Try it' Buttons, Adaptive to user input. Dark mode for compute feel. Neon accents., Typing text effects, shimmering generation loaders, morphing layouts, Immediate value demonstration. 'Show, don't tell'. Low friction start.
- **CTA Placement:** Floating Sticky CTA or End of Horizontal Track
- **Section Order:** 1. Intro (Vertical), 2. The Journey (Horizontal Track), 3. Detail Reveal, 4. Vertical Footer
---
## Anti-Patterns (Do NOT Use)
- ❌ Light mode default
- ❌ Slow rendering
### Additional Forbidden Patterns
-**Emojis as icons** — Use SVG icons (Heroicons, Lucide, Simple Icons)
-**Missing cursor:pointer** — All clickable elements must have cursor:pointer
-**Layout-shifting hovers** — Avoid scale transforms that shift layout
-**Low contrast text** — Maintain 4.5:1 minimum contrast ratio
-**Instant state changes** — Always use transitions (150-300ms)
-**Invisible focus states** — Focus states must be visible for a11y
---
## Pre-Delivery Checklist
Before delivering any UI code, verify:
- [ ] No emojis used as icons (use SVG instead)
- [ ] All icons from consistent icon set (Heroicons/Lucide)
- [ ] `cursor-pointer` on all clickable elements
- [ ] Hover states with smooth transitions (150-300ms)
- [ ] Light mode: text contrast 4.5:1 minimum
- [ ] Focus states visible for keyboard navigation
- [ ] `prefers-reduced-motion` respected
- [ ] Responsive: 375px, 768px, 1024px, 1440px
- [ ] No content hidden behind fixed navbars
- [ ] No horizontal scroll on mobile

View File

@@ -0,0 +1,339 @@
import { useState, useEffect } from 'react';
import { saasClient, type SaaSConfigItem } from '../../lib/saas-client';
import { ArrowLeft, ArrowRight, Upload, Check, Loader2, RefreshCw } from 'lucide-react';
interface LocalModel {
id: string;
name: string;
provider: string;
[key: string]: unknown;
}
type SyncDirection = 'local-to-saas' | 'saas-to-local' | 'merge';
interface SyncConflict {
key: string;
localValue: string | null;
saasValue: string | null;
}
export function ConfigMigrationWizard({ onDone }: { onDone: () => void }) {
const [step, setStep] = useState<1 | 2 | 3>(1);
const [direction, setDirection] = useState<SyncDirection>('local-to-saas');
const [isSyncing, setIsSyncing] = useState(false);
const [syncResult, setSyncResult] = useState<'success' | 'partial' | null>(null);
const [error, setError] = useState<string | null>(null);
// Data
const [localModels, setLocalModels] = useState<LocalModel[]>([]);
const [saasConfigs, setSaasConfigs] = useState<SaaSConfigItem[]>([]);
const [conflicts, setConflicts] = useState<SyncConflict[]>([]);
const [selectedKeys, setSelectedKeys] = useState<Set<string>>(new Set());
// Step 1: Load data
useEffect(() => {
if (step !== 1) return;
// Load local models from localStorage
try {
const raw = localStorage.getItem('zclaw-custom-models');
if (raw) {
const parsed = JSON.parse(raw) as LocalModel[];
setLocalModels(Array.isArray(parsed) ? parsed : []);
}
} catch {
setLocalModels([]);
}
// Load SaaS config items
saasClient.listConfig().then(setSaasConfigs).catch(() => setSaasConfigs([]));
}, [step]);
const localCount = localModels.length;
const saasCount = saasConfigs.length;
// Step 2: Compute conflicts based on direction
useEffect(() => {
if (step !== 2) return;
const found: SyncConflict[] = [];
if (direction === 'local-to-saas' || direction === 'merge') {
// Check which local models already exist in SaaS
for (const model of localModels) {
const exists = saasConfigs.some((c) => c.key_path === `models.${model.id}`);
if (exists) {
found.push({
key: model.id,
localValue: JSON.stringify({ name: model.name, provider: model.provider }),
saasValue: '已存在',
});
}
}
}
if (direction === 'saas-to-local' || direction === 'merge') {
// SaaS configs that have values not in local
for (const config of saasConfigs) {
if (!config.current_value) continue;
const localRaw = localStorage.getItem('zclaw-custom-models');
const localModels: LocalModel[] = localRaw ? JSON.parse(localRaw) : [];
const isLocal = localModels.some((m) => m.id === config.key_path.replace('models.', ''));
if (!isLocal && config.category === 'model') {
found.push({
key: config.key_path,
localValue: null,
saasValue: config.current_value,
});
}
}
}
setConflicts(found);
setSelectedKeys(new Set(found.map((c) => c.key)));
}, [step, direction, localModels, saasConfigs]);
// Step 3: Execute sync
async function executeSync() {
setIsSyncing(true);
setError(null);
try {
if (direction === 'local-to-saas' && localModels.length > 0) {
// Push local models as config items
for (const model of localModels) {
const existingItem = saasConfigs.find((c) => c.key_path === `models.${model.id}`);
if (existingItem && !selectedKeys.has(model.id)) continue;
const body = {
category: 'model',
key_path: `models.${model.id}`,
value_type: 'json',
current_value: JSON.stringify({ name: model.name, provider: model.provider }),
source: 'desktop',
description: `从桌面端同步: ${model.name}`,
};
if (existingItem) {
await saasClient.request<unknown>('PUT', `/api/v1/config/items/${existingItem.id}`, body);
} else {
await saasClient.request<unknown>('POST', '/api/v1/config/items', body);
}
}
} else if (direction === 'saas-to-local' && saasConfigs.length > 0) {
// Pull SaaS models to local
const syncedModels = localModels.filter((m) => !selectedKeys.has(m.id));
const saasModels = saasConfigs
.filter((c) => c.category === 'model' && c.current_value)
.map((c) => {
try {
return JSON.parse(c.current_value!) as LocalModel;
} catch {
return null;
}
})
.filter((m): m is LocalModel => m !== null);
const merged = [...syncedModels, ...saasModels];
localStorage.setItem('zclaw-custom-models', JSON.stringify(merged));
} else if (direction === 'merge') {
// Merge: local wins for conflicts
const kept = localModels.filter((m) => !selectedKeys.has(m.id));
const saasOnly = saasConfigs
.filter((c) => c.category === 'model' && c.current_value)
.map((c) => {
try {
return JSON.parse(c.current_value!) as LocalModel;
} catch {
return null;
}
})
.filter((m): m is LocalModel => m !== null)
.filter((m) => !localModels.some((lm) => lm.id === m.id));
const merged = [...kept, ...saasOnly];
localStorage.setItem('zclaw-custom-models', JSON.stringify(merged));
}
setSyncResult(conflicts.length > 0 && conflicts.length === selectedKeys.size ? 'partial' : 'success');
} catch (err: unknown) {
setError(err instanceof Error ? err.message : '同步失败');
} finally {
setIsSyncing(false);
}
}
// Reset
function reset() {
setStep(1);
setDirection('local-to-saas');
setSyncResult(null);
setError(null);
setSelectedKeys(new Set());
}
return (
<div className="bg-white rounded-xl border border-gray-200 p-5 shadow-sm">
{/* Header */}
<div className="flex items-center justify-between mb-4">
<div className="flex items-center gap-2">
<Upload className="w-4 h-4 text-gray-500" />
<span className="text-sm font-medium text-gray-700"></span>
</div>
{step > 1 && (
<button onClick={() => setStep((step - 1) as 1 | 2)} className="text-xs text-gray-500 hover:text-gray-700 cursor-pointer">
<ArrowLeft className="w-3.5 h-3.5 inline" />
</button>
)}
</div>
{/* Step 1: Direction & Preview */}
{step === 1 && (
<div className="space-y-4">
<p className="text-sm text-gray-500">
SaaS
</p>
<div className="space-y-2">
<DirectionOption
label="本地 → SaaS"
description={`${localCount} 个本地模型推送到 SaaS 平台`}
selected={direction === 'local-to-saas'}
onClick={() => setDirection('local-to-saas')}
/>
<DirectionOption
label="SaaS → 本地"
description={`从 SaaS 平台拉取 ${saasCount} 项配置到本地`}
selected={direction === 'saas-to-local'}
onClick={() => setDirection('saas-to-local')}
/>
<DirectionOption
label="双向合并"
description="合并两边配置,冲突时保留本地版本"
selected={direction === 'merge'}
onClick={() => setDirection('merge')}
/>
</div>
<button
onClick={() => setStep(2)}
disabled={localCount === 0 && saasCount === 0}
className="w-full py-2 text-sm font-medium text-white bg-emerald-600 rounded-lg hover:bg-emerald-700 disabled:opacity-50 transition-colors"
>
<ArrowRight className="w-4 h-4 inline" />
</button>
</div>
)}
{/* Step 2: Resolve conflicts */}
{step === 2 && (
<div className="space-y-4">
{conflicts.length > 0 ? (
<>
<p className="text-sm text-amber-600">
{conflicts.length} {direction === 'local-to-saas' ? '本地' : 'SaaS'}
</p>
<div className="space-y-1.5">
{conflicts.map((c) => (
<label key={c.key} className="flex items-center gap-2 p-2 rounded-lg bg-gray-50 cursor-pointer text-sm">
<input
type="checkbox"
checked={selectedKeys.has(c.key)}
onChange={(e) => {
setSelectedKeys((prev) => {
const next = new Set(prev);
if (e.target.checked) next.add(c.key);
else next.delete(c.key);
return next;
});
}}
className="rounded"
/>
<span className="font-medium text-gray-800">{c.key}</span>
<span className="text-xs text-gray-400 truncate">
({direction === 'local-to-saas' ? '本地' : 'SaaS'}: {c.saasValue})
</span>
</label>
))}
</div>
</>
) : (
<div className="flex items-center gap-2 text-sm text-emerald-600">
<Check className="w-4 h-4" />
<span></span>
</div>
)}
<button
onClick={() => { setStep(3); executeSync(); }}
className="w-full py-2 text-sm font-medium text-white bg-emerald-600 rounded-lg hover:bg-emerald-700 transition-colors"
>
{isSyncing ? (
<><Loader2 className="w-4 h-4 inline animate-spin" /> ...</>
) : (
<><ArrowRight className="w-4 h-4 inline" /> </>
)}
</button>
</div>
)}
{/* Step 3: Result */}
{step === 3 && (
<div className="space-y-4">
{syncResult === 'success' ? (
<div className="flex items-center gap-2 text-sm text-emerald-600">
<Check className="w-5 h-5" />
<span></span>
</div>
) : syncResult === 'partial' ? (
<div className="flex items-center gap-2 text-amber-600">
<Check className="w-5 h-5" />
<span>{conflicts.length} </span>
</div>
) : error ? (
<div className="text-sm text-red-500">{error}</div>
) : null}
<div className="flex gap-2">
<button
onClick={reset}
className="flex-1 py-2 text-sm text-gray-600 border border-gray-300 rounded-lg hover:bg-gray-50 transition-colors cursor-pointer"
>
<RefreshCw className="w-3.5 h-3.5 inline" />
</button>
<button
onClick={onDone}
className="flex-1 py-2 text-sm font-medium text-white bg-emerald-600 rounded-lg hover:bg-emerald-700 transition-colors"
>
</button>
</div>
</div>
)}
</div>
);
}
function DirectionOption({
label,
description,
selected,
onClick,
}: {
label: string;
description: string;
selected: boolean;
onClick: () => void;
}) {
return (
<button
onClick={onClick}
className={`w-full text-left p-3 rounded-lg border transition-colors cursor-pointer ${
selected ? 'border-emerald-500 bg-emerald-50' : 'border-gray-200 hover:border-gray-300'
}`}
>
<div className="text-sm font-medium text-gray-800">{label}</div>
<div className="text-xs text-gray-500">{description}</div>
</button>
);
}

View File

@@ -0,0 +1,190 @@
import { useState, useEffect, useCallback } from 'react';
import { saasClient, type RelayTaskInfo } from '../../lib/saas-client';
import { useSaaSStore } from '../../store/saasStore';
import {
RefreshCw, RotateCw, Loader2, AlertCircle,
CheckCircle, XCircle, Clock, Zap,
} from 'lucide-react';
const STATUS_TABS = [
{ key: '', label: '全部' },
{ key: 'completed', label: '成功' },
{ key: 'failed', label: '失败' },
{ key: 'processing', label: '处理中' },
{ key: 'queued', label: '排队中' },
] as const;
function StatusBadge({ status }: { status: string }) {
const config: Record<string, { bg: string; text: string; icon: typeof CheckCircle }> = {
completed: { bg: 'bg-emerald-100 text-emerald-700', text: '成功', icon: CheckCircle },
failed: { bg: 'bg-red-100 text-red-700', text: '失败', icon: XCircle },
processing: { bg: 'bg-amber-100 text-amber-700', text: '处理中', icon: Zap },
queued: { bg: 'bg-gray-100 text-gray-500', text: '排队中', icon: Clock },
};
const c = config[status] ?? config.queued;
const Icon = c.icon;
return (
<span className={`inline-flex items-center gap-1 text-xs px-2 py-0.5 rounded-full font-medium ${c.bg}`}>
<Icon className="w-3 h-3" />
{c.text}
</span>
);
}
function formatTime(iso: string | null): string {
if (!iso) return '-';
try {
const d = new Date(iso);
return d.toLocaleString('zh-CN', { month: '2-digit', day: '2-digit', hour: '2-digit', minute: '2-digit' });
} catch {
return iso;
}
}
export function RelayTasksPanel() {
const account = useSaaSStore((s) => s.account);
const isAdmin = account?.role === 'admin';
const [tasks, setTasks] = useState<RelayTaskInfo[]>([]);
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const [statusFilter, setStatusFilter] = useState('');
const [retryingId, setRetryingId] = useState<string | null>(null);
const fetchTasks = useCallback(async () => {
setIsLoading(true);
setError(null);
try {
const query = statusFilter ? { status: statusFilter } : undefined;
const data = await saasClient.listRelayTasks(query);
setTasks(data);
} catch (err: unknown) {
setError(err instanceof Error ? err.message : '加载失败');
setTasks([]);
} finally {
setIsLoading(false);
}
}, [statusFilter]);
useEffect(() => {
fetchTasks();
}, [fetchTasks]);
const handleRetry = async (taskId: string) => {
setRetryingId(taskId);
try {
await saasClient.retryRelayTask(taskId);
await fetchTasks();
} catch (err: unknown) {
setError(err instanceof Error ? err.message : '重试失败');
} finally {
setRetryingId(null);
}
};
return (
<div className="bg-white rounded-xl border border-gray-200 p-5 shadow-sm space-y-4">
{/* Header */}
<div className="flex items-center justify-between">
<h3 className="text-sm font-semibold text-gray-900"></h3>
<button
type="button"
onClick={fetchTasks}
disabled={isLoading}
className="p-1.5 text-gray-400 hover:text-gray-600 hover:bg-gray-100 rounded-lg transition-colors cursor-pointer disabled:opacity-50"
>
<RefreshCw className={`w-4 h-4 ${isLoading ? 'animate-spin' : ''}`} />
</button>
</div>
{/* Status filter tabs */}
<div className="flex gap-1 border-b border-gray-200">
{STATUS_TABS.map((tab) => (
<button
key={tab.key}
type="button"
onClick={() => setStatusFilter(tab.key)}
className={`px-3 py-1.5 text-xs font-medium cursor-pointer transition-colors border-b-2 ${
statusFilter === tab.key
? 'border-emerald-500 text-emerald-600'
: 'border-transparent text-gray-500 hover:text-gray-700'
}`}
>
{tab.label}
</button>
))}
</div>
{error && (
<div className="flex items-start gap-2 text-sm text-red-600 bg-red-50 rounded-lg p-3">
<AlertCircle className="w-4 h-4 mt-0.5 flex-shrink-0" />
<span>{error}</span>
</div>
)}
{isLoading && tasks.length === 0 ? (
<div className="flex items-center justify-center py-8 text-gray-400">
<Loader2 className="w-5 h-5 animate-spin mr-2" />
...
</div>
) : tasks.length === 0 ? (
<div className="text-center py-8 text-sm text-gray-400">
</div>
) : (
<div className="space-y-2 max-h-80 overflow-y-auto">
{tasks.map((task) => (
<div
key={task.id}
className="flex items-center gap-3 px-3 py-2 rounded-lg border border-gray-100 hover:bg-gray-50 transition-colors"
>
{/* Status */}
<StatusBadge status={task.status} />
{/* Info */}
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2">
<span className="text-sm font-medium text-gray-900 truncate">
{task.model_id}
</span>
<span className="text-xs text-gray-400">
{task.input_tokens > 0 || task.output_tokens > 0
? `(${task.input_tokens}in / ${task.output_tokens}out)`
: ''}
</span>
</div>
{task.error_message && (
<p className="text-xs text-red-500 truncate mt-0.5" title={task.error_message}>
{task.error_message}
</p>
)}
</div>
{/* Time */}
<span className="text-xs text-gray-400 whitespace-nowrap">
{formatTime(task.created_at)}
</span>
{/* Retry button (admin only, failed tasks only) */}
{isAdmin && task.status === 'failed' && (
<button
type="button"
onClick={() => handleRetry(task.id)}
disabled={retryingId === task.id}
className="flex-shrink-0 p-1 text-gray-400 hover:text-emerald-600 hover:bg-emerald-50 rounded transition-colors cursor-pointer disabled:opacity-50"
title="重试"
>
{retryingId === task.id ? (
<Loader2 className="w-3.5 h-3.5 animate-spin" />
) : (
<RotateCw className="w-3.5 h-3.5" />
)}
</button>
)}
</div>
))}
</div>
)}
</div>
);
}

View File

@@ -0,0 +1,394 @@
import { useState } from 'react';
import { LogIn, UserPlus, Globe, Eye, EyeOff, Loader2, AlertCircle, Mail, Shield, ShieldCheck, ArrowLeft } from 'lucide-react';
interface SaaSLoginProps {
onLogin: (saasUrl: string, username: string, password: string) => Promise<void>;
onLoginWithTotp?: (saasUrl: string, username: string, password: string, totpCode: string) => Promise<void>;
onRegister?: (saasUrl: string, username: string, email: string, password: string, displayName?: string) => Promise<void>;
initialUrl?: string;
isLoggingIn?: boolean;
totpRequired?: boolean;
error?: string | null;
}
export function SaaSLogin({ onLogin, onLoginWithTotp, onRegister, initialUrl, isLoggingIn, totpRequired, error }: SaaSLoginProps) {
const [serverUrl, setServerUrl] = useState(initialUrl || '');
const [username, setUsername] = useState('');
const [email, setEmail] = useState('');
const [password, setPassword] = useState('');
const [confirmPassword, setConfirmPassword] = useState('');
const [displayName, setDisplayName] = useState('');
const [showPassword, setShowPassword] = useState(false);
const [isRegister, setIsRegister] = useState(false);
const [localError, setLocalError] = useState<string | null>(null);
const [totpCode, setTotpCode] = useState('');
const [showTotpStep, setShowTotpStep] = useState(false);
// Sync with parent prop
if (totpRequired && !showTotpStep) {
setShowTotpStep(true);
}
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setLocalError(null);
if (!serverUrl.trim()) {
setLocalError('请输入服务器地址');
return;
}
if (!username.trim()) {
setLocalError('请输入用户名');
return;
}
if (!password) {
setLocalError('请输入密码');
return;
}
if (isRegister) {
if (!email.trim()) {
setLocalError('请输入邮箱地址');
return;
}
if (!/^[^\s@]+@[^\s@]+\.[^\s@]+$/.test(email.trim())) {
setLocalError('邮箱格式不正确');
return;
}
if (password.length < 6) {
setLocalError('密码长度至少 6 个字符');
return;
}
if (password !== confirmPassword) {
setLocalError('两次输入的密码不一致');
return;
}
if (onRegister) {
try {
await onRegister(
serverUrl.trim(),
username.trim(),
email.trim(),
password,
displayName.trim() || undefined,
);
} catch (err: unknown) {
const message = err instanceof Error ? err.message : String(err);
setLocalError(message);
}
return;
}
}
try {
await onLogin(serverUrl.trim(), username.trim(), password);
// If TOTP required, login() won't throw but store sets totpRequired
// The effect above will switch to TOTP step
} catch (err: unknown) {
const message = err instanceof Error ? err.message : String(err);
setLocalError(message);
}
};
const handleTotpSubmit = async () => {
if (!onLoginWithTotp || totpCode.length !== 6) return;
setLocalError(null);
try {
await onLoginWithTotp(serverUrl.trim(), username.trim(), password, totpCode);
setTotpCode('');
setShowTotpStep(false);
} catch (err: unknown) {
const message = err instanceof Error ? err.message : String(err);
setLocalError(message);
}
};
const handleBackToLogin = () => {
setShowTotpStep(false);
setTotpCode('');
setLocalError(null);
};
const displayError = error || localError;
const handleTabSwitch = (register: boolean) => {
setIsRegister(register);
setLocalError(null);
setConfirmPassword('');
setEmail('');
setDisplayName('');
};
return (
<div className="bg-white rounded-xl border border-gray-200 p-6 shadow-sm">
{/* TOTP Verification Step */}
{showTotpStep ? (
<div className="space-y-4">
<div className="flex items-center gap-2 mb-1">
<Shield className="w-5 h-5 text-emerald-600" />
<h2 className="text-lg font-semibold text-gray-900"></h2>
</div>
<p className="text-sm text-gray-500">
TOTP
</p>
<div>
<label htmlFor="totp-code" className="block text-sm font-medium text-gray-700 mb-1.5">
TOTP
</label>
<input
id="totp-code"
type="text"
inputMode="numeric"
maxLength={6}
value={totpCode}
onChange={(e) => setTotpCode(e.target.value.replace(/\D/g, ''))}
placeholder="000000"
autoComplete="one-time-code"
autoFocus
className="w-full px-3 py-2 border border-gray-300 rounded-lg text-sm font-mono tracking-widest text-center focus:outline-none focus:ring-2 focus:ring-emerald-500/20 focus:border-emerald-500 bg-white text-gray-900"
disabled={isLoggingIn}
onKeyDown={(e) => {
if (e.key === 'Enter' && totpCode.length === 6) handleTotpSubmit();
}}
/>
</div>
{displayError && (
<div className="flex items-start gap-2 text-sm text-red-600 bg-red-50 rounded-lg p-3">
<AlertCircle className="w-4 h-4 mt-0.5 flex-shrink-0" />
<span>{displayError}</span>
</div>
)}
<div className="flex gap-2">
<button
type="button"
onClick={handleBackToLogin}
disabled={isLoggingIn}
className="flex-1 flex items-center justify-center gap-2 px-4 py-2.5 text-sm text-gray-600 border border-gray-300 rounded-lg hover:bg-gray-50 transition-colors disabled:opacity-50 cursor-pointer"
>
<ArrowLeft className="w-4 h-4" />
</button>
<button
type="button"
onClick={handleTotpSubmit}
disabled={isLoggingIn || totpCode.length !== 6}
className="flex-1 flex items-center justify-center gap-2 px-4 py-2.5 bg-emerald-500 hover:bg-emerald-600 text-white text-sm font-medium rounded-lg transition-colors disabled:opacity-50 disabled:cursor-not-allowed cursor-pointer"
>
{isLoggingIn ? (
<Loader2 className="w-4 h-4 animate-spin" />
) : (
<ShieldCheck className="w-4 h-4" />
)}
</button>
</div>
</div>
) : (
<>
<h2 className="text-lg font-semibold text-gray-900 mb-1">
{isRegister ? '注册 SaaS 账号' : '登录 SaaS 平台'}
</h2>
<p className="text-sm text-gray-500 mb-5">
{isRegister
? '创建账号以使用 ZCLAW 云端服务'
: '连接到 ZCLAW SaaS 平台,解锁云端能力'}
</p>
{/* Tab Switcher */}
<div className="flex mb-5 border-b border-gray-200">
<button
type="button"
onClick={() => handleTabSwitch(false)}
className={`px-4 py-2.5 text-sm font-medium cursor-pointer transition-colors border-b-2 ${
!isRegister
? 'border-emerald-500 text-emerald-600'
: 'border-transparent text-gray-500 hover:text-gray-700'
}`}
>
<span className="flex items-center gap-1.5">
<LogIn className="w-3.5 h-3.5" />
</span>
</button>
{onRegister && (
<button
type="button"
onClick={() => handleTabSwitch(true)}
className={`px-4 py-2.5 text-sm font-medium cursor-pointer transition-colors border-b-2 ${
isRegister
? 'border-emerald-500 text-emerald-600'
: 'border-transparent text-gray-500 hover:text-gray-700'
}`}
>
<span className="flex items-center gap-1.5">
<UserPlus className="w-3.5 h-3.5" />
</span>
</button>
)}
</div>
{/* Form */}
<form onSubmit={handleSubmit} className="space-y-4">
{/* Server URL */}
<div>
<label htmlFor="saas-url" className="block text-sm font-medium text-gray-700 mb-1.5">
</label>
<div className="relative">
<Globe className="absolute left-3 top-1/2 -translate-y-1/2 w-4 h-4 text-gray-400" />
<input
id="saas-url"
type="url"
value={serverUrl}
onChange={(e) => setServerUrl(e.target.value)}
placeholder="https://saas.zclaw.com"
className="w-full pl-10 pr-3 py-2 border border-gray-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-emerald-500/20 focus:border-emerald-500 bg-white text-gray-900"
disabled={isLoggingIn}
/>
</div>
</div>
{/* Username */}
<div>
<label htmlFor="saas-username" className="block text-sm font-medium text-gray-700 mb-1.5">
</label>
<input
id="saas-username"
type="text"
value={username}
onChange={(e) => setUsername(e.target.value)}
placeholder="your-username"
autoComplete="username"
className="w-full px-3 py-2 border border-gray-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-emerald-500/20 focus:border-emerald-500 bg-white text-gray-900"
disabled={isLoggingIn}
/>
</div>
{/* Email (Register only) */}
{isRegister && (
<div>
<label htmlFor="saas-email" className="block text-sm font-medium text-gray-700 mb-1.5">
</label>
<div className="relative">
<Mail className="absolute left-3 top-1/2 -translate-y-1/2 w-4 h-4 text-gray-400" />
<input
id="saas-email"
type="email"
value={email}
onChange={(e) => setEmail(e.target.value)}
placeholder="you@example.com"
autoComplete="email"
className="w-full pl-10 pr-3 py-2 border border-gray-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-emerald-500/20 focus:border-emerald-500 bg-white text-gray-900"
disabled={isLoggingIn}
/>
</div>
</div>
)}
{/* Display Name (Register only, optional) */}
{isRegister && (
<div>
<label htmlFor="saas-display-name" className="block text-sm font-medium text-gray-700 mb-1.5">
<span className="text-gray-400 font-normal">()</span>
</label>
<input
id="saas-display-name"
type="text"
value={displayName}
onChange={(e) => setDisplayName(e.target.value)}
placeholder="ZCLAW User"
autoComplete="name"
className="w-full px-3 py-2 border border-gray-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-emerald-500/20 focus:border-emerald-500 bg-white text-gray-900"
disabled={isLoggingIn}
/>
</div>
)}
{/* Password */}
<div>
<label htmlFor="saas-password" className="block text-sm font-medium text-gray-700 mb-1.5">
</label>
<div className="relative">
<input
id="saas-password"
type={showPassword ? 'text' : 'password'}
value={password}
onChange={(e) => setPassword(e.target.value)}
placeholder={isRegister ? '至少 6 个字符' : 'Enter password'}
autoComplete={isRegister ? 'new-password' : 'current-password'}
className="w-full px-3 pr-10 py-2 border border-gray-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-emerald-500/20 focus:border-emerald-500 bg-white text-gray-900"
disabled={isLoggingIn}
/>
<button
type="button"
onClick={() => setShowPassword(!showPassword)}
className="absolute right-3 top-1/2 -translate-y-1/2 text-gray-400 hover:text-gray-600 cursor-pointer"
tabIndex={-1}
>
{showPassword ? <EyeOff className="w-4 h-4" /> : <Eye className="w-4 h-4" />}
</button>
</div>
</div>
{/* Confirm Password (Register only) */}
{isRegister && (
<div>
<label htmlFor="saas-confirm-password" className="block text-sm font-medium text-gray-700 mb-1.5">
</label>
<input
id="saas-confirm-password"
type={showPassword ? 'text' : 'password'}
value={confirmPassword}
onChange={(e) => setConfirmPassword(e.target.value)}
placeholder="Re-enter password"
autoComplete="new-password"
className="w-full px-3 py-2 border border-gray-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-emerald-500/20 focus:border-emerald-500 bg-white text-gray-900"
disabled={isLoggingIn}
/>
</div>
)}
{/* Error Display */}
{displayError && (
<div className="flex items-start gap-2 text-sm text-red-600 bg-red-50 rounded-lg p-3">
<AlertCircle className="w-4 h-4 mt-0.5 flex-shrink-0" />
<span>{displayError}</span>
</div>
)}
{/* Submit Button */}
<button
type="submit"
disabled={isLoggingIn}
className="w-full flex items-center justify-center gap-2 px-4 py-2.5 bg-emerald-500 hover:bg-emerald-600 text-white text-sm font-medium rounded-lg transition-colors disabled:opacity-50 disabled:cursor-not-allowed cursor-pointer"
>
{isLoggingIn ? (
<>
<Loader2 className="w-4 h-4 animate-spin" />
{isRegister ? '注册中...' : '登录中...'}
</>
) : (
<>
{isRegister ? (
<><UserPlus className="w-4 h-4" /></>
) : (
<><LogIn className="w-4 h-4" /></>
)}
</>
)}
</button>
</form>
</>
)}
</div>
);
}

View File

@@ -0,0 +1,334 @@
import { useState } from 'react';
import { useSaaSStore } from '../../store/saasStore';
import { SaaSLogin } from './SaaSLogin';
import { SaaSStatus } from './SaaSStatus';
import { ConfigMigrationWizard } from './ConfigMigrationWizard';
import { TOTPSettings } from './TOTPSettings';
import { RelayTasksPanel } from './RelayTasksPanel';
import { Cloud, Info, KeyRound } from 'lucide-react';
import { saasClient } from '../../lib/saas-client';
export function SaaSSettings() {
const isLoggedIn = useSaaSStore((s) => s.isLoggedIn);
const account = useSaaSStore((s) => s.account);
const saasUrl = useSaaSStore((s) => s.saasUrl);
const connectionMode = useSaaSStore((s) => s.connectionMode);
const login = useSaaSStore((s) => s.login);
const loginWithTotp = useSaaSStore((s) => s.loginWithTotp);
const register = useSaaSStore((s) => s.register);
const logout = useSaaSStore((s) => s.logout);
const totpRequired = useSaaSStore((s) => s.totpRequired);
const [showLogin, setShowLogin] = useState(!isLoggedIn);
const [loginError, setLoginError] = useState<string | null>(null);
const [isLoggingIn, setIsLoggingIn] = useState(false);
const handleLogin = async (url: string, username: string, password: string) => {
setIsLoggingIn(true);
setLoginError(null);
try {
await login(url, username, password);
if (useSaaSStore.getState().totpRequired) {
return;
}
setShowLogin(false);
} catch (err: unknown) {
const message = err instanceof Error ? err.message : '登录失败';
setLoginError(message);
} finally {
setIsLoggingIn(false);
}
};
const handleLoginWithTotp = async (url: string, username: string, password: string, totpCode: string) => {
setIsLoggingIn(true);
setLoginError(null);
try {
await loginWithTotp(url, username, password, totpCode);
setShowLogin(false);
} catch (err: unknown) {
const message = err instanceof Error ? err.message : 'TOTP 验证失败';
setLoginError(message);
} finally {
setIsLoggingIn(false);
}
};
const handleRegister = async (
url: string,
username: string,
email: string,
password: string,
displayName?: string,
) => {
setIsLoggingIn(true);
setLoginError(null);
try {
await register(url, username, email, password, displayName);
// register auto-logs in, no need for separate login call
setShowLogin(false);
} catch (err: unknown) {
const message = err instanceof Error ? err.message : '注册失败';
setLoginError(message);
} finally {
setIsLoggingIn(false);
}
};
const handleLogout = () => {
logout();
setShowLogin(true);
setLoginError(null);
};
return (
<div className="max-w-2xl">
<div className="flex items-center gap-3 mb-6">
<div className="w-9 h-9 rounded-lg bg-emerald-100 flex items-center justify-center">
<Cloud className="w-5 h-5 text-emerald-600" />
</div>
<div>
<h1 className="text-xl font-bold text-gray-900">SaaS </h1>
<p className="text-sm text-gray-500"> ZCLAW </p>
</div>
</div>
{/* Connection mode info */}
<div className="flex items-start gap-2 text-sm text-gray-500 bg-blue-50 rounded-lg border border-blue-100 p-3 mb-5">
<Info className="w-4 h-4 mt-0.5 text-blue-500 flex-shrink-0" />
<span>
: <strong className="text-gray-700">{connectionMode === 'saas' ? 'SaaS 云端' : connectionMode === 'gateway' ? 'Gateway' : '本地 Tauri'}</strong>
{connectionMode !== 'saas' && '连接 SaaS 平台可解锁云端同步、团队协作等高级功能。'}
</span>
</div>
{/* Login form or status display */}
{!showLogin ? (
<SaaSStatus
isLoggedIn={isLoggedIn}
account={account}
saasUrl={saasUrl}
onLogout={handleLogout}
onLogin={() => setShowLogin(true)}
/>
) : (
<SaaSLogin
onLogin={handleLogin}
onLoginWithTotp={handleLoginWithTotp}
onRegister={handleRegister}
initialUrl={saasUrl}
isLoggingIn={isLoggingIn}
totpRequired={totpRequired}
error={loginError}
/>
)}
{/* Features list when logged in */}
{isLoggedIn && !showLogin && (
<div className="mt-6">
<h2 className="text-sm font-medium text-gray-500 uppercase tracking-wide mb-3">
</h2>
<div className="bg-white rounded-xl border border-gray-200 p-5 shadow-sm">
<div className="space-y-3">
<CloudFeatureRow
name="云端同步"
description="对话记录和配置自动同步到云端"
status="active"
/>
<CloudFeatureRow
name="团队协作"
description="与团队成员共享 Agent 和技能"
status={account?.role === 'admin' || account?.role === 'pro' ? 'active' : 'inactive'}
/>
<CloudFeatureRow
name="高级分析"
description="使用统计和用量分析仪表板"
status={account?.role === 'admin' || account?.role === 'pro' ? 'active' : 'inactive'}
/>
</div>
</div>
</div>
)}
{/* Password change section */}
{isLoggedIn && !showLogin && <ChangePasswordSection />}
{/* TOTP 2FA */}
{isLoggedIn && !showLogin && (
<div className="mt-6">
<h2 className="text-sm font-medium text-gray-500 uppercase tracking-wide mb-3">
</h2>
<TOTPSettings />
</div>
)}
{/* Relay tasks */}
{isLoggedIn && !showLogin && (
<div className="mt-6">
<h2 className="text-sm font-medium text-gray-500 uppercase tracking-wide mb-3">
</h2>
<RelayTasksPanel />
</div>
)}
{/* Config migration wizard */}
{isLoggedIn && !showLogin && (
<div className="mt-6">
<h2 className="text-sm font-medium text-gray-500 uppercase tracking-wide mb-3">
</h2>
<ConfigMigrationWizard onDone={() => {/* no-op: wizard self-contained */}} />
</div>
)}
</div>
);
}
function CloudFeatureRow({
name,
description,
status,
}: {
name: string;
description: string;
status: 'active' | 'inactive';
}) {
return (
<div className="flex items-center justify-between py-1">
<div>
<div className="text-sm font-medium text-gray-900">{name}</div>
<div className="text-xs text-gray-500">{description}</div>
</div>
<span
className={`text-xs px-2 py-0.5 rounded-full font-medium ${
status === 'active'
? 'bg-emerald-100 text-emerald-700'
: 'bg-gray-100 text-gray-500'
}`}
>
{status === 'active' ? '可用' : '需要订阅'}
</span>
</div>
);
}
function ChangePasswordSection() {
const [isOpen, setIsOpen] = useState(false);
const [oldPassword, setOldPassword] = useState('');
const [newPassword, setNewPassword] = useState('');
const [confirmPassword, setConfirmPassword] = useState('');
const [error, setError] = useState<string | null>(null);
const [success, setSuccess] = useState(false);
const [isSubmitting, setIsSubmitting] = useState(false);
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setError(null);
setSuccess(false);
if (newPassword.length < 8) {
setError('新密码至少 8 个字符');
return;
}
if (newPassword !== confirmPassword) {
setError('两次输入的新密码不一致');
return;
}
setIsSubmitting(true);
try {
await saasClient.changePassword(oldPassword, newPassword);
setSuccess(true);
setOldPassword('');
setNewPassword('');
setConfirmPassword('');
} catch (err: unknown) {
const message = err instanceof Error ? err.message : '密码修改失败';
setError(message);
} finally {
setIsSubmitting(false);
}
};
return (
<div className="mt-6">
<div
className="flex items-center justify-between cursor-pointer"
onClick={() => setIsOpen(!isOpen)}
>
<h2 className="text-sm font-medium text-gray-500 uppercase tracking-wide">
</h2>
<span className="text-xs text-gray-400">{isOpen ? '收起' : '展开'}</span>
</div>
{isOpen && (
<div className="bg-white rounded-xl border border-gray-200 p-5 shadow-sm mt-3">
<div className="flex items-center gap-2 mb-4">
<KeyRound className="w-4 h-4 text-gray-400" />
<span className="text-sm font-medium text-gray-700"></span>
</div>
<form onSubmit={handleSubmit} className="space-y-3">
<div>
<label className="block text-xs font-medium text-gray-500 mb-1">
</label>
<input
type="password"
value={oldPassword}
onChange={(e) => setOldPassword(e.target.value)}
required
className="w-full px-3 py-2 text-sm border border-gray-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-emerald-500 focus:border-transparent"
/>
</div>
<div>
<label className="block text-xs font-medium text-gray-500 mb-1">
</label>
<input
type="password"
value={newPassword}
onChange={(e) => setNewPassword(e.target.value)}
required
minLength={8}
className="w-full px-3 py-2 text-sm border border-gray-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-emerald-500 focus:border-transparent"
/>
</div>
<div>
<label className="block text-xs font-medium text-gray-500 mb-1">
</label>
<input
type="password"
value={confirmPassword}
onChange={(e) => setConfirmPassword(e.target.value)}
required
minLength={8}
className="w-full px-3 py-2 text-sm border border-gray-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-emerald-500 focus:border-transparent"
/>
</div>
{error && (
<p className="text-xs text-red-500">{error}</p>
)}
{success && (
<p className="text-xs text-emerald-600"></p>
)}
<button
type="submit"
disabled={isSubmitting}
className="w-full py-2 text-sm font-medium text-white bg-emerald-600 rounded-lg hover:bg-emerald-700 disabled:opacity-50 transition-colors"
>
{isSubmitting ? '修改中...' : '修改密码'}
</button>
</form>
</div>
)}
</div>
);
}

View File

@@ -0,0 +1,192 @@
import { useEffect, useState } from 'react';
import { saasClient, type SaaSAccountInfo, type SaaSModelInfo } from '../../lib/saas-client';
import { Cloud, CloudOff, LogOut, RefreshCw, Cpu, CheckCircle, XCircle, Loader2 } from 'lucide-react';
import { useSaaSStore } from '../../store/saasStore';
interface SaaSStatusProps {
isLoggedIn: boolean;
account: SaaSAccountInfo | null;
saasUrl: string;
onLogout: () => void;
onLogin: () => void;
}
export function SaaSStatus({ isLoggedIn, account, saasUrl, onLogout, onLogin }: SaaSStatusProps) {
const availableModels = useSaaSStore((s) => s.availableModels);
const fetchAvailableModels = useSaaSStore((s) => s.fetchAvailableModels);
const [serverReachable, setServerReachable] = useState<boolean>(true);
const [checkingHealth, setCheckingHealth] = useState(false);
const [healthOk, setHealthOk] = useState<boolean | null>(null);
const [showDetails, setShowDetails] = useState(false);
useEffect(() => {
if (isLoggedIn) {
fetchAvailableModels();
}
}, [isLoggedIn, fetchAvailableModels]);
// Poll server reachability every 30s
useEffect(() => {
if (!isLoggedIn) return;
const check = () => {
setServerReachable(saasClient.isServerReachable());
};
check();
const timer = setInterval(check, 30000);
return () => clearInterval(timer);
}, [isLoggedIn]);
async function checkHealth() {
setCheckingHealth(true);
setHealthOk(null);
try {
const response = await fetch(`${saasUrl}/api/health`, {
signal: AbortSignal.timeout(5000),
});
setHealthOk(response.ok);
} catch {
setHealthOk(false);
} finally {
setCheckingHealth(false);
}
}
if (isLoggedIn && account) {
const displayName = account.display_name || account.username;
const initial = displayName[0].toUpperCase();
return (
<div className="space-y-4">
{/* Main status bar */}
<div className="flex items-center justify-between rounded-lg border border-emerald-200 bg-emerald-50 p-4">
<div className="flex items-center gap-3">
<div className="w-9 h-9 rounded-full bg-emerald-500 flex items-center justify-center text-white font-semibold text-sm flex-shrink-0">
{initial}
</div>
<div className="min-w-0">
<div className="font-medium text-gray-900 text-sm">{displayName}</div>
<div className="text-xs text-gray-500 truncate">{saasUrl}</div>
<span className="inline-block mt-0.5 text-xs px-1.5 py-0.5 rounded bg-emerald-100 text-emerald-700 font-medium">
{account.role}
</span>
</div>
</div>
<div className="flex items-center gap-2 flex-shrink-0">
{serverReachable ? (
<div className="flex items-center gap-1.5 text-emerald-600 text-xs">
<Cloud className="w-3.5 h-3.5" />
<span></span>
</div>
) : (
<div className="flex items-center gap-1.5 text-amber-500 text-xs">
<CloudOff className="w-3.5 h-3.5" />
<span>线</span>
</div>
)}
<button
onClick={() => setShowDetails(!showDetails)}
className="px-2 py-1.5 text-xs text-gray-600 border border-gray-300 rounded-lg hover:bg-gray-100 transition-colors cursor-pointer"
>
</button>
<button
onClick={onLogout}
className="flex items-center gap-1.5 px-3 py-1.5 text-xs text-gray-600 border border-gray-300 rounded-lg hover:bg-gray-100 transition-colors cursor-pointer"
>
<LogOut className="w-3.5 h-3.5" />
</button>
</div>
</div>
{/* Expandable details */}
{showDetails && (
<div className="bg-white rounded-xl border border-gray-200 p-5 shadow-sm space-y-4">
{/* Health Check */}
<div className="flex justify-between items-center">
<span className="text-sm text-gray-700"></span>
<div className="flex items-center gap-2">
{healthOk === null && !checkingHealth && (
<span className="text-xs text-gray-400"></span>
)}
{checkingHealth && <Loader2 className="w-4 h-4 animate-spin text-gray-400" />}
{healthOk === true && (
<div className="flex items-center gap-1 text-green-600 text-sm">
<CheckCircle className="w-4 h-4" />
</div>
)}
{healthOk === false && (
<div className="flex items-center gap-1 text-red-500 text-sm">
<XCircle className="w-4 h-4" />
</div>
)}
<button
onClick={checkHealth}
disabled={checkingHealth}
className="p-1 text-gray-400 hover:text-gray-600 cursor-pointer disabled:opacity-50"
>
<RefreshCw className={`w-3.5 h-3.5 ${checkingHealth ? 'animate-spin' : ''}`} />
</button>
</div>
</div>
{/* Available Models */}
<div>
<div className="flex items-center gap-2 mb-2">
<Cpu className="w-4 h-4 text-gray-500" />
<span className="text-sm font-medium text-gray-700">
({availableModels.length})
</span>
</div>
{availableModels.length === 0 ? (
<p className="text-sm text-gray-400 pl-6">
Provider Model
</p>
) : (
<div className="space-y-1 pl-6">
{availableModels.map((model) => (
<ModelRow key={model.id} model={model} />
))}
</div>
)}
</div>
</div>
)}
</div>
);
}
return (
<div className="flex items-center justify-between rounded-lg border border-gray-200 bg-gray-50 p-4">
<div className="flex items-center gap-3">
<CloudOff className="w-5 h-5 text-gray-400" />
<div>
<div className="font-medium text-gray-900 text-sm">SaaS </div>
<div className="text-xs text-gray-500"></div>
</div>
</div>
<button
onClick={onLogin}
className="flex items-center gap-1.5 px-3 py-1.5 text-xs text-white bg-emerald-500 rounded-lg hover:bg-emerald-600 transition-colors cursor-pointer"
>
<Cloud className="w-3.5 h-3.5" />
</button>
</div>
);
}
function ModelRow({ model }: { model: SaaSModelInfo }) {
return (
<div className="flex items-center justify-between py-1.5 px-3 bg-gray-50 rounded-lg">
<span className="text-sm text-gray-800">{model.alias || model.id}</span>
<div className="flex items-center gap-2 text-xs text-gray-400">
{model.supports_streaming && <span></span>}
{model.supports_vision && <span></span>}
<span className="font-mono">{(model.context_window / 1000).toFixed(0)}k</span>
</div>
</div>
);
}

View File

@@ -0,0 +1,285 @@
import { useState } from 'react';
import { useSaaSStore } from '../../store/saasStore';
import { Shield, ShieldCheck, ShieldOff, Copy, Check, Loader2, AlertCircle, X } from 'lucide-react';
export function TOTPSettings() {
const account = useSaaSStore((s) => s.account);
const totpSetupData = useSaaSStore((s) => s.totpSetupData);
const isLoading = useSaaSStore((s) => s.isLoading);
const storeError = useSaaSStore((s) => s.error);
const setupTotp = useSaaSStore((s) => s.setupTotp);
const verifyTotp = useSaaSStore((s) => s.verifyTotp);
const disableTotp = useSaaSStore((s) => s.disableTotp);
const cancelTotpSetup = useSaaSStore((s) => s.cancelTotpSetup);
const [verifyCode, setVerifyCode] = useState('');
const [disablePassword, setDisablePassword] = useState('');
const [showDisable, setShowDisable] = useState(false);
const [localError, setLocalError] = useState<string | null>(null);
const [success, setSuccess] = useState<string | null>(null);
const [copiedSecret, setCopiedSecret] = useState(false);
const displayError = storeError || localError;
const isEnabled = account?.totp_enabled ?? false;
const isSettingUp = !!totpSetupData;
const handleSetup = async () => {
setLocalError(null);
setSuccess(null);
setVerifyCode('');
try {
await setupTotp();
} catch {
// error already in store
}
};
const handleVerify = async () => {
if (verifyCode.length !== 6) return;
setLocalError(null);
setSuccess(null);
try {
await verifyTotp(verifyCode);
setVerifyCode('');
setSuccess('TOTP 已成功启用');
} catch {
// error already in store
}
};
const handleDisable = async () => {
if (!disablePassword) {
setLocalError('请输入密码确认');
return;
}
setLocalError(null);
setSuccess(null);
try {
await disableTotp(disablePassword);
setDisablePassword('');
setShowDisable(false);
setSuccess('TOTP 已成功禁用');
} catch {
// error already in store
}
};
const handleCopySecret = async () => {
if (!totpSetupData) return;
try {
await navigator.clipboard.writeText(totpSetupData.secret);
setCopiedSecret(true);
setTimeout(() => setCopiedSecret(false), 2000);
} catch {
// clipboard API not available
}
};
const handleCancel = () => {
cancelTotpSetup();
setVerifyCode('');
setLocalError(null);
};
// Setup flow: QR code + verify code input
if (isSettingUp) {
return (
<div className="bg-white rounded-xl border border-gray-200 p-5 shadow-sm space-y-4">
<div className="flex items-center justify-between">
<div className="flex items-center gap-2">
<Shield className="w-5 h-5 text-emerald-600" />
<h3 className="text-sm font-semibold text-gray-900"></h3>
</div>
<button
type="button"
onClick={handleCancel}
className="text-gray-400 hover:text-gray-600 cursor-pointer"
>
<X className="w-4 h-4" />
</button>
</div>
<p className="text-sm text-gray-500">
使 Google Authenticator / Authy
</p>
{/* QR Code */}
<div className="flex flex-col items-center gap-3 py-2">
<img
src={`https://api.qrserver.com/v1/create-qr-code/?data=${encodeURIComponent(totpSetupData.otpauth_uri)}&size=200x200`}
alt="TOTP QR Code"
className="w-48 h-48 border border-gray-200 rounded-lg"
/>
</div>
{/* Manual secret */}
<div>
<p className="text-xs text-gray-500 mb-1">:</p>
<div className="flex items-center gap-2">
<code className="flex-1 px-2 py-1 bg-gray-50 rounded text-xs font-mono text-gray-700 break-all">
{totpSetupData.secret}
</code>
<button
type="button"
onClick={handleCopySecret}
className="flex-shrink-0 p-1 text-gray-400 hover:text-emerald-600 cursor-pointer"
title="复制密钥"
>
{copiedSecret ? <Check className="w-4 h-4" /> : <Copy className="w-4 h-4" />}
</button>
</div>
</div>
{/* Verify code input */}
<div>
<label htmlFor="totp-verify-code" className="block text-sm font-medium text-gray-700 mb-1.5">
</label>
<input
id="totp-verify-code"
type="text"
inputMode="numeric"
maxLength={6}
value={verifyCode}
onChange={(e) => setVerifyCode(e.target.value.replace(/\D/g, ''))}
placeholder="输入 6 位验证码"
autoComplete="one-time-code"
autoFocus
className="w-full px-3 py-2 border border-gray-300 rounded-lg text-sm font-mono tracking-widest text-center focus:outline-none focus:ring-2 focus:ring-emerald-500/20 focus:border-emerald-500 bg-white text-gray-900"
disabled={isLoading}
onKeyDown={(e) => {
if (e.key === 'Enter' && verifyCode.length === 6) handleVerify();
}}
/>
</div>
{displayError && (
<div className="flex items-start gap-2 text-sm text-red-600 bg-red-50 rounded-lg p-3">
<AlertCircle className="w-4 h-4 mt-0.5 flex-shrink-0" />
<span>{displayError}</span>
</div>
)}
<div className="flex gap-2">
<button
type="button"
onClick={handleCancel}
disabled={isLoading}
className="flex-1 px-4 py-2 text-sm text-gray-600 border border-gray-300 rounded-lg hover:bg-gray-50 transition-colors disabled:opacity-50 cursor-pointer"
>
</button>
<button
type="button"
onClick={handleVerify}
disabled={isLoading || verifyCode.length !== 6}
className="flex-1 flex items-center justify-center gap-2 px-4 py-2 bg-emerald-500 hover:bg-emerald-600 text-white text-sm font-medium rounded-lg transition-colors disabled:opacity-50 disabled:cursor-not-allowed cursor-pointer"
>
{isLoading ? <Loader2 className="w-4 h-4 animate-spin" /> : <ShieldCheck className="w-4 h-4" />}
</button>
</div>
</div>
);
}
return (
<div className="bg-white rounded-xl border border-gray-200 p-5 shadow-sm space-y-4">
<div className="flex items-center justify-between">
<div className="flex items-center gap-2">
{isEnabled ? (
<ShieldCheck className="w-5 h-5 text-emerald-600" />
) : (
<ShieldOff className="w-5 h-5 text-gray-400" />
)}
<h3 className="text-sm font-semibold text-gray-900"></h3>
</div>
<span className={`text-xs px-2 py-0.5 rounded-full font-medium ${
isEnabled ? 'bg-emerald-100 text-emerald-700' : 'bg-gray-100 text-gray-500'
}`}>
{isEnabled ? '已启用' : '未启用'}
</span>
</div>
<p className="text-sm text-gray-500">
{isEnabled
? '你的账号已启用双因素认证,登录时需要输入 TOTP 验证码。'
: '启用双因素认证可以增强账号安全性。'}
</p>
{displayError && (
<div className="flex items-start gap-2 text-sm text-red-600 bg-red-50 rounded-lg p-3">
<AlertCircle className="w-4 h-4 mt-0.5 flex-shrink-0" />
<span>{displayError}</span>
</div>
)}
{success && (
<div className="flex items-start gap-2 text-sm text-emerald-600 bg-emerald-50 rounded-lg p-3">
<Check className="w-4 h-4 mt-0.5 flex-shrink-0" />
<span>{success}</span>
</div>
)}
{!isEnabled && !showDisable && (
<button
type="button"
onClick={handleSetup}
disabled={isLoading}
className="flex items-center justify-center gap-2 px-4 py-2 bg-emerald-500 hover:bg-emerald-600 text-white text-sm font-medium rounded-lg transition-colors disabled:opacity-50 cursor-pointer"
>
{isLoading ? <Loader2 className="w-4 h-4 animate-spin" /> : <Shield className="w-4 h-4" />}
TOTP
</button>
)}
{isEnabled && !showDisable && (
<button
type="button"
onClick={() => setShowDisable(true)}
className="flex items-center justify-center gap-2 px-4 py-2 text-sm text-red-600 border border-red-300 rounded-lg hover:bg-red-50 transition-colors cursor-pointer"
>
<ShieldOff className="w-4 h-4" />
TOTP
</button>
)}
{showDisable && (
<div className="space-y-3 p-3 bg-red-50 rounded-lg border border-red-200">
<p className="text-sm text-red-700"> TOTP </p>
<input
type="password"
value={disablePassword}
onChange={(e) => setDisablePassword(e.target.value)}
placeholder="输入当前密码"
autoComplete="current-password"
className="w-full px-3 py-2 border border-gray-300 rounded-lg text-sm focus:outline-none focus:ring-2 focus:ring-red-500/20 focus:border-red-500 bg-white text-gray-900"
disabled={isLoading}
onKeyDown={(e) => {
if (e.key === 'Enter') handleDisable();
}}
/>
<div className="flex gap-2">
<button
type="button"
onClick={() => { setShowDisable(false); setDisablePassword(''); setLocalError(null); }}
disabled={isLoading}
className="flex-1 px-3 py-1.5 text-sm text-gray-600 border border-gray-300 rounded-lg hover:bg-gray-50 transition-colors cursor-pointer"
>
</button>
<button
type="button"
onClick={handleDisable}
disabled={isLoading || !disablePassword}
className="flex-1 flex items-center justify-center gap-2 px-3 py-1.5 text-sm text-red-600 border border-red-300 rounded-lg hover:bg-red-100 transition-colors disabled:opacity-50 cursor-pointer"
>
{isLoading ? <Loader2 className="w-3.5 h-3.5 animate-spin" /> : null}
</button>
</div>
</div>
)}
</div>
);
}

View File

@@ -6,18 +6,7 @@ import { useConfigStore } from '../../store/configStore';
import { useChatStore } from '../../store/chatStore';
import { silentErrorHandler } from '../../lib/error-utils';
import { Plus, Pencil, Trash2, Star, Eye, EyeOff, AlertCircle, X, Zap, Check } from 'lucide-react';
// 自定义模型数据结构
interface CustomModel {
id: string;
name: string;
provider: string;
apiKey?: string;
apiProtocol: 'openai' | 'anthropic' | 'custom';
baseUrl?: string;
isDefault?: boolean;
createdAt: string;
}
import type { CustomModel, CustomModelApiProtocol } from '../../types/config';
// Embedding 配置数据结构
interface EmbeddingConfig {
@@ -140,7 +129,7 @@ export function ModelsAPI() {
modelId: 'glm-4-flash',
displayName: '',
apiKey: '',
apiProtocol: 'openai' as 'openai' | 'anthropic' | 'custom',
apiProtocol: 'openai' as CustomModelApiProtocol,
baseUrl: '',
});
@@ -650,7 +639,7 @@ export function ModelsAPI() {
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-2">API </label>
<select
value={formData.apiProtocol}
onChange={(e) => setFormData({ ...formData, apiProtocol: e.target.value as 'openai' | 'anthropic' | 'custom' })}
onChange={(e) => setFormData({ ...formData, apiProtocol: e.target.value as CustomModelApiProtocol })}
className="w-full px-3 py-2 border border-gray-200 dark:border-gray-600 rounded-lg text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-white focus:outline-none focus:ring-2 focus:ring-orange-500"
>
<option value="openai">OpenAI</option>

View File

@@ -18,6 +18,7 @@ import {
Heart,
Key,
Database,
Cloud,
} from 'lucide-react';
import { silentErrorHandler } from '../../lib/error-utils';
import { General } from './General';
@@ -37,6 +38,7 @@ import { TaskList } from '../TaskList';
import { HeartbeatConfig } from '../HeartbeatConfig';
import { SecureStorage } from './SecureStorage';
import { VikingPanel } from '../VikingPanel';
import { SaaSSettings } from '../SaaS/SaaSSettings';
interface SettingsLayoutProps {
onBack: () => void;
@@ -54,6 +56,7 @@ type SettingsPage =
| 'privacy'
| 'security'
| 'storage'
| 'saas'
| 'viking'
| 'audit'
| 'tasks'
@@ -72,6 +75,7 @@ const menuItems: { id: SettingsPage; label: string; icon: React.ReactNode }[] =
{ id: 'workspace', label: '工作区', icon: <FolderOpen className="w-4 h-4" /> },
{ id: 'privacy', label: '数据与隐私', icon: <Shield className="w-4 h-4" /> },
{ id: 'storage', label: '安全存储', icon: <Key className="w-4 h-4" /> },
{ id: 'saas', label: 'SaaS 平台', icon: <Cloud className="w-4 h-4" /> },
{ id: 'viking', label: '语义记忆', icon: <Database className="w-4 h-4" /> },
{ id: 'security', label: '安全状态', icon: <Shield className="w-4 h-4" /> },
{ id: 'audit', label: '审计日志', icon: <ClipboardList className="w-4 h-4" /> },
@@ -97,6 +101,7 @@ export function SettingsLayout({ onBack }: SettingsLayoutProps) {
case 'workspace': return <Workspace />;
case 'privacy': return <Privacy />;
case 'storage': return <SecureStorage />;
case 'saas': return <SaaSSettings />;
case 'security': return (
<div className="space-y-6">
<div>

View File

@@ -455,10 +455,24 @@ export function clearSecurityLog(): void {
}
/**
* Generate a random API key for testing
* WARNING: Only use for testing purposes
* Generate a random API key for testing.
*
* @internal This function is intended solely for automated tests and
* development tooling. It must never be called in production
* builds because generated keys are not cryptographically secure
* and should never be used to authenticate against real services.
*
* @param type - The API key type to generate a test key for
* @returns A random API key that passes format validation for the given type
* @throws {Error} If called outside of a development or test environment
*/
export function generateTestApiKey(type: ApiKeyType): string {
if (import.meta.env?.DEV !== true && import.meta.env?.MODE !== 'test') {
throw new Error(
'[Security] generateTestApiKey may only be called in development or test environments'
);
}
const rules = KEY_VALIDATION_RULES[type];
const length = rules.minLength + 10;
const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';

View File

@@ -37,13 +37,17 @@ export {
DEFAULT_GATEWAY_URL,
REST_API_URL,
FALLBACK_GATEWAY_URLS,
ZCLAW_GRPC_PORT,
ZCLAW_LEGACY_PORT,
normalizeGatewayUrl,
isLocalhost,
getStoredGatewayUrl,
setStoredGatewayUrl,
getStoredGatewayToken,
setStoredGatewayToken,
detectConnectionMode,
} from './gateway-storage';
export type { ConnectionMode } from './gateway-storage';
// === Internal imports ===
import type {
@@ -69,6 +73,7 @@ import {
isLocalhost,
getStoredGatewayUrl,
getStoredGatewayToken,
detectConnectionMode,
} from './gateway-storage';
import type { GatewayConfigSnapshot, GatewayModelChoice } from './gateway-config';
@@ -273,8 +278,8 @@ export class GatewayClient {
return Promise.resolve();
}
// Check if URL is for ZCLAW (port 4200 or 50051) - use REST mode
if (this.url.includes(':4200') || this.url.includes(':50051')) {
// Check if URL is for ZCLAW (known kernel ports) - use REST mode
if (detectConnectionMode(this.url) === 'rest') {
return this.connectRest();
}

View File

@@ -40,15 +40,47 @@ export function isLocalhost(url: string): boolean {
}
}
// === Port Constants ===
/** Default gRPC/HTTP port used by the ZCLAW kernel */
export const ZCLAW_GRPC_PORT = 50051;
/** Legacy/alternative port used in development or older configurations */
export const ZCLAW_LEGACY_PORT = 4200;
// === Connection Mode ===
/**
* Determines how the client connects to the ZCLAW gateway.
* - `rest`: Kernel exposes an HTTP REST API (gRPC-gateway). Used when the
* URL contains a known kernel port.
* - `ws`: Direct WebSocket connection to the kernel.
*/
export type ConnectionMode = 'rest' | 'ws';
/**
* Decide the connection mode based on the gateway URL.
*
* When the URL contains a known kernel port (gRPC or legacy), the client
* routes requests through the REST adapter instead of opening a raw
* WebSocket.
*/
export function detectConnectionMode(url: string): ConnectionMode {
if (url.includes(`:${ZCLAW_GRPC_PORT}`) || url.includes(`:${ZCLAW_LEGACY_PORT}`)) {
return 'rest';
}
return 'ws';
}
// === URL Constants ===
// ZCLAW endpoints (port 50051 - actual running port)
// Note: REST API uses relative path to leverage Vite proxy for CORS bypass
export const DEFAULT_GATEWAY_URL = `${DEFAULT_WS_PROTOCOL}127.0.0.1:50051/ws`;
export const DEFAULT_GATEWAY_URL = `${DEFAULT_WS_PROTOCOL}127.0.0.1:${ZCLAW_GRPC_PORT}/ws`;
export const REST_API_URL = ''; // Empty = use relative path (Vite proxy)
export const FALLBACK_GATEWAY_URLS = [
DEFAULT_GATEWAY_URL,
`${DEFAULT_WS_PROTOCOL}127.0.0.1:4200/ws`,
`${DEFAULT_WS_PROTOCOL}127.0.0.1:${ZCLAW_LEGACY_PORT}/ws`,
];
const GATEWAY_URL_STORAGE_KEY = 'zclaw_gateway_url';

View File

@@ -18,7 +18,7 @@ import { DEFAULT_MODEL_ID, DEFAULT_OPENAI_BASE_URL } from '../constants/models';
// === Types ===
export type LLMProvider = 'openai' | 'volcengine' | 'gateway' | 'mock';
export type LLMProvider = 'openai' | 'volcengine' | 'gateway' | 'saas' | 'mock';
export interface LLMConfig {
provider: LLMProvider;
@@ -77,6 +77,12 @@ const DEFAULT_CONFIGS: Record<LLMProvider, LLMConfig> = {
temperature: 0.7,
timeout: 60000,
},
saas: {
provider: 'saas',
maxTokens: 4096,
temperature: 0.7,
timeout: 300000, // 5 min for streaming
},
mock: {
provider: 'mock',
maxTokens: 100,
@@ -412,6 +418,85 @@ class GatewayLLMAdapter implements LLMServiceAdapter {
}
}
// === SaaS Relay Adapter (via SaaS backend) ===
class SaasLLMAdapter implements LLMServiceAdapter {
private config: LLMConfig;
constructor(config: LLMConfig) {
this.config = { ...DEFAULT_CONFIGS.saas, ...config };
}
async complete(messages: LLMMessage[], options?: Partial<LLMConfig>): Promise<LLMResponse> {
const config = { ...this.config, ...options };
const startTime = Date.now();
// Dynamic import to avoid circular dependency
const { useSaaSStore } = await import('../store/saasStore');
const { saasUrl, authToken } = useSaaSStore.getState();
if (!saasUrl || !authToken) {
throw new Error('[SaaS] 未登录 SaaS 平台,请先在设置中登录');
}
// Dynamic import of SaaSClient singleton
const { saasClient } = await import('./saas-client');
saasClient.setBaseUrl(saasUrl);
saasClient.setToken(authToken);
const openaiBody = {
model: config.model || 'default',
messages,
max_tokens: config.maxTokens || 4096,
temperature: config.temperature ?? 0.7,
stream: false,
};
const response = await saasClient.chatCompletion(
openaiBody,
AbortSignal.timeout(config.timeout || 300000),
);
if (!response.ok) {
const errorData = await response.json().catch(() => ({
error: 'unknown',
message: `SaaS relay 请求失败 (${response.status})`,
}));
throw new Error(
`[SaaS] ${errorData.message || errorData.error || `请求失败: ${response.status}`}`,
);
}
const data = await response.json();
const latencyMs = Date.now() - startTime;
return {
content: data.choices?.[0]?.message?.content || '',
tokensUsed: {
input: data.usage?.prompt_tokens || 0,
output: data.usage?.completion_tokens || 0,
},
model: data.model,
latencyMs,
};
}
isAvailable(): boolean {
// Check synchronously via localStorage for availability check
// Dynamic import would be async, so we use a simpler check
try {
const token = localStorage.getItem('zclaw-saas-token');
return !!token;
} catch {
return false;
}
}
getProvider(): LLMProvider {
return 'saas';
}
}
// === Factory ===
let cachedAdapter: LLMServiceAdapter | null = null;
@@ -427,6 +512,8 @@ export function createLLMAdapter(config?: Partial<LLMConfig>): LLMServiceAdapter
return new VolcengineLLMAdapter(finalConfig);
case 'gateway':
return new GatewayLLMAdapter(finalConfig);
case 'saas':
return new SaasLLMAdapter(finalConfig);
case 'mock':
default:
return new MockLLMAdapter(finalConfig);

View File

@@ -0,0 +1,763 @@
/**
* ZCLAW SaaS Client
*
* Typed HTTP client for the ZCLAW SaaS backend API (v1).
* Handles authentication, model listing, chat relay, and config management.
*
* API base path: /api/v1/...
* Auth: Bearer token in Authorization header
*
* Security: JWT token is stored via secureStorage (OS keychain or encrypted localStorage).
* URL, account info, and connection mode remain in plain localStorage (non-sensitive).
*/
import { secureStorage } from './secure-storage';
// === Storage Keys ===
const SAASTOKEN_KEY = 'zclaw-saas-token';
const SAASURL_KEY = 'zclaw-saas-url';
const SAASACCOUNT_KEY = 'zclaw-saas-account';
const SAASMODE_KEY = 'zclaw-connection-mode';
// === Types ===
/** Public account info returned by the SaaS backend */
export interface SaaSAccountInfo {
id: string;
username: string;
email: string;
display_name: string;
role: string;
status: string;
totp_enabled: boolean;
created_at: string;
}
/** A model available for relay through the SaaS backend */
export interface SaaSModelInfo {
id: string;
provider_id: string;
alias: string;
context_window: number;
max_output_tokens: number;
supports_streaming: boolean;
supports_vision: boolean;
}
/** Config item from the SaaS backend */
export interface SaaSConfigItem {
id: string;
category: string;
key_path: string;
value_type: string;
current_value: string | null;
default_value: string | null;
source: string;
description: string | null;
requires_restart: boolean;
created_at: string;
updated_at: string;
}
/** SaaS API error shape */
export interface SaaSErrorResponse {
error: string;
message: string;
}
/** Login response from POST /api/v1/auth/login */
export interface SaaSLoginResponse {
token: string;
account: SaaSAccountInfo;
}
/** Refresh response from POST /api/v1/auth/refresh */
interface SaaSRefreshResponse {
token: string;
}
/** TOTP setup response from POST /api/v1/auth/totp/setup */
export interface TotpSetupResponse {
otpauth_uri: string;
secret: string;
issuer: string;
}
/** TOTP verify/disable response */
export interface TotpResultResponse {
ok: boolean;
totp_enabled: boolean;
message: string;
}
/** Device info stored on the SaaS backend */
export interface DeviceInfo {
id: string;
device_id: string;
device_name: string | null;
platform: string | null;
app_version: string | null;
last_seen_at: string;
created_at: string;
}
/** Relay task info from GET /api/v1/relay/tasks */
export interface RelayTaskInfo {
id: string;
account_id: string;
provider_id: string;
model_id: string;
status: string;
priority: number;
attempt_count: number;
max_attempts: number;
input_tokens: number;
output_tokens: number;
error_message: string | null;
queued_at: string;
started_at: string | null;
completed_at: string | null;
created_at: string;
}
/** Config diff request for POST /api/v1/config/diff and /sync */
export interface SyncConfigRequest {
client_fingerprint: string;
action: 'push' | 'merge';
config_keys: string[];
client_values: Record<string, unknown>;
}
/** A single config diff entry */
export interface ConfigDiffItem {
key_path: string;
client_value: string | null;
saas_value: string | null;
conflict: boolean;
}
/** Config diff response */
export interface ConfigDiffResponse {
items: ConfigDiffItem[];
total_keys: number;
conflicts: number;
}
/** Config sync result */
export interface ConfigSyncResult {
updated: number;
created: number;
skipped: number;
}
// === JWT Helpers ===
/**
* Decode a JWT payload without verifying the signature.
* Returns the parsed JSON payload, or null if the token is malformed.
*/
export function decodeJwtPayload<T = Record<string, unknown>>(token: string): T | null {
try {
const parts = token.split('.');
if (parts.length !== 3) return null;
// JWT payload is Base64Url-encoded
const base64 = parts[1].replace(/-/g, '+').replace(/_/g, '/');
const json = decodeURIComponent(
atob(base64)
.split('')
.map((c) => '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2))
.join(''),
);
return JSON.parse(json) as T;
} catch {
return null;
}
}
/** JWT payload shape we care about */
interface JwtPayload {
exp?: number;
iat?: number;
sub?: string;
}
/**
* Calculate the delay (ms) until 80% of the token's lifetime has elapsed.
* This is the ideal moment to trigger a proactive refresh.
* Returns null if the token has no exp claim or is already past 80% lifetime.
*/
export function getRefreshDelay(exp: number): number | null {
const now = Math.floor(Date.now() / 1000);
const totalLifetime = exp - now;
if (totalLifetime <= 0) return null; // already expired
// Refresh at 80% of the token's remaining lifetime
const refreshAt = now + Math.floor(totalLifetime * 0.8);
const delayMs = (refreshAt - now) * 1000;
// Minimum 5-second guard to avoid hammering the endpoint
return delayMs > 5000 ? delayMs : 5000;
}
// === Error Class ===
export class SaaSApiError extends Error {
constructor(
public readonly status: number,
public readonly code: string,
message: string,
) {
super(message);
this.name = 'SaaSApiError';
}
}
// === Session Persistence ===
export interface SaaSSession {
token: string;
account: SaaSAccountInfo | null;
saasUrl: string;
}
/**
* Read a value from localStorage with error handling.
*/
function readLegacyLocalStorage(key: string): string | null {
try {
return localStorage.getItem(key);
} catch {
return null;
}
}
/**
* Load a persisted SaaS session using secure storage for the JWT token.
* Falls back to legacy localStorage if secureStorage has no token (migration).
* Returns null if no valid session exists.
*/
export async function loadSaaSSessionAsync(): Promise<SaaSSession | null> {
try {
// Try secure storage first (keychain or encrypted localStorage)
const token = await secureStorage.get(SAASTOKEN_KEY);
// Migration: if secureStorage is empty, try legacy localStorage
const legacyToken = !token ? readLegacyLocalStorage(SAASTOKEN_KEY) : null;
const saasUrl = readLegacyLocalStorage(SAASURL_KEY);
const accountRaw = readLegacyLocalStorage(SAASACCOUNT_KEY);
const effectiveToken = token || legacyToken;
if (!effectiveToken || !saasUrl) {
return null;
}
const account: SaaSAccountInfo | null = accountRaw
? (JSON.parse(accountRaw) as SaaSAccountInfo)
: null;
// If we found a legacy token in localStorage, migrate it to secure storage
if (legacyToken && !token) {
await secureStorage.set(SAASTOKEN_KEY, legacyToken);
// Remove plaintext token from localStorage after migration
try { localStorage.removeItem(SAASTOKEN_KEY); } catch { /* ignore */ }
}
return { token: effectiveToken, account, saasUrl };
} catch {
// Corrupted data - clear all
await clearSaaSSessionAsync();
return null;
}
}
/**
* Persist a SaaS session using secure storage for the JWT token.
* URL and account info remain in localStorage (non-sensitive).
*/
export async function saveSaaSSessionAsync(session: SaaSSession): Promise<void> {
await secureStorage.set(SAASTOKEN_KEY, session.token);
// Remove legacy plaintext token from localStorage
try { localStorage.removeItem(SAASTOKEN_KEY); } catch { /* ignore */ }
localStorage.setItem(SAASURL_KEY, session.saasUrl);
if (session.account) {
localStorage.setItem(SAASACCOUNT_KEY, JSON.stringify(session.account));
}
}
/**
* Clear the persisted SaaS session from both secure storage and localStorage.
*/
export async function clearSaaSSessionAsync(): Promise<void> {
await secureStorage.delete(SAASTOKEN_KEY);
try { localStorage.removeItem(SAASTOKEN_KEY); } catch { /* ignore */ }
try { localStorage.removeItem(SAASURL_KEY); } catch { /* ignore */ }
try { localStorage.removeItem(SAASACCOUNT_KEY); } catch { /* ignore */ }
}
/**
* Persist the connection mode to localStorage.
* Connection mode is non-sensitive -- no need for secure storage.
*/
export function saveConnectionMode(mode: string): void {
localStorage.setItem(SAASMODE_KEY, mode);
}
/**
* Load the connection mode from localStorage.
* Returns null if not set.
*/
export function loadConnectionMode(): string | null {
return localStorage.getItem(SAASMODE_KEY);
}
// === Client Implementation ===
/** Callback invoked when token refresh fails and the session should be terminated. */
export type OnSessionExpired = () => void;
export class SaaSClient {
private baseUrl: string;
private token: string | null = null;
private refreshTimerId: ReturnType<typeof setTimeout> | null = null;
private visibilityHandler: (() => void) | null = null;
private onSessionExpired: OnSessionExpired | null = null;
constructor(baseUrl: string) {
this.baseUrl = baseUrl.replace(/\/+$/, '');
}
/** Update the base URL (e.g. when user changes server address) */
setBaseUrl(url: string): void {
this.baseUrl = url.replace(/\/+$/, '');
}
/** Get the current base URL */
getBaseUrl(): string {
return this.baseUrl;
}
/** Set or clear the auth token. Automatically schedules a proactive refresh. */
setToken(token: string | null): void {
this.token = token;
if (token) {
this.scheduleTokenRefresh();
} else {
this.cancelTokenRefresh();
}
}
/**
* Register a callback invoked when the proactive token refresh fails.
* The caller should use this to trigger a logout/redirect flow.
*/
setOnSessionExpired(handler: OnSessionExpired): void {
this.onSessionExpired = handler;
}
/** Check if the client has an auth token */
isAuthenticated(): boolean {
return !!this.token;
}
/**
* Schedule a proactive token refresh at 80% of the token's remaining lifetime.
* Also registers a visibilitychange listener to re-check when the tab regains focus.
*/
scheduleTokenRefresh(): void {
this.cancelTokenRefresh();
if (!this.token) return;
const payload = decodeJwtPayload<JwtPayload>(this.token);
if (!payload?.exp) return;
const delay = getRefreshDelay(payload.exp);
if (delay === null) {
// Token already expired or too close -- attempt immediate refresh
this.attemptTokenRefresh();
return;
}
this.refreshTimerId = setTimeout(() => {
this.attemptTokenRefresh();
}, delay);
// When the tab becomes visible again, check if we should refresh sooner
if (typeof document !== 'undefined' && !this.visibilityHandler) {
this.visibilityHandler = () => {
if (document.visibilityState === 'visible') {
this.checkAndRefreshToken();
}
};
document.addEventListener('visibilitychange', this.visibilityHandler);
}
}
/**
* Cancel any pending token refresh timer and remove the visibility listener.
*/
cancelTokenRefresh(): void {
if (this.refreshTimerId !== null) {
clearTimeout(this.refreshTimerId);
this.refreshTimerId = null;
}
if (this.visibilityHandler !== null && typeof document !== 'undefined') {
document.removeEventListener('visibilitychange', this.visibilityHandler);
this.visibilityHandler = null;
}
}
/**
* Check if the current token is close to expiry and refresh if needed.
* Called on visibility change to handle clock skew / long background tabs.
*/
private checkAndRefreshToken(): void {
if (!this.token) return;
const payload = decodeJwtPayload<JwtPayload>(this.token);
if (!payload?.exp) return;
const now = Math.floor(Date.now() / 1000);
const remaining = payload.exp - now;
// If less than 20% of lifetime remains, refresh now
if (remaining <= 0) {
this.attemptTokenRefresh();
return;
}
// If the scheduled refresh is more than 60s away and we're within 80%, do it now
const delay = getRefreshDelay(payload.exp);
if (delay !== null && delay < 60_000) {
this.attemptTokenRefresh();
}
}
/**
* Attempt to refresh the token. On failure, invoke the session-expired callback.
* Persists the new token via secureStorage.
*/
private attemptTokenRefresh(): Promise<void> {
return this.refreshToken()
.then(async (newToken) => {
// Persist the new token to secure storage
const existing = await loadSaaSSessionAsync();
if (existing) {
await saveSaaSSessionAsync({ ...existing, token: newToken });
}
})
.catch(() => {
// Refresh failed -- notify the app to log out
this.cancelTokenRefresh();
if (this.onSessionExpired) {
this.onSessionExpired();
}
});
}
// --- Core HTTP ---
/** Track whether the server appears reachable */
private _serverReachable: boolean = true;
/** Check if the SaaS server was last known to be reachable */
isServerReachable(): boolean {
return this._serverReachable;
}
/**
* Make an authenticated request with automatic retry on transient failures.
* Retries up to 2 times with exponential backoff (1s, 2s).
* Throws SaaSApiError on non-ok responses.
*/
public async request<T>(
method: string,
path: string,
body?: unknown,
timeoutMs = 15000,
): Promise<T> {
const maxRetries = 2;
const baseDelay = 1000;
for (let attempt = 0; attempt <= maxRetries; attempt++) {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
};
if (this.token) {
headers['Authorization'] = `Bearer ${this.token}`;
}
try {
const response = await fetch(`${this.baseUrl}${path}`, {
method,
headers,
body: body !== undefined ? JSON.stringify(body) : undefined,
signal: AbortSignal.timeout(timeoutMs),
});
this._serverReachable = true;
// Handle 401 specially - caller may want to trigger re-auth
if (response.status === 401) {
throw new SaaSApiError(401, 'UNAUTHORIZED', '认证已过期,请重新登录');
}
if (!response.ok) {
const errorBody = (await response.json().catch(() => null)) as SaaSErrorResponse | null;
throw new SaaSApiError(
response.status,
errorBody?.error || 'UNKNOWN',
errorBody?.message || `请求失败 (${response.status})`,
);
}
// 204 No Content
if (response.status === 204) {
return undefined as T;
}
return response.json() as Promise<T>;
} catch (err: unknown) {
const isNetworkError = err instanceof TypeError
&& (err.message.includes('Failed to fetch') || err.message.includes('NetworkError'));
if (isNetworkError && attempt < maxRetries) {
this._serverReachable = false;
const delay = baseDelay * Math.pow(2, attempt);
await new Promise((r) => setTimeout(r, delay));
continue;
}
this._serverReachable = false;
if (err instanceof SaaSApiError) throw err;
throw new SaaSApiError(0, 'NETWORK_ERROR', `网络错误: ${err instanceof Error ? err.message : String(err)}`);
}
}
// Unreachable, but TypeScript needs it
throw new SaaSApiError(0, 'UNKNOWN', '请求失败');
}
// --- Health ---
/**
* Quick connectivity check against the SaaS backend.
*/
async healthCheck(): Promise<boolean> {
try {
await this.request<unknown>('GET', '/api/health', undefined, 5000);
return true;
} catch {
return false;
}
}
// --- Auth Endpoints ---
/**
* Login with username and password.
* Auto-sets the client token on success.
*/
async login(username: string, password: string, totpCode?: string): Promise<SaaSLoginResponse> {
const body: Record<string, string> = { username, password };
if (totpCode) body.totp_code = totpCode;
const data = await this.request<SaaSLoginResponse>(
'POST', '/api/v1/auth/login', body,
);
this.token = data.token;
return data;
}
/**
* Register a new account.
* Auto-sets the client token on success.
*/
async register(data: {
username: string;
email: string;
password: string;
display_name?: string;
}): Promise<SaaSLoginResponse> {
const result = await this.request<SaaSLoginResponse>(
'POST', '/api/v1/auth/register', data,
);
this.token = result.token;
return result;
}
/**
* Get the current authenticated user's account info.
*/
async me(): Promise<SaaSAccountInfo> {
return this.request<SaaSAccountInfo>('GET', '/api/v1/auth/me');
}
/**
* Refresh the current token.
* Auto-updates the client token on success.
*/
async refreshToken(): Promise<string> {
const data = await this.request<SaaSRefreshResponse>('POST', '/api/v1/auth/refresh');
this.token = data.token;
return data.token;
}
/**
* Change the current user's password.
*/
async changePassword(oldPassword: string, newPassword: string): Promise<void> {
await this.request<unknown>('PUT', '/api/v1/auth/password', {
old_password: oldPassword,
new_password: newPassword,
});
}
// --- TOTP Endpoints ---
/** Generate a TOTP secret and otpauth URI */
async setupTotp(): Promise<TotpSetupResponse> {
return this.request<TotpSetupResponse>('POST', '/api/v1/auth/totp/setup');
}
/** Verify a TOTP code and enable 2FA */
async verifyTotp(code: string): Promise<TotpResultResponse> {
return this.request<TotpResultResponse>('POST', '/api/v1/auth/totp/verify', { code });
}
/** Disable 2FA (requires password confirmation) */
async disableTotp(password: string): Promise<TotpResultResponse> {
return this.request<TotpResultResponse>('POST', '/api/v1/auth/totp/disable', { password });
}
// --- Device Endpoints ---
/**
* Register or update this device with the SaaS backend.
* Uses UPSERT semantics -- same (account, device_id) updates last_seen_at.
*/
async registerDevice(params: {
device_id: string;
device_name?: string;
platform?: string;
app_version?: string;
}): Promise<void> {
await this.request<unknown>('POST', '/api/v1/devices/register', params);
}
/**
* Send a heartbeat to indicate the device is still active.
*/
async deviceHeartbeat(deviceId: string): Promise<void> {
await this.request<unknown>('POST', '/api/v1/devices/heartbeat', {
device_id: deviceId,
});
}
/**
* List devices registered for the current account.
*/
async listDevices(): Promise<DeviceInfo[]> {
return this.request<DeviceInfo[]>('GET', '/api/v1/devices');
}
// --- Model Endpoints ---
/**
* List available models for relay.
* Only returns enabled models from enabled providers.
*/
async listModels(): Promise<SaaSModelInfo[]> {
return this.request<SaaSModelInfo[]>('GET', '/api/v1/relay/models');
}
// --- Relay Task Management ---
/** List relay tasks for the current user */
async listRelayTasks(query?: { status?: string; page?: number; page_size?: number }): Promise<RelayTaskInfo[]> {
const params = new URLSearchParams();
if (query?.status) params.set('status', query.status);
if (query?.page) params.set('page', String(query.page));
if (query?.page_size) params.set('page_size', String(query.page_size));
const qs = params.toString();
return this.request<RelayTaskInfo[]>('GET', `/api/v1/relay/tasks${qs ? '?' + qs : ''}`);
}
/** Get a single relay task */
async getRelayTask(taskId: string): Promise<RelayTaskInfo> {
return this.request<RelayTaskInfo>('GET', `/api/v1/relay/tasks/${taskId}`);
}
/** Retry a failed relay task (admin only) */
async retryRelayTask(taskId: string): Promise<{ ok: boolean; task_id: string }> {
return this.request<{ ok: boolean; task_id: string }>('POST', `/api/v1/relay/tasks/${taskId}/retry`);
}
// --- Chat Relay ---
/**
* Send a chat completion request via the SaaS relay.
* Returns the raw Response object to support both streaming and non-streaming.
*
* The caller is responsible for:
* - Reading the response body (JSON or SSE stream)
* - Handling errors from the response
*/
async chatCompletion(
body: unknown,
signal?: AbortSignal,
): Promise<Response> {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
};
if (this.token) {
headers['Authorization'] = `Bearer ${this.token}`;
}
// Use caller's AbortSignal if provided, otherwise default 5min timeout
const effectiveSignal = signal ?? AbortSignal.timeout(300_000);
const response = await fetch(
`${this.baseUrl}/api/v1/relay/chat/completions`,
{
method: 'POST',
headers,
body: JSON.stringify(body),
signal: effectiveSignal,
},
);
return response;
}
// --- Config Endpoints ---
/**
* List config items, optionally filtered by category.
*/
async listConfig(category?: string): Promise<SaaSConfigItem[]> {
const qs = category ? `?category=${encodeURIComponent(category)}` : '';
return this.request<SaaSConfigItem[]>('GET', `/api/v1/config/items${qs}`);
}
/** Compute config diff between client and SaaS (read-only) */
async computeConfigDiff(request: SyncConfigRequest): Promise<ConfigDiffResponse> {
return this.request<ConfigDiffResponse>('POST', '/api/v1/config/diff', request);
}
/** Sync config from client to SaaS (push) or merge */
async syncConfig(request: SyncConfigRequest): Promise<ConfigSyncResult> {
return this.request<ConfigSyncResult>('POST', '/api/v1/config/sync', request);
}
}
// === Singleton ===
/**
* Global SaaS client singleton.
* Initialized with a default URL; the URL and token are updated on login.
*/
export const saasClient = new SaaSClient('https://saas.zclaw.com');

View File

@@ -37,18 +37,9 @@ const log = createLogger('ConnectionStore');
// === Custom Models Helpers ===
const CUSTOM_MODELS_STORAGE_KEY = 'zclaw-custom-models';
import type { CustomModel } from '../types/config';
interface CustomModel {
id: string;
name: string;
provider: string;
apiKey?: string;
apiProtocol: 'openai' | 'anthropic' | 'custom';
baseUrl?: string;
isDefault?: boolean;
createdAt: string;
}
const CUSTOM_MODELS_STORAGE_KEY = 'zclaw-custom-models';
/**
* Get custom models from localStorage
@@ -213,6 +204,37 @@ export const useConnectionStore = create<ConnectionStore>((set, get) => {
try {
set({ error: null });
// === SaaS Relay Mode ===
// Check connection mode from localStorage (set by saasStore).
// This takes priority over Tauri/Gateway when the user has selected SaaS mode.
const savedMode = localStorage.getItem('zclaw-connection-mode');
if (savedMode === 'saas') {
const { loadSaaSSessionAsync, saasClient } = await import('../lib/saas-client');
const session = await loadSaaSSessionAsync();
if (!session || !session.token || !session.saasUrl) {
throw new Error('SaaS 模式未登录,请先在设置中登录 SaaS 平台');
}
log.debug('Using SaaS relay mode:', session.saasUrl);
// Configure the singleton client
saasClient.setBaseUrl(session.saasUrl);
saasClient.setToken(session.token);
// Health check via GET /api/v1/relay/models
try {
await saasClient.listModels();
} catch (err) {
const errMsg = err instanceof Error ? err.message : String(err);
throw new Error(`SaaS 平台连接失败: ${errMsg}`);
}
set({ connectionState: 'connected', gatewayVersion: 'saas-relay' });
log.debug('Connected to SaaS relay');
return;
}
// === Internal Kernel Mode (Tauri) ===
// Check at RUNTIME, not at module load time, to ensure __TAURI_INTERNALS__ is available
const useInternalKernel = isTauriRuntime();

View File

@@ -35,6 +35,10 @@ export type { SessionStore, SessionStateSlice, SessionActionsSlice, Session, Ses
export { useMemoryGraphStore } from './memoryGraphStore';
export type { MemoryGraphStore, GraphNode, GraphEdge, GraphFilter, GraphLayout } from './memoryGraphStore';
// === SaaS Store ===
export { useSaaSStore } from './saasStore';
export type { SaaSStore, SaaSStateSlice, SaaSActionsSlice, ConnectionMode } from './saasStore';
// === Browser Hand Store ===
export { useBrowserHandStore } from './browserHandStore';

View File

@@ -0,0 +1,489 @@
/**
* SaaS Store - SaaS Platform Connection State Management
*
* Manages SaaS login state, account info, connection mode,
* and available models. Persists auth token via secureStorage
* (OS keychain or encrypted localStorage) for security.
*
* Connection modes:
* - 'tauri': Local Kernel via Tauri (default)
* - 'gateway': External Gateway via WebSocket
* - 'saas': SaaS backend relay
*/
import { create } from 'zustand';
import {
saasClient,
SaaSApiError,
loadSaaSSessionAsync,
saveSaaSSessionAsync,
clearSaaSSessionAsync,
saveConnectionMode,
loadConnectionMode,
type SaaSAccountInfo,
type SaaSModelInfo,
type SaaSLoginResponse,
type TotpSetupResponse,
} from '../lib/saas-client';
import { createLogger } from '../lib/logger';
const log = createLogger('SaaSStore');
// === Device ID ===
/** Generate or load a persistent device ID for this browser instance */
function getOrCreateDeviceId(): string {
const KEY = 'zclaw-device-id';
const existing = localStorage.getItem(KEY);
if (existing) return existing;
const newId = crypto.randomUUID();
localStorage.setItem(KEY, newId);
return newId;
}
const DEVICE_ID = getOrCreateDeviceId();
// === Types ===
export type ConnectionMode = 'tauri' | 'gateway' | 'saas';
export interface SaaSStateSlice {
isLoggedIn: boolean;
account: SaaSAccountInfo | null;
saasUrl: string;
authToken: string | null;
connectionMode: ConnectionMode;
availableModels: SaaSModelInfo[];
isLoading: boolean;
error: string | null;
totpRequired: boolean;
totpSetupData: TotpSetupResponse | null;
}
export interface SaaSActionsSlice {
login: (saasUrl: string, username: string, password: string) => Promise<void>;
loginWithTotp: (saasUrl: string, username: string, password: string, totpCode: string) => Promise<void>;
register: (saasUrl: string, username: string, email: string, password: string, displayName?: string) => Promise<void>;
logout: () => Promise<void>;
setConnectionMode: (mode: ConnectionMode) => void;
fetchAvailableModels: () => Promise<void>;
registerCurrentDevice: () => Promise<void>;
clearError: () => void;
restoreSession: () => Promise<void>;
setupTotp: () => Promise<TotpSetupResponse>;
verifyTotp: (code: string) => Promise<void>;
disableTotp: (password: string) => Promise<void>;
cancelTotpSetup: () => void;
}
export type SaaSStore = SaaSStateSlice & SaaSActionsSlice;
// === Constants ===
const DEFAULT_SAAS_URL = 'https://saas.zclaw.com';
// === Helpers ===
/** Determine the initial connection mode from persisted state */
function resolveInitialMode(hasSession: boolean): ConnectionMode {
const persistedMode = loadConnectionMode();
if (persistedMode === 'tauri' || persistedMode === 'gateway' || persistedMode === 'saas') {
return persistedMode;
}
return hasSession ? 'saas' : 'tauri';
}
// === Store Implementation ===
export const useSaaSStore = create<SaaSStore>((set, get) => {
// Determine initial connection mode synchronously from localStorage.
// Session token will be loaded asynchronously via restoreSession().
const persistedMode = loadConnectionMode();
const hasSession = persistedMode === 'saas';
const initialMode = resolveInitialMode(hasSession);
// Kick off async session restoration immediately.
// The store initializes with a "potentially logged in" state based on
// the connection mode, and restoreSession() will either hydrate the token
// or clear the session if secure storage has no token.
loadSaaSSessionAsync().then((session) => {
if (session) {
saasClient.setBaseUrl(session.saasUrl);
saasClient.setToken(session.token);
set({
isLoggedIn: true,
account: session.account,
saasUrl: session.saasUrl,
authToken: session.token,
connectionMode: resolveInitialMode(true),
});
// Fetch models in background after async restore
get().fetchAvailableModels().catch(() => {});
} else if (persistedMode === 'saas') {
// Connection mode was 'saas' but no token found -- reset to tauri
saveConnectionMode('tauri');
set({ connectionMode: 'tauri' });
}
}).catch(() => {
// secureStorage read failed -- keep defaults
});
return {
// === Initial State ===
// Session data will be hydrated by the async restoreSession above.
isLoggedIn: hasSession,
account: null,
saasUrl: DEFAULT_SAAS_URL,
authToken: null,
connectionMode: initialMode,
availableModels: [],
isLoading: false,
error: null,
totpRequired: false,
totpSetupData: null,
// === Actions ===
login: async (saasUrl: string, username: string, password: string) => {
set({ isLoading: true, error: null });
try {
const trimmedUrl = saasUrl.trim();
const trimmedUsername = username.trim();
if (!trimmedUrl) {
throw new Error('请输入服务器地址');
}
if (!trimmedUsername) {
throw new Error('请输入用户名');
}
if (!password) {
throw new Error('请输入密码');
}
const normalizedUrl = trimmedUrl.replace(/\/+$/, '');
// Configure singleton client and attempt login
saasClient.setBaseUrl(normalizedUrl);
const loginData: SaaSLoginResponse = await saasClient.login(trimmedUsername, password);
// Persist session securely
const sessionData = {
token: loginData.token,
account: loginData.account,
saasUrl: normalizedUrl,
};
await saveSaaSSessionAsync(sessionData);
saveConnectionMode('saas');
set({
isLoggedIn: true,
account: loginData.account,
saasUrl: normalizedUrl,
authToken: loginData.token,
connectionMode: 'saas',
isLoading: false,
error: null,
});
// Register device and start heartbeat in background
get().registerCurrentDevice().catch((err: unknown) => {
log.warn('Failed to register device:', err);
});
// Fetch available models in background (non-blocking)
get().fetchAvailableModels().catch((err: unknown) => {
log.warn('Failed to fetch models after login:', err);
});
} catch (err: unknown) {
// Check for TOTP required signal
if (err instanceof SaaSApiError && err.code === 'TOTP_ERROR' && err.status === 400) {
set({ isLoading: false, totpRequired: true, error: null });
return;
}
const message = err instanceof SaaSApiError
? err.message
: err instanceof Error
? err.message
: String(err);
const isNetworkError = message.includes('Failed to fetch')
|| message.includes('NetworkError')
|| message.includes('ECONNREFUSED')
|| message.includes('timeout');
const userMessage = isNetworkError
? `无法连接到 SaaS 服务器: ${get().saasUrl}`
: message;
set({ isLoading: false, error: userMessage });
throw new Error(userMessage);
}
},
loginWithTotp: async (saasUrl: string, username: string, password: string, totpCode: string) => {
set({ isLoading: true, error: null, totpRequired: false });
try {
const normalizedUrl = saasUrl.trim().replace(/\/+$/, '');
saasClient.setBaseUrl(normalizedUrl);
const loginData = await saasClient.login(username.trim(), password, totpCode);
const sessionData = {
token: loginData.token,
account: loginData.account,
saasUrl: normalizedUrl,
};
await saveSaaSSessionAsync(sessionData);
saveConnectionMode('saas');
set({
isLoggedIn: true,
account: loginData.account,
saasUrl: normalizedUrl,
authToken: loginData.token,
connectionMode: 'saas',
isLoading: false,
error: null,
totpRequired: false,
});
get().registerCurrentDevice().catch((err: unknown) => {
log.warn('Failed to register device:', err);
});
get().fetchAvailableModels().catch((err: unknown) => {
log.warn('Failed to fetch models:', err);
});
} catch (err: unknown) {
const message = err instanceof SaaSApiError ? err.message
: err instanceof Error ? err.message : String(err);
set({ isLoading: false, error: message });
throw new Error(message);
}
},
register: async (saasUrl: string, username: string, email: string, password: string, displayName?: string) => {
set({ isLoading: true, error: null });
try {
const trimmedUrl = saasUrl.trim();
if (!trimmedUrl) {
throw new Error('请输入服务器地址');
}
if (!username.trim()) {
throw new Error('请输入用户名');
}
if (!email.trim()) {
throw new Error('请输入邮箱');
}
if (!password) {
throw new Error('请输入密码');
}
const normalizedUrl = trimmedUrl.replace(/\/+$/, '');
saasClient.setBaseUrl(normalizedUrl);
const registerData: SaaSLoginResponse = await saasClient.register({
username: username.trim(),
email: email.trim(),
password,
display_name: displayName,
});
const sessionData = {
token: registerData.token,
account: registerData.account,
saasUrl: normalizedUrl,
};
await saveSaaSSessionAsync(sessionData);
saveConnectionMode('saas');
set({
isLoggedIn: true,
account: registerData.account,
saasUrl: normalizedUrl,
authToken: registerData.token,
connectionMode: 'saas',
isLoading: false,
error: null,
});
get().registerCurrentDevice().catch((err: unknown) => {
log.warn('Failed to register device after register:', err);
});
get().fetchAvailableModels().catch((err: unknown) => {
log.warn('Failed to fetch models after register:', err);
});
} catch (err: unknown) {
const message = err instanceof SaaSApiError
? err.message
: err instanceof Error
? err.message
: String(err);
set({ isLoading: false, error: message });
throw new Error(message);
}
},
logout: async () => {
saasClient.setToken(null);
await clearSaaSSessionAsync();
saveConnectionMode('tauri');
set({
isLoggedIn: false,
account: null,
authToken: null,
connectionMode: 'tauri',
availableModels: [],
error: null,
totpRequired: false,
totpSetupData: null,
});
},
setConnectionMode: (mode: ConnectionMode) => {
const { isLoggedIn } = get();
// Cannot switch to SaaS mode if not logged in
if (mode === 'saas' && !isLoggedIn) {
return;
}
saveConnectionMode(mode);
set({ connectionMode: mode });
},
fetchAvailableModels: async () => {
const { isLoggedIn, authToken, saasUrl } = get();
if (!isLoggedIn || !authToken) {
set({ availableModels: [] });
return;
}
try {
saasClient.setBaseUrl(saasUrl);
saasClient.setToken(authToken);
const models = await saasClient.listModels();
set({ availableModels: models });
} catch (err: unknown) {
log.warn('Failed to fetch available models:', err);
// Do not set error state - model fetch failure is non-critical
set({ availableModels: [] });
}
},
registerCurrentDevice: async () => {
const { isLoggedIn, authToken, saasUrl } = get();
if (!isLoggedIn || !authToken) {
return;
}
try {
saasClient.setBaseUrl(saasUrl);
saasClient.setToken(authToken);
await saasClient.registerDevice({
device_id: DEVICE_ID,
device_name: `${navigator.userAgent.split(' ').slice(0, 3).join(' ')}`,
platform: navigator.platform,
app_version: __APP_VERSION__ || 'unknown',
});
log.info('Device registered successfully');
// Start periodic heartbeat (every 5 minutes)
if (typeof window !== 'undefined' && !get()._heartbeatTimer) {
const timer = window.setInterval(() => {
const state = get();
if (state.isLoggedIn && state.authToken) {
saasClient.deviceHeartbeat(DEVICE_ID).catch(() => {});
} else {
window.clearInterval(timer);
}
}, 5 * 60 * 1000);
set({ _heartbeatTimer: timer } as unknown as Partial<SaaSStore>);
}
} catch (err: unknown) {
log.warn('Failed to register device:', err);
}
},
clearError: () => {
set({ error: null });
},
restoreSession: async () => {
const restored = await loadSaaSSessionAsync();
if (restored) {
saasClient.setBaseUrl(restored.saasUrl);
saasClient.setToken(restored.token);
set({
isLoggedIn: true,
account: restored.account,
saasUrl: restored.saasUrl,
authToken: restored.token,
connectionMode: loadConnectionMode() === 'saas' ? 'saas' : 'tauri',
});
get().fetchAvailableModels().catch(() => {});
}
},
setupTotp: async () => {
set({ isLoading: true, error: null });
try {
const setupData = await saasClient.setupTotp();
set({ totpSetupData: setupData, isLoading: false });
return setupData;
} catch (err: unknown) {
const message = err instanceof SaaSApiError ? err.message
: err instanceof Error ? err.message : String(err);
set({ isLoading: false, error: message });
throw new Error(message);
}
},
verifyTotp: async (code: string) => {
set({ isLoading: true, error: null });
try {
await saasClient.verifyTotp(code);
const account = await saasClient.me();
const { saasUrl, authToken } = get();
if (authToken) {
await saveSaaSSessionAsync({ token: authToken, account, saasUrl });
}
set({ totpSetupData: null, isLoading: false, account });
} catch (err: unknown) {
const message = err instanceof SaaSApiError ? err.message
: err instanceof Error ? err.message : String(err);
set({ isLoading: false, error: message });
throw new Error(message);
}
},
disableTotp: async (password: string) => {
set({ isLoading: true, error: null });
try {
await saasClient.disableTotp(password);
const account = await saasClient.me();
const { saasUrl, authToken } = get();
if (authToken) {
await saveSaaSSessionAsync({ token: authToken, account, saasUrl });
}
set({ isLoading: false, account });
} catch (err: unknown) {
const message = err instanceof SaaSApiError ? err.message
: err instanceof Error ? err.message : String(err);
set({ isLoading: false, error: message });
throw new Error(message);
}
},
cancelTotpSetup: () => {
set({ totpSetupData: null });
},
};
});

View File

@@ -571,3 +571,35 @@ export interface ConfigFileMetadata {
/** Whether the file has unresolved env vars */
hasUnresolvedEnvVars?: boolean;
}
// ============================================================
// Custom Model Types
// ============================================================
/**
* API protocol supported by a custom model provider.
*/
export type CustomModelApiProtocol = 'openai' | 'anthropic' | 'custom';
/**
* User-defined custom model configuration.
* Used by the model settings UI and the connection store.
*/
export interface CustomModel {
/** Unique identifier */
id: string;
/** Human-readable model name */
name: string;
/** Provider / vendor name */
provider: string;
/** API key (optional, stored separately in secure storage) */
apiKey?: string;
/** Which API protocol this provider speaks */
apiProtocol: CustomModelApiProtocol;
/** Base URL for the provider API (optional) */
baseUrl?: string;
/** Whether this model is the user's default */
isDefault?: boolean;
/** ISO-8601 timestamp of when this model was added */
createdAt: string;
}

View File

@@ -141,6 +141,12 @@ export type {
AutomationItem,
} from './automation';
// Custom Model Types
export type {
CustomModel,
CustomModelApiProtocol,
} from './config';
// Automation Constants and Functions
export {
HAND_CATEGORY_MAP,

Some files were not shown because too many files have changed in this diff Show More