Compare commits

1 Commits

Author SHA1 Message Date
iven
44256a511c feat: 增强SaaS后端功能与安全性
refactor: 重构数据库连接使用PostgreSQL替代SQLite
feat(auth): 增加JWT验证的audience和issuer检查
feat(crypto): 添加AES-256-GCM字段加密支持
feat(api): 集成utoipa实现OpenAPI文档
fix(admin): 修复配置项表单验证逻辑
style: 统一代码格式与类型定义
docs: 更新技术栈文档说明PostgreSQL
2026-03-31 00:12:53 +08:00
177 changed files with 9731 additions and 948 deletions

93
.dockerignore Normal file
View File

@@ -0,0 +1,93 @@
# ============================================================
# ZCLAW SaaS Backend - Docker Ignore
# ============================================================
# Build artifacts
target/
# Frontend applications (not needed for SaaS backend)
desktop/
admin/
design-system/
# Node.js
node_modules/
.pnpm-store/
bun.lock
pnpm-lock.yaml
package.json
package-lock.json
# Git
.git/
.gitignore
# IDE and editor
.vscode/
.idea/
*.swp
*.swo
*~
# OS files
.DS_Store
Thumbs.db
# Docker
.docker/
docker-compose*.yml
Dockerfile
.dockerignore
# Documentation
docs/
*.md
!saas-config.toml
CLAUDE.md
CLAUDE*.md
# Environment files (secrets)
.env
.env.*
saas-env.example
# Data files
saas-data/
saas-data.db
saas-data.db-shm
saas-data.db-wal
*.db
*.db-shm
*.db-wal
# Test artifacts
tests/
test-results/
test.rs
*.log
# Temporary files
tmp-screenshot.png
tmp/
temp/
*.tmp
# Claude worktree metadata
.claude/
plans/
pipelines/
scripts/
hands/
skills/
plugins/
config/
extract.js
extract_models.js
extract_privacy.js
start-all.ps1
start.ps1
start.sh
Makefile
PROGRESS.md
CHANGELOG.md
pencil-new.pen

150
Cargo.lock generated
View File

@@ -2800,6 +2800,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]] [[package]]
name = "minimal-lexical" name = "minimal-lexical"
version = "0.2.1" version = "0.2.1"
@@ -4148,6 +4158,41 @@ dependencies = [
"zeroize", "zeroize",
] ]
[[package]]
name = "rust-embed"
version = "8.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04113cb9355a377d83f06ef1f0a45b8ab8cd7d8b1288160717d66df5c7988d27"
dependencies = [
"rust-embed-impl",
"rust-embed-utils",
"walkdir",
]
[[package]]
name = "rust-embed-impl"
version = "8.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da0902e4c7c8e997159ab384e6d0fc91c221375f6894346ae107f47dd0f3ccaa"
dependencies = [
"proc-macro2",
"quote",
"rust-embed-utils",
"shellexpand",
"syn 2.0.117",
"walkdir",
]
[[package]]
name = "rust-embed-utils"
version = "8.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5bcdef0be6fe7f6fa333b1073c949729274b05f123a0ad7efcb8efd878e5c3b1"
dependencies = [
"sha2",
"walkdir",
]
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "2.1.1" version = "2.1.1"
@@ -4617,6 +4662,15 @@ dependencies = [
"lazy_static", "lazy_static",
] ]
[[package]]
name = "shellexpand"
version = "3.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32824fab5e16e6c4d86dc1ba84489390419a39f97699852b66480bb87d297ed8"
dependencies = [
"dirs",
]
[[package]] [[package]]
name = "shlex" name = "shlex"
version = "1.3.0" version = "1.3.0"
@@ -4795,6 +4849,7 @@ dependencies = [
"atoi", "atoi",
"byteorder", "byteorder",
"bytes", "bytes",
"chrono",
"crc", "crc",
"crossbeam-queue", "crossbeam-queue",
"either", "either",
@@ -4855,6 +4910,7 @@ dependencies = [
"sha2", "sha2",
"sqlx-core", "sqlx-core",
"sqlx-mysql", "sqlx-mysql",
"sqlx-postgres",
"sqlx-sqlite", "sqlx-sqlite",
"syn 1.0.109", "syn 1.0.109",
"tempfile", "tempfile",
@@ -4873,6 +4929,7 @@ dependencies = [
"bitflags 2.11.0", "bitflags 2.11.0",
"byteorder", "byteorder",
"bytes", "bytes",
"chrono",
"crc", "crc",
"digest", "digest",
"dotenvy", "dotenvy",
@@ -4914,6 +4971,7 @@ dependencies = [
"base64 0.21.7", "base64 0.21.7",
"bitflags 2.11.0", "bitflags 2.11.0",
"byteorder", "byteorder",
"chrono",
"crc", "crc",
"dotenvy", "dotenvy",
"etcetera", "etcetera",
@@ -4949,6 +5007,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa" checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
dependencies = [ dependencies = [
"atoi", "atoi",
"chrono",
"flume", "flume",
"futures-channel", "futures-channel",
"futures-core", "futures-core",
@@ -5989,6 +6048,12 @@ dependencies = [
"unic-common", "unic-common",
] ]
[[package]]
name = "unicase"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142"
[[package]] [[package]]
name = "unicode-bidi" name = "unicode-bidi"
version = "0.3.18" version = "0.3.18"
@@ -6099,6 +6164,70 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "utoipa"
version = "4.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5afb1a60e207dca502682537fefcfd9921e71d0b83e9576060f09abc6efab23"
dependencies = [
"indexmap 2.13.0",
"serde",
"serde_json",
"utoipa-gen 4.3.1",
]
[[package]]
name = "utoipa"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2fcc29c80c21c31608227e0912b2d7fddba57ad76b606890627ba8ee7964e993"
dependencies = [
"indexmap 2.13.0",
"serde",
"serde_json",
"utoipa-gen 5.4.0",
]
[[package]]
name = "utoipa-gen"
version = "4.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20c24e8ab68ff9ee746aad22d39b5535601e6416d1b0feeabf78be986a5c4392"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 2.0.117",
]
[[package]]
name = "utoipa-gen"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d79d08d92ab8af4c5e8a6da20c47ae3f61a0f1dabc1997cdf2d082b757ca08b"
dependencies = [
"proc-macro2",
"quote",
"regex",
"syn 2.0.117",
]
[[package]]
name = "utoipa-swagger-ui"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f839caa8e09dddc3ff1c3112a91ef7da0601075ba5025d9f33ae99c4cb9b6e51"
dependencies = [
"axum",
"mime_guess",
"regex",
"rust-embed",
"serde",
"serde_json",
"utoipa 4.2.3",
"zip 0.6.6",
]
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.22.0" version = "1.22.0"
@@ -7338,7 +7467,7 @@ dependencies = [
"zclaw-runtime", "zclaw-runtime",
"zclaw-skills", "zclaw-skills",
"zclaw-types", "zclaw-types",
"zip", "zip 2.4.2",
] ]
[[package]] [[package]]
@@ -7432,17 +7561,19 @@ dependencies = [
name = "zclaw-saas" name = "zclaw-saas"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"aes-gcm",
"anyhow", "anyhow",
"argon2", "argon2",
"async-stream",
"axum", "axum",
"axum-extra", "axum-extra",
"bytes",
"chrono", "chrono",
"dashmap", "dashmap",
"data-encoding", "data-encoding",
"futures", "futures",
"hex", "hex",
"jsonwebtoken", "jsonwebtoken",
"libsqlite3-sys",
"rand 0.8.5", "rand 0.8.5",
"reqwest 0.12.28", "reqwest 0.12.28",
"secrecy", "secrecy",
@@ -7461,8 +7592,9 @@ dependencies = [
"tracing-subscriber", "tracing-subscriber",
"url", "url",
"urlencoding", "urlencoding",
"utoipa 5.4.0",
"utoipa-swagger-ui",
"uuid", "uuid",
"zclaw-types",
] ]
[[package]] [[package]]
@@ -7572,6 +7704,18 @@ dependencies = [
"syn 2.0.117", "syn 2.0.117",
] ]
[[package]]
name = "zip"
version = "0.6.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261"
dependencies = [
"byteorder",
"crc32fast",
"crossbeam-utils",
"flate2",
]
[[package]] [[package]]
name = "zip" name = "zip"
version = "2.4.2" version = "2.4.2"

View File

@@ -57,7 +57,7 @@ chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1", features = ["v4", "v5", "serde"] } uuid = { version = "1", features = ["v4", "v5", "serde"] }
# Database # Database
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite"] } sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "chrono"] }
libsqlite3-sys = { version = "0.27", features = ["bundled"] } libsqlite3-sys = { version = "0.27", features = ["bundled"] }
# HTTP client (for LLM drivers) # HTTP client (for LLM drivers)

83
Dockerfile Normal file
View File

@@ -0,0 +1,83 @@
# ============================================================
# ZCLAW SaaS Backend - Multi-stage Docker Build
# ============================================================
# ---- Stage 1: Builder ----
FROM rust:1.75-bookworm AS builder
# Install build dependencies for sqlx (postgres) and libsqlite3-sys (bundled)
RUN apt-get update && apt-get install -y --no-install-recommends \
pkg-config \
libssl-dev \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
# Copy workspace manifests first to leverage Docker layer caching
COPY Cargo.toml Cargo.lock ./
# Create stub source files so cargo can resolve and cache dependencies
# This avoids rebuilding dependencies when only application code changes
RUN mkdir -p crates/zclaw-saas/src \
&& echo 'fn main() {}' > crates/zclaw-saas/src/main.rs \
&& for crate in zclaw-types zclaw-memory zclaw-runtime zclaw-kernel \
zclaw-skills zclaw-hands zclaw-channels zclaw-protocols \
zclaw-pipeline zclaw-growth; do \
mkdir -p crates/$crate/src && echo '' > crates/$crate/src/lib.rs; \
done \
&& mkdir -p desktop/src-tauri/src && echo 'fn main() {}' > desktop/src-tauri/src/main.rs
# Pre-build dependencies (release profile with caching)
RUN cargo build --release --package zclaw-saas 2>/dev/null || true
# Copy actual source code (invalidates stubs, triggers recompile of app code only)
COPY crates/ crates/
COPY desktop/ desktop/
# Touch source files to invalidate the stub timestamps
RUN touch crates/zclaw-saas/src/main.rs \
&& for crate in zclaw-types zclaw-memory zclaw-runtime zclaw-kernel \
zclaw-skills zclaw-hands zclaw-channels zclaw-protocols \
zclaw-pipeline zclaw-growth; do \
touch crates/$crate/src/lib.rs 2>/dev/null || true; \
done \
&& touch desktop/src-tauri/src/main.rs 2>/dev/null || true
# Build the actual binary
RUN cargo build --release --package zclaw-saas
# ---- Stage 2: Runtime ----
FROM debian:bookworm-slim AS runtime
# Install runtime dependencies (ca-certificates for HTTPS, libgcc for Rust runtime)
RUN apt-get update && apt-get install -y --no-install-recommends \
ca-certificates \
libgcc-s \
&& rm -rf /var/lib/apt/lists/* \
&& update-ca-certificates
# Create non-root user for security
RUN groupadd --gid 1000 zclaw \
&& useradd --uid 1000 --gid zclaw --shell /bin/false zclaw
WORKDIR /app
# Copy binary from builder
COPY --from=builder /app/target/release/zclaw-saas /app/zclaw-saas
# Copy configuration file
COPY saas-config.toml /app/saas-config.toml
# Ensure the non-root user owns the application files
RUN chown -R zclaw:zclaw /app
USER zclaw
# Expose the SaaS API port
EXPOSE 8080
# Health check endpoint (matches the saas-config.toml port)
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
CMD ["/app/zclaw-saas", "--healthcheck"] || exit 1
ENTRYPOINT ["/app/zclaw-saas"]

View File

@@ -1,7 +1,9 @@
# ZCLAW Makefile # ZCLAW Makefile
# Cross-platform task runner # Cross-platform task runner
.PHONY: help start start-dev start-no-browser desktop desktop-build setup test clean .PHONY: help start start-dev start-no-browser desktop desktop-build setup test clean \
saas-build saas-run saas-test saas-test-integration saas-clippy saas-migrate \
saas-docker-up saas-docker-down saas-docker-build
help: ## Show this help message help: ## Show this help message
@echo "ZCLAW - AI Agent Desktop Client" @echo "ZCLAW - AI Agent Desktop Client"
@@ -71,3 +73,32 @@ clean-deep: clean ## Deep clean (including pnpm cache)
@rm -rf desktop/pnpm-lock.yaml @rm -rf desktop/pnpm-lock.yaml
@rm -rf pnpm-lock.yaml @rm -rf pnpm-lock.yaml
@echo "Deep clean complete. Run 'pnpm install' to reinstall." @echo "Deep clean complete. Run 'pnpm install' to reinstall."
# === SaaS Backend ===
saas-build: ## Build zclaw-saas crate
@cargo build -p zclaw-saas
saas-run: ## Start SaaS backend (cargo run)
@cargo run -p zclaw-saas
saas-test: ## Run SaaS unit tests
@cargo test -p zclaw-saas -- --test-threads=1
saas-test-integration: ## Run SaaS integration tests (requires PostgreSQL)
@cargo test -p zclaw-saas -- --ignored --test-threads=1
saas-clippy: ## Run clippy on zclaw-saas
@cargo clippy -p zclaw-saas -- -D warnings
saas-migrate: ## Run database migrations
@cargo run -p zclaw-saas -- --migrate
saas-docker-up: ## Start SaaS services (PostgreSQL + backend)
@docker compose up -d
saas-docker-down: ## Stop SaaS services
@docker compose down
saas-docker-build: ## Build SaaS Docker images
@docker compose build

2
admin/.gitignore vendored
View File

@@ -1,2 +1,4 @@
.next/ .next/
node_modules/ node_modules/
.env.local
.env*.local

View File

@@ -1,4 +1,44 @@
/** @type {import('next').NextConfig} */ /** @type {import('next').NextConfig} */
const nextConfig = {} const nextConfig = {
async headers() {
return [
{
source: '/(.*)',
headers: [
{
key: 'X-Frame-Options',
value: 'DENY',
},
{
key: 'X-Content-Type-Options',
value: 'nosniff',
},
{
key: 'Referrer-Policy',
value: 'strict-origin-when-cross-origin',
},
{
key: 'Content-Security-Policy',
value: [
"default-src 'self'",
"script-src 'self' 'unsafe-eval' 'unsafe-inline'",
"style-src 'self' 'unsafe-inline' https://fonts.googleapis.com",
"font-src 'self' https://fonts.gstatic.com",
"img-src 'self' data: blob:",
"connect-src 'self'",
"frame-ancestors 'none'",
"base-uri 'self'",
"form-action 'self'",
].join('; '),
},
{
key: 'Permissions-Policy',
value: 'camera=(), microphone=(), geolocation=()',
},
],
},
]
},
}
module.exports = nextConfig module.exports = nextConfig

View File

@@ -11,10 +11,10 @@
"dependencies": { "dependencies": {
"@radix-ui/react-dialog": "^1.1.14", "@radix-ui/react-dialog": "^1.1.14",
"@radix-ui/react-select": "^2.2.5", "@radix-ui/react-select": "^2.2.5",
"@radix-ui/react-separator": "^1.1.7",
"@radix-ui/react-switch": "^1.2.5", "@radix-ui/react-switch": "^1.2.5",
"@radix-ui/react-tabs": "^1.1.12", "@radix-ui/react-tabs": "^1.1.12",
"@radix-ui/react-tooltip": "^1.2.7", "@radix-ui/react-tooltip": "^1.2.7",
"@radix-ui/react-separator": "^1.1.7",
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"lucide-react": "^0.484.0", "lucide-react": "^0.484.0",
@@ -22,6 +22,7 @@
"react": "^18.3.1", "react": "^18.3.1",
"react-dom": "^18.3.1", "react-dom": "^18.3.1",
"recharts": "^2.15.3", "recharts": "^2.15.3",
"sonner": "^2.0.7",
"tailwind-merge": "^3.0.2" "tailwind-merge": "^3.0.2"
}, },
"devDependencies": { "devDependencies": {

14
admin/pnpm-lock.yaml generated
View File

@@ -47,6 +47,9 @@ importers:
recharts: recharts:
specifier: ^2.15.3 specifier: ^2.15.3
version: 2.15.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1) version: 2.15.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
sonner:
specifier: ^2.0.7
version: 2.0.7(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
tailwind-merge: tailwind-merge:
specifier: ^3.0.2 specifier: ^3.0.2
version: 3.5.0 version: 3.5.0
@@ -1063,6 +1066,12 @@ packages:
scheduler@0.23.2: scheduler@0.23.2:
resolution: {integrity: sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==} resolution: {integrity: sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==}
sonner@2.0.7:
resolution: {integrity: sha512-W6ZN4p58k8aDKA4XPcx2hpIQXBRAgyiWVkYhT7CvK6D3iAu7xjvVyhQHg2/iaKJZ1XVJ4r7XuwGL+WGEK37i9w==}
peerDependencies:
react: ^18.0.0 || ^19.0.0 || ^19.0.0-rc
react-dom: ^18.0.0 || ^19.0.0 || ^19.0.0-rc
source-map-js@1.2.1: source-map-js@1.2.1:
resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==} resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==}
engines: {node: '>=0.10.0'} engines: {node: '>=0.10.0'}
@@ -2052,6 +2061,11 @@ snapshots:
dependencies: dependencies:
loose-envify: 1.4.0 loose-envify: 1.4.0
sonner@2.0.7(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
dependencies:
react: 18.3.1
react-dom: 18.3.1(react@18.3.1)
source-map-js@1.2.1: {} source-map-js@1.2.1: {}
streamsearch@1.1.0: {} streamsearch@1.1.0: {}

View File

@@ -68,6 +68,13 @@ export default function AccountsPage() {
const [total, setTotal] = useState(0) const [total, setTotal] = useState(0)
const [page, setPage] = useState(1) const [page, setPage] = useState(1)
const [search, setSearch] = useState('') const [search, setSearch] = useState('')
// 搜索 debounce: 输入后 300ms 再触发请求
const [debouncedSearchState, setDebouncedSearchState] = useState('')
useEffect(() => {
const timer = setTimeout(() => setDebouncedSearchState(search), 300)
return () => clearTimeout(timer)
}, [search])
const [roleFilter, setRoleFilter] = useState<string>('all') const [roleFilter, setRoleFilter] = useState<string>('all')
const [statusFilter, setStatusFilter] = useState<string>('all') const [statusFilter, setStatusFilter] = useState<string>('all')
const [loading, setLoading] = useState(true) const [loading, setLoading] = useState(true)
@@ -87,7 +94,7 @@ export default function AccountsPage() {
setError('') setError('')
try { try {
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE } const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
if (search.trim()) params.search = search.trim() if (debouncedSearchState.trim()) params.search = debouncedSearchState.trim()
if (roleFilter !== 'all') params.role = roleFilter if (roleFilter !== 'all') params.role = roleFilter
if (statusFilter !== 'all') params.status = statusFilter if (statusFilter !== 'all') params.status = statusFilter
@@ -103,7 +110,7 @@ export default function AccountsPage() {
} finally { } finally {
setLoading(false) setLoading(false)
} }
}, [page, search, roleFilter, statusFilter]) }, [page, debouncedSearchState, roleFilter, statusFilter])
useEffect(() => { useEffect(() => {
fetchAccounts() fetchAccounts()

View File

@@ -88,6 +88,19 @@ export default function ConfigPage() {
async function handleSave() { async function handleSave() {
if (!editTarget) return if (!editTarget) return
// 表单验证
if (editValue.trim() === '') {
setError('配置值不能为空')
return
}
if (editTarget.value_type === 'number' && isNaN(Number(editValue))) {
setError('请输入有效的数字')
return
}
if (editTarget.value_type === 'boolean' && editValue !== 'true' && editValue !== 'false') {
setError('布尔值只能为 true 或 false')
return
}
setSaving(true) setSaving(true)
try { try {
let parsedValue: string | number | boolean = editValue let parsedValue: string | number | boolean = editValue
@@ -96,7 +109,7 @@ export default function ConfigPage() {
} else if (editTarget.value_type === 'boolean') { } else if (editTarget.value_type === 'boolean') {
parsedValue = editValue === 'true' parsedValue = editValue === 'true'
} }
await api.config.update(editTarget.id, { value: parsedValue }) await api.config.update(editTarget.id, { current_value: parsedValue })
setEditTarget(null) setEditTarget(null)
fetchConfigs(activeTab) fetchConfigs(activeTab)
} catch (err) { } catch (err) {

View File

@@ -0,0 +1,125 @@
'use client'
import { useEffect, useState } from 'react'
import { Monitor, Loader2, RefreshCw } from 'lucide-react'
import { Badge } from '@/components/ui/badge'
import {
Table, TableBody, TableCell, TableHead, TableHeader, TableRow,
} from '@/components/ui/table'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
import type { DeviceInfo } from '@/lib/types'
function formatRelativeTime(dateStr: string): string {
const now = Date.now()
const then = new Date(dateStr).getTime()
const diffMs = now - then
const diffMin = Math.floor(diffMs / 60000)
const diffHour = Math.floor(diffMs / 3600000)
const diffDay = Math.floor(diffMs / 86400000)
if (diffMin < 1) return '刚刚'
if (diffMin < 60) return `${diffMin} 分钟前`
if (diffHour < 24) return `${diffHour} 小时前`
return `${diffDay} 天前`
}
function isOnline(lastSeen: string): boolean {
return Date.now() - new Date(lastSeen).getTime() < 5 * 60 * 1000
}
export default function DevicesPage() {
const [devices, setDevices] = useState<DeviceInfo[]>([])
const [loading, setLoading] = useState(true)
const [error, setError] = useState('')
async function fetchDevices() {
setLoading(true)
setError('')
try {
const res = await api.devices.list()
setDevices(res)
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
else setError('加载失败')
} finally {
setLoading(false)
}
}
useEffect(() => { fetchDevices() }, [])
return (
<div className="space-y-4">
<div className="flex items-center justify-between">
<h2 className="text-lg font-semibold text-foreground"></h2>
<button
onClick={fetchDevices}
disabled={loading}
className="flex items-center gap-2 rounded-md border border-border px-3 py-1.5 text-sm text-muted-foreground hover:bg-muted hover:text-foreground transition-colors cursor-pointer disabled:opacity-50"
>
<RefreshCw className={`h-4 w-4 ${loading ? 'animate-spin' : ''}`} />
</button>
</div>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
</div>
)}
{loading && !devices.length ? (
<div className="flex items-center justify-center py-12">
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
</div>
) : devices.length === 0 ? (
<div className="flex flex-col items-center justify-center py-12 text-muted-foreground">
<Monitor className="h-10 w-10 mb-3" />
<p className="text-sm"></p>
</div>
) : (
<div className="rounded-md border border-border">
<Table>
<TableHeader>
<TableRow>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
<TableHead></TableHead>
</TableRow>
</TableHeader>
<TableBody>
{devices.map((d) => (
<TableRow key={d.id}>
<TableCell className="font-medium">
{d.device_name || d.device_id}
</TableCell>
<TableCell>
<Badge variant="secondary">{d.platform || 'unknown'}</Badge>
</TableCell>
<TableCell className="text-muted-foreground">
{d.app_version || '-'}
</TableCell>
<TableCell>
<Badge variant={isOnline(d.last_seen_at) ? 'success' : 'outline'}>
{isOnline(d.last_seen_at) ? '在线' : '离线'}
</Badge>
</TableCell>
<TableCell className="text-muted-foreground">
{formatRelativeTime(d.last_seen_at)}
</TableCell>
<TableCell className="text-muted-foreground text-xs">
{new Date(d.created_at).toLocaleString('zh-CN')}
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
)}
</div>
)
}

View File

@@ -1,6 +1,6 @@
'use client' 'use client'
import { useState, type ReactNode } from 'react' import { useState, useEffect, type ReactNode } from 'react'
import Link from 'next/link' import Link from 'next/link'
import { usePathname, useRouter } from 'next/navigation' import { usePathname, useRouter } from 'next/navigation'
import { import {
@@ -17,46 +17,71 @@ import {
ChevronLeft, ChevronLeft,
Menu, Menu,
Bell, Bell,
UserCog,
ShieldCheck,
Monitor,
} from 'lucide-react' } from 'lucide-react'
import { AuthGuard, useAuth } from '@/components/auth-guard' import { AuthGuard, useAuth } from '@/components/auth-guard'
import { logout } from '@/lib/auth' import { logout } from '@/lib/auth'
import { cn } from '@/lib/utils' import { cn } from '@/lib/utils'
const navItems = [ const navItems = [
{ href: '/', label: '仪表盘', icon: LayoutDashboard }, { href: '/', label: '仪表盘', icon: LayoutDashboard, permission: null },
{ href: '/accounts', label: '账号管理', icon: Users }, { href: '/accounts', label: '账号管理', icon: Users, permission: 'account:admin' },
{ href: '/providers', label: '服务商', icon: Server }, { href: '/providers', label: '服务商', icon: Server, permission: 'model:admin' },
{ href: '/models', label: '模型管理', icon: Cpu }, { href: '/models', label: '模型管理', icon: Cpu, permission: 'model:admin' },
{ href: '/api-keys', label: 'API 密钥', icon: Key }, { href: '/api-keys', label: 'API 密钥', icon: Key, permission: null },
{ href: '/usage', label: '用量统计', icon: BarChart3 }, { href: '/usage', label: '用量统计', icon: BarChart3, permission: null },
{ href: '/relay', label: '中转任务', icon: ArrowLeftRight }, { href: '/relay', label: '中转任务', icon: ArrowLeftRight, permission: 'relay:admin' },
{ href: '/config', label: '系统配置', icon: Settings }, { href: '/config', label: '系统配置', icon: Settings, permission: 'admin:full' },
{ href: '/logs', label: '操作日志', icon: FileText }, { href: '/logs', label: '操作日志', icon: FileText, permission: 'admin:full' },
{ href: '/profile', label: '个人设置', icon: UserCog, permission: null },
{ href: '/security', label: '安全设置', icon: ShieldCheck, permission: null },
{ href: '/devices', label: '设备管理', icon: Monitor, permission: null },
] ]
function Sidebar({ function Sidebar({
collapsed, collapsed,
onToggle, onToggle,
mobileOpen,
onMobileClose,
}: { }: {
collapsed: boolean collapsed: boolean
onToggle: () => void onToggle: () => void
mobileOpen: boolean
onMobileClose: () => void
}) { }) {
const pathname = usePathname() const pathname = usePathname()
const router = useRouter() const router = useRouter()
const { account } = useAuth() const { account } = useAuth()
// 路由变化时关闭移动端菜单
useEffect(() => {
onMobileClose()
}, [pathname, onMobileClose])
function handleLogout() { function handleLogout() {
logout() logout()
router.replace('/login') router.replace('/login')
} }
return ( return (
<aside <>
className={cn( {/* 移动端 overlay */}
'fixed left-0 top-0 z-40 flex h-screen flex-col border-r border-border bg-card transition-all duration-300', {mobileOpen && (
collapsed ? 'w-16' : 'w-64', <div
className="fixed inset-0 z-40 bg-black/50 lg:hidden"
onClick={onMobileClose}
/>
)} )}
> <aside
className={cn(
'fixed left-0 top-0 z-50 flex h-screen flex-col border-r border-border bg-card transition-all duration-300',
collapsed ? 'w-16' : 'w-64',
'lg:z-40',
mobileOpen ? 'translate-x-0' : '-translate-x-full lg:translate-x-0',
)}
>
{/* Logo */} {/* Logo */}
<div className="flex h-14 items-center border-b border-border px-4"> <div className="flex h-14 items-center border-b border-border px-4">
<Link href="/" className="flex items-center gap-2 cursor-pointer"> <Link href="/" className="flex items-center gap-2 cursor-pointer">
@@ -75,7 +100,15 @@ function Sidebar({
{/* 导航 */} {/* 导航 */}
<nav className="flex-1 overflow-y-auto scrollbar-thin py-2 px-2"> <nav className="flex-1 overflow-y-auto scrollbar-thin py-2 px-2">
<ul className="space-y-1"> <ul className="space-y-1">
{navItems.map((item) => { {navItems
.filter((item) => {
if (!item.permission) return true
if (!account) return false
// super_admin 拥有所有权限
if (account.role === 'super_admin') return true
return account.permissions?.includes(item.permission) ?? false
})
.map((item) => {
const isActive = const isActive =
item.href === '/' item.href === '/'
? pathname === '/' ? pathname === '/'
@@ -119,6 +152,19 @@ function Sidebar({
</button> </button>
</div> </div>
{/* 折叠时显示退出按钮 */}
{collapsed && (
<div className="border-t border-border p-2">
<button
onClick={handleLogout}
className="flex w-full items-center justify-center rounded-md px-3 py-2 text-muted-foreground hover:bg-muted hover:text-destructive transition-colors duration-200 cursor-pointer"
title="退出登录"
>
<LogOut className="h-4 w-4" />
</button>
</div>
)}
{/* 用户信息 */} {/* 用户信息 */}
{!collapsed && ( {!collapsed && (
<div className="border-t border-border p-3"> <div className="border-t border-border p-3">
@@ -145,10 +191,11 @@ function Sidebar({
</div> </div>
)} )}
</aside> </aside>
</>
) )
} }
function Header() { function Header({ children }: { children?: ReactNode }) {
const pathname = usePathname() const pathname = usePathname()
const currentNav = navItems.find( const currentNav = navItems.find(
(item) => (item) =>
@@ -160,7 +207,7 @@ function Header() {
return ( return (
<header className="sticky top-0 z-30 flex h-14 items-center border-b border-border bg-background/80 backdrop-blur-sm px-6"> <header className="sticky top-0 z-30 flex h-14 items-center border-b border-border bg-background/80 backdrop-blur-sm px-6">
{/* 移动端菜单按钮 */} {/* 移动端菜单按钮 */}
<MobileMenuButton /> {children}
{/* 页面标题 */} {/* 页面标题 */}
<h1 className="text-lg font-semibold text-foreground"> <h1 className="text-lg font-semibold text-foreground">
@@ -180,10 +227,10 @@ function Header() {
) )
} }
function MobileMenuButton() { function MobileMenuButton({ onClick }: { onClick: () => void }) {
// Placeholder for mobile menu toggle
return ( return (
<button <button
onClick={onClick}
className="mr-3 rounded-md p-2 text-muted-foreground hover:bg-muted hover:text-foreground transition-colors duration-200 lg:hidden cursor-pointer" className="mr-3 rounded-md p-2 text-muted-foreground hover:bg-muted hover:text-foreground transition-colors duration-200 lg:hidden cursor-pointer"
> >
<Menu className="h-5 w-5" /> <Menu className="h-5 w-5" />
@@ -191,28 +238,68 @@ function MobileMenuButton() {
) )
} }
/** 路由级权限守卫:隐藏导航项但用户直接访问 URL 时拦截 */
function PageGuard({ children }: { children: ReactNode }) {
const pathname = usePathname()
const router = useRouter()
const { account } = useAuth()
const matchedNav = navItems.find((item) =>
item.href === '/' ? pathname === '/' : pathname.startsWith(item.href),
)
if (matchedNav?.permission && account) {
if (account.role !== 'super_admin' && !(account.permissions?.includes(matchedNav.permission) ?? false)) {
return (
<div className="flex flex-1 items-center justify-center">
<div className="text-center space-y-3">
<p className="text-lg font-medium text-muted-foreground"></p>
<p className="text-sm text-muted-foreground">访{matchedNav.label}</p>
<button
onClick={() => router.replace('/')}
className="text-sm text-primary hover:underline cursor-pointer"
>
</button>
</div>
</div>
)
}
}
return <>{children}</>
}
export default function DashboardLayout({ children }: { children: ReactNode }) { export default function DashboardLayout({ children }: { children: ReactNode }) {
const [sidebarCollapsed, setSidebarCollapsed] = useState(false) const [sidebarCollapsed, setSidebarCollapsed] = useState(false)
const [mobileOpen, setMobileOpen] = useState(false)
return ( return (
<AuthGuard> <AuthGuard>
<div className="flex min-h-screen"> <PageGuard>
<div className="flex min-h-screen">
<Sidebar <Sidebar
collapsed={sidebarCollapsed} collapsed={sidebarCollapsed}
onToggle={() => setSidebarCollapsed(!sidebarCollapsed)} onToggle={() => setSidebarCollapsed(!sidebarCollapsed)}
mobileOpen={mobileOpen}
onMobileClose={() => setMobileOpen(false)}
/> />
<div <div
className={cn( className={cn(
'flex flex-1 flex-col transition-all duration-300', 'flex flex-1 flex-col transition-all duration-300',
sidebarCollapsed ? 'ml-16' : 'ml-64', 'ml-0 lg:transition-all',
sidebarCollapsed ? 'lg:ml-16' : 'lg:ml-64',
)} )}
> >
<Header /> <Header>
<MobileMenuButton onClick={() => setMobileOpen(true)} />
</Header>
<main className="flex-1 overflow-auto p-6 scrollbar-thin"> <main className="flex-1 overflow-auto p-6 scrollbar-thin">
{children} {children}
</main> </main>
</div> </div>
</div> </div>
</PageGuard>
</AuthGuard> </AuthGuard>
) )
} }

View File

@@ -108,8 +108,8 @@ export default function ModelsPage() {
const fetchProviders = useCallback(async () => { const fetchProviders = useCallback(async () => {
try { try {
const res = await api.providers.list({ page: 1, page_size: 100 }) const res = await api.providers.list()
setProviders(res.items) setProviders(res)
} catch { } catch {
// ignore // ignore
} }

View File

@@ -35,7 +35,7 @@ import { api } from '@/lib/api-client'
import { formatNumber, formatDate } from '@/lib/utils' import { formatNumber, formatDate } from '@/lib/utils'
import type { import type {
DashboardStats, DashboardStats,
UsageRecord, UsageStats,
OperationLog, OperationLog,
} from '@/lib/types' } from '@/lib/types'
@@ -87,7 +87,7 @@ function StatusBadge({ status }: { status: string }) {
export default function DashboardPage() { export default function DashboardPage() {
const [stats, setStats] = useState<DashboardStats | null>(null) const [stats, setStats] = useState<DashboardStats | null>(null)
const [usageData, setUsageData] = useState<UsageRecord[]>([]) const [usageStats, setUsageStats] = useState<UsageStats | null>(null)
const [recentLogs, setRecentLogs] = useState<OperationLog[]>([]) const [recentLogs, setRecentLogs] = useState<OperationLog[]>([])
const [loading, setLoading] = useState(true) const [loading, setLoading] = useState(true)
const [error, setError] = useState('') const [error, setError] = useState('')
@@ -97,15 +97,17 @@ export default function DashboardPage() {
try { try {
const [statsRes, usageRes, logsRes] = await Promise.allSettled([ const [statsRes, usageRes, logsRes] = await Promise.allSettled([
api.stats.dashboard(), api.stats.dashboard(),
api.usage.daily({ days: 30 }), api.usage.get(),
api.logs.list({ page: 1, page_size: 5 }), api.logs.list({ page: 1, page_size: 5 }),
]) ])
if (statsRes.status === 'fulfilled') setStats(statsRes.value) if (statsRes.status === 'fulfilled') setStats(statsRes.value)
if (usageRes.status === 'fulfilled') setUsageData(usageRes.value) if (usageRes.status === 'fulfilled') setUsageStats(usageRes.value)
if (logsRes.status === 'fulfilled') setRecentLogs(logsRes.value.items) if (logsRes.status === 'fulfilled') setRecentLogs(logsRes.value)
} catch (err) {
setError('加载数据失败,请检查后端服务是否启动') if (statsRes.status === 'rejected' && usageRes.status === 'rejected' && logsRes.status === 'rejected') {
setError('加载数据失败,请检查后端服务是否启动')
}
} finally { } finally {
setLoading(false) setLoading(false)
} }
@@ -140,9 +142,9 @@ export default function DashboardPage() {
) )
} }
const chartData = usageData.map((r) => ({ const chartData = (usageStats?.by_day ?? []).map((r) => ({
day: r.day.slice(5), // MM-DD day: r.date.slice(5), // MM-DD
请求量: r.count, 请求量: r.request_count,
Input: r.input_tokens, Input: r.input_tokens,
Output: r.output_tokens, Output: r.output_tokens,
})) }))

View File

@@ -0,0 +1,154 @@
'use client'
import { useState } from 'react'
import { Lock, Loader2, Eye, EyeOff, Check } from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card'
import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client'
export default function ProfilePage() {
const [oldPassword, setOldPassword] = useState('')
const [newPassword, setNewPassword] = useState('')
const [confirmPassword, setConfirmPassword] = useState('')
const [showOld, setShowOld] = useState(false)
const [showNew, setShowNew] = useState(false)
const [showConfirm, setShowConfirm] = useState(false)
const [saving, setSaving] = useState(false)
const [error, setError] = useState('')
const [success, setSuccess] = useState('')
async function handleSubmit(e: React.FormEvent) {
e.preventDefault()
setError('')
setSuccess('')
if (newPassword.length < 8) {
setError('新密码至少 8 个字符')
return
}
if (newPassword !== confirmPassword) {
setError('两次输入的新密码不一致')
return
}
setSaving(true)
try {
await api.auth.changePassword({ old_password: oldPassword, new_password: newPassword })
setSuccess('密码修改成功')
setOldPassword('')
setNewPassword('')
setConfirmPassword('')
} catch (err) {
if (err instanceof ApiRequestError) {
setError(err.body.message || '修改失败')
} else {
setError('网络错误,请稍后重试')
}
} finally {
setSaving(false)
}
}
return (
<div className="max-w-lg">
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<Lock className="h-5 w-5" />
</CardTitle>
<CardDescription></CardDescription>
</CardHeader>
<CardContent>
<form onSubmit={handleSubmit} className="space-y-4">
<div className="space-y-2">
<Label htmlFor="old-password"></Label>
<div className="relative">
<Input
id="old-password"
type={showOld ? 'text' : 'password'}
value={oldPassword}
onChange={(e) => setOldPassword(e.target.value)}
placeholder="请输入当前密码"
required
/>
<button
type="button"
onClick={() => setShowOld(!showOld)}
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
>
{showOld ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
</button>
</div>
</div>
<div className="space-y-2">
<Label htmlFor="new-password"></Label>
<div className="relative">
<Input
id="new-password"
type={showNew ? 'text' : 'password'}
value={newPassword}
onChange={(e) => setNewPassword(e.target.value)}
placeholder="至少 8 个字符"
required
minLength={8}
/>
<button
type="button"
onClick={() => setShowNew(!showNew)}
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
>
{showNew ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
</button>
</div>
</div>
<div className="space-y-2">
<Label htmlFor="confirm-password"></Label>
<div className="relative">
<Input
id="confirm-password"
type={showConfirm ? 'text' : 'password'}
value={confirmPassword}
onChange={(e) => setConfirmPassword(e.target.value)}
placeholder="再次输入新密码"
required
minLength={8}
/>
<button
type="button"
onClick={() => setShowConfirm(!showConfirm)}
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
>
{showConfirm ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
</button>
</div>
</div>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
{error}
</div>
)}
{success && (
<div className="rounded-md bg-emerald-500/10 border border-emerald-500/20 px-4 py-3 text-sm text-emerald-500 flex items-center gap-2">
<Check className="h-4 w-4" />
{success}
</div>
)}
<Button type="submit" disabled={saving || !oldPassword || !newPassword || !confirmPassword}>
{saving && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</form>
</CardContent>
</Card>
</div>
)
}

View File

@@ -39,7 +39,7 @@ import {
} from '@/components/ui/select' } from '@/components/ui/select'
import { api } from '@/lib/api-client' import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client' import { ApiRequestError } from '@/lib/api-client'
import { formatDate, maskApiKey } from '@/lib/utils' import { formatDate } from '@/lib/utils'
import type { Provider } from '@/lib/types' import type { Provider } from '@/lib/types'
const PAGE_SIZE = 20 const PAGE_SIZE = 20
@@ -49,7 +49,6 @@ interface ProviderForm {
display_name: string display_name: string
base_url: string base_url: string
api_protocol: 'openai' | 'anthropic' api_protocol: 'openai' | 'anthropic'
api_key: string
enabled: boolean enabled: boolean
rate_limit_rpm: string rate_limit_rpm: string
rate_limit_tpm: string rate_limit_tpm: string
@@ -60,7 +59,6 @@ const emptyForm: ProviderForm = {
display_name: '', display_name: '',
base_url: '', base_url: '',
api_protocol: 'openai', api_protocol: 'openai',
api_key: '',
enabled: true, enabled: true,
rate_limit_rpm: '', rate_limit_rpm: '',
rate_limit_tpm: '', rate_limit_tpm: '',
@@ -117,7 +115,6 @@ export default function ProvidersPage() {
display_name: provider.display_name, display_name: provider.display_name,
base_url: provider.base_url, base_url: provider.base_url,
api_protocol: provider.api_protocol, api_protocol: provider.api_protocol,
api_key: provider.api_key || '',
enabled: provider.enabled, enabled: provider.enabled,
rate_limit_rpm: provider.rate_limit_rpm?.toString() || '', rate_limit_rpm: provider.rate_limit_rpm?.toString() || '',
rate_limit_tpm: provider.rate_limit_tpm?.toString() || '', rate_limit_tpm: provider.rate_limit_tpm?.toString() || '',
@@ -134,7 +131,6 @@ export default function ProvidersPage() {
display_name: form.display_name.trim(), display_name: form.display_name.trim(),
base_url: form.base_url.trim(), base_url: form.base_url.trim(),
api_protocol: form.api_protocol, api_protocol: form.api_protocol,
api_key: form.api_key.trim() || undefined,
enabled: form.enabled, enabled: form.enabled,
rate_limit_rpm: form.rate_limit_rpm ? parseInt(form.rate_limit_rpm, 10) : undefined, rate_limit_rpm: form.rate_limit_rpm ? parseInt(form.rate_limit_rpm, 10) : undefined,
rate_limit_tpm: form.rate_limit_tpm ? parseInt(form.rate_limit_tpm, 10) : undefined, rate_limit_tpm: form.rate_limit_tpm ? parseInt(form.rate_limit_tpm, 10) : undefined,
@@ -202,7 +198,6 @@ export default function ProvidersPage() {
<TableHead></TableHead> <TableHead></TableHead>
<TableHead>Base URL</TableHead> <TableHead>Base URL</TableHead>
<TableHead></TableHead> <TableHead></TableHead>
<TableHead>API Key</TableHead>
<TableHead></TableHead> <TableHead></TableHead>
<TableHead>RPM </TableHead> <TableHead>RPM </TableHead>
<TableHead></TableHead> <TableHead></TableHead>
@@ -222,9 +217,6 @@ export default function ProvidersPage() {
{p.api_protocol} {p.api_protocol}
</Badge> </Badge>
</TableCell> </TableCell>
<TableCell className="font-mono text-xs text-muted-foreground">
{maskApiKey(p.api_key)}
</TableCell>
<TableCell> <TableCell>
<Badge variant={p.enabled ? 'success' : 'secondary'}> <Badge variant={p.enabled ? 'success' : 'secondary'}>
{p.enabled ? '是' : '否'} {p.enabled ? '是' : '否'}
@@ -316,15 +308,6 @@ export default function ProvidersPage() {
</SelectContent> </SelectContent>
</Select> </Select>
</div> </div>
<div className="space-y-2">
<Label>API Key</Label>
<Input
type="password"
value={form.api_key}
onChange={(e) => setForm({ ...form, api_key: e.target.value })}
placeholder={editTarget ? '留空则不修改' : 'sk-...'}
/>
</div>
<div className="flex items-center gap-3"> <div className="flex items-center gap-3">
<Switch <Switch
checked={form.enabled} checked={form.enabled}

View File

@@ -2,12 +2,12 @@
import { useEffect, useState, useCallback } from 'react' import { useEffect, useState, useCallback } from 'react'
import { import {
Search,
Loader2, Loader2,
ChevronLeft, ChevronLeft,
ChevronRight, ChevronRight,
ChevronDown, ChevronDown,
ChevronUp, ChevronUp,
RotateCcw,
} from 'lucide-react' } from 'lucide-react'
import { Button } from '@/components/ui/button' import { Button } from '@/components/ui/button'
import { Badge } from '@/components/ui/badge' import { Badge } from '@/components/ui/badge'
@@ -55,6 +55,7 @@ export default function RelayPage() {
const [loading, setLoading] = useState(true) const [loading, setLoading] = useState(true)
const [error, setError] = useState('') const [error, setError] = useState('')
const [expandedId, setExpandedId] = useState<string | null>(null) const [expandedId, setExpandedId] = useState<string | null>(null)
const [retryingId, setRetryingId] = useState<string | null>(null)
const fetchTasks = useCallback(async () => { const fetchTasks = useCallback(async () => {
setLoading(true) setLoading(true)
@@ -83,6 +84,20 @@ export default function RelayPage() {
setExpandedId((prev) => (prev === id ? null : id)) setExpandedId((prev) => (prev === id ? null : id))
} }
async function handleRetry(taskId: string, e: React.MouseEvent) {
e.stopPropagation()
setRetryingId(taskId)
try {
await api.relay.retry(taskId)
fetchTasks()
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message)
else setError('重试失败')
} finally {
setRetryingId(null)
}
}
return ( return (
<div className="space-y-4"> <div className="space-y-4">
{/* 筛选 */} {/* 筛选 */}
@@ -131,6 +146,7 @@ export default function RelayPage() {
<TableHead>Output Tokens</TableHead> <TableHead>Output Tokens</TableHead>
<TableHead></TableHead> <TableHead></TableHead>
<TableHead></TableHead> <TableHead></TableHead>
<TableHead className="text-right"></TableHead>
</TableRow> </TableRow>
</TableHeader> </TableHeader>
<TableBody> <TableBody>
@@ -169,10 +185,27 @@ export default function RelayPage() {
<TableCell className="font-mono text-xs text-muted-foreground"> <TableCell className="font-mono text-xs text-muted-foreground">
{formatDate(task.created_at)} {formatDate(task.created_at)}
</TableCell> </TableCell>
<TableCell className="text-right">
{task.status === 'failed' && (
<Button
variant="ghost"
size="icon"
onClick={(e) => handleRetry(task.id, e)}
disabled={retryingId === task.id}
title="重试"
>
{retryingId === task.id ? (
<Loader2 className="h-4 w-4 animate-spin" />
) : (
<RotateCcw className="h-4 w-4" />
)}
</Button>
)}
</TableCell>
</TableRow> </TableRow>
{expandedId === task.id && ( {expandedId === task.id && (
<TableRow key={`${task.id}-detail`}> <TableRow key={`${task.id}-detail`}>
<TableCell colSpan={10} className="bg-muted/20 px-8 py-4"> <TableCell colSpan={11} className="bg-muted/20 px-8 py-4">
<div className="grid grid-cols-2 gap-4 text-sm"> <div className="grid grid-cols-2 gap-4 text-sm">
<div> <div>
<p className="text-muted-foreground"> ID</p> <p className="text-muted-foreground"> ID</p>

View File

@@ -0,0 +1,203 @@
'use client'
import { useState } from 'react'
import { ShieldCheck, Loader2, Eye, EyeOff, QrCode, Key, AlertTriangle } from 'lucide-react'
import { Button } from '@/components/ui/button'
import { Input } from '@/components/ui/input'
import { Label } from '@/components/ui/label'
import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card'
import { Badge } from '@/components/ui/badge'
import { api } from '@/lib/api-client'
import { useAuth } from '@/components/auth-guard'
import { ApiRequestError } from '@/lib/api-client'
export default function SecurityPage() {
const { account } = useAuth()
const totpEnabled = account?.totp_enabled ?? false
// Setup state
const [step, setStep] = useState<'idle' | 'verify' | 'done'>('idle')
const [otpauthUri, setOtpauthUri] = useState('')
const [secret, setSecret] = useState('')
const [verifyCode, setVerifyCode] = useState('')
const [loading, setLoading] = useState(false)
const [error, setError] = useState('')
// Disable state
const [disablePassword, setDisablePassword] = useState('')
const [showDisablePassword, setShowDisablePassword] = useState(false)
const [disabling, setDisabling] = useState(false)
async function handleSetup() {
setLoading(true)
setError('')
try {
const res = await api.auth.totpSetup()
setOtpauthUri(res.otpauth_uri)
setSecret(res.secret)
setStep('verify')
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message || '获取密钥失败')
else setError('网络错误')
} finally {
setLoading(false)
}
}
async function handleVerify() {
if (verifyCode.length !== 6) {
setError('请输入 6 位验证码')
return
}
setLoading(true)
setError('')
try {
await api.auth.totpVerify({ code: verifyCode })
setStep('done')
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message || '验证失败')
else setError('网络错误')
} finally {
setLoading(false)
}
}
async function handleDisable() {
if (!disablePassword) {
setError('请输入密码以确认禁用')
return
}
setDisabling(true)
setError('')
try {
await api.auth.totpDisable({ password: disablePassword })
setDisablePassword('')
window.location.reload()
} catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message || '禁用失败')
else setError('网络错误')
} finally {
setDisabling(false)
}
}
return (
<div className="max-w-lg space-y-6">
{/* TOTP 状态 */}
<Card>
<CardHeader>
<CardTitle className="flex items-center gap-2">
<ShieldCheck className="h-5 w-5" />
(TOTP)
</CardTitle>
<CardDescription>
使 Google Authenticator
</CardDescription>
</CardHeader>
<CardContent>
<div className="flex items-center gap-3 mb-4">
<span className="text-sm text-muted-foreground">:</span>
<Badge variant={totpEnabled ? 'success' : 'secondary'}>
{totpEnabled ? '已启用' : '未启用'}
</Badge>
</div>
{error && (
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive mb-4">
{error}
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer"></button>
</div>
)}
{/* 未启用: 设置流程 */}
{!totpEnabled && step === 'idle' && (
<Button onClick={handleSetup} disabled={loading}>
{loading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
<Key className="mr-2 h-4 w-4" />
</Button>
)}
{!totpEnabled && step === 'verify' && (
<div className="space-y-4">
<div className="rounded-md border border-border p-4 space-y-3">
<div className="flex items-center gap-2 text-sm font-medium">
<QrCode className="h-4 w-4" />
1: 扫描二维码或手动输入密钥
</div>
<div className="bg-muted rounded-md p-3 font-mono text-xs break-all">
{otpauthUri}
</div>
<div className="space-y-1">
<p className="text-xs text-muted-foreground">:</p>
<p className="font-mono text-sm font-medium select-all">{secret}</p>
</div>
</div>
<div className="space-y-2">
<Label>
2: 输入 6
</Label>
<Input
value={verifyCode}
onChange={(e) => setVerifyCode(e.target.value.replace(/\D/g, '').slice(0, 6))}
placeholder="请输入应用中显示的 6 位数字"
maxLength={6}
className="font-mono tracking-widest text-center"
/>
</div>
<div className="flex gap-2">
<Button variant="outline" onClick={() => { setStep('idle'); setVerifyCode('') }}>
</Button>
<Button onClick={handleVerify} disabled={loading || verifyCode.length !== 6}>
{loading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</div>
</div>
)}
{!totpEnabled && step === 'done' && (
<div className="rounded-md bg-emerald-500/10 border border-emerald-500/20 p-4 text-sm text-emerald-500">
</div>
)}
{/* 已启用: 禁用流程 */}
{totpEnabled && (
<div className="space-y-4">
<div className="rounded-md bg-amber-500/10 border border-amber-500/20 p-3 flex items-start gap-2 text-sm text-amber-600">
<AlertTriangle className="h-4 w-4 mt-0.5 shrink-0" />
<span></span>
</div>
<div className="space-y-2">
<Label></Label>
<div className="relative">
<Input
type={showDisablePassword ? 'text' : 'password'}
value={disablePassword}
onChange={(e) => setDisablePassword(e.target.value)}
placeholder="请输入当前密码"
/>
<button
type="button"
onClick={() => setShowDisablePassword(!showDisablePassword)}
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
>
{showDisablePassword ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
</button>
</div>
</div>
<Button variant="destructive" onClick={handleDisable} disabled={disabling || !disablePassword}>
{disabling && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
</Button>
</div>
)}
</CardContent>
</Card>
</div>
)
}

View File

@@ -25,12 +25,11 @@ import {
import { api } from '@/lib/api-client' import { api } from '@/lib/api-client'
import { ApiRequestError } from '@/lib/api-client' import { ApiRequestError } from '@/lib/api-client'
import { formatNumber } from '@/lib/utils' import { formatNumber } from '@/lib/utils'
import type { UsageRecord, UsageByModel } from '@/lib/types' import type { UsageStats } from '@/lib/types'
export default function UsagePage() { export default function UsagePage() {
const [days, setDays] = useState(7) const [days, setDays] = useState(7)
const [dailyData, setDailyData] = useState<UsageRecord[]>([]) const [usageStats, setUsageStats] = useState<UsageStats | null>(null)
const [modelData, setModelData] = useState<UsageByModel[]>([])
const [loading, setLoading] = useState(true) const [loading, setLoading] = useState(true)
const [error, setError] = useState('') const [error, setError] = useState('')
@@ -38,13 +37,11 @@ export default function UsagePage() {
setLoading(true) setLoading(true)
setError('') setError('')
try { try {
const [dailyRes, modelRes] = await Promise.allSettled([ const from = new Date()
api.usage.daily({ days }), from.setDate(from.getDate() - days)
api.usage.byModel({ days }), const fromStr = from.toISOString().slice(0, 10)
]) const res = await api.usage.get({ from: fromStr })
if (dailyRes.status === 'fulfilled') setDailyData(dailyRes.value) setUsageStats(res)
else throw new Error('Failed to fetch daily usage')
if (modelRes.status === 'fulfilled') setModelData(modelRes.value)
} catch (err) { } catch (err) {
if (err instanceof ApiRequestError) setError(err.body.message) if (err instanceof ApiRequestError) setError(err.body.message)
else setError('加载数据失败') else setError('加载数据失败')
@@ -57,22 +54,24 @@ export default function UsagePage() {
fetchData() fetchData()
}, [fetchData]) }, [fetchData])
const lineChartData = dailyData.map((r) => ({ const byDay = usageStats?.by_day ?? []
day: r.day.slice(5),
const lineChartData = byDay.map((r) => ({
day: r.date.slice(5),
Input: r.input_tokens, Input: r.input_tokens,
Output: r.output_tokens, Output: r.output_tokens,
})) }))
const barChartData = modelData.map((r) => ({ const barChartData = (usageStats?.by_model ?? []).map((r) => ({
model: r.model_id, model: r.model_id,
请求量: r.count, 请求量: r.request_count,
Input: r.input_tokens, Input: r.input_tokens,
Output: r.output_tokens, Output: r.output_tokens,
})) }))
const totalInput = dailyData.reduce((s, r) => s + r.input_tokens, 0) const totalInput = byDay.reduce((s, r) => s + r.input_tokens, 0)
const totalOutput = dailyData.reduce((s, r) => s + r.output_tokens, 0) const totalOutput = byDay.reduce((s, r) => s + r.output_tokens, 0)
const totalRequests = dailyData.reduce((s, r) => s + r.count, 0) const totalRequests = byDay.reduce((s, r) => s + r.request_count, 0)
if (loading) { if (loading) {
return ( return (

View File

@@ -1,4 +1,5 @@
import type { Metadata } from 'next' import type { Metadata } from 'next'
import { Toaster } from 'sonner'
import './globals.css' import './globals.css'
export const metadata: Metadata = { export const metadata: Metadata = {
@@ -21,6 +22,7 @@ export default function RootLayout({
</head> </head>
<body className="min-h-screen bg-background font-sans antialiased"> <body className="min-h-screen bg-background font-sans antialiased">
{children} {children}
<Toaster richColors position="top-right" />
</body> </body>
</html> </html>
) )

View File

@@ -2,7 +2,7 @@
import { useState, type FormEvent } from 'react' import { useState, type FormEvent } from 'react'
import { useRouter } from 'next/navigation' import { useRouter } from 'next/navigation'
import { Lock, User, Loader2, Eye, EyeOff } from 'lucide-react' import { Lock, User, Loader2, Eye, EyeOff, ShieldCheck } from 'lucide-react'
import { api } from '@/lib/api-client' import { api } from '@/lib/api-client'
import { login } from '@/lib/auth' import { login } from '@/lib/auth'
import { ApiRequestError } from '@/lib/api-client' import { ApiRequestError } from '@/lib/api-client'
@@ -12,7 +12,8 @@ export default function LoginPage() {
const [username, setUsername] = useState('') const [username, setUsername] = useState('')
const [password, setPassword] = useState('') const [password, setPassword] = useState('')
const [showPassword, setShowPassword] = useState(false) const [showPassword, setShowPassword] = useState(false)
const [remember, setRemember] = useState(false) const [totpCode, setTotpCode] = useState('')
const [showTotp, setShowTotp] = useState(false)
const [loading, setLoading] = useState(false) const [loading, setLoading] = useState(false)
const [error, setError] = useState('') const [error, setError] = useState('')
@@ -31,12 +32,22 @@ export default function LoginPage() {
setLoading(true) setLoading(true)
try { try {
const res = await api.auth.login({ username: username.trim(), password }) const res = await api.auth.login({
username: username.trim(),
password,
totp_code: showTotp ? totpCode.trim() || undefined : undefined,
})
login(res.token, res.account) login(res.token, res.account)
router.replace('/') router.replace('/')
} catch (err) { } catch (err) {
if (err instanceof ApiRequestError) { if (err instanceof ApiRequestError) {
setError(err.body.message || '登录失败,请检查用户名和密码') // 检测 TOTP 错误码,自动显示验证码输入框
if (err.body.error === 'totp_required' || err.body.message?.includes('双因素认证') || err.body.message?.includes('TOTP')) {
setShowTotp(true)
setError(err.body.message || '此账号已启用双因素认证,请输入验证码')
} else {
setError(err.body.message || '登录失败,请检查用户名和密码')
}
} else { } else {
setError('网络错误,请稍后重试') setError('网络错误,请稍后重试')
} }
@@ -152,22 +163,30 @@ export default function LoginPage() {
</div> </div>
</div> </div>
{/* 记住我 */} {/* TOTP 验证码 (仅账号启用 2FA 时显示) */}
<div className="flex items-center gap-2"> {showTotp && (
<input <div className="space-y-2">
id="remember" <label
type="checkbox" htmlFor="totp_code"
checked={remember} className="text-sm font-medium text-foreground"
onChange={(e) => setRemember(e.target.checked)} >
className="h-4 w-4 rounded border-input bg-transparent accent-primary cursor-pointer" <span className="inline-flex items-center gap-1">
/> <ShieldCheck className="h-3.5 w-3.5" />
<label
htmlFor="remember" </span>
className="text-sm text-muted-foreground cursor-pointer select-none" </label>
> <input
id="totp_code"
</label> type="text"
</div> placeholder="请输入 6 位验证码"
value={totpCode}
onChange={(e) => setTotpCode(e.target.value.replace(/\D/g, '').slice(0, 6))}
className="flex h-10 w-full rounded-md border border-input bg-transparent px-3 py-2 text-sm tracking-widest text-center font-mono shadow-sm transition-colors duration-200 placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
maxLength={6}
autoFocus
/>
</div>
)}
{/* 错误信息 */} {/* 错误信息 */}
{error && ( {error && (

View File

@@ -1,29 +1,71 @@
'use client' 'use client'
import { useEffect, useState, type ReactNode } from 'react' import { createContext, useContext, useEffect, useState, useCallback, type ReactNode } from 'react'
import { useRouter } from 'next/navigation' import { useRouter } from 'next/navigation'
import { isAuthenticated, getAccount } from '@/lib/auth' import { isAuthenticated, getAccount, logout as clearCredentials, scheduleTokenRefresh, cancelTokenRefresh, setOnSessionExpired } from '@/lib/auth'
import { api } from '@/lib/api-client'
import type { AccountPublic } from '@/lib/types' import type { AccountPublic } from '@/lib/types'
interface AuthContextValue {
account: AccountPublic | null
loading: boolean
refresh: () => Promise<void>
}
const AuthContext = createContext<AuthContextValue>({
account: null,
loading: true,
refresh: async () => {},
})
export function useAuth() {
return useContext(AuthContext)
}
interface AuthGuardProps { interface AuthGuardProps {
children: ReactNode children: ReactNode
} }
export function AuthGuard({ children }: AuthGuardProps) { export function AuthGuard({ children }: AuthGuardProps) {
const router = useRouter() const router = useRouter()
const [authorized, setAuthorized] = useState(false)
const [account, setAccount] = useState<AccountPublic | null>(null) const [account, setAccount] = useState<AccountPublic | null>(null)
const [loading, setLoading] = useState(true)
const refresh = useCallback(async () => {
try {
const me = await api.auth.me()
setAccount(me)
} catch {
clearCredentials()
router.replace('/login')
}
}, [router])
useEffect(() => { useEffect(() => {
if (!isAuthenticated()) { if (!isAuthenticated()) {
router.replace('/login') router.replace('/login')
return return
} }
setAccount(getAccount()) // 验证 token 有效性并获取最新账号信息
setAuthorized(true) refresh().finally(() => setLoading(false))
}, [router, refresh])
// Set up proactive token refresh with session-expired handler
useEffect(() => {
const handleSessionExpired = () => {
clearCredentials()
router.replace('/login')
}
setOnSessionExpired(handleSessionExpired)
scheduleTokenRefresh()
return () => {
cancelTokenRefresh()
setOnSessionExpired(null)
}
}, [router]) }, [router])
if (!authorized) { if (loading) {
return ( return (
<div className="flex h-screen w-screen items-center justify-center bg-background"> <div className="flex h-screen w-screen items-center justify-center bg-background">
<div className="h-8 w-8 animate-spin rounded-full border-2 border-primary border-t-transparent" /> <div className="h-8 w-8 animate-spin rounded-full border-2 border-primary border-t-transparent" />
@@ -31,18 +73,13 @@ export function AuthGuard({ children }: AuthGuardProps) {
) )
} }
return <>{children}</> if (!account) {
} return null
}
export function useAuth() {
const [account, setAccount] = useState<AccountPublic | null>(null) return (
const [loading, setLoading] = useState(true) <AuthContext.Provider value={{ account, loading, refresh }}>
{children}
useEffect(() => { </AuthContext.Provider>
const acc = getAccount() )
setAccount(acc)
setLoading(false)
}, [])
return { account, loading, isAuthenticated: isAuthenticated() }
} }

View File

@@ -2,13 +2,15 @@
// ZCLAW SaaS Admin — 类型化 HTTP 客户端 // ZCLAW SaaS Admin — 类型化 HTTP 客户端
// ============================================================ // ============================================================
import { getToken, logout } from './auth' import { getToken, logout, refreshToken } from './auth'
import { toast } from 'sonner'
import type { import type {
AccountPublic, AccountPublic,
ApiError, ApiError,
ConfigItem, ConfigItem,
CreateTokenRequest, CreateTokenRequest,
DashboardStats, DashboardStats,
DeviceInfo,
LoginRequest, LoginRequest,
LoginResponse, LoginResponse,
Model, Model,
@@ -18,7 +20,7 @@ import type {
RelayTask, RelayTask,
TokenInfo, TokenInfo,
UsageByModel, UsageByModel,
UsageRecord, UsageStats,
} from './types' } from './types'
// ── 错误类 ──────────────────────────────────────────────── // ── 错误类 ────────────────────────────────────────────────
@@ -36,6 +38,7 @@ export class ApiRequestError extends Error {
// ── 基础请求 ────────────────────────────────────────────── // ── 基础请求 ──────────────────────────────────────────────
const BASE_URL = process.env.NEXT_PUBLIC_SAAS_API_URL || 'http://localhost:8080' const BASE_URL = process.env.NEXT_PUBLIC_SAAS_API_URL || 'http://localhost:8080'
const API_PREFIX = '/api/v1'
async function request<T>( async function request<T>(
method: string, method: string,
@@ -50,13 +53,34 @@ async function request<T>(
headers['Authorization'] = `Bearer ${token}` headers['Authorization'] = `Bearer ${token}`
} }
const res = await fetch(`${BASE_URL}${path}`, { const res = await fetch(`${BASE_URL}${API_PREFIX}${path}`, {
method, method,
headers, headers,
body: body ? JSON.stringify(body) : undefined, body: body ? JSON.stringify(body) : undefined,
}) })
if (res.status === 401) { if (res.status === 401) {
// 尝试刷新 token 后重试
try {
const newToken = await refreshToken()
headers['Authorization'] = `Bearer ${newToken}`
const retryRes = await fetch(`${BASE_URL}${API_PREFIX}${path}`, {
method,
headers,
body: body ? JSON.stringify(body) : undefined,
})
if (retryRes.ok || retryRes.status === 204) {
return retryRes.status === 204 ? (undefined as T) : retryRes.json()
}
// 刷新成功但重试仍失败,走正常错误处理
if (!retryRes.ok) {
let errorBody: ApiError
try { errorBody = await retryRes.json() } catch { errorBody = { error: 'unknown', message: `请求失败 (${retryRes.status})` } }
throw new ApiRequestError(retryRes.status, errorBody)
}
} catch {
// 刷新失败,执行登出
}
logout() logout()
if (typeof window !== 'undefined') { if (typeof window !== 'undefined') {
window.location.href = '/login' window.location.href = '/login'
@@ -71,6 +95,9 @@ async function request<T>(
} catch { } catch {
errorBody = { error: 'unknown', message: `请求失败 (${res.status})` } errorBody = { error: 'unknown', message: `请求失败 (${res.status})` }
} }
if (typeof window !== 'undefined') {
toast.error(errorBody.message || `请求失败 (${res.status})`)
}
throw new ApiRequestError(res.status, errorBody) throw new ApiRequestError(res.status, errorBody)
} }
@@ -88,7 +115,7 @@ export const api = {
// ── 认证 ────────────────────────────────────────────── // ── 认证 ──────────────────────────────────────────────
auth: { auth: {
async login(data: LoginRequest): Promise<LoginResponse> { async login(data: LoginRequest): Promise<LoginResponse> {
return request<LoginResponse>('POST', '/api/auth/login', data) return request<LoginResponse>('POST', '/auth/login', data)
}, },
async register(data: { async register(data: {
@@ -97,11 +124,27 @@ export const api = {
email: string email: string
display_name?: string display_name?: string
}): Promise<LoginResponse> { }): Promise<LoginResponse> {
return request<LoginResponse>('POST', '/api/auth/register', data) return request<LoginResponse>('POST', '/auth/register', data)
}, },
async me(): Promise<AccountPublic> { async me(): Promise<AccountPublic> {
return request<AccountPublic>('GET', '/api/auth/me') return request<AccountPublic>('GET', '/auth/me')
},
async changePassword(data: { old_password: string; new_password: string }): Promise<void> {
return request<void>('PUT', '/auth/password', data)
},
async totpSetup(): Promise<{ otpauth_uri: string; secret: string; issuer: string }> {
return request<{ otpauth_uri: string; secret: string; issuer: string }>('POST', '/auth/totp/setup')
},
async totpVerify(data: { code: string }): Promise<void> {
return request<void>('POST', '/auth/totp/verify', data)
},
async totpDisable(data: { password: string }): Promise<void> {
return request<void>('POST', '/auth/totp/disable', data)
}, },
}, },
@@ -115,25 +158,25 @@ export const api = {
status?: string status?: string
}): Promise<PaginatedResponse<AccountPublic>> { }): Promise<PaginatedResponse<AccountPublic>> {
const qs = buildQueryString(params) const qs = buildQueryString(params)
return request<PaginatedResponse<AccountPublic>>('GET', `/api/accounts${qs}`) return request<PaginatedResponse<AccountPublic>>('GET', `/accounts${qs}`)
}, },
async get(id: string): Promise<AccountPublic> { async get(id: string): Promise<AccountPublic> {
return request<AccountPublic>('GET', `/api/accounts/${id}`) return request<AccountPublic>('GET', `/accounts/${id}`)
}, },
async update( async update(
id: string, id: string,
data: Partial<Pick<AccountPublic, 'display_name' | 'email' | 'role'>>, data: Partial<Pick<AccountPublic, 'display_name' | 'email' | 'role'>>,
): Promise<AccountPublic> { ): Promise<AccountPublic> {
return request<AccountPublic>('PATCH', `/api/accounts/${id}`, data) return request<AccountPublic>('PUT', `/accounts/${id}`, data)
}, },
async updateStatus( async updateStatus(
id: string, id: string,
data: { status: AccountPublic['status'] }, data: { status: AccountPublic['status'] },
): Promise<void> { ): Promise<void> {
return request<void>('PATCH', `/api/accounts/${id}/status`, data) return request<void>('PATCH', `/accounts/${id}/status`, data)
}, },
}, },
@@ -144,22 +187,26 @@ export const api = {
page_size?: number page_size?: number
}): Promise<PaginatedResponse<Provider>> { }): Promise<PaginatedResponse<Provider>> {
const qs = buildQueryString(params) const qs = buildQueryString(params)
return request<PaginatedResponse<Provider>>('GET', `/api/providers${qs}`) return request<PaginatedResponse<Provider>>('GET', `/providers${qs}`)
},
async get(id: string): Promise<Provider> {
return request<Provider>('GET', `/providers/${id}`)
}, },
async create(data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>): Promise<Provider> { async create(data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>): Promise<Provider> {
return request<Provider>('POST', '/api/providers', data) return request<Provider>('POST', '/providers', data)
}, },
async update( async update(
id: string, id: string,
data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>, data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>,
): Promise<Provider> { ): Promise<Provider> {
return request<Provider>('PATCH', `/api/providers/${id}`, data) return request<Provider>('PUT', `/providers/${id}`, data)
}, },
async delete(id: string): Promise<void> { async delete(id: string): Promise<void> {
return request<void>('DELETE', `/api/providers/${id}`) return request<void>('DELETE', `/providers/${id}`)
}, },
}, },
@@ -171,19 +218,23 @@ export const api = {
provider_id?: string provider_id?: string
}): Promise<PaginatedResponse<Model>> { }): Promise<PaginatedResponse<Model>> {
const qs = buildQueryString(params) const qs = buildQueryString(params)
return request<PaginatedResponse<Model>>('GET', `/api/models${qs}`) return request<PaginatedResponse<Model>>('GET', `/models${qs}`)
},
async get(id: string): Promise<Model> {
return request<Model>('GET', `/models/${id}`)
}, },
async create(data: Partial<Omit<Model, 'id'>>): Promise<Model> { async create(data: Partial<Omit<Model, 'id'>>): Promise<Model> {
return request<Model>('POST', '/api/models', data) return request<Model>('POST', '/models', data)
}, },
async update(id: string, data: Partial<Omit<Model, 'id'>>): Promise<Model> { async update(id: string, data: Partial<Omit<Model, 'id'>>): Promise<Model> {
return request<Model>('PATCH', `/api/models/${id}`, data) return request<Model>('PUT', `/models/${id}`, data)
}, },
async delete(id: string): Promise<void> { async delete(id: string): Promise<void> {
return request<void>('DELETE', `/api/models/${id}`) return request<void>('DELETE', `/models/${id}`)
}, },
}, },
@@ -194,28 +245,23 @@ export const api = {
page_size?: number page_size?: number
}): Promise<PaginatedResponse<TokenInfo>> { }): Promise<PaginatedResponse<TokenInfo>> {
const qs = buildQueryString(params) const qs = buildQueryString(params)
return request<PaginatedResponse<TokenInfo>>('GET', `/api/tokens${qs}`) return request<PaginatedResponse<TokenInfo>>('GET', `/tokens${qs}`)
}, },
async create(data: CreateTokenRequest): Promise<TokenInfo> { async create(data: CreateTokenRequest): Promise<TokenInfo> {
return request<TokenInfo>('POST', '/api/tokens', data) return request<TokenInfo>('POST', '/tokens', data)
}, },
async revoke(id: string): Promise<void> { async revoke(id: string): Promise<void> {
return request<void>('DELETE', `/api/tokens/${id}`) return request<void>('DELETE', `/tokens/${id}`)
}, },
}, },
// ── 用量统计 ────────────────────────────────────────── // ── 用量统计 ──────────────────────────────────────────
usage: { usage: {
async daily(params?: { days?: number }): Promise<UsageRecord[]> { async get(params?: { from?: string; to?: string; provider_id?: string; model_id?: string }): Promise<UsageStats> {
const qs = buildQueryString(params) const qs = buildQueryString(params)
return request<UsageRecord[]>('GET', `/api/usage/daily${qs}`) return request<UsageStats>('GET', `/usage${qs}`)
},
async byModel(params?: { days?: number }): Promise<UsageByModel[]> {
const qs = buildQueryString(params)
return request<UsageByModel[]>('GET', `/api/usage/by-model${qs}`)
}, },
}, },
@@ -227,11 +273,15 @@ export const api = {
status?: string status?: string
}): Promise<PaginatedResponse<RelayTask>> { }): Promise<PaginatedResponse<RelayTask>> {
const qs = buildQueryString(params) const qs = buildQueryString(params)
return request<PaginatedResponse<RelayTask>>('GET', `/api/relay/tasks${qs}`) return request<PaginatedResponse<RelayTask>>('GET', `/relay/tasks${qs}`)
}, },
async get(id: string): Promise<RelayTask> { async get(id: string): Promise<RelayTask> {
return request<RelayTask>('GET', `/api/relay/tasks/${id}`) return request<RelayTask>('GET', `/relay/tasks/${id}`)
},
async retry(id: string): Promise<void> {
return request<void>('POST', `/relay/tasks/${id}/retry`)
}, },
}, },
@@ -241,11 +291,11 @@ export const api = {
category?: string category?: string
}): Promise<ConfigItem[]> { }): Promise<ConfigItem[]> {
const qs = buildQueryString(params) const qs = buildQueryString(params)
return request<ConfigItem[]>('GET', `/api/config${qs}`) return request<ConfigItem[]>('GET', `/config/items${qs}`)
}, },
async update(id: string, data: { value: string | number | boolean }): Promise<ConfigItem> { async update(id: string, data: { current_value: string | number | boolean }): Promise<ConfigItem> {
return request<ConfigItem>('PATCH', `/api/config/${id}`, data) return request<ConfigItem>('PUT', `/config/items/${id}`, data)
}, },
}, },
@@ -255,16 +305,29 @@ export const api = {
page?: number page?: number
page_size?: number page_size?: number
action?: string action?: string
}): Promise<PaginatedResponse<OperationLog>> { }): Promise<OperationLog[]> {
const qs = buildQueryString(params) const qs = buildQueryString(params)
return request<PaginatedResponse<OperationLog>>('GET', `/api/logs${qs}`) return request<OperationLog[]>('GET', `/logs/operations${qs}`)
}, },
}, },
// ── 仪表盘 ──────────────────────────────────────────── // ── 仪表盘 ────────────────────────────────────────────
stats: { stats: {
async dashboard(): Promise<DashboardStats> { async dashboard(): Promise<DashboardStats> {
return request<DashboardStats>('GET', '/api/stats/dashboard') return request<DashboardStats>('GET', '/stats/dashboard')
},
},
// ── 设备管理 ──────────────────────────────────────────
devices: {
async list(): Promise<DeviceInfo[]> {
return request<DeviceInfo[]>('GET', '/devices')
},
async register(data: { device_id: string; device_name?: string; platform?: string; app_version?: string }) {
return request<{ ok: boolean; device_id: string }>('POST', '/devices/register', data)
},
async heartbeat(data: { device_id: string }) {
return request<{ ok: boolean }>('POST', '/devices/heartbeat', data)
}, },
}, },
} }

View File

@@ -2,21 +2,74 @@
// ZCLAW SaaS Admin — JWT Token 管理 // ZCLAW SaaS Admin — JWT Token 管理
// ============================================================ // ============================================================
import type { AccountPublic } from './types' import type { AccountPublic, LoginResponse } from './types'
const TOKEN_KEY = 'zclaw_admin_token' const TOKEN_KEY = 'zclaw_admin_token'
const ACCOUNT_KEY = 'zclaw_admin_account' const ACCOUNT_KEY = 'zclaw_admin_account'
/** 保存登录凭证 */ // ── JWT 辅助函数 ────────────────────────────────────────────
interface JwtPayload {
exp?: number
iat?: number
sub?: string
}
/**
* Decode a JWT payload without verifying the signature.
* Returns the parsed JSON payload, or null if the token is malformed.
*/
function decodeJwtPayload<T = Record<string, unknown>>(token: string): T | null {
try {
const parts = token.split('.')
if (parts.length !== 3) return null
const base64 = parts[1].replace(/-/g, '+').replace(/_/g, '/')
const json = decodeURIComponent(
atob(base64)
.split('')
.map((c) => '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2))
.join(''),
)
return JSON.parse(json) as T
} catch {
return null
}
}
/**
* Calculate the delay (ms) until 80% of the token's remaining lifetime
* has elapsed. Returns null if the token is already past that point.
*/
function getRefreshDelay(exp: number): number | null {
const now = Math.floor(Date.now() / 1000)
const totalLifetime = exp - now
if (totalLifetime <= 0) return null
const refreshAt = now + Math.floor(totalLifetime * 0.8)
const delayMs = (refreshAt - now) * 1000
return delayMs > 5000 ? delayMs : 5000
}
// ── 定时刷新状态 ────────────────────────────────────────────
let refreshTimerId: ReturnType<typeof setTimeout> | null = null
let visibilityHandler: (() => void) | null = null
let sessionExpiredCallback: (() => void) | null = null
// ── 凭证操作 ────────────────────────────────────────────────
/** 保存登录凭证并启动自动刷新 */
export function login(token: string, account: AccountPublic): void { export function login(token: string, account: AccountPublic): void {
if (typeof window === 'undefined') return if (typeof window === 'undefined') return
localStorage.setItem(TOKEN_KEY, token) localStorage.setItem(TOKEN_KEY, token)
localStorage.setItem(ACCOUNT_KEY, JSON.stringify(account)) localStorage.setItem(ACCOUNT_KEY, JSON.stringify(account))
scheduleTokenRefresh()
} }
/** 清除登录凭证 */ /** 清除登录凭证并停止自动刷新 */
export function logout(): void { export function logout(): void {
if (typeof window === 'undefined') return if (typeof window === 'undefined') return
cancelTokenRefresh()
localStorage.removeItem(TOKEN_KEY) localStorage.removeItem(TOKEN_KEY)
localStorage.removeItem(ACCOUNT_KEY) localStorage.removeItem(ACCOUNT_KEY)
} }
@@ -43,3 +96,121 @@ export function getAccount(): AccountPublic | null {
export function isAuthenticated(): boolean { export function isAuthenticated(): boolean {
return !!getToken() return !!getToken()
} }
/** 尝试刷新 token成功则更新 localStorage 并返回新 token */
export async function refreshToken(): Promise<string> {
const res = await fetch(
`${process.env.NEXT_PUBLIC_SAAS_API_URL || 'http://localhost:8080'}/api/v1/auth/refresh`,
{
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${getToken()}`,
},
},
)
if (!res.ok) {
throw new Error('Token 刷新失败')
}
const data: LoginResponse = await res.json()
login(data.token, data.account)
return data.token
}
// ── 自动刷新调度 ────────────────────────────────────────────
/**
* Register a callback invoked when the proactive token refresh fails.
* The caller should use this to trigger a logout/redirect flow.
*/
export function setOnSessionExpired(handler: (() => void) | null): void {
sessionExpiredCallback = handler
}
/**
* Schedule a proactive token refresh at 80% of the token's remaining lifetime.
* Also registers a visibilitychange listener to re-check when the tab regains focus.
*/
export function scheduleTokenRefresh(): void {
cancelTokenRefresh()
const token = getToken()
if (!token) return
const payload = decodeJwtPayload<JwtPayload>(token)
if (!payload?.exp) return
const delay = getRefreshDelay(payload.exp)
if (delay === null) {
attemptTokenRefresh()
return
}
refreshTimerId = setTimeout(() => {
attemptTokenRefresh()
}, delay)
if (typeof document !== 'undefined' && !visibilityHandler) {
visibilityHandler = () => {
if (document.visibilityState === 'visible') {
checkAndRefreshToken()
}
}
document.addEventListener('visibilitychange', visibilityHandler)
}
}
/**
* Cancel any pending token refresh timer and remove the visibility listener.
*/
export function cancelTokenRefresh(): void {
if (refreshTimerId !== null) {
clearTimeout(refreshTimerId)
refreshTimerId = null
}
if (visibilityHandler !== null && typeof document !== 'undefined') {
document.removeEventListener('visibilitychange', visibilityHandler)
visibilityHandler = null
}
}
/**
* Check if the current token is close to expiry and refresh if needed.
* Called on visibility change to handle clock skew / long background tabs.
*/
function checkAndRefreshToken(): void {
const token = getToken()
if (!token) return
const payload = decodeJwtPayload<JwtPayload>(token)
if (!payload?.exp) return
const now = Math.floor(Date.now() / 1000)
const remaining = payload.exp - now
if (remaining <= 0) {
attemptTokenRefresh()
return
}
const delay = getRefreshDelay(payload.exp)
if (delay !== null && delay < 60_000) {
attemptTokenRefresh()
}
}
/**
* Attempt to refresh the token. On success, the new token is persisted via
* login() which also reschedules the next refresh. On failure, invoke the
* session-expired callback.
*/
async function attemptTokenRefresh(): Promise<void> {
try {
await refreshToken()
} catch {
cancelTokenRefresh()
if (sessionExpiredCallback) {
sessionExpiredCallback()
}
}
}

View File

@@ -9,6 +9,7 @@ export interface AccountPublic {
email: string email: string
display_name: string display_name: string
role: 'super_admin' | 'admin' | 'user' role: 'super_admin' | 'admin' | 'user'
permissions: string[]
status: 'active' | 'disabled' | 'suspended' status: 'active' | 'disabled' | 'suspended'
totp_enabled: boolean totp_enabled: boolean
created_at: string created_at: string
@@ -18,6 +19,7 @@ export interface AccountPublic {
export interface LoginRequest { export interface LoginRequest {
username: string username: string
password: string password: string
totp_code?: string
} }
/** 登录响应 */ /** 登录响应 */
@@ -47,7 +49,6 @@ export interface Provider {
id: string id: string
name: string name: string
display_name: string display_name: string
api_key?: string
base_url: string base_url: string
api_protocol: 'openai' | 'anthropic' api_protocol: 'openai' | 'anthropic'
enabled: boolean enabled: boolean
@@ -109,18 +110,28 @@ export interface RelayTask {
created_at: string created_at: string
} }
/** 用量记录 */ /** 用量统计 — 后端返回的完整结构 */
export interface UsageRecord { export interface UsageStats {
day: string total_requests: number
count: number total_input_tokens: number
total_output_tokens: number
by_model: UsageByModel[]
by_day: DailyUsage[]
}
/** 每日用量 */
export interface DailyUsage {
date: string
request_count: number
input_tokens: number input_tokens: number
output_tokens: number output_tokens: number
} }
/** 按模型用量 */ /** 按模型用量 */
export interface UsageByModel { export interface UsageByModel {
provider_id: string
model_id: string model_id: string
count: number request_count: number
input_tokens: number input_tokens: number
output_tokens: number output_tokens: number
} }
@@ -131,21 +142,23 @@ export interface ConfigItem {
category: string category: string
key_path: string key_path: string
value_type: 'string' | 'number' | 'boolean' value_type: 'string' | 'number' | 'boolean'
current_value?: string | number | boolean current_value?: string
default_value?: string | number | boolean default_value?: string
source: 'default' | 'env' | 'db' source: 'default' | 'env' | 'db'
description?: string description?: string
requires_restart: boolean requires_restart: boolean
created_at: string
updated_at: string
} }
/** 操作日志 */ /** 操作日志 */
export interface OperationLog { export interface OperationLog {
id: string id: number
account_id: string account_id: string
action: string action: string
target_type: string target_type: string
target_id: string target_id: string
details?: string details?: Record<string, unknown>
ip_address?: string ip_address?: string
created_at: string created_at: string
} }
@@ -161,6 +174,17 @@ export interface DashboardStats {
tokens_today_output: number tokens_today_output: number
} }
/** 设备信息 */
export interface DeviceInfo {
id: string
device_id: string
device_name?: string
platform?: string
app_version?: string
last_seen_at: string
created_at: string
}
/** API 错误响应 */ /** API 错误响应 */
export interface ApiError { export interface ApiError {
error: string error: string

View File

@@ -9,8 +9,6 @@ name = "zclaw-saas"
path = "src/main.rs" path = "src/main.rs"
[dependencies] [dependencies]
zclaw-types = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
futures = { workspace = true } futures = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
@@ -23,7 +21,6 @@ chrono = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }
tracing-subscriber = { workspace = true } tracing-subscriber = { workspace = true }
sqlx = { workspace = true } sqlx = { workspace = true }
libsqlite3-sys = { workspace = true }
reqwest = { workspace = true } reqwest = { workspace = true }
secrecy = { workspace = true } secrecy = { workspace = true }
sha2 = { workspace = true } sha2 = { workspace = true }
@@ -34,6 +31,8 @@ url = "2"
axum = { workspace = true } axum = { workspace = true }
axum-extra = { workspace = true } axum-extra = { workspace = true }
bytes = { workspace = true }
async-stream = { workspace = true }
tower = { workspace = true } tower = { workspace = true }
tower-http = { workspace = true } tower-http = { workspace = true }
jsonwebtoken = { workspace = true } jsonwebtoken = { workspace = true }
@@ -41,6 +40,9 @@ argon2 = { workspace = true }
totp-rs = { workspace = true } totp-rs = { workspace = true }
urlencoding = "2" urlencoding = "2"
data-encoding = "2" data-encoding = "2"
aes-gcm = { workspace = true }
utoipa = { version = "5", features = ["axum_extras"] }
utoipa-swagger-ui = { version = "5", features = ["axum"] }
[dev-dependencies] [dev-dependencies]
tempfile = { workspace = true } tempfile = { workspace = true }

View File

@@ -121,23 +121,43 @@ pub async fn list_operation_logs(
let page: i64 = params.get("page").and_then(|v| v.parse().ok()).unwrap_or(1); let page: i64 = params.get("page").and_then(|v| v.parse().ok()).unwrap_or(1);
let page_size: i64 = params.get("page_size").and_then(|v| v.parse().ok()).unwrap_or(50); let page_size: i64 = params.get("page_size").and_then(|v| v.parse().ok()).unwrap_or(50);
let offset = (page - 1) * page_size; let offset = (page - 1) * page_size;
let action_filter = params.get("action").map(|s| s.as_str());
let target_type_filter = params.get("target_type").map(|s| s.as_str());
let rows: Vec<(i64, Option<String>, String, Option<String>, Option<String>, Option<String>, Option<String>, String)> = let mut sql = String::from(
sqlx::query_as( "SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at
"SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at FROM operation_logs"
FROM operation_logs ORDER BY created_at DESC LIMIT ?1 OFFSET ?2" );
) let mut param_idx: usize = 1;
.bind(page_size) if action_filter.is_some() || target_type_filter.is_some() {
.bind(offset) sql.push_str(" WHERE 1=1");
.fetch_all(&state.db) if action_filter.is_some() {
.await?; sql.push_str(&format!(" AND action = ${}", param_idx));
param_idx += 1;
}
if target_type_filter.is_some() {
sql.push_str(&format!(" AND target_type = ${}", param_idx));
param_idx += 1;
}
}
sql.push_str(&format!(" ORDER BY created_at DESC LIMIT ${} OFFSET ${}", param_idx, param_idx + 1));
let mut query = sqlx::query_as::<_, (i64, Option<String>, String, Option<String>, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>)>(&sql);
if let Some(action) = action_filter {
query = query.bind(action);
}
if let Some(target_type) = target_type_filter {
query = query.bind(target_type);
}
query = query.bind(page_size).bind(offset);
let rows = query.fetch_all(&state.db).await?;
let items: Vec<serde_json::Value> = rows.into_iter().map(|(id, account_id, action, target_type, target_id, details, ip_address, created_at)| { let items: Vec<serde_json::Value> = rows.into_iter().map(|(id, account_id, action, target_type, target_id, details, ip_address, created_at)| {
serde_json::json!({ serde_json::json!({
"id": id, "account_id": account_id, "action": action, "id": id, "account_id": account_id, "action": action,
"target_type": target_type, "target_id": target_id, "target_type": target_type, "target_id": target_id,
"details": details.and_then(|d| serde_json::from_str::<serde_json::Value>(&d).ok()), "details": details.and_then(|d| serde_json::from_str::<serde_json::Value>(&d).ok()),
"ip_address": ip_address, "created_at": created_at, "ip_address": ip_address, "created_at": created_at.to_rfc3339(),
}) })
}).collect(); }).collect();
@@ -151,32 +171,27 @@ pub async fn dashboard_stats(
) -> SaasResult<Json<serde_json::Value>> { ) -> SaasResult<Json<serde_json::Value>> {
require_admin(&ctx)?; require_admin(&ctx)?;
let total_accounts: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM accounts") let row: (i64, i64, i64, i64, i64, i64, i64) = sqlx::query_as(
.fetch_one(&state.db).await?; "SELECT
let active_accounts: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM accounts WHERE status = 'active'") (SELECT COUNT(*) FROM accounts),
.fetch_one(&state.db).await?; (SELECT COUNT(*) FROM accounts WHERE status = 'active'),
let tasks_today: (i64,) = sqlx::query_as( (SELECT COUNT(*) FROM relay_tasks WHERE DATE(created_at) = CURRENT_DATE),
"SELECT COUNT(*) FROM relay_tasks WHERE date(created_at) = date('now')" (SELECT COUNT(*) FROM providers WHERE enabled = true),
).fetch_one(&state.db).await?; (SELECT COUNT(*) FROM models WHERE enabled = true),
let active_providers: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM providers WHERE enabled = 1") (SELECT COALESCE(SUM(input_tokens), 0) FROM usage_records WHERE DATE(created_at) = CURRENT_DATE),
.fetch_one(&state.db).await?; (SELECT COALESCE(SUM(output_tokens), 0) FROM usage_records WHERE DATE(created_at) = CURRENT_DATE)"
let active_models: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM models WHERE enabled = 1") )
.fetch_one(&state.db).await?; .fetch_one(&state.db)
let tokens_today_input: (i64,) = sqlx::query_as( .await?;
"SELECT COALESCE(SUM(input_tokens), 0) FROM usage_records WHERE date(created_at) = date('now')"
).fetch_one(&state.db).await?;
let tokens_today_output: (i64,) = sqlx::query_as(
"SELECT COALESCE(SUM(output_tokens), 0) FROM usage_records WHERE date(created_at) = date('now')"
).fetch_one(&state.db).await?;
Ok(Json(serde_json::json!({ Ok(Json(serde_json::json!({
"total_accounts": total_accounts.0, "total_accounts": row.0,
"active_accounts": active_accounts.0, "active_accounts": row.1,
"tasks_today": tasks_today.0, "tasks_today": row.2,
"active_providers": active_providers.0, "active_providers": row.3,
"active_models": active_models.0, "active_models": row.4,
"tokens_today_input": tokens_today_input.0, "tokens_today_input": row.5,
"tokens_today_output": tokens_today_output.0, "tokens_today_output": row.6,
}))) })))
} }
@@ -186,59 +201,48 @@ pub async fn dashboard_stats(
pub async fn register_device( pub async fn register_device(
State(state): State<AppState>, State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>, Extension(ctx): Extension<AuthContext>,
Json(req): Json<serde_json::Value>, Json(req): Json<super::types::RegisterDeviceRequest>,
) -> SaasResult<Json<serde_json::Value>> { ) -> SaasResult<Json<serde_json::Value>> {
let device_id = req.get("device_id") let now = chrono::Utc::now();
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("缺少 device_id".into()))?;
let device_name = req.get("device_name").and_then(|v| v.as_str()).unwrap_or("Unknown");
let platform = req.get("platform").and_then(|v| v.as_str()).unwrap_or("unknown");
let app_version = req.get("app_version").and_then(|v| v.as_str()).unwrap_or("");
let now = chrono::Utc::now().to_rfc3339();
let device_uuid = uuid::Uuid::new_v4().to_string(); let device_uuid = uuid::Uuid::new_v4().to_string();
// UPSERT: 已存在则更新 last_seen_at不存在则插入 // UPSERT: 已存在则更新 last_seen_at不存在则插入
sqlx::query( sqlx::query(
"INSERT INTO devices (id, account_id, device_id, device_name, platform, app_version, last_seen_at, created_at) "INSERT INTO devices (id, account_id, device_id, device_name, platform, app_version, last_seen_at, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?7) VALUES ($1, $2, $3, $4, $5, $6, $7, $7)
ON CONFLICT(account_id, device_id) DO UPDATE SET ON CONFLICT(account_id, device_id) DO UPDATE SET
device_name = ?4, platform = ?5, app_version = ?6, last_seen_at = ?7" device_name = EXCLUDED.device_name, platform = EXCLUDED.platform, app_version = EXCLUDED.app_version, last_seen_at = EXCLUDED.last_seen_at"
) )
.bind(&device_uuid) .bind(&device_uuid)
.bind(&ctx.account_id) .bind(&ctx.account_id)
.bind(device_id) .bind(&req.device_id)
.bind(device_name) .bind(&req.device_name)
.bind(platform) .bind(&req.platform)
.bind(app_version) .bind(&req.app_version)
.bind(&now) .bind(&now)
.execute(&state.db) .execute(&state.db)
.await?; .await?;
log_operation(&state.db, &ctx.account_id, "device.register", "device", device_id, log_operation(&state.db, &ctx.account_id, "device.register", "device", &req.device_id,
Some(serde_json::json!({"device_name": device_name, "platform": platform})), Some(serde_json::json!({"device_name": req.device_name, "platform": req.platform})),
ctx.client_ip.as_deref()).await?; ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true, "device_id": device_id}))) Ok(Json(serde_json::json!({"ok": true, "device_id": req.device_id})))
} }
/// POST /api/v1/devices/heartbeat — 设备心跳 /// POST /api/v1/devices/heartbeat — 设备心跳
pub async fn device_heartbeat( pub async fn device_heartbeat(
State(state): State<AppState>, State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>, Extension(ctx): Extension<AuthContext>,
Json(req): Json<serde_json::Value>, Json(req): Json<super::types::DeviceHeartbeatRequest>,
) -> SaasResult<Json<serde_json::Value>> { ) -> SaasResult<Json<serde_json::Value>> {
let device_id = req.get("device_id") let now = chrono::Utc::now();
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("缺少 device_id".into()))?;
let now = chrono::Utc::now().to_rfc3339();
let result = sqlx::query( let result = sqlx::query(
"UPDATE devices SET last_seen_at = ?1 WHERE account_id = ?2 AND device_id = ?3" "UPDATE devices SET last_seen_at = $1 WHERE account_id = $2 AND device_id = $3"
) )
.bind(&now) .bind(&now)
.bind(&ctx.account_id) .bind(&ctx.account_id)
.bind(device_id) .bind(&req.device_id)
.execute(&state.db) .execute(&state.db)
.await?; .await?;
@@ -253,22 +257,22 @@ pub async fn device_heartbeat(
pub async fn list_devices( pub async fn list_devices(
State(state): State<AppState>, State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>, Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<Vec<serde_json::Value>>> { ) -> SaasResult<Json<Vec<super::types::DeviceInfo>>> {
let rows: Vec<(String, String, Option<String>, Option<String>, Option<String>, String, String)> = let rows: Vec<(String, String, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, device_id, device_name, platform, app_version, last_seen_at, created_at "SELECT id, device_id, device_name, platform, app_version, last_seen_at, created_at
FROM devices WHERE account_id = ?1 ORDER BY last_seen_at DESC" FROM devices WHERE account_id = $1 ORDER BY last_seen_at DESC"
) )
.bind(&ctx.account_id) .bind(&ctx.account_id)
.fetch_all(&state.db) .fetch_all(&state.db)
.await?; .await?;
let items: Vec<serde_json::Value> = rows.into_iter().map(|r| { let items: Vec<super::types::DeviceInfo> = rows.into_iter().map(|r| {
serde_json::json!({ super::types::DeviceInfo {
"id": r.0, "device_id": r.1, id: r.0, device_id: r.1,
"device_name": r.2, "platform": r.3, "app_version": r.4, device_name: r.2, platform: r.3, app_version: r.4,
"last_seen_at": r.5, "created_at": r.6, last_seen_at: r.5.to_rfc3339(), created_at: r.6.to_rfc3339(),
}) }
}).collect(); }).collect();
Ok(Json(items)) Ok(Json(items))

View File

@@ -1,11 +1,11 @@
//! 账号管理业务逻辑 //! 账号管理业务逻辑
use sqlx::SqlitePool; use sqlx::PgPool;
use crate::error::{SaasError, SaasResult}; use crate::error::{SaasError, SaasResult};
use super::types::*; use super::types::*;
pub async fn list_accounts( pub async fn list_accounts(
db: &SqlitePool, db: &PgPool,
query: &ListAccountsQuery, query: &ListAccountsQuery,
) -> SaasResult<PaginatedResponse<serde_json::Value>> { ) -> SaasResult<PaginatedResponse<serde_json::Value>> {
let page = query.page.unwrap_or(1).max(1); let page = query.page.unwrap_or(1).max(1);
@@ -14,21 +14,25 @@ pub async fn list_accounts(
let mut where_clauses = Vec::new(); let mut where_clauses = Vec::new();
let mut params: Vec<String> = Vec::new(); let mut params: Vec<String> = Vec::new();
let mut param_idx: usize = 1;
if let Some(role) = &query.role { if let Some(role) = &query.role {
where_clauses.push("role = ?".to_string()); where_clauses.push(format!("role = ${}", param_idx));
params.push(role.clone()); params.push(role.clone());
param_idx += 1;
} }
if let Some(status) = &query.status { if let Some(status) = &query.status {
where_clauses.push("status = ?".to_string()); where_clauses.push(format!("status = ${}", param_idx));
params.push(status.clone()); params.push(status.clone());
param_idx += 1;
} }
if let Some(search) = &query.search { if let Some(search) = &query.search {
where_clauses.push("(username LIKE ? OR email LIKE ? OR display_name LIKE ?)".to_string()); where_clauses.push(format!("(username LIKE ${} OR email LIKE ${} OR display_name LIKE ${})", param_idx, param_idx + 1, param_idx + 2));
let pattern = format!("%{}%", search); let pattern = format!("%{}%", search);
params.push(pattern.clone()); params.push(pattern.clone());
params.push(pattern.clone()); params.push(pattern.clone());
params.push(pattern); params.push(pattern);
param_idx += 3;
} }
let where_sql = if where_clauses.is_empty() { let where_sql = if where_clauses.is_empty() {
@@ -46,10 +50,10 @@ pub async fn list_accounts(
let data_sql = format!( let data_sql = format!(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at "SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts {} ORDER BY created_at DESC LIMIT ? OFFSET ?", FROM accounts {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
where_sql where_sql, param_idx, param_idx + 1
); );
let mut data_query = sqlx::query_as::<_, (String, String, String, String, String, String, bool, Option<String>, String)>(&data_sql); let mut data_query = sqlx::query_as::<_, (String, String, String, String, String, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)>(&data_sql);
for p in &params { for p in &params {
data_query = data_query.bind(p); data_query = data_query.bind(p);
} }
@@ -61,7 +65,7 @@ pub async fn list_accounts(
serde_json::json!({ serde_json::json!({
"id": id, "username": username, "email": email, "display_name": display_name, "id": id, "username": username, "email": email, "display_name": display_name,
"role": role, "status": status, "totp_enabled": totp_enabled, "role": role, "status": status, "totp_enabled": totp_enabled,
"last_login_at": last_login_at, "created_at": created_at, "last_login_at": last_login_at.map(|t| t.to_rfc3339()), "created_at": created_at.to_rfc3339(),
}) })
}) })
.collect(); .collect();
@@ -69,11 +73,11 @@ pub async fn list_accounts(
Ok(PaginatedResponse { items, total, page, page_size }) Ok(PaginatedResponse { items, total, page, page_size })
} }
pub async fn get_account(db: &SqlitePool, account_id: &str) -> SaasResult<serde_json::Value> { pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json::Value> {
let row: Option<(String, String, String, String, String, String, bool, Option<String>, String)> = let row: Option<(String, String, String, String, String, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at "SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
FROM accounts WHERE id = ?1" FROM accounts WHERE id = $1"
) )
.bind(account_id) .bind(account_id)
.fetch_optional(db) .fetch_optional(db)
@@ -85,43 +89,45 @@ pub async fn get_account(db: &SqlitePool, account_id: &str) -> SaasResult<serde_
Ok(serde_json::json!({ Ok(serde_json::json!({
"id": id, "username": username, "email": email, "display_name": display_name, "id": id, "username": username, "email": email, "display_name": display_name,
"role": role, "status": status, "totp_enabled": totp_enabled, "role": role, "status": status, "totp_enabled": totp_enabled,
"last_login_at": last_login_at, "created_at": created_at, "last_login_at": last_login_at.map(|t| t.to_rfc3339()), "created_at": created_at.to_rfc3339(),
})) }))
} }
pub async fn update_account( pub async fn update_account(
db: &SqlitePool, db: &PgPool,
account_id: &str, account_id: &str,
req: &UpdateAccountRequest, req: &UpdateAccountRequest,
) -> SaasResult<serde_json::Value> { ) -> SaasResult<serde_json::Value> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let mut updates = Vec::new(); let mut updates = Vec::new();
let mut params: Vec<String> = Vec::new(); let mut params: Vec<String> = Vec::new();
let mut param_idx: usize = 1;
if let Some(ref v) = req.display_name { updates.push("display_name = ?"); params.push(v.clone()); } if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.email { updates.push("email = ?"); params.push(v.clone()); } if let Some(ref v) = req.email { updates.push(format!("email = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.role { updates.push("role = ?"); params.push(v.clone()); } if let Some(ref v) = req.role { updates.push(format!("role = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.avatar_url { updates.push("avatar_url = ?"); params.push(v.clone()); } if let Some(ref v) = req.avatar_url { updates.push(format!("avatar_url = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if updates.is_empty() { if updates.is_empty() {
return get_account(db, account_id).await; return get_account(db, account_id).await;
} }
updates.push("updated_at = ?"); updates.push(format!("updated_at = ${}", param_idx));
params.push(now.clone()); param_idx += 1;
params.push(account_id.to_string()); params.push(account_id.to_string());
let sql = format!("UPDATE accounts SET {} WHERE id = ?", updates.join(", ")); let sql = format!("UPDATE accounts SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql); let mut query = sqlx::query(&sql);
for p in &params { for p in &params {
query = query.bind(p); query = query.bind(p);
} }
query = query.bind(now);
query.execute(db).await?; query.execute(db).await?;
get_account(db, account_id).await get_account(db, account_id).await
} }
pub async fn update_account_status( pub async fn update_account_status(
db: &SqlitePool, db: &PgPool,
account_id: &str, account_id: &str,
status: &str, status: &str,
) -> SaasResult<()> { ) -> SaasResult<()> {
@@ -129,8 +135,8 @@ pub async fn update_account_status(
if !valid.contains(&status) { if !valid.contains(&status) {
return Err(SaasError::InvalidInput(format!("无效状态: {},有效值: {:?}", status, valid))); return Err(SaasError::InvalidInput(format!("无效状态: {},有效值: {:?}", status, valid)));
} }
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let result = sqlx::query("UPDATE accounts SET status = ?1, updated_at = ?2 WHERE id = ?3") let result = sqlx::query("UPDATE accounts SET status = $1, updated_at = $2 WHERE id = $3")
.bind(status).bind(&now).bind(account_id) .bind(status).bind(&now).bind(account_id)
.execute(db).await?; .execute(db).await?;
@@ -141,7 +147,7 @@ pub async fn update_account_status(
} }
pub async fn create_api_token( pub async fn create_api_token(
db: &SqlitePool, db: &PgPool,
account_id: &str, account_id: &str,
req: &CreateTokenRequest, req: &CreateTokenRequest,
) -> SaasResult<TokenInfo> { ) -> SaasResult<TokenInfo> {
@@ -154,16 +160,18 @@ pub async fn create_api_token(
let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes())); let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
let token_prefix = raw_token[..8].to_string(); let token_prefix = raw_token[..8].to_string();
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let now_str = now.to_rfc3339();
let expires_at = req.expires_days.map(|d| { let expires_at = req.expires_days.map(|d| {
(chrono::Utc::now() + chrono::Duration::days(d)).to_rfc3339() chrono::Utc::now() + chrono::Duration::days(d)
}); });
let expires_at_str = expires_at.as_ref().map(|t| t.to_rfc3339());
let permissions = serde_json::to_string(&req.permissions)?; let permissions = serde_json::to_string(&req.permissions)?;
let token_id = uuid::Uuid::new_v4().to_string(); let token_id = uuid::Uuid::new_v4().to_string();
sqlx::query( sqlx::query(
"INSERT INTO api_tokens (id, account_id, name, token_hash, token_prefix, permissions, created_at, expires_at) "INSERT INTO api_tokens (id, account_id, name, token_hash, token_prefix, permissions, created_at, expires_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
) )
.bind(&token_id) .bind(&token_id)
.bind(account_id) .bind(account_id)
@@ -182,20 +190,20 @@ pub async fn create_api_token(
token_prefix, token_prefix,
permissions: req.permissions.clone(), permissions: req.permissions.clone(),
last_used_at: None, last_used_at: None,
expires_at, expires_at: expires_at_str,
created_at: now, created_at: now_str,
token: Some(raw_token), token: Some(raw_token),
}) })
} }
pub async fn list_api_tokens( pub async fn list_api_tokens(
db: &SqlitePool, db: &PgPool,
account_id: &str, account_id: &str,
) -> SaasResult<Vec<TokenInfo>> { ) -> SaasResult<Vec<TokenInfo>> {
let rows: Vec<(String, String, String, String, Option<String>, Option<String>, String)> = let rows: Vec<(String, String, String, String, Option<chrono::DateTime<chrono::Utc>>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, name, token_prefix, permissions, last_used_at, expires_at, created_at "SELECT id, name, token_prefix, permissions, last_used_at, expires_at, created_at
FROM api_tokens WHERE account_id = ?1 AND revoked_at IS NULL ORDER BY created_at DESC" FROM api_tokens WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC"
) )
.bind(account_id) .bind(account_id)
.fetch_all(db) .fetch_all(db)
@@ -203,14 +211,14 @@ pub async fn list_api_tokens(
Ok(rows.into_iter().map(|(id, name, token_prefix, perms, last_used, expires, created)| { Ok(rows.into_iter().map(|(id, name, token_prefix, perms, last_used, expires, created)| {
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default(); let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
TokenInfo { id, name, token_prefix, permissions, last_used_at: last_used, expires_at: expires, created_at: created, token: None, } TokenInfo { id, name, token_prefix, permissions, last_used_at: last_used.map(|t| t.to_rfc3339()), expires_at: expires.map(|t| t.to_rfc3339()), created_at: created.to_rfc3339(), token: None, }
}).collect()) }).collect())
} }
pub async fn revoke_api_token(db: &SqlitePool, token_id: &str, account_id: &str) -> SaasResult<()> { pub async fn revoke_api_token(db: &PgPool, token_id: &str, account_id: &str) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let result = sqlx::query( let result = sqlx::query(
"UPDATE api_tokens SET revoked_at = ?1 WHERE id = ?2 AND account_id = ?3 AND revoked_at IS NULL" "UPDATE api_tokens SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
) )
.bind(&now).bind(token_id).bind(account_id) .bind(&now).bind(token_id).bind(account_id)
.execute(db).await?; .execute(db).await?;

View File

@@ -2,7 +2,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateAccountRequest { pub struct UpdateAccountRequest {
pub display_name: Option<String>, pub display_name: Option<String>,
pub email: Option<String>, pub email: Option<String>,
@@ -10,12 +10,12 @@ pub struct UpdateAccountRequest {
pub avatar_url: Option<String>, pub avatar_url: Option<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateStatusRequest { pub struct UpdateStatusRequest {
pub status: String, pub status: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
pub struct ListAccountsQuery { pub struct ListAccountsQuery {
pub page: Option<u32>, pub page: Option<u32>,
pub page_size: Option<u32>, pub page_size: Option<u32>,
@@ -32,14 +32,28 @@ pub struct PaginatedResponse<T: Serialize> {
pub page_size: u32, pub page_size: u32,
} }
#[derive(Debug, Deserialize)] /// Concrete type alias for OpenAPI schema generation.
///
/// NOTE: This is intentionally a concrete (non-generic) type because utoipa
/// requires concrete types for schema generation. It is functionally
/// identical to `Paginated<AccountPublic>`.
#[derive(Debug, Serialize, utoipa::ToSchema)]
#[allow(clippy::manual_non_exhaustive)] // kept for OpenAPI schema
pub struct AccountPublicPaginatedResponse {
pub items: Vec<crate::auth::types::AccountPublic>,
pub total: i64,
pub page: u32,
pub page_size: u32,
}
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateTokenRequest { pub struct CreateTokenRequest {
pub name: String, pub name: String,
pub permissions: Vec<String>, pub permissions: Vec<String>,
pub expires_days: Option<i64>, pub expires_days: Option<i64>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct TokenInfo { pub struct TokenInfo {
pub id: String, pub id: String,
pub name: String, pub name: String,
@@ -51,3 +65,35 @@ pub struct TokenInfo {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub token: Option<String>, pub token: Option<String>,
} }
// ============ Device Types ============
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct RegisterDeviceRequest {
pub device_id: String,
#[serde(default = "default_device_name")]
pub device_name: String,
#[serde(default = "default_platform")]
pub platform: String,
#[serde(default)]
pub app_version: String,
}
fn default_device_name() -> String { "Unknown".into() }
fn default_platform() -> String { "unknown".into() }
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct DeviceHeartbeatRequest {
pub device_id: String,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct DeviceInfo {
pub id: String,
pub device_id: String,
pub device_name: Option<String>,
pub platform: Option<String>,
pub app_version: Option<String>,
pub last_seen_at: String,
pub created_at: String,
}

View File

@@ -16,16 +16,24 @@ pub async fn register(
State(state): State<AppState>, State(state): State<AppState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(req): Json<RegisterRequest>, Json(req): Json<RegisterRequest>,
) -> SaasResult<(StatusCode, Json<AccountPublic>)> { ) -> SaasResult<(StatusCode, Json<LoginResponse>)> {
if req.username.len() < 3 { // 4.6: 用户名格式验证 — 3-32 字符,仅允许字母数字下划线
return Err(SaasError::InvalidInput("用户名至少 3 个字符".into())); if req.username.len() < 3 || req.username.len() > 32 {
return Err(SaasError::InvalidInput("用户名长度需在 3-32 个字符之间".into()));
}
if !req.username.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(SaasError::InvalidInput("用户名仅允许字母、数字和下划线".into()));
}
// 4.7: 邮箱格式验证
if !req.email.contains('@') || !req.email.split('@').nth(1).map_or(false, |d| d.contains('.')) {
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
} }
if req.password.len() < 8 { if req.password.len() < 8 {
return Err(SaasError::InvalidInput("密码至少 8 个字符".into())); return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
} }
let existing: Vec<(String,)> = sqlx::query_as( let existing: Vec<(String,)> = sqlx::query_as(
"SELECT id FROM accounts WHERE username = ?1 OR email = ?2" "SELECT id FROM accounts WHERE username = $1 OR email = $2"
) )
.bind(&req.username) .bind(&req.username)
.bind(&req.email) .bind(&req.email)
@@ -40,11 +48,11 @@ pub async fn register(
let account_id = uuid::Uuid::new_v4().to_string(); let account_id = uuid::Uuid::new_v4().to_string();
let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配 let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配
let display_name = req.display_name.unwrap_or_default(); let display_name = req.display_name.unwrap_or_default();
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
sqlx::query( sqlx::query(
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at) "INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'active', ?7, ?7)" VALUES ($1, $2, $3, $4, $5, $6, 'active', $7, $7)"
) )
.bind(&account_id) .bind(&account_id)
.bind(&req.username) .bind(&req.username)
@@ -52,22 +60,33 @@ pub async fn register(
.bind(&password_hash) .bind(&password_hash)
.bind(&display_name) .bind(&display_name)
.bind(&role) .bind(&role)
.bind(&now) .bind(now)
.execute(&state.db) .execute(&state.db)
.await?; .await?;
let client_ip = addr.ip().to_string(); let client_ip = addr.ip().to_string();
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?; log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
Ok((StatusCode::CREATED, Json(AccountPublic { // Generate JWT token for auto-login after registration
id: account_id, let config = state.config.read().await;
username: req.username, let token = create_token(
email: req.email, &account_id, &role, vec![],
display_name, state.jwt_secret.expose_secret(), config.auth.jwt_expiration_hours,
role, )?;
status: "active".into(),
totp_enabled: false, Ok((StatusCode::CREATED, Json(LoginResponse {
created_at: now, token,
account: AccountPublic {
id: account_id,
username: req.username,
email: req.email,
display_name,
role,
permissions: vec![],
status: "active".into(),
totp_enabled: false,
created_at: now.to_rfc3339(),
},
}))) })))
} }
@@ -77,10 +96,10 @@ pub async fn login(
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(req): Json<LoginRequest>, Json(req): Json<LoginRequest>,
) -> SaasResult<Json<LoginResponse>> { ) -> SaasResult<Json<LoginResponse>> {
let row: Option<(String, String, String, String, String, String, bool, String)> = let row: Option<(String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at "SELECT id, username, email, display_name, role, status, totp_enabled, created_at
FROM accounts WHERE username = ?1 OR email = ?1" FROM accounts WHERE username = $1 OR email = $1"
) )
.bind(&req.username) .bind(&req.username)
.fetch_optional(&state.db) .fetch_optional(&state.db)
@@ -88,13 +107,14 @@ pub async fn login(
let (id, username, email, display_name, role, status, totp_enabled, created_at) = let (id, username, email, display_name, role, status, totp_enabled, created_at) =
row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?; row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
let created_at = created_at.to_rfc3339();
if status != "active" { if status != "active" {
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", status))); return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", status)));
} }
let (password_hash,): (String,) = sqlx::query_as( let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = ?1" "SELECT password_hash FROM accounts WHERE id = $1"
) )
.bind(&id) .bind(&id)
.fetch_one(&state.db) .fetch_one(&state.db)
@@ -110,7 +130,7 @@ pub async fn login(
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?; .ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
let (totp_secret,): (Option<String>,) = sqlx::query_as( let (totp_secret,): (Option<String>,) = sqlx::query_as(
"SELECT totp_secret FROM accounts WHERE id = ?1" "SELECT totp_secret FROM accounts WHERE id = $1"
) )
.bind(&id) .bind(&id)
.fetch_one(&state.db) .fetch_one(&state.db)
@@ -120,7 +140,10 @@ pub async fn login(
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into()) SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
})?; })?;
if !super::totp::verify_totp_code(&secret, code) { // 解密 TOTP 密钥(兼容迁移期间的明文数据)
let decrypted_secret = state.field_encryption.decrypt_or_plaintext(&secret);
if !super::totp::verify_totp_code(&decrypted_secret, code) {
return Err(SaasError::Totp("TOTP 码错误或已过期".into())); return Err(SaasError::Totp("TOTP 码错误或已过期".into()));
} }
} }
@@ -133,9 +156,9 @@ pub async fn login(
config.auth.jwt_expiration_hours, config.auth.jwt_expiration_hours,
)?; )?;
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
sqlx::query("UPDATE accounts SET last_login_at = ?1 WHERE id = ?2") sqlx::query("UPDATE accounts SET last_login_at = $1 WHERE id = $2")
.bind(&now).bind(&id) .bind(now).bind(&id)
.execute(&state.db).await?; .execute(&state.db).await?;
let client_ip = addr.ip().to_string(); let client_ip = addr.ip().to_string();
log_operation(&state.db, &id, "account.login", "account", &id, None, Some(&client_ip)).await?; log_operation(&state.db, &id, "account.login", "account", &id, None, Some(&client_ip)).await?;
@@ -143,7 +166,7 @@ pub async fn login(
Ok(Json(LoginResponse { Ok(Json(LoginResponse {
token, token,
account: AccountPublic { account: AccountPublic {
id, username, email, display_name, role, status, totp_enabled, created_at, id, username, email, display_name, role, permissions, status, totp_enabled, created_at,
}, },
})) }))
} }
@@ -152,14 +175,30 @@ pub async fn login(
pub async fn refresh( pub async fn refresh(
State(state): State<AppState>, State(state): State<AppState>,
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>, axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> { ) -> SaasResult<Json<LoginResponse>> {
let config = state.config.read().await; let config = state.config.read().await;
let token = create_token( let token = create_token(
&ctx.account_id, &ctx.role, ctx.permissions.clone(), &ctx.account_id, &ctx.role, ctx.permissions.clone(),
state.jwt_secret.expose_secret(), state.jwt_secret.expose_secret(),
config.auth.jwt_expiration_hours, config.auth.jwt_expiration_hours,
)?; )?;
Ok(Json(serde_json::json!({ "token": token })))
// 查询账号信息以返回完整 LoginResponse
let row = sqlx::query_as::<_, (String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)>(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
FROM accounts WHERE id = $1"
)
.bind(&ctx.account_id)
.fetch_optional(&state.db)
.await?
.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
let (id, username, email, display_name, role, status, totp_enabled, created_at) = row;
let created_at = created_at.to_rfc3339();
Ok(Json(LoginResponse {
token,
account: AccountPublic { id, username, email, display_name, role, permissions: ctx.permissions, status, totp_enabled, created_at },
}))
} }
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息 /// GET /api/v1/auth/me — 返回当前认证用户的公开信息
@@ -167,10 +206,10 @@ pub async fn me(
State(state): State<AppState>, State(state): State<AppState>,
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>, axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
) -> SaasResult<Json<AccountPublic>> { ) -> SaasResult<Json<AccountPublic>> {
let row: Option<(String, String, String, String, String, String, bool, String)> = let row: Option<(String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at "SELECT id, username, email, display_name, role, status, totp_enabled, created_at
FROM accounts WHERE id = ?1" FROM accounts WHERE id = $1"
) )
.bind(&ctx.account_id) .bind(&ctx.account_id)
.fetch_optional(&state.db) .fetch_optional(&state.db)
@@ -178,9 +217,10 @@ pub async fn me(
let (id, username, email, display_name, role, status, totp_enabled, created_at) = let (id, username, email, display_name, role, status, totp_enabled, created_at) =
row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?; row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
let created_at = created_at.to_rfc3339();
Ok(Json(AccountPublic { Ok(Json(AccountPublic {
id, username, email, display_name, role, status, totp_enabled, created_at, id, username, email, display_name, role, permissions: ctx.permissions, status, totp_enabled, created_at,
})) }))
} }
@@ -196,7 +236,7 @@ pub async fn change_password(
// 获取当前密码哈希 // 获取当前密码哈希
let (password_hash,): (String,) = sqlx::query_as( let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = ?1" "SELECT password_hash FROM accounts WHERE id = $1"
) )
.bind(&ctx.account_id) .bind(&ctx.account_id)
.fetch_one(&state.db) .fetch_one(&state.db)
@@ -209,10 +249,10 @@ pub async fn change_password(
// 更新密码 // 更新密码
let new_hash = hash_password(&req.new_password)?; let new_hash = hash_password(&req.new_password)?;
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
sqlx::query("UPDATE accounts SET password_hash = ?1, updated_at = ?2 WHERE id = ?3") sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2 WHERE id = $3")
.bind(&new_hash) .bind(&new_hash)
.bind(&now) .bind(now)
.bind(&ctx.account_id) .bind(&ctx.account_id)
.execute(&state.db) .execute(&state.db)
.await?; .await?;
@@ -223,16 +263,16 @@ pub async fn change_password(
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"}))) Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
} }
pub(crate) async fn get_role_permissions(db: &sqlx::SqlitePool, role: &str) -> SaasResult<Vec<String>> { pub(crate) async fn get_role_permissions(db: &sqlx::PgPool, role: &str) -> SaasResult<Vec<String>> {
let row: Option<(String,)> = sqlx::query_as( let row: Option<(String,)> = sqlx::query_as(
"SELECT permissions FROM roles WHERE id = ?1" "SELECT permissions FROM roles WHERE id = $1"
) )
.bind(role) .bind(role)
.fetch_optional(db) .fetch_optional(db)
.await?; .await?;
let permissions_str = row let permissions_str = row
.ok_or_else(|| SaasError::Internal(format!("角色 {} 不存在", role)))? .ok_or_else(|| SaasError::Forbidden(format!("角色 {} 不存在或无权限", role)))?
.0; .0;
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?; let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
@@ -252,7 +292,7 @@ pub fn check_permission(ctx: &AuthContext, permission: &str) -> SaasResult<()> {
/// 记录操作日志 /// 记录操作日志
pub async fn log_operation( pub async fn log_operation(
db: &sqlx::SqlitePool, db: &sqlx::PgPool,
account_id: &str, account_id: &str,
action: &str, action: &str,
target_type: &str, target_type: &str,
@@ -260,10 +300,10 @@ pub async fn log_operation(
details: Option<serde_json::Value>, details: Option<serde_json::Value>,
ip_address: Option<&str>, ip_address: Option<&str>,
) -> SaasResult<()> { ) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
sqlx::query( sqlx::query(
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at) "INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)" VALUES ($1, $2, $3, $4, $5, $6, $7)"
) )
.bind(account_id) .bind(account_id)
.bind(action) .bind(action)
@@ -271,8 +311,54 @@ pub async fn log_operation(
.bind(target_id) .bind(target_id)
.bind(details.map(|d| d.to_string())) .bind(details.map(|d| d.to_string()))
.bind(ip_address) .bind(ip_address)
.bind(&now) .bind(now)
.execute(db) .execute(db)
.await?; .await?;
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::types::AuthContext;
fn ctx(permissions: Vec<&str>) -> AuthContext {
AuthContext {
account_id: "test-id".into(),
role: "user".into(),
permissions: permissions.into_iter().map(String::from).collect(),
client_ip: None,
}
}
#[test]
fn test_check_permission_admin_full() {
let c = ctx(vec!["admin:full"]);
assert!(check_permission(&c, "config:write").is_ok());
assert!(check_permission(&c, "account:admin").is_ok());
assert!(check_permission(&c, "any:permission").is_ok());
}
#[test]
fn test_check_permission_has_permission() {
let c = ctx(vec!["config:write", "model:read"]);
assert!(check_permission(&c, "config:write").is_ok());
assert!(check_permission(&c, "model:read").is_ok());
}
#[test]
fn test_check_permission_missing() {
let c = ctx(vec!["model:read"]);
let result = check_permission(&c, "config:write");
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("config:write"));
}
#[test]
fn test_check_permission_empty_list() {
let c = ctx(vec![]);
assert!(check_permission(&c, "config:write").is_err());
assert!(check_permission(&c, "admin:full").is_err());
}
}

View File

@@ -10,17 +10,24 @@ use crate::error::SaasResult;
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Claims { pub struct Claims {
pub sub: String, pub sub: String,
pub aud: String,
pub iss: String,
pub role: String, pub role: String,
pub permissions: Vec<String>, pub permissions: Vec<String>,
pub iat: i64, pub iat: i64,
pub exp: i64, pub exp: i64,
} }
const JWT_AUDIENCE: &str = "zclaw-saas";
const JWT_ISSUER: &str = "zclaw-saas";
impl Claims { impl Claims {
pub fn new(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self { pub fn new(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self {
let now = Utc::now(); let now = Utc::now();
Self { Self {
sub: account_id.to_string(), sub: account_id.to_string(),
aud: JWT_AUDIENCE.to_string(),
iss: JWT_ISSUER.to_string(),
role: role.to_string(), role: role.to_string(),
permissions, permissions,
iat: now.timestamp(), iat: now.timestamp(),
@@ -48,10 +55,14 @@ pub fn create_token(
/// 验证 JWT Token /// 验证 JWT Token
pub fn verify_token(token: &str, secret: &str) -> SaasResult<Claims> { pub fn verify_token(token: &str, secret: &str) -> SaasResult<Claims> {
let mut validation = Validation::default();
validation.set_audience(&[JWT_AUDIENCE]);
validation.set_issuer(&[JWT_ISSUER]);
let token_data = decode::<Claims>( let token_data = decode::<Claims>(
token, token,
&DecodingKey::from_secret(secret.as_bytes()), &DecodingKey::from_secret(secret.as_bytes()),
&Validation::default(), &validation,
)?; )?;
Ok(token_data.claims) Ok(token_data.claims)
} }

View File

@@ -29,7 +29,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
let row: Option<(String, Option<String>, String)> = sqlx::query_as( let row: Option<(String, Option<String>, String)> = sqlx::query_as(
"SELECT account_id, expires_at, permissions FROM api_tokens "SELECT account_id, expires_at, permissions FROM api_tokens
WHERE token_hash = ?1 AND revoked_at IS NULL" WHERE token_hash = $1 AND revoked_at IS NULL"
) )
.bind(&token_hash) .bind(&token_hash)
.fetch_optional(&state.db) .fetch_optional(&state.db)
@@ -50,7 +50,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
// 查询关联账号的角色 // 查询关联账号的角色
let (role,): (String,) = sqlx::query_as( let (role,): (String,) = sqlx::query_as(
"SELECT role FROM accounts WHERE id = ?1 AND status = 'active'" "SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
) )
.bind(&account_id) .bind(&account_id)
.fetch_optional(&state.db) .fetch_optional(&state.db)
@@ -70,9 +70,9 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
// 异步更新 last_used_at不阻塞请求 // 异步更新 last_used_at不阻塞请求
let db = state.db.clone(); let db = state.db.clone();
tokio::spawn(async move { tokio::spawn(async move {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = ?1 WHERE token_hash = ?2") let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
.bind(&now).bind(&token_hash) .bind(now).bind(&token_hash)
.execute(&db).await; .execute(&db).await;
}); });
@@ -84,23 +84,11 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
}) })
} }
/// 从请求中提取客户端 IP /// 从请求中提取客户端 IP(仅信任直连 IP不信任可伪造的 proxy header
fn extract_client_ip(req: &Request) -> Option<String> { fn extract_client_ip(req: &Request) -> Option<String> {
// 优先从 ConnectInfo 获取 req.extensions()
if let Some(ConnectInfo(addr)) = req.extensions().get::<ConnectInfo<SocketAddr>>() { .get::<ConnectInfo<SocketAddr>>()
return Some(addr.ip().to_string()); .map(|addr| addr.ip().to_string())
}
// 回退到 X-Forwarded-For / X-Real-IP
if let Some(forwarded) = req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
{
return Some(forwarded.split(',').next()?.trim().to_string());
}
req.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
} }
/// 认证中间件: 从 JWT 或 API Token 提取身份 /// 认证中间件: 从 JWT 或 API Token 提取身份

View File

@@ -94,7 +94,7 @@ pub async fn setup_totp(
) -> SaasResult<Json<TotpSetupResponse>> { ) -> SaasResult<Json<TotpSetupResponse>> {
// 如果已启用 TOTP先清除旧密钥 // 如果已启用 TOTP先清除旧密钥
let (username,): (String,) = sqlx::query_as( let (username,): (String,) = sqlx::query_as(
"SELECT username FROM accounts WHERE id = ?1" "SELECT username FROM accounts WHERE id = $1"
) )
.bind(&ctx.account_id) .bind(&ctx.account_id)
.fetch_one(&state.db) .fetch_one(&state.db)
@@ -103,9 +103,10 @@ pub async fn setup_totp(
let config = state.config.read().await; let config = state.config.read().await;
let setup = generate_totp_secret(&config.auth.totp_issuer, &username); let setup = generate_totp_secret(&config.auth.totp_issuer, &username);
// 存储密钥 (但不启用,需要 /verify 确认) // 加密 TOTP 密钥后存储 (但不启用,需要 /verify 确认)
sqlx::query("UPDATE accounts SET totp_secret = ?1 WHERE id = ?2") let encrypted_secret = state.field_encryption.encrypt(&setup.secret)?;
.bind(&setup.secret) sqlx::query("UPDATE accounts SET totp_secret = $1 WHERE id = $2")
.bind(&encrypted_secret)
.bind(&ctx.account_id) .bind(&ctx.account_id)
.execute(&state.db) .execute(&state.db)
.await?; .await?;
@@ -130,7 +131,7 @@ pub async fn verify_totp(
// 获取存储的密钥 // 获取存储的密钥
let (totp_secret,): (Option<String>,) = sqlx::query_as( let (totp_secret,): (Option<String>,) = sqlx::query_as(
"SELECT totp_secret FROM accounts WHERE id = ?1" "SELECT totp_secret FROM accounts WHERE id = $1"
) )
.bind(&ctx.account_id) .bind(&ctx.account_id)
.fetch_one(&state.db) .fetch_one(&state.db)
@@ -140,14 +141,17 @@ pub async fn verify_totp(
SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into()) SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into())
})?; })?;
if !verify_totp_code(&secret, code) { // 解密 TOTP 密钥(兼容迁移期间的明文数据)
let decrypted_secret = state.field_encryption.decrypt_or_plaintext(&secret);
if !verify_totp_code(&decrypted_secret, code) {
return Err(SaasError::Totp("TOTP 码验证失败".into())); return Err(SaasError::Totp("TOTP 码验证失败".into()));
} }
// 验证成功 → 启用 TOTP // 验证成功 → 启用 TOTP
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
sqlx::query("UPDATE accounts SET totp_enabled = 1, updated_at = ?1 WHERE id = ?2") sqlx::query("UPDATE accounts SET totp_enabled = true, updated_at = $1 WHERE id = $2")
.bind(&now) .bind(now)
.bind(&ctx.account_id) .bind(&ctx.account_id)
.execute(&state.db) .execute(&state.db)
.await?; .await?;
@@ -167,7 +171,7 @@ pub async fn disable_totp(
) -> SaasResult<Json<serde_json::Value>> { ) -> SaasResult<Json<serde_json::Value>> {
// 验证密码 // 验证密码
let (password_hash,): (String,) = sqlx::query_as( let (password_hash,): (String,) = sqlx::query_as(
"SELECT password_hash FROM accounts WHERE id = ?1" "SELECT password_hash FROM accounts WHERE id = $1"
) )
.bind(&ctx.account_id) .bind(&ctx.account_id)
.fetch_one(&state.db) .fetch_one(&state.db)
@@ -178,9 +182,9 @@ pub async fn disable_totp(
} }
// 清除 TOTP // 清除 TOTP
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
sqlx::query("UPDATE accounts SET totp_enabled = 0, totp_secret = NULL, updated_at = ?1 WHERE id = ?2") sqlx::query("UPDATE accounts SET totp_enabled = false, totp_secret = NULL, updated_at = $1 WHERE id = $2")
.bind(&now) .bind(now)
.bind(&ctx.account_id) .bind(&ctx.account_id)
.execute(&state.db) .execute(&state.db)
.await?; .await?;
@@ -190,3 +194,65 @@ pub async fn disable_totp(
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"}))) Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"})))
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_totp_secret_format() {
let result = generate_totp_secret("TestIssuer", "user@example.com");
assert!(result.otpauth_uri.starts_with("otpauth://totp/"));
assert!(result.otpauth_uri.contains("secret="));
assert!(result.otpauth_uri.contains("issuer=TestIssuer"));
assert!(result.otpauth_uri.contains("algorithm=SHA1"));
assert!(result.otpauth_uri.contains("digits=6"));
assert!(result.otpauth_uri.contains("period=30"));
// Base32 编码的 20 字节 = 32 字符
assert_eq!(result.secret.len(), 32);
assert_eq!(result.issuer, "TestIssuer");
}
#[test]
fn test_generate_totp_secret_special_chars() {
let result = generate_totp_secret("My App", "user@domain:8080");
// 特殊字符应被 URL 编码
assert!(!result.otpauth_uri.contains("user@domain:8080"));
assert!(result.otpauth_uri.contains("user%40domain"));
}
#[test]
fn test_verify_totp_code_valid() {
// 使用 generate_random_secret 创建合法 secret然后生成并验证码
let secret = generate_random_secret();
let secret_bytes = data_encoding::BASE32.decode(secret.as_bytes()).unwrap();
let totp = totp_rs::TOTP::new(
totp_rs::Algorithm::SHA1, 6, 1, 30, secret_bytes,
).unwrap();
let valid_code = totp.generate(chrono::Utc::now().timestamp() as u64);
assert!(verify_totp_code(&secret, &valid_code));
}
#[test]
fn test_verify_totp_code_invalid() {
let secret = generate_random_secret();
assert!(!verify_totp_code(&secret, "000000"));
assert!(!verify_totp_code(&secret, "999999"));
assert!(!verify_totp_code(&secret, "abcdef"));
}
#[test]
fn test_verify_totp_code_invalid_secret() {
assert!(!verify_totp_code("not-valid-base32!!!", "123456"));
assert!(!verify_totp_code("", "123456"));
assert!(!verify_totp_code("", "123456"));
}
#[test]
fn test_verify_totp_code_empty() {
let secret = "JBSWY3DPEHPK3PXP";
assert!(!verify_totp_code(secret, ""));
assert!(!verify_totp_code(secret, "12345"));
assert!(!verify_totp_code(secret, "1234567"));
}
}

View File

@@ -3,7 +3,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// 登录请求 /// 登录请求
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct LoginRequest { pub struct LoginRequest {
pub username: String, pub username: String,
pub password: String, pub password: String,
@@ -11,14 +11,14 @@ pub struct LoginRequest {
} }
/// 登录响应 /// 登录响应
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct LoginResponse { pub struct LoginResponse {
pub token: String, pub token: String,
pub account: AccountPublic, pub account: AccountPublic,
} }
/// 注册请求 /// 注册请求
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct RegisterRequest { pub struct RegisterRequest {
pub username: String, pub username: String,
pub email: String, pub email: String,
@@ -27,20 +27,21 @@ pub struct RegisterRequest {
} }
/// 修改密码请求 /// 修改密码请求
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct ChangePasswordRequest { pub struct ChangePasswordRequest {
pub old_password: String, pub old_password: String,
pub new_password: String, pub new_password: String,
} }
/// 公开账号信息 (无敏感数据) /// 公开账号信息 (无敏感数据)
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
pub struct AccountPublic { pub struct AccountPublic {
pub id: String, pub id: String,
pub username: String, pub username: String,
pub email: String, pub email: String,
pub display_name: String, pub display_name: String,
pub role: String, pub role: String,
pub permissions: Vec<String>,
pub status: String, pub status: String,
pub totp_enabled: bool, pub totp_enabled: bool,
pub created_at: String, pub created_at: String,

View File

@@ -45,10 +45,13 @@ pub struct AuthConfig {
/// 中转服务配置 /// 中转服务配置
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelayConfig { pub struct RelayConfig {
#[doc(hidden)]
#[serde(default = "default_max_queue")] #[serde(default = "default_max_queue")]
pub max_queue_size: usize, pub max_queue_size: usize,
#[doc(hidden)]
#[serde(default = "default_max_concurrent")] #[serde(default = "default_max_concurrent")]
pub max_concurrent_per_provider: usize, pub max_concurrent_per_provider: usize,
#[doc(hidden)]
#[serde(default = "default_batch_window")] #[serde(default = "default_batch_window")]
pub batch_window_ms: u64, pub batch_window_ms: u64,
#[serde(default = "default_retry_delay")] #[serde(default = "default_retry_delay")]
@@ -59,7 +62,22 @@ pub struct RelayConfig {
fn default_host() -> String { "0.0.0.0".into() } fn default_host() -> String { "0.0.0.0".into() }
fn default_port() -> u16 { 8080 } fn default_port() -> u16 { 8080 }
fn default_db_url() -> String { "sqlite:./saas-data.db".into() } fn default_db_url() -> String {
// 无默认值:生产环境必须通过 DATABASE_URL 或配置文件设置
// 开发环境可设置 ZCLAW_SAAS_DEV=true 使用 postgres://localhost:5432/zclaw
std::env::var("DATABASE_URL")
.unwrap_or_else(|_| {
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if is_dev {
"postgres://localhost:5432/zclaw".into()
} else {
tracing::error!("DATABASE_URL 未设置且非开发环境");
String::new()
}
})
}
fn default_jwt_hours() -> i64 { 24 } fn default_jwt_hours() -> i64 { 24 }
fn default_totp_issuer() -> String { "ZCLAW SaaS".into() } fn default_totp_issuer() -> String { "ZCLAW SaaS".into() }
fn default_max_queue() -> usize { 1000 } fn default_max_queue() -> usize { 1000 }
@@ -155,6 +173,16 @@ impl SaaSConfig {
SaaSConfig::default() SaaSConfig::default()
}; };
// 验证数据库 URL 已配置
if config.database.url.is_empty() {
anyhow::bail!(
"数据库 URL 未配置。请通过以下方式之一设置:\n\
1. 在配置文件中设置 [database].url\n\
2. 设置 DATABASE_URL 环境变量\n\
开发环境可设置 ZCLAW_SAAS_DEV=true 使用默认值。"
);
}
Ok(config) Ok(config)
} }
@@ -182,3 +210,94 @@ impl SaaSConfig {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_has_expected_values() {
let config = SaaSConfig::default();
assert_eq!(config.server.host, "0.0.0.0");
assert_eq!(config.server.port, 8080);
assert!(config.server.cors_origins.is_empty());
assert_eq!(config.auth.jwt_expiration_hours, 24);
assert_eq!(config.auth.totp_issuer, "ZCLAW SaaS");
assert_eq!(config.rate_limit.requests_per_minute, 60);
assert_eq!(config.rate_limit.burst, 10);
assert_eq!(config.relay.max_queue_size, 1000);
assert_eq!(config.relay.max_concurrent_per_provider, 5);
assert_eq!(config.relay.max_attempts, 3);
}
#[test]
fn rate_limit_default_matches_manual() {
let config = SaaSConfig::default();
assert_eq!(config.rate_limit.requests_per_minute, 60);
assert_eq!(config.rate_limit.burst, 10);
}
#[test]
fn parse_minimal_config_toml() {
let toml_str = r#"
[server]
host = "127.0.0.1"
port = 9090
[database]
url = "postgres://localhost/zclaw"
[auth]
jwt_expiration_hours = 48
[relay]
max_queue_size = 500
"#;
let config: SaaSConfig = toml::from_str(toml_str).expect("parse should succeed");
assert_eq!(config.server.host, "127.0.0.1");
assert_eq!(config.server.port, 9090);
assert_eq!(config.database.url, "postgres://localhost/zclaw");
assert_eq!(config.auth.jwt_expiration_hours, 48);
assert_eq!(config.relay.max_queue_size, 500);
// defaults should fill in
assert_eq!(config.rate_limit.requests_per_minute, 60);
assert_eq!(config.relay.max_attempts, 3);
}
#[test]
fn parse_full_config_with_rate_limit() {
let toml_str = r#"
[server]
host = "0.0.0.0"
port = 8080
cors_origins = ["http://localhost:3000", "http://admin.example.com"]
[database]
url = "postgres://db:5432/zclaw"
[auth]
jwt_expiration_hours = 12
totp_issuer = "MyCorp"
[relay]
max_queue_size = 2000
max_concurrent_per_provider = 10
batch_window_ms = 100
retry_delay_ms = 2000
max_attempts = 5
[rate_limit]
requests_per_minute = 120
burst = 20
"#;
let config: SaaSConfig = toml::from_str(toml_str).expect("parse should succeed");
assert_eq!(config.server.cors_origins.len(), 2);
assert_eq!(config.auth.jwt_expiration_hours, 12);
assert_eq!(config.auth.totp_issuer, "MyCorp");
assert_eq!(config.relay.max_concurrent_per_provider, 10);
assert_eq!(config.relay.retry_delay_ms, 2000);
assert_eq!(config.relay.max_attempts, 5);
assert_eq!(config.rate_limit.requests_per_minute, 120);
assert_eq!(config.rate_limit.burst, 20);
}
}

View File

@@ -0,0 +1,277 @@
//! AES-256-GCM 字段级加密
//!
//! 用于加密数据库中存储的敏感字段(如 API Key
//! 每次加密生成随机 12 字节 nonce密文格式: `base64(nonce || ciphertext || tag)`。
use aes_gcm::aead::{AeadInPlace, KeyInit, OsRng};
use aes_gcm::{Aes256Gcm, AeadCore, Nonce};
use data_encoding::BASE64;
use std::fmt;
use crate::error::{SaasError, SaasResult};
/// AES-256-GCM 密钥字节长度
const KEY_LEN: usize = 32;
/// GCM nonce 字节长度 (96-bit推荐值)
const NONCE_LEN: usize = 12;
/// 字段加密器,持有 AES-256-GCM 密钥
///
/// 线程安全,可通过 `Arc` 在多任务间共享。
pub struct FieldEncryption {
cipher: Aes256Gcm,
}
impl fmt::Debug for FieldEncryption {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FieldEncryption")
.field("cipher", &"<redacted>")
.finish()
}
}
impl FieldEncryption {
/// 从环境变量加载或生成加密密钥
///
/// - **生产环境**: 必须设置 `ZCLAW_SAAS_FIELD_ENCRYPTION_KEY`32 字节 hex 编码)
/// - **开发环境** (`ZCLAW_SAAS_DEV=true`): 自动生成随机密钥并输出警告
pub fn new() -> anyhow::Result<Self> {
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
let key_bytes = match std::env::var("ZCLAW_SAAS_FIELD_ENCRYPTION_KEY") {
Ok(hex_key) => {
let bytes = hex::decode(&hex_key).map_err(|e| {
anyhow::anyhow!(
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 格式无效 (期望 64 字符 hex): {e}"
)
})?;
if bytes.len() != KEY_LEN {
anyhow::bail!(
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 长度错误: 期望 {KEY_LEN} 字节, 实际 {} 字节",
bytes.len()
);
}
tracing::info!("Field encryption key loaded from environment");
bytes
}
Err(_) => {
if is_dev {
let random_key: [u8; KEY_LEN] = rand::random();
let hex_key = hex::encode(random_key);
tracing::warn!(
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 未设置,已生成随机密钥 (仅限开发环境):\n {hex_key}\n\
生产环境必须设置此环境变量!"
);
random_key.to_vec()
} else {
anyhow::bail!(
"ZCLAW_SAAS_FIELD_ENCRYPTION_KEY 环境变量未设置。\n\
请设置一个 32 字节 hex 编码密钥 (64 字符)。\n\
生成方式: openssl rand -hex 32\n\
开发环境可设置 ZCLAW_SAAS_DEV=true 自动生成。"
);
}
}
};
let key = aes_gcm::Key::<Aes256Gcm>::from_slice(&key_bytes);
let cipher = Aes256Gcm::new(key);
Ok(Self { cipher })
}
/// 加密明文,返回 base64 编码密文
///
/// 密文格式: `base64(nonce_12bytes || ciphertext || gcm_tag_16bytes)`
pub fn encrypt(&self, plaintext: &str) -> SaasResult<String> {
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
let payload = plaintext.as_bytes();
// AeadInPlace::encrypt_in_place_append_tag 会在 payload 后面追加 16 字节 tag
let mut buffer = payload.to_vec();
self.cipher
.encrypt_in_place(&nonce, &[], &mut buffer)
.map_err(|e| SaasError::Encryption(format!("加密失败: {e}")))?;
// 构造输出: nonce (12) || ciphertext + tag
let mut output = Vec::with_capacity(NONCE_LEN + buffer.len());
output.extend_from_slice(&nonce);
output.extend_from_slice(&buffer);
Ok(BASE64.encode(&output))
}
/// 解密 base64 编码密文,返回原始明文
///
/// 输入格式: `base64(nonce_12bytes || ciphertext || gcm_tag_16bytes)`
pub fn decrypt(&self, ciphertext: &str) -> SaasResult<String> {
let raw = BASE64
.decode(ciphertext.as_bytes())
.map_err(|e| SaasError::Encryption(format!("Base64 解码失败: {e}")))?;
if raw.len() < NONCE_LEN {
return Err(SaasError::Encryption(
"密文长度不足: 无法提取 nonce".to_string(),
));
}
let (nonce_bytes, encrypted) = raw.split_at(NONCE_LEN);
let nonce = Nonce::from_slice(nonce_bytes);
let mut buffer = encrypted.to_vec();
self.cipher
.decrypt_in_place(nonce, &[], &mut buffer)
.map_err(|e| SaasError::Encryption(format!("解密失败 (密文可能已损坏或密钥不匹配): {e}")))?;
String::from_utf8(buffer)
.map_err(|e| SaasError::Encryption(format!("解密结果非有效 UTF-8: {e}")))
}
/// 尝试解密,失败时返回原始明文(用于迁移期间兼容未加密的旧数据)
///
/// 在字段加密上线前,数据库中可能已存在未加密的明文数据。
/// 此方法先尝试解密若解密失败Base64 解码失败、GCM 认证失败等),
/// 则假设数据是旧版明文,直接返回原值。
pub fn decrypt_or_plaintext(&self, value: &str) -> String {
self.decrypt(value).unwrap_or_else(|_| value.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
/// 辅助: 用固定密钥创建 FieldEncryption测试专用
fn test_encryption() -> FieldEncryption {
// 固定 32 字节密钥,仅用于测试
let key_bytes: [u8; KEY_LEN] = [
0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08,
0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10,
0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18,
0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20,
];
let key = aes_gcm::Key::<Aes256Gcm>::from_slice(&key_bytes);
let cipher = Aes256Gcm::new(key);
FieldEncryption { cipher }
}
#[test]
fn encrypt_produces_base64_output() {
let enc = test_encryption();
let result = enc.encrypt("hello world");
assert!(result.is_ok());
let ciphertext = result.unwrap();
// base64 输出应该能被 BASE64 解码
assert!(BASE64.decode(ciphertext.as_bytes()).is_ok());
}
#[test]
fn encrypt_decrypt_roundtrip() {
let enc = test_encryption();
let plaintext = "sk-proj-abc123SECRET_API_KEY_!@#$%";
let ciphertext = enc.encrypt(plaintext).expect("encrypt should succeed");
let decrypted = enc.decrypt(&ciphertext).expect("decrypt should succeed");
assert_eq!(decrypted, plaintext);
}
#[test]
fn encrypt_decrypt_roundtrip_chinese() {
let enc = test_encryption();
let plaintext = "这是一个包含中文的敏感字段测试";
let ciphertext = enc.encrypt(plaintext).expect("encrypt should succeed");
let decrypted = enc.decrypt(&ciphertext).expect("decrypt should succeed");
assert_eq!(decrypted, plaintext);
}
#[test]
fn different_encryptions_produce_different_ciphertexts() {
let enc = test_encryption();
let plaintext = "same-plaintext";
let ct1 = enc.encrypt(plaintext).unwrap();
let ct2 = enc.encrypt(plaintext).unwrap();
// 由于随机 nonce相同明文的密文应该不同
assert_ne!(ct1, ct2);
}
#[test]
fn decrypt_wrong_key_fails() {
let enc1 = test_encryption();
// 用不同密钥创建另一个加密器
let key_bytes2: [u8; KEY_LEN] = [
0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, 0xf8,
0xf7, 0xf6, 0xf5, 0xf4, 0xf3, 0xf2, 0xf1, 0xf0,
0xef, 0xee, 0xed, 0xec, 0xeb, 0xea, 0xe9, 0xe8,
0xe7, 0xe6, 0xe5, 0xe4, 0xe3, 0xe2, 0xe1, 0xe0,
];
let key2 = aes_gcm::Key::<Aes256Gcm>::from_slice(&key_bytes2);
let cipher2 = Aes256Gcm::new(key2);
let enc2 = FieldEncryption { cipher: cipher2 };
let ciphertext = enc1.encrypt("secret").unwrap();
let result = enc2.decrypt(&ciphertext);
assert!(result.is_err());
}
#[test]
fn decrypt_invalid_base64_fails() {
let enc = test_encryption();
let result = enc.decrypt("not-valid-base64!!!");
assert!(result.is_err());
}
#[test]
fn decrypt_too_short_ciphertext_fails() {
let enc = test_encryption();
// 构造一个短于 12 字节 nonce 的有效 base64 字符串
let short = BASE64.encode(&[0x01, 0x02, 0x03]);
let result = enc.decrypt(&short);
assert!(result.is_err());
}
#[test]
fn decrypt_tampered_ciphertext_fails() {
let enc = test_encryption();
let ciphertext = enc.encrypt("sensitive-data").unwrap();
// 解码、篡改、重新编码
let mut raw = BASE64.decode(ciphertext.as_bytes()).unwrap();
// 翻转 nonce 后的一个字节
let tamper_pos = NONCE_LEN + 2;
if tamper_pos < raw.len() {
raw[tamper_pos] ^= 0xff;
}
let tampered = BASE64.encode(&raw);
let result = enc.decrypt(&tampered);
assert!(result.is_err());
}
#[test]
fn encrypt_empty_string_roundtrip() {
let enc = test_encryption();
let ciphertext = enc.encrypt("").unwrap();
let decrypted = enc.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, "");
}
#[test]
fn ciphertext_format_has_nonce_prefix() {
let enc = test_encryption();
let ciphertext = enc.encrypt("test").unwrap();
let raw = BASE64.decode(ciphertext.as_bytes()).unwrap();
// raw 应该 = nonce(12) + ciphertext + tag(16)
// 至少 12 + 16 = 28 字节(明文 4 字节加密后 4 字节 + 16 字节 tag
assert!(raw.len() >= NONCE_LEN + 16);
}
}

View File

@@ -0,0 +1,243 @@
//! CSRF 防护: Origin 校验中间件
//!
//! 对所有状态变更请求 (POST/PUT/PATCH/DELETE) 校验 `Origin` 请求头,
//! 确保其与 `server.cors_origins` 白名单中的某项匹配。
//!
//! - GET / HEAD / OPTIONS 请求跳过校验 (安全方法)
//! - 缺少 Origin 头时拒绝 (403)
//! - Origin 不匹配白名单时拒绝 (403)
//! - `ZCLAW_SAAS_DEV=true` 时跳过校验
//!
//! 这是 Bearer Token API 最合适的 CSRF 防护方案。
//! 如果未来迁移到 Cookie 认证,需要升级为 CSRF Token 方案。
use axum::{
extract::{Request, State},
http::{header, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
use tracing::warn;
use crate::state::AppState;
/// 需要进行 Origin 校验的 HTTP 方法
const CSRF_UNSAFE_METHODS: &[&str] = &["POST", "PUT", "PATCH", "DELETE"];
/// Origin 校验中间件
///
/// 在 auth_middleware 之后、rate_limit_middleware 之前执行。
/// 已认证的请求若缺少或不匹配 Origin 头,返回 403 Forbidden。
pub async fn origin_check_middleware(
State(state): State<AppState>,
req: Request,
next: Next,
) -> Response {
// 开发模式跳过校验
if is_dev_mode() {
return next.run(req).await;
}
// 安全方法跳过校验
let method = req.method().as_str().to_uppercase();
if !CSRF_UNSAFE_METHODS.contains(&method.as_str()) {
return next.run(req).await;
}
// 获取 Origin 头
let origin_header = match req.headers().get(header::ORIGIN) {
Some(value) => match value.to_str() {
Ok(origin) => origin,
Err(_) => {
warn!("CSRF: Origin header contains invalid UTF-8");
return csrf_reject("ORIGIN_INVALID", "Origin 请求头格式无效");
}
},
None => {
warn!("CSRF: Missing Origin header on {} {}", method, req.uri());
return csrf_reject("ORIGIN_MISSING", "缺少 Origin 请求头");
}
};
// 从配置读取白名单
let allowed_origins = {
let config = state.config.read().await;
config.server.cors_origins.clone()
};
// 白名单为空时不校验 (生产环境已在 main.rs 中强制要求配置)
if allowed_origins.is_empty() {
return next.run(req).await;
}
// 校验 Origin 是否在白名单中
if !origin_matches_whitelist(origin_header, &allowed_origins) {
warn!(
"CSRF: Origin '{}' not in whitelist for {} {}",
origin_header,
method,
req.uri()
);
return csrf_reject("ORIGIN_NOT_ALLOWED", "Origin 不在允许列表中");
}
next.run(req).await
}
/// 判断是否为开发模式
fn is_dev_mode() -> bool {
std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false)
}
/// 校验 Origin 是否匹配白名单中的某项
///
/// 匹配规则: 精确匹配 (scheme + host + port)。
/// 例如白名单 `https://admin.zclaw.com` 只匹配该 Origin
/// 不匹配 `https://evil.zclaw.com`。
fn origin_matches_whitelist(origin: &str, whitelist: &[String]) -> bool {
// 使用 url::Url 进行规范化比较,避免字符串拼接攻击
let parsed_origin = match url::Url::parse(origin) {
Ok(url) => url,
Err(_) => return false,
};
for allowed in whitelist {
if let Ok(allowed_url) = url::Url::parse(allowed) {
if origins_equal(&parsed_origin, &allowed_url) {
return true;
}
} else {
// 白名单条目本身无法解析,降级为字符串比较
if origin == allowed {
return true;
}
}
}
false
}
/// 比较两个 Origin URL 是否相等 (scheme + host + port)
///
/// 同时拒绝包含路径的 URL: 真实的 Origin 头永远不会包含路径。
/// 如果传入的 origin 字符串包含路径,视为不合法的 Origin。
fn origins_equal(a: &url::Url, b: &url::Url) -> bool {
// scheme 必须完全一致
if a.scheme() != b.scheme() {
return false;
}
// host 必须完全一致
if a.host_str() != b.host_str() {
return false;
}
// port 必须完全一致 (url::Url 会规范化默认端口: 80/HTTP, 443/HTTPS)
if a.port() != b.port() {
return false;
}
// 防御性检查: 合法的 Origin 不应包含路径、query string 或 fragment
// 如果任一 URL 的 path 不是 "/" 或有 query/fragment视为可疑请求
if a.path() != "/" || b.path() != "/" {
return false;
}
if a.query().is_some() || b.query().is_some() {
return false;
}
if a.fragment().is_some() || b.fragment().is_some() {
return false;
}
true
}
/// 返回 403 拒绝响应
fn csrf_reject(error_code: &str, message: &str) -> Response {
(
StatusCode::FORBIDDEN,
[("Content-Type", "application/json")],
axum::Json(serde_json::json!({
"error": error_code,
"message": message,
})),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_origin_matches_whitelist_exact() {
let whitelist = vec![
"https://admin.zclaw.com".to_string(),
"http://localhost:3000".to_string(),
];
assert!(origin_matches_whitelist("https://admin.zclaw.com", &whitelist));
assert!(origin_matches_whitelist("http://localhost:3000", &whitelist));
assert!(!origin_matches_whitelist("https://evil.zclaw.com", &whitelist));
// url::Url normalizes port 443 for HTTPS to None, so these match
assert!(origin_matches_whitelist("https://admin.zclaw.com:443", &whitelist));
assert!(!origin_matches_whitelist("http://localhost:3001", &whitelist));
}
#[test]
fn test_origin_matches_whitelist_empty() {
let whitelist: Vec<String> = vec![];
assert!(!origin_matches_whitelist("https://example.com", &whitelist));
}
#[test]
fn test_origin_matches_whitelist_with_path() {
let whitelist = vec!["https://admin.zclaw.com".to_string()];
// 标准 Origin 不包含路径,应该匹配
assert!(origin_matches_whitelist("https://admin.zclaw.com", &whitelist));
// 包含路径的 Origin 不合法 (浏览器永远不会发送带路径的 Origin)
assert!(!origin_matches_whitelist("https://admin.zclaw.com/evil", &whitelist));
// 带查询字符串的 Origin 也不合法
assert!(!origin_matches_whitelist("https://admin.zclaw.com/?evil=1", &whitelist));
}
#[test]
fn test_origin_matches_whitelist_invalid_origin() {
let whitelist = vec!["https://admin.zclaw.com".to_string()];
assert!(!origin_matches_whitelist("not-a-url", &whitelist));
assert!(!origin_matches_whitelist("", &whitelist));
}
#[test]
fn test_origins_equal() {
let a = url::Url::parse("https://admin.zclaw.com").unwrap();
let b = url::Url::parse("https://admin.zclaw.com").unwrap();
assert!(origins_equal(&a, &b));
// Different scheme
let c = url::Url::parse("http://admin.zclaw.com").unwrap();
assert!(!origins_equal(&a, &c));
// Different host
let d = url::Url::parse("https://evil.zclaw.com").unwrap();
assert!(!origins_equal(&a, &d));
// Different port
let e = url::Url::parse("https://admin.zclaw.com:8443").unwrap();
assert!(!origins_equal(&a, &e));
// Explicit default port vs implicit
let f = url::Url::parse("https://admin.zclaw.com:443").unwrap();
// url::Url normalizes 443 for HTTPS, so both have None port
assert!(origins_equal(&a, &f));
}
#[test]
fn test_is_dev_mode() {
// Don't modify env in tests; just verify the function signature works
// Actual env-var-based behavior tested in integration tests
let _ = is_dev_mode();
}
}

View File

@@ -1,9 +1,9 @@
//! 数据库初始化与 Schema //! 数据库初始化与 Schema (PostgreSQL)
use sqlx::SqlitePool; use sqlx::PgPool;
use crate::error::SaasResult; use crate::error::SaasResult;
const SCHEMA_VERSION: i32 = 1; const SCHEMA_VERSION: i32 = 2;
const SCHEMA_SQL: &str = r#" const SCHEMA_SQL: &str = r#"
CREATE TABLE IF NOT EXISTS saas_schema_version ( CREATE TABLE IF NOT EXISTS saas_schema_version (
@@ -20,10 +20,10 @@ CREATE TABLE IF NOT EXISTS accounts (
role TEXT NOT NULL DEFAULT 'user', role TEXT NOT NULL DEFAULT 'user',
status TEXT NOT NULL DEFAULT 'active', status TEXT NOT NULL DEFAULT 'active',
totp_secret TEXT, totp_secret TEXT,
totp_enabled INTEGER NOT NULL DEFAULT 0, totp_enabled BOOLEAN NOT NULL DEFAULT false,
last_login_at TEXT, last_login_at TIMESTAMPTZ,
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX IF NOT EXISTS idx_accounts_email ON accounts(email); CREATE INDEX IF NOT EXISTS idx_accounts_email ON accounts(email);
CREATE INDEX IF NOT EXISTS idx_accounts_role ON accounts(role); CREATE INDEX IF NOT EXISTS idx_accounts_role ON accounts(role);
@@ -35,10 +35,10 @@ CREATE TABLE IF NOT EXISTS api_tokens (
token_hash TEXT NOT NULL, token_hash TEXT NOT NULL,
token_prefix TEXT NOT NULL, token_prefix TEXT NOT NULL,
permissions TEXT NOT NULL DEFAULT '[]', permissions TEXT NOT NULL DEFAULT '[]',
last_used_at TEXT, last_used_at TIMESTAMPTZ,
expires_at TEXT, expires_at TIMESTAMPTZ,
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TEXT, revoked_at TIMESTAMPTZ,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
); );
CREATE INDEX IF NOT EXISTS idx_api_tokens_account ON api_tokens(account_id); CREATE INDEX IF NOT EXISTS idx_api_tokens_account ON api_tokens(account_id);
@@ -46,32 +46,23 @@ CREATE INDEX IF NOT EXISTS idx_api_tokens_hash ON api_tokens(token_hash);
CREATE TABLE IF NOT EXISTS roles ( CREATE TABLE IF NOT EXISTS roles (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
name TEXT NOT NULL, name TEXT NOT NULL UNIQUE,
description TEXT, description TEXT,
permissions TEXT NOT NULL DEFAULT '[]', permissions TEXT NOT NULL DEFAULT '[]',
is_system INTEGER NOT NULL DEFAULT 0, is_system BOOLEAN NOT NULL DEFAULT false,
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE TABLE IF NOT EXISTS permission_templates (
id TEXT PRIMARY KEY,
name TEXT NOT NULL,
description TEXT,
permissions TEXT NOT NULL DEFAULT '[]',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
); );
CREATE TABLE IF NOT EXISTS operation_logs ( CREATE TABLE IF NOT EXISTS operation_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT, id BIGSERIAL PRIMARY KEY,
account_id TEXT, account_id TEXT,
action TEXT NOT NULL, action TEXT NOT NULL,
target_type TEXT, target_type TEXT,
target_id TEXT, target_id TEXT,
details TEXT, details TEXT,
ip_address TEXT, ip_address TEXT,
created_at TEXT NOT NULL created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX IF NOT EXISTS idx_op_logs_account ON operation_logs(account_id); CREATE INDEX IF NOT EXISTS idx_op_logs_account ON operation_logs(account_id);
CREATE INDEX IF NOT EXISTS idx_op_logs_action ON operation_logs(action); CREATE INDEX IF NOT EXISTS idx_op_logs_action ON operation_logs(action);
@@ -84,12 +75,12 @@ CREATE TABLE IF NOT EXISTS providers (
api_key TEXT, api_key TEXT,
base_url TEXT NOT NULL, base_url TEXT NOT NULL,
api_protocol TEXT NOT NULL DEFAULT 'openai', api_protocol TEXT NOT NULL DEFAULT 'openai',
enabled INTEGER NOT NULL DEFAULT 1, enabled BOOLEAN NOT NULL DEFAULT true,
rate_limit_rpm INTEGER, rate_limit_rpm INTEGER,
rate_limit_tpm INTEGER, rate_limit_tpm INTEGER,
config_json TEXT DEFAULT '{}', config_json TEXT DEFAULT '{}',
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE TABLE IF NOT EXISTS models ( CREATE TABLE IF NOT EXISTS models (
@@ -99,13 +90,13 @@ CREATE TABLE IF NOT EXISTS models (
alias TEXT NOT NULL, alias TEXT NOT NULL,
context_window INTEGER NOT NULL DEFAULT 8192, context_window INTEGER NOT NULL DEFAULT 8192,
max_output_tokens INTEGER NOT NULL DEFAULT 4096, max_output_tokens INTEGER NOT NULL DEFAULT 4096,
supports_streaming INTEGER NOT NULL DEFAULT 1, supports_streaming BOOLEAN NOT NULL DEFAULT true,
supports_vision INTEGER NOT NULL DEFAULT 0, supports_vision BOOLEAN NOT NULL DEFAULT false,
enabled INTEGER NOT NULL DEFAULT 1, enabled BOOLEAN NOT NULL DEFAULT true,
pricing_input REAL DEFAULT 0, pricing_input DOUBLE PRECISION DEFAULT 0,
pricing_output REAL DEFAULT 0, pricing_output DOUBLE PRECISION DEFAULT 0,
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL, updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(provider_id, model_id), UNIQUE(provider_id, model_id),
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
); );
@@ -118,18 +109,18 @@ CREATE TABLE IF NOT EXISTS account_api_keys (
key_value TEXT NOT NULL, key_value TEXT NOT NULL,
key_label TEXT, key_label TEXT,
permissions TEXT NOT NULL DEFAULT '[]', permissions TEXT NOT NULL DEFAULT '[]',
enabled INTEGER NOT NULL DEFAULT 1, enabled BOOLEAN NOT NULL DEFAULT true,
last_used_at TEXT, last_used_at TIMESTAMPTZ,
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL, updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TEXT, revoked_at TIMESTAMPTZ,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE, FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
); );
CREATE INDEX IF NOT EXISTS idx_account_api_keys_account ON account_api_keys(account_id); CREATE INDEX IF NOT EXISTS idx_account_api_keys_account ON account_api_keys(account_id);
CREATE TABLE IF NOT EXISTS usage_records ( CREATE TABLE IF NOT EXISTS usage_records (
id INTEGER PRIMARY KEY AUTOINCREMENT, id BIGSERIAL PRIMARY KEY,
account_id TEXT NOT NULL, account_id TEXT NOT NULL,
provider_id TEXT NOT NULL, provider_id TEXT NOT NULL,
model_id TEXT NOT NULL, model_id TEXT NOT NULL,
@@ -138,10 +129,12 @@ CREATE TABLE IF NOT EXISTS usage_records (
latency_ms INTEGER, latency_ms INTEGER,
status TEXT NOT NULL DEFAULT 'success', status TEXT NOT NULL DEFAULT 'success',
error_message TEXT, error_message TEXT,
created_at TEXT NOT NULL created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX IF NOT EXISTS idx_usage_account ON usage_records(account_id); CREATE INDEX IF NOT EXISTS idx_usage_account ON usage_records(account_id);
CREATE INDEX IF NOT EXISTS idx_usage_time ON usage_records(created_at); CREATE INDEX IF NOT EXISTS idx_usage_time ON usage_records(created_at);
CREATE INDEX IF NOT EXISTS idx_usage_provider ON usage_records(provider_id);
CREATE INDEX IF NOT EXISTS idx_usage_model ON usage_records(model_id);
CREATE TABLE IF NOT EXISTS relay_tasks ( CREATE TABLE IF NOT EXISTS relay_tasks (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
@@ -158,14 +151,15 @@ CREATE TABLE IF NOT EXISTS relay_tasks (
input_tokens INTEGER DEFAULT 0, input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0, output_tokens INTEGER DEFAULT 0,
error_message TEXT, error_message TEXT,
queued_at TEXT NOT NULL, queued_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
started_at TEXT, started_at TIMESTAMPTZ,
completed_at TEXT, completed_at TIMESTAMPTZ,
created_at TEXT NOT NULL created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX IF NOT EXISTS idx_relay_status ON relay_tasks(status); CREATE INDEX IF NOT EXISTS idx_relay_status ON relay_tasks(status);
CREATE INDEX IF NOT EXISTS idx_relay_account ON relay_tasks(account_id); CREATE INDEX IF NOT EXISTS idx_relay_account ON relay_tasks(account_id);
CREATE INDEX IF NOT EXISTS idx_relay_provider ON relay_tasks(provider_id); CREATE INDEX IF NOT EXISTS idx_relay_provider ON relay_tasks(provider_id);
CREATE INDEX IF NOT EXISTS idx_relay_account_status ON relay_tasks(account_id, status);
CREATE TABLE IF NOT EXISTS config_items ( CREATE TABLE IF NOT EXISTS config_items (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
@@ -176,15 +170,15 @@ CREATE TABLE IF NOT EXISTS config_items (
default_value TEXT, default_value TEXT,
source TEXT NOT NULL DEFAULT 'local', source TEXT NOT NULL DEFAULT 'local',
description TEXT, description TEXT,
requires_restart INTEGER NOT NULL DEFAULT 0, requires_restart BOOLEAN NOT NULL DEFAULT false,
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL, updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(category, key_path) UNIQUE(category, key_path)
); );
CREATE INDEX IF NOT EXISTS idx_config_category ON config_items(category); CREATE INDEX IF NOT EXISTS idx_config_category ON config_items(category);
CREATE TABLE IF NOT EXISTS config_sync_log ( CREATE TABLE IF NOT EXISTS config_sync_log (
id INTEGER PRIMARY KEY AUTOINCREMENT, id BIGSERIAL PRIMARY KEY,
account_id TEXT NOT NULL, account_id TEXT NOT NULL,
client_fingerprint TEXT NOT NULL, client_fingerprint TEXT NOT NULL,
action TEXT NOT NULL, action TEXT NOT NULL,
@@ -192,7 +186,7 @@ CREATE TABLE IF NOT EXISTS config_sync_log (
client_values TEXT, client_values TEXT,
saas_values TEXT, saas_values TEXT,
resolution TEXT, resolution TEXT,
created_at TEXT NOT NULL created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX IF NOT EXISTS idx_sync_account ON config_sync_log(account_id); CREATE INDEX IF NOT EXISTS idx_sync_account ON config_sync_log(account_id);
@@ -203,8 +197,8 @@ CREATE TABLE IF NOT EXISTS devices (
device_name TEXT, device_name TEXT,
platform TEXT, platform TEXT,
app_version TEXT, app_version TEXT,
last_seen_at TEXT NOT NULL, last_seen_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
); );
CREATE INDEX IF NOT EXISTS idx_devices_account ON devices(account_id); CREATE INDEX IF NOT EXISTS idx_devices_account ON devices(account_id);
@@ -213,55 +207,76 @@ CREATE UNIQUE INDEX IF NOT EXISTS idx_devices_unique ON devices(account_id, devi
"#; "#;
const SEED_ROLES: &str = r#" const SEED_ROLES: &str = r#"
INSERT OR IGNORE INTO roles (id, name, description, permissions, is_system, created_at, updated_at) INSERT INTO roles (id, name, description, permissions, is_system, created_at, updated_at)
VALUES VALUES
('super_admin', '超级管理员', '拥有所有权限', '["admin:full","account:admin","provider:manage","model:manage","relay:admin","config:write"]', 1, datetime('now'), datetime('now')), ('super_admin', '超级管理员', '拥有所有权限', '["admin:full","account:admin","provider:manage","model:manage","relay:admin","config:write"]', true, NOW(), NOW()),
('admin', '管理员', '管理账号和配置', '["account:read","account:admin","provider:manage","model:read","model:manage","relay:use","relay:admin","config:read","config:write"]', 1, datetime('now'), datetime('now')), ('admin', '管理员', '管理账号和配置', '["account:read","account:admin","provider:manage","model:read","model:manage","relay:use","relay:admin","config:read","config:write"]', true, NOW(), NOW()),
('user', '普通用户', '基础使用权限', '["model:read","relay:use","config:read"]', 1, datetime('now'), datetime('now')); ('user', '普通用户', '基础使用权限', '["model:read","relay:use","config:read"]', true, NOW(), NOW())
ON CONFLICT (id) DO NOTHING;
"#; "#;
/// 初始化数据库 /// PostgreSQL 不支持在单条 prepared statement 中执行多条 SQL 命令,
pub async fn init_db(database_url: &str) -> SaasResult<SqlitePool> { /// 因此需要拆分后逐条执行。
if database_url.starts_with("sqlite:") { async fn execute_multi_statements(pool: &PgPool, sql: &str) -> SaasResult<()> {
let path_part = database_url.strip_prefix("sqlite:").unwrap_or(""); for stmt in sql.split(';') {
if path_part != ":memory:" { let trimmed = stmt.trim();
if let Some(parent) = std::path::Path::new(path_part).parent() { if trimmed.is_empty() {
if !parent.as_os_str().is_empty() && !parent.exists() { continue;
std::fs::create_dir_all(parent)?; }
} if let Err(e) = sqlx::query(trimmed).execute(pool).await {
let err_str = e.to_string();
// 忽略 "已存在" 类错误 (并发初始化或重复调用)
let is_already_exists = err_str.contains("already exists")
|| err_str.contains("已经存在")
|| err_str.contains("重复键");
if !is_already_exists {
return Err(e.into());
} }
} }
} }
Ok(())
}
let pool = SqlitePool::connect(database_url).await?; /// 初始化数据库
sqlx::query("PRAGMA journal_mode=WAL;") pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
.execute(&pool) tracing::info!("Connecting to database: {}", database_url);
.await?; let pool = PgPool::connect(database_url).await?;
sqlx::query(SCHEMA_SQL).execute(&pool).await?; execute_multi_statements(&pool, SCHEMA_SQL).await?;
sqlx::query("INSERT OR IGNORE INTO saas_schema_version (version) VALUES (?1)") execute_multi_statements(&pool, SEED_ROLES).await?;
.bind(SCHEMA_VERSION)
.execute(&pool)
.await?;
sqlx::query(SEED_ROLES).execute(&pool).await?;
seed_admin_account(&pool).await?; seed_admin_account(&pool).await?;
tracing::info!("Database initialized (schema v{})", SCHEMA_VERSION); tracing::info!("Database initialized (schema v{})", SCHEMA_VERSION);
Ok(pool) Ok(pool)
} }
/// 创建内存数据库 (测试用) /// 创建测试数据库 (连接到真实 PG 实例)
pub async fn init_memory_db() -> SaasResult<SqlitePool> { /// 测试前清空所有数据,确保每次从干净状态开始
let pool = SqlitePool::connect("sqlite::memory:").await?; pub async fn init_test_db() -> SaasResult<PgPool> {
sqlx::query(SCHEMA_SQL).execute(&pool).await?; let url = std::env::var("ZCLAW_TEST_DATABASE_URL")
sqlx::query("INSERT OR IGNORE INTO saas_schema_version (version) VALUES (?1)") .unwrap_or_else(|_| "postgres://localhost:5432/zclaw_test".to_string());
.bind(SCHEMA_VERSION) let pool = PgPool::connect(&url).await?;
.execute(&pool) execute_multi_statements(&pool, SCHEMA_SQL).await?;
.await?; clean_test_data(&pool).await?;
sqlx::query(SEED_ROLES).execute(&pool).await?; execute_multi_statements(&pool, SEED_ROLES).await?;
Ok(pool) Ok(pool)
} }
/// 清空所有表数据 (按外键依赖顺序,使用 DELETE 而非 TRUNCATE)
/// DELETE 不获取 ACCESS EXCLUSIVE 锁,对并发更友好
pub async fn clean_test_data(pool: &PgPool) -> SaasResult<()> {
let tables_to_clean = [
"config_sync_log", "config_items", "usage_records", "relay_tasks",
"account_api_keys", "models", "providers", "operation_logs",
"api_tokens", "devices", "roles", "accounts",
];
for table in &tables_to_clean {
let _ = sqlx::query(&format!("DELETE FROM {}", table))
.execute(pool).await;
}
Ok(())
}
/// 如果 accounts 表为空且环境变量已设置,自动创建 super_admin 账号 /// 如果 accounts 表为空且环境变量已设置,自动创建 super_admin 账号
async fn seed_admin_account(pool: &SqlitePool) -> SaasResult<()> { async fn seed_admin_account(pool: &PgPool) -> SaasResult<()> {
let has_accounts: (bool,) = sqlx::query_as( let has_accounts: (bool,) = sqlx::query_as(
"SELECT EXISTS(SELECT 1 FROM accounts LIMIT 1) as has" "SELECT EXISTS(SELECT 1 FROM accounts LIMIT 1) as has"
) )
@@ -291,18 +306,16 @@ async fn seed_admin_account(pool: &SqlitePool) -> SaasResult<()> {
let password_hash = hash_password(&admin_password)?; let password_hash = hash_password(&admin_password)?;
let account_id = uuid::Uuid::new_v4().to_string(); let account_id = uuid::Uuid::new_v4().to_string();
let email = format!("{}@zclaw.local", admin_username); let email = format!("{}@zclaw.local", admin_username);
let now = chrono::Utc::now().to_rfc3339();
sqlx::query( sqlx::query(
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at) "INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, 'super_admin', 'active', ?6, ?6)" VALUES ($1, $2, $3, $4, $5, 'super_admin', 'active', NOW(), NOW())"
) )
.bind(&account_id) .bind(&account_id)
.bind(&admin_username) .bind(&admin_username)
.bind(&email) .bind(&email)
.bind(&password_hash) .bind(&password_hash)
.bind(&admin_username) .bind(&admin_username)
.bind(&now)
.execute(pool) .execute(pool)
.await?; .await?;
@@ -316,13 +329,35 @@ async fn seed_admin_account(pool: &SqlitePool) -> SaasResult<()> {
mod tests { mod tests {
use super::*; use super::*;
/// 全局 Mutex 用于序列化所有数据库测试,避免并行测试之间的数据竞争
static TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
/// 共享测试连接池,避免每次测试都创建新连接
static TEST_POOL: tokio::sync::OnceCell<PgPool> = tokio::sync::OnceCell::const_new();
/// 获取测试连接池(异步初始化,避免嵌套 runtime 问题)
async fn get_test_pool() -> &'static PgPool {
TEST_POOL.get_or_init(|| async {
init_test_db().await.expect("init_test_db failed")
}).await
}
/// 每个测试前清理数据,确保隔离
async fn clean_before_test(pool: &PgPool) {
clean_test_data(pool).await.expect("clean_test_data failed");
execute_multi_statements(pool, SEED_ROLES).await.expect("seed roles failed");
}
#[tokio::test] #[tokio::test]
async fn test_init_memory_db() { async fn test_init_test_db() {
let pool = init_memory_db().await.unwrap(); // 获取全局锁,确保测试串行执行
let _guard = TEST_LOCK.lock().unwrap();
let pool = get_test_pool().await;
clean_before_test(pool).await;
let roles: Vec<(String,)> = sqlx::query_as( let roles: Vec<(String,)> = sqlx::query_as(
"SELECT id FROM roles WHERE is_system = 1" "SELECT id FROM roles WHERE is_system = true"
) )
.fetch_all(&pool) .fetch_all(pool)
.await .await
.unwrap(); .unwrap();
assert_eq!(roles.len(), 3); assert_eq!(roles.len(), 3);
@@ -330,17 +365,20 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_schema_tables_exist() { async fn test_schema_tables_exist() {
let pool = init_memory_db().await.unwrap(); let _guard = TEST_LOCK.lock().unwrap();
let pool = get_test_pool().await;
clean_before_test(pool).await;
let tables = [ let tables = [
"accounts", "api_tokens", "roles", "permission_templates", "accounts", "api_tokens", "roles",
"operation_logs", "providers", "models", "account_api_keys", "operation_logs", "providers", "models", "account_api_keys",
"usage_records", "relay_tasks", "config_items", "config_sync_log", "devices", "usage_records", "relay_tasks", "config_items", "config_sync_log", "devices",
]; ];
for table in tables { for table in tables {
let count: (i64,) = sqlx::query_as(&format!( let count: (i64,) = sqlx::query_as(&format!(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='{}'", table "SELECT COUNT(*) FROM information_schema.tables WHERE table_schema='public' AND table_name='{}'", table
)) ))
.fetch_one(&pool) .fetch_one(pool)
.await .await
.unwrap(); .unwrap();
assert_eq!(count.0, 1, "Table {} should exist", table); assert_eq!(count.0, 1, "Table {} should exist", table);

View File

@@ -127,3 +127,62 @@ impl IntoResponse for SaasError {
/// Result 类型别名 /// Result 类型别名
pub type SaasResult<T> = std::result::Result<T, SaasError>; pub type SaasResult<T> = std::result::Result<T, SaasError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn status_code_maps_correctly() {
assert_eq!(SaasError::NotFound("x".into()).status_code(), StatusCode::NOT_FOUND);
assert_eq!(SaasError::Forbidden("x".into()).status_code(), StatusCode::FORBIDDEN);
assert_eq!(SaasError::Unauthorized.status_code(), StatusCode::UNAUTHORIZED);
assert_eq!(SaasError::InvalidInput("x".into()).status_code(), StatusCode::BAD_REQUEST);
assert_eq!(SaasError::AlreadyExists("x".into()).status_code(), StatusCode::CONFLICT);
assert_eq!(SaasError::RateLimited("x".into()).status_code(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(SaasError::Relay("x".into()).status_code(), StatusCode::BAD_GATEWAY);
assert_eq!(SaasError::Totp("x".into()).status_code(), StatusCode::BAD_REQUEST);
assert_eq!(SaasError::Internal("x".into()).status_code(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(SaasError::AuthError("x".into()).status_code(), StatusCode::UNAUTHORIZED);
}
#[test]
fn error_code_returns_expected_strings() {
assert_eq!(SaasError::NotFound("x".into()).error_code(), "NOT_FOUND");
assert_eq!(SaasError::RateLimited("x".into()).error_code(), "RATE_LIMITED");
assert_eq!(SaasError::Unauthorized.error_code(), "UNAUTHORIZED");
assert_eq!(SaasError::Encryption("x".into()).error_code(), "ENCRYPTION_ERROR");
}
#[tokio::test]
async fn into_response_hides_internal_errors() {
// 内部错误不应泄露细节
let err = SaasError::Internal("secret database password exposed".into());
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
let body_bytes = axum::body::to_bytes(resp.into_body(), 1024)
.await
.expect("body should be readable");
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["error"], "INTERNAL_ERROR");
assert_eq!(body["message"], "服务内部错误");
assert!(!body["message"].as_str().unwrap().contains("secret"));
}
#[tokio::test]
async fn into_response_shows_user_facing_errors() {
let err = SaasError::InvalidInput("用户名不能为空".into());
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body_bytes = axum::body::to_bytes(resp.into_body(), 1024)
.await
.expect("body should be readable");
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["error"], "INVALID_INPUT");
// InvalidInput includes the "无效输入: " prefix from Display impl
let msg = body["message"].as_str().unwrap();
assert!(msg.contains("用户名不能为空"));
}
}

View File

@@ -3,9 +3,12 @@
//! 独立的 SaaS 后端服务,提供账号权限管理、模型配置、请求中转和配置迁移。 //! 独立的 SaaS 后端服务,提供账号权限管理、模型配置、请求中转和配置迁移。
pub mod config; pub mod config;
pub mod crypto;
pub mod csrf;
pub mod db; pub mod db;
pub mod error; pub mod error;
pub mod middleware; pub mod middleware;
pub mod openapi;
pub mod state; pub mod state;
pub mod auth; pub mod auth;

View File

@@ -1,7 +1,16 @@
//! ZCLAW SaaS 服务入口 //! ZCLAW SaaS 服务入口
use std::time::{Duration, Instant};
use tracing::info; use tracing::info;
use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState}; use zclaw_saas::{config::SaaSConfig, db::init_db, state::AppState};
use axum::{extract::State, Json};
async fn health_handler(State(_state): State<AppState>) -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "ok",
"service": "zclaw-saas",
}))
}
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
@@ -19,7 +28,63 @@ async fn main() -> anyhow::Result<()> {
info!("Database initialized"); info!("Database initialized");
let state = AppState::new(db, config.clone())?; let state = AppState::new(db, config.clone())?;
let app = build_router(state);
// SEC-14: 后台清理 rate_limit_entries DashMap防止不活跃账号条目无限增长。
// 中间件仅在被请求命中时清理对应 entry不活跃的 account 永远不会被回收。
// 此任务每 5 分钟扫描一次,移除所有时间戳均已超过 2 分钟的 entry
// (滑动窗口为 1 分钟2 分钟是安全的 2x 余量)。
{
let rate_limit_entries = state.rate_limit_entries.clone();
tokio::spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(5 * 60)).await;
let cutoff = Instant::now() - Duration::from_secs(2 * 60);
let mut removed = 0usize;
rate_limit_entries.retain(|_account_id, timestamps| {
timestamps.retain(|&ts| ts > cutoff);
let keep = !timestamps.is_empty();
if !keep {
removed += 1;
}
keep
});
if removed > 0 {
info!(
removed,
remaining = rate_limit_entries.len(),
"rate limiter cleanup: removed stale entries"
);
}
}
});
}
// CORS 安全检查:生产环境必须配置 cors_origins
if config.server.cors_origins.is_empty() {
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if !is_dev {
anyhow::bail!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。");
}
}
let app = build_router(state, &config);
// Swagger UI / OpenAPI 文档
// TODO: 启用 Swagger UI 后取消注释 (需要 utoipa / utoipa-swagger-ui 版本对齐)
// let app = {
// use utoipa_swagger_ui::SwaggerUi;
// use utoipa::OpenApi;
// let openapi = zclaw_saas::openapi::ApiDoc::openapi();
// app.merge(
// SwaggerUi::new("/api-docs/openapi.json")
// .url("/api-docs/openapi.json", openapi),
// )
// };
let listener = tokio::net::TcpListener::bind(format!("{}:{}", config.server.host, config.server.port)) let listener = tokio::net::TcpListener::bind(format!("{}:{}", config.server.host, config.server.port))
.await?; .await?;
@@ -29,27 +94,19 @@ async fn main() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
fn build_router(state: AppState) -> axum::Router { fn build_router(state: AppState, config: &SaaSConfig) -> axum::Router {
use axum::middleware; use axum::middleware;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer; use tower_http::trace::TraceLayer;
use axum::http::HeaderValue; use axum::http::HeaderValue;
let cors = { let cors = {
let config = state.config.blocking_read();
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if config.server.cors_origins.is_empty() { if config.server.cors_origins.is_empty() {
if is_dev { // 开发环境允许任意 origin生产环境已在 main 中拦截)
CorsLayer::new() CorsLayer::new()
.allow_origin(Any) .allow_origin(Any)
.allow_methods(Any) .allow_methods(Any)
.allow_headers(Any) .allow_headers(Any)
} else {
tracing::error!("生产环境必须配置 server.cors_origins不能使用 allow_origin(Any)");
panic!("生产环境必须配置 server.cors_origins 白名单。开发环境可设置 ZCLAW_SAAS_DEV=true 绕过。");
}
} else { } else {
let origins: Vec<HeaderValue> = config.server.cors_origins.iter() let origins: Vec<HeaderValue> = config.server.cors_origins.iter()
.filter_map(|o: &String| o.parse::<HeaderValue>().ok()) .filter_map(|o: &String| o.parse::<HeaderValue>().ok())
@@ -72,14 +129,20 @@ fn build_router(state: AppState) -> axum::Router {
state.clone(), state.clone(),
zclaw_saas::middleware::rate_limit_middleware, zclaw_saas::middleware::rate_limit_middleware,
)) ))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::csrf::origin_check_middleware,
))
.layer(middleware::from_fn_with_state( .layer(middleware::from_fn_with_state(
state.clone(), state.clone(),
zclaw_saas::auth::auth_middleware, zclaw_saas::auth::auth_middleware,
)); ));
axum::Router::new() axum::Router::new()
.route("/api/health", axum::routing::get(health_handler))
.merge(public_routes) .merge(public_routes)
.merge(protected_routes) .merge(protected_routes)
.layer(axum::extract::DefaultBodyLimit::max(10 * 1024 * 1024)) // 10MB 请求体限制,防止 DoS
.layer(TraceLayer::new_for_http()) .layer(TraceLayer::new_for_http())
.layer(cors) .layer(cors)
.with_state(state) .with_state(state)

View File

@@ -10,6 +10,58 @@ use std::time::Instant;
use crate::state::AppState; use crate::state::AppState;
/// 速率限制检查结果
#[derive(Debug, PartialEq)]
pub(crate) enum RateLimitResult {
/// 允许通过
Allowed,
/// 被限制,附带 Retry-After 秒数
Limited { retry_after_secs: u64 },
}
/// 滑动窗口速率限制核心逻辑(纯函数,便于测试)
///
/// 返回 `RateLimitResult::Allowed` 表示未超限(已记录本次请求),
/// `RateLimitResult::Limited` 表示超限。
pub(crate) fn check_rate_limit(
entries: &mut Vec<Instant>,
now: Instant,
window_duration: std::time::Duration,
max_requests: u64,
) -> RateLimitResult {
let window_start = now - window_duration;
// 清理过期条目
entries.retain(|&ts| ts > window_start);
let count = entries.len() as u64;
if count < max_requests {
entries.push(now);
RateLimitResult::Allowed
} else {
// 计算最早条目的过期时间作为 Retry-After
entries.sort();
let earliest = *entries.first().unwrap_or(&now);
let elapsed = now.duration_since(earliest).as_secs();
let retry_after = window_duration.as_secs().saturating_sub(elapsed);
RateLimitResult::Limited {
retry_after_secs: retry_after,
}
}
}
#[cfg(test)]
/// 清理过期条目并移除空 entry
fn cleanup_stale_entries(
map: &dashmap::DashMap<String, Vec<Instant>>,
cutoff: Instant,
) {
map.retain(|_, entries| {
entries.retain(|&ts| ts > cutoff);
!entries.is_empty()
});
}
/// 滑动窗口速率限制中间件 /// 滑动窗口速率限制中间件
/// ///
/// 按 account_id (从 AuthContext 提取) 做 per-minute 限流。 /// 按 account_id (从 AuthContext 提取) 做 per-minute 限流。
@@ -37,45 +89,186 @@ pub async fn rate_limit_middleware(
drop(config); drop(config);
let now = Instant::now(); let now = Instant::now();
let window_start = now - std::time::Duration::from_secs(60); let window = std::time::Duration::from_secs(60);
// 滑动窗口: 清理过期条目 + 计数
let current_count = { let current_count = {
let mut entries = state.rate_limit_entries.entry(account_id.clone()).or_default(); let mut entries = state.rate_limit_entries.entry(account_id.clone()).or_default();
entries.retain(|&ts| ts > window_start); let result = check_rate_limit(&mut entries, now, window, max_requests);
let count = entries.len() as u64; if let RateLimitResult::Limited { retry_after_secs } = result {
if count < max_requests { if let Some(entries) = state.rate_limit_entries.get_mut(&account_id) {
entries.push(now); if entries.is_empty() {
0 // 未超限 drop(entries);
} else { state.rate_limit_entries.remove(&account_id);
count }
}
return (
StatusCode::TOO_MANY_REQUESTS,
[
("Retry-After", retry_after_secs.to_string()),
("Content-Type", "application/json".to_string()),
],
axum::Json(serde_json::json!({
"error": "RATE_LIMITED",
"message": format!("请求过于频繁,请在 {} 秒后重试", retry_after_secs),
})),
)
.into_response();
} }
entries.len() as u64
}; };
if current_count >= max_requests { // 清理空 entry (不再活跃的用户)
// 计算最早条目的过期时间作为 Retry-After if current_count == 0 {
let retry_after = if let Some(mut entries) = state.rate_limit_entries.get_mut(&account_id) { if let Some(entries) = state.rate_limit_entries.get_mut(&account_id) {
entries.sort(); if entries.is_empty() {
let earliest = *entries.first().unwrap_or(&now); drop(entries);
let elapsed = now.duration_since(earliest).as_secs(); state.rate_limit_entries.remove(&account_id);
60u64.saturating_sub(elapsed) }
} else { }
60
};
return (
StatusCode::TOO_MANY_REQUESTS,
[
("Retry-After", retry_after.to_string()),
("Content-Type", "application/json".to_string()),
],
axum::Json(serde_json::json!({
"error": "RATE_LIMITED",
"message": format!("请求过于频繁,请在 {} 秒后重试", retry_after),
})),
)
.into_response();
} }
next.run(req).await next.run(req).await
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allows_under_limit() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
for i in 0..5 {
let result = check_rate_limit(&mut entries, now, window, 10);
assert_eq!(result, RateLimitResult::Allowed, "request {} should be allowed", i);
}
assert_eq!(entries.len() as u64, 5);
}
#[test]
fn blocks_at_limit() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
let limit: u64 = 3;
// 填到限额
for _ in 0..limit {
let result = check_rate_limit(&mut entries, now, window, limit);
assert_eq!(result, RateLimitResult::Allowed);
}
assert_eq!(entries.len() as u64, limit);
// 下一个应该被限流
let result = check_rate_limit(&mut entries, now, window, limit);
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 60 });
// 不应该增加新条目
assert_eq!(entries.len() as u64, limit);
}
#[test]
fn expired_entries_are_cleaned() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
// 插入一个 61 秒前的旧条目
entries.push(now - std::time::Duration::from_secs(61));
assert_eq!(entries.len(), 1);
// 旧条目应该被清理,然后允许新请求
let result = check_rate_limit(&mut entries, now, window, 1);
assert_eq!(result, RateLimitResult::Allowed);
}
#[test]
fn retry_after_reflects_earliest_entry() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
let limit: u64 = 2;
// 第一个请求在 10 秒前
let first_time = now - std::time::Duration::from_secs(10);
entries.push(first_time);
// 第二个请求现在
entries.push(now);
assert_eq!(entries.len() as u64, limit);
let result = check_rate_limit(&mut entries, now, window, limit);
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 50 });
}
#[test]
fn burst_allows_extra_requests() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
let rpm: u64 = 5;
let burst: u64 = 3;
let max = rpm + burst; // 8
// 前 8 个请求应该全部通过
for _ in 0..max {
let result = check_rate_limit(&mut entries, now, window, max);
assert_eq!(result, RateLimitResult::Allowed);
}
// 第 9 个被限流
let result = check_rate_limit(&mut entries, now, window, max);
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 60 });
}
#[test]
fn cleanup_removes_expired_and_empty() {
let map: dashmap::DashMap<String, Vec<Instant>> = dashmap::DashMap::new();
let now = Instant::now();
let cutoff = now - std::time::Duration::from_secs(120);
// 活跃用户
map.insert("active".to_string(), vec![now]);
// 过期用户
map.insert(
"expired".to_string(),
vec![now - std::time::Duration::from_secs(200)],
);
// 空用户
map.insert("empty".to_string(), vec![]);
cleanup_stale_entries(&map, cutoff);
assert!(map.contains_key("active"));
assert!(!map.contains_key("expired"));
assert!(!map.contains_key("empty"));
}
#[test]
fn empty_entries_allowed() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(60);
let result = check_rate_limit(&mut entries, now, window, 0);
// limit=0 means always limited
assert_eq!(result, RateLimitResult::Limited { retry_after_secs: 60 });
}
#[test]
fn single_request_with_large_window() {
let mut entries: Vec<Instant> = vec![];
let now = Instant::now();
let window = std::time::Duration::from_secs(3600);
let limit: u64 = 100;
for _ in 0..limit {
let result = check_rate_limit(&mut entries, now, window, limit);
assert_eq!(result, RateLimitResult::Allowed);
}
assert_eq!(entries.len() as u64, limit);
let result = check_rate_limit(&mut entries, now, window, limit);
assert!(matches!(result, RateLimitResult::Limited { .. }));
}
}

View File

@@ -7,7 +7,7 @@ use axum::{
use crate::state::AppState; use crate::state::AppState;
use crate::error::SaasResult; use crate::error::SaasResult;
use crate::auth::types::AuthContext; use crate::auth::types::AuthContext;
use crate::auth::handlers::check_permission; use crate::auth::handlers::{check_permission, log_operation};
use super::{types::*, service}; use super::{types::*, service};
/// GET /api/v1/config/items?category=xxx&source=xxx /// GET /api/v1/config/items?category=xxx&source=xxx
@@ -36,6 +36,9 @@ pub async fn create_config_item(
) -> SaasResult<(StatusCode, Json<ConfigItemInfo>)> { ) -> SaasResult<(StatusCode, Json<ConfigItemInfo>)> {
check_permission(&ctx, "config:write")?; check_permission(&ctx, "config:write")?;
let item = service::create_config_item(&state.db, &req).await?; let item = service::create_config_item(&state.db, &req).await?;
log_operation(&state.db, &ctx.account_id, "config.create", "config_item", &item.id,
Some(serde_json::json!({"category": req.category, "key_path": req.key_path})),
ctx.client_ip.as_deref()).await?;
Ok((StatusCode::CREATED, Json(item))) Ok((StatusCode::CREATED, Json(item)))
} }
@@ -47,7 +50,10 @@ pub async fn update_config_item(
Json(req): Json<UpdateConfigItemRequest>, Json(req): Json<UpdateConfigItemRequest>,
) -> SaasResult<Json<ConfigItemInfo>> { ) -> SaasResult<Json<ConfigItemInfo>> {
check_permission(&ctx, "config:write")?; check_permission(&ctx, "config:write")?;
service::update_config_item(&state.db, &id, &req).await.map(Json) let item = service::update_config_item(&state.db, &id, &req).await?;
log_operation(&state.db, &ctx.account_id, "config.update", "config_item", &id, None,
ctx.client_ip.as_deref()).await?;
Ok(Json(item))
} }
/// DELETE /api/v1/config/items/:id (admin only) /// DELETE /api/v1/config/items/:id (admin only)
@@ -58,6 +64,8 @@ pub async fn delete_config_item(
) -> SaasResult<Json<serde_json::Value>> { ) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "config:write")?; check_permission(&ctx, "config:write")?;
service::delete_config_item(&state.db, &id).await?; service::delete_config_item(&state.db, &id).await?;
log_operation(&state.db, &ctx.account_id, "config.delete", "config_item", &id, None,
ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true}))) Ok(Json(serde_json::json!({"ok": true})))
} }
@@ -76,16 +84,24 @@ pub async fn seed_config(
) -> SaasResult<Json<serde_json::Value>> { ) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "config:write")?; check_permission(&ctx, "config:write")?;
let count = service::seed_default_config_items(&state.db).await?; let count = service::seed_default_config_items(&state.db).await?;
log_operation(&state.db, &ctx.account_id, "config.seed", "config_items", "batch",
Some(serde_json::json!({"created": count})),
ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"created": count}))) Ok(Json(serde_json::json!({"created": count})))
} }
/// POST /api/v1/config/sync /// POST /api/v1/config/sync (admin only)
pub async fn sync_config( pub async fn sync_config(
State(state): State<AppState>, State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>, Extension(ctx): Extension<AuthContext>,
Json(req): Json<SyncConfigRequest>, Json(req): Json<SyncConfigRequest>,
) -> SaasResult<Json<super::service::ConfigSyncResult>> { ) -> SaasResult<Json<super::service::ConfigSyncResult>> {
super::service::sync_config(&state.db, &ctx.account_id, &req).await.map(Json) check_permission(&ctx, "config:write")?;
let result = super::service::sync_config(&state.db, &ctx.account_id, &req).await?;
log_operation(&state.db, &ctx.account_id, "config.sync", "config_sync", &ctx.account_id,
Some(serde_json::json!({"action": req.action, "updated": result.updated, "created": result.created, "skipped": result.skipped})),
ctx.client_ip.as_deref()).await?;
Ok(Json(result))
} }
/// POST /api/v1/config/diff /// POST /api/v1/config/diff

View File

@@ -1,6 +1,6 @@
//! 配置迁移业务逻辑 //! 配置迁移业务逻辑
use sqlx::SqlitePool; use sqlx::PgPool;
use crate::error::{SaasError, SaasResult}; use crate::error::{SaasError, SaasResult};
use super::types::*; use super::types::*;
use serde::Serialize; use serde::Serialize;
@@ -8,20 +8,20 @@ use serde::Serialize;
// ============ Config Items ============ // ============ Config Items ============
pub async fn list_config_items( pub async fn list_config_items(
db: &SqlitePool, query: &ConfigQuery, db: &PgPool, query: &ConfigQuery,
) -> SaasResult<Vec<ConfigItemInfo>> { ) -> SaasResult<Vec<ConfigItemInfo>> {
let sql = match (&query.category, &query.source) { let sql = match (&query.category, &query.source) {
(Some(_), Some(_)) => { (Some(_), Some(_)) => {
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at "SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE category = ?1 AND source = ?2 ORDER BY category, key_path" FROM config_items WHERE category = $1 AND source = $2 ORDER BY category, key_path"
} }
(Some(_), None) => { (Some(_), None) => {
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at "SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE category = ?1 ORDER BY key_path" FROM config_items WHERE category = $1 ORDER BY key_path"
} }
(None, Some(_)) => { (None, Some(_)) => {
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at "SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE source = ?1 ORDER BY category, key_path" FROM config_items WHERE source = $1 ORDER BY category, key_path"
} }
(None, None) => { (None, None) => {
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at "SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
@@ -29,7 +29,7 @@ pub async fn list_config_items(
} }
}; };
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)>(sql); let mut query_builder = sqlx::query_as::<_, (String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>(sql);
if let Some(cat) = &query.category { if let Some(cat) = &query.category {
query_builder = query_builder.bind(cat); query_builder = query_builder.bind(cat);
@@ -40,15 +40,15 @@ pub async fn list_config_items(
let rows = query_builder.fetch_all(db).await?; let rows = query_builder.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)| { Ok(rows.into_iter().map(|(id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)| {
ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at } ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() }
}).collect()) }).collect())
} }
pub async fn get_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<ConfigItemInfo> { pub async fn get_config_item(db: &PgPool, item_id: &str) -> SaasResult<ConfigItemInfo> {
let row: Option<(String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, String, String)> = let row: Option<(String, String, String, String, Option<String>, Option<String>, String, Option<String>, bool, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at "SELECT id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at
FROM config_items WHERE id = ?1" FROM config_items WHERE id = $1"
) )
.bind(item_id) .bind(item_id)
.fetch_optional(db) .fetch_optional(db)
@@ -57,20 +57,20 @@ pub async fn get_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<Confi
let (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at) = let (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("配置项 {} 不存在", item_id)))?; row.ok_or_else(|| SaasError::NotFound(format!("配置项 {} 不存在", item_id)))?;
Ok(ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at }) Ok(ConfigItemInfo { id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() })
} }
pub async fn create_config_item( pub async fn create_config_item(
db: &SqlitePool, req: &CreateConfigItemRequest, db: &PgPool, req: &CreateConfigItemRequest,
) -> SaasResult<ConfigItemInfo> { ) -> SaasResult<ConfigItemInfo> {
let id = uuid::Uuid::new_v4().to_string(); let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let source = req.source.as_deref().unwrap_or("local"); let source = req.source.as_deref().unwrap_or("local");
let requires_restart = req.requires_restart.unwrap_or(false); let requires_restart = req.requires_restart.unwrap_or(false);
// 检查唯一性 // 检查唯一性
let existing: Option<(String,)> = sqlx::query_as( let existing: Option<(String,)> = sqlx::query_as(
"SELECT id FROM config_items WHERE category = ?1 AND key_path = ?2" "SELECT id FROM config_items WHERE category = $1 AND key_path = $2"
) )
.bind(&req.category).bind(&req.key_path) .bind(&req.category).bind(&req.key_path)
.fetch_optional(db).await?; .fetch_optional(db).await?;
@@ -83,7 +83,7 @@ pub async fn create_config_item(
sqlx::query( sqlx::query(
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at) "INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?10)" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $10)"
) )
.bind(&id).bind(&req.category).bind(&req.key_path).bind(&req.value_type) .bind(&id).bind(&req.category).bind(&req.key_path).bind(&req.value_type)
.bind(&req.current_value).bind(&req.default_value).bind(source) .bind(&req.current_value).bind(&req.default_value).bind(source)
@@ -94,36 +94,38 @@ pub async fn create_config_item(
} }
pub async fn update_config_item( pub async fn update_config_item(
db: &SqlitePool, item_id: &str, req: &UpdateConfigItemRequest, db: &PgPool, item_id: &str, req: &UpdateConfigItemRequest,
) -> SaasResult<ConfigItemInfo> { ) -> SaasResult<ConfigItemInfo> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let mut updates = Vec::new(); let mut updates = Vec::new();
let mut params: Vec<String> = Vec::new(); let mut params: Vec<String> = Vec::new();
let mut param_idx: i32 = 1;
if let Some(ref v) = req.current_value { updates.push("current_value = ?"); params.push(v.clone()); } if let Some(ref v) = req.current_value { updates.push(format!("current_value = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.source { updates.push("source = ?"); params.push(v.clone()); } if let Some(ref v) = req.source { updates.push(format!("source = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if let Some(ref v) = req.description { updates.push("description = ?"); params.push(v.clone()); } if let Some(ref v) = req.description { updates.push(format!("description = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
if updates.is_empty() { if updates.is_empty() {
return get_config_item(db, item_id).await; return get_config_item(db, item_id).await;
} }
updates.push("updated_at = ?"); updates.push(format!("updated_at = ${}", param_idx));
params.push(now); param_idx += 1;
params.push(item_id.to_string()); params.push(item_id.to_string());
let sql = format!("UPDATE config_items SET {} WHERE id = ?", updates.join(", ")); let sql = format!("UPDATE config_items SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql); let mut query = sqlx::query(&sql);
for p in &params { for p in &params {
query = query.bind(p); query = query.bind(p);
} }
query = query.bind(now);
query.execute(db).await?; query.execute(db).await?;
get_config_item(db, item_id).await get_config_item(db, item_id).await
} }
pub async fn delete_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<()> { pub async fn delete_config_item(db: &PgPool, item_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM config_items WHERE id = ?1") let result = sqlx::query("DELETE FROM config_items WHERE id = $1")
.bind(item_id).execute(db).await?; .bind(item_id).execute(db).await?;
if result.rows_affected() == 0 { if result.rows_affected() == 0 {
return Err(SaasError::NotFound(format!("配置项 {} 不存在", item_id))); return Err(SaasError::NotFound(format!("配置项 {} 不存在", item_id)));
@@ -133,7 +135,7 @@ pub async fn delete_config_item(db: &SqlitePool, item_id: &str) -> SaasResult<()
// ============ Config Analysis ============ // ============ Config Analysis ============
pub async fn analyze_config(db: &SqlitePool) -> SaasResult<ConfigAnalysis> { pub async fn analyze_config(db: &PgPool) -> SaasResult<ConfigAnalysis> {
let items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?; let items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
let mut categories: std::collections::HashMap<String, (i64, i64)> = std::collections::HashMap::new(); let mut categories: std::collections::HashMap<String, (i64, i64)> = std::collections::HashMap::new();
@@ -157,7 +159,7 @@ pub async fn analyze_config(db: &SqlitePool) -> SaasResult<ConfigAnalysis> {
} }
/// 种子默认配置项 /// 种子默认配置项
pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> { pub async fn seed_default_config_items(db: &PgPool) -> SaasResult<usize> {
let defaults = [ let defaults = [
("server", "server.host", "string", Some("127.0.0.1"), Some("127.0.0.1"), "服务器监听地址"), ("server", "server.host", "string", Some("127.0.0.1"), Some("127.0.0.1"), "服务器监听地址"),
("server", "server.port", "integer", Some("4200"), Some("4200"), "服务器端口"), ("server", "server.port", "integer", Some("4200"), Some("4200"), "服务器端口"),
@@ -175,11 +177,11 @@ pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
]; ];
let mut created = 0; let mut created = 0;
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
for (category, key_path, value_type, default_value, current_value, description) in defaults { for (category, key_path, value_type, default_value, current_value, description) in defaults {
let existing: Option<(String,)> = sqlx::query_as( let existing: Option<(String,)> = sqlx::query_as(
"SELECT id FROM config_items WHERE category = ?1 AND key_path = ?2" "SELECT id FROM config_items WHERE category = $1 AND key_path = $2"
) )
.bind(category).bind(key_path) .bind(category).bind(key_path)
.fetch_optional(db) .fetch_optional(db)
@@ -189,7 +191,7 @@ pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
let id = uuid::Uuid::new_v4().to_string(); let id = uuid::Uuid::new_v4().to_string();
sqlx::query( sqlx::query(
"INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at) "INSERT INTO config_items (id, category, key_path, value_type, current_value, default_value, source, description, requires_restart, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'local', ?7, 0, ?8, ?8)" VALUES ($1, $2, $3, $4, $5, $6, 'local', $7, false, $8, $8)"
) )
.bind(&id).bind(category).bind(key_path).bind(value_type) .bind(&id).bind(category).bind(key_path).bind(value_type)
.bind(current_value).bind(default_value).bind(description).bind(&now) .bind(current_value).bind(default_value).bind(description).bind(&now)
@@ -204,21 +206,20 @@ pub async fn seed_default_config_items(db: &SqlitePool) -> SaasResult<usize> {
// ============ Config Sync ============ // ============ Config Sync ============
/// 计算客户端与 SaaS 端的配置差异 /// 纯函数:计算客户端与 SaaS 配置项的差异(不依赖数据库)
pub async fn compute_config_diff( pub fn compute_diff_items(
db: &SqlitePool, req: &SyncConfigRequest, config_keys: &[String],
) -> SaasResult<ConfigDiffResponse> { client_values: &serde_json::Value,
let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?; saas_items: &[ConfigItemInfo],
) -> (Vec<ConfigDiffItem>, usize) {
let mut items = Vec::new(); let mut items = Vec::new();
let mut conflicts = 0usize; let mut conflicts = 0usize;
for key in &req.config_keys { for key in config_keys {
let client_val = req.client_values.get(key) let client_val = client_values.get(key)
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.map(|s| s.to_string()); .map(|s| s.to_string());
// 查找 SaaS 端的值
let saas_item = saas_items.iter().find(|item| item.key_path == *key); let saas_item = saas_items.iter().find(|item| item.key_path == *key);
let saas_val = saas_item.and_then(|item| item.current_value.clone()); let saas_val = saas_item.and_then(|item| item.current_value.clone());
@@ -239,6 +240,17 @@ pub async fn compute_config_diff(
}); });
} }
(items, conflicts)
}
/// 计算客户端与 SaaS 端的配置差异
pub async fn compute_config_diff(
db: &PgPool, req: &SyncConfigRequest,
) -> SaasResult<ConfigDiffResponse> {
let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
let (items, conflicts) = compute_diff_items(&req.config_keys, &req.client_values, &saas_items);
Ok(ConfigDiffResponse { Ok(ConfigDiffResponse {
total_keys: items.len(), total_keys: items.len(),
conflicts, conflicts,
@@ -248,16 +260,16 @@ pub async fn compute_config_diff(
/// 执行配置同步 (实际写入 config_items) /// 执行配置同步 (实际写入 config_items)
pub async fn sync_config( pub async fn sync_config(
db: &SqlitePool, account_id: &str, req: &SyncConfigRequest, db: &PgPool, account_id: &str, req: &SyncConfigRequest,
) -> SaasResult<ConfigSyncResult> { ) -> SaasResult<ConfigSyncResult> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let config_keys_str = serde_json::to_string(&req.config_keys)?; let config_keys_str = serde_json::to_string(&req.config_keys)?;
let client_values_str = Some(serde_json::to_string(&req.client_values)?); let client_values_str = Some(serde_json::to_string(&req.client_values)?);
// 获取 SaaS 端的配置值 // 获取 SaaS 端的配置值
let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?; let saas_items = list_config_items(db, &ConfigQuery { category: None, source: None }).await?;
let mut updated = 0i64; let mut updated = 0i64;
let created = 0i64; let mut created = 0i64;
let mut skipped = 0i64; let mut skipped = 0i64;
for key in &req.config_keys { for key in &req.config_keys {
@@ -273,13 +285,20 @@ pub async fn sync_config(
if let Some(val) = &client_val { if let Some(val) = &client_val {
if let Some(item) = saas_item { if let Some(item) = saas_item {
// 更新已有配置项 // 更新已有配置项
sqlx::query("UPDATE config_items SET current_value = ?1, source = 'local', updated_at = ?2 WHERE id = ?3") sqlx::query("UPDATE config_items SET current_value = $1, source = 'local', updated_at = $2 WHERE id = $3")
.bind(val).bind(&now).bind(&item.id) .bind(val).bind(&now).bind(&item.id)
.execute(db).await?; .execute(db).await?;
updated += 1; updated += 1;
} else { } else {
// 推送时如果 SaaS 不存在该 key,记录跳过 // SaaS 不存在该 key → 自动创建
skipped += 1; let new_id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO config_items (id, category, key_path, value_type, current_value, source, requires_restart, created_at, updated_at)
VALUES ($1, 'imported', $2, 'string', $3, 'local', false, $4, $4)"
)
.bind(&new_id).bind(key).bind(val).bind(&now)
.execute(db).await?;
created += 1;
} }
} }
} }
@@ -288,7 +307,7 @@ pub async fn sync_config(
if let Some(val) = &client_val { if let Some(val) = &client_val {
if let Some(item) = saas_item { if let Some(item) = saas_item {
if item.current_value.is_none() || item.current_value.as_deref() == Some("") { if item.current_value.is_none() || item.current_value.as_deref() == Some("") {
sqlx::query("UPDATE config_items SET current_value = ?1, source = 'local', updated_at = ?2 WHERE id = ?3") sqlx::query("UPDATE config_items SET current_value = $1, source = 'local', updated_at = $2 WHERE id = $3")
.bind(val).bind(&now).bind(&item.id) .bind(val).bind(&now).bind(&item.id)
.execute(db).await?; .execute(db).await?;
updated += 1; updated += 1;
@@ -296,9 +315,17 @@ pub async fn sync_config(
// 冲突: SaaS 有值 → 保留 SaaS 值 // 冲突: SaaS 有值 → 保留 SaaS 值
skipped += 1; skipped += 1;
} }
} else {
// SaaS 完全没有该 key → 创建
let new_id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO config_items (id, category, key_path, value_type, current_value, source, requires_restart, created_at, updated_at)
VALUES ($1, 'imported', $2, 'string', $3, 'local', false, $4, $4)"
)
.bind(&new_id).bind(key).bind(val).bind(&now)
.execute(db).await?;
created += 1;
} }
// 客户端有但 SaaS 完全没有的 key → 不自动创建 (需要管理员先创建)
skipped += 1;
} }
} }
_ => { _ => {
@@ -323,7 +350,7 @@ pub async fn sync_config(
sqlx::query( sqlx::query(
"INSERT INTO config_sync_log (account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at) "INSERT INTO config_sync_log (account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
) )
.bind(account_id).bind(&req.client_fingerprint) .bind(account_id).bind(&req.client_fingerprint)
.bind(&req.action).bind(&config_keys_str).bind(&client_values_str) .bind(&req.action).bind(&config_keys_str).bind(&client_values_str)
@@ -343,18 +370,126 @@ pub struct ConfigSyncResult {
} }
pub async fn list_sync_logs( pub async fn list_sync_logs(
db: &SqlitePool, account_id: &str, db: &PgPool, account_id: &str,
) -> SaasResult<Vec<ConfigSyncLogInfo>> { ) -> SaasResult<Vec<ConfigSyncLogInfo>> {
let rows: Vec<(i64, String, String, String, String, Option<String>, Option<String>, Option<String>, String)> = let rows: Vec<(i64, String, String, String, String, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at "SELECT id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at
FROM config_sync_log WHERE account_id = ?1 ORDER BY created_at DESC LIMIT 50" FROM config_sync_log WHERE account_id = $1 ORDER BY created_at DESC LIMIT 50"
) )
.bind(account_id) .bind(account_id)
.fetch_all(db) .fetch_all(db)
.await?; .await?;
Ok(rows.into_iter().map(|(id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)| { Ok(rows.into_iter().map(|(id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at)| {
ConfigSyncLogInfo { id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at } ConfigSyncLogInfo { id, account_id, client_fingerprint, action, config_keys, client_values, saas_values, resolution, created_at: created_at.to_rfc3339() }
}).collect()) }).collect())
} }
#[cfg(test)]
mod tests {
use super::*;
fn make_saas_item(key: &str, value: Option<&str>) -> ConfigItemInfo {
ConfigItemInfo {
id: "test-id".into(),
category: "test".into(),
key_path: key.into(),
value_type: "string".into(),
current_value: value.map(String::from),
default_value: None,
source: "local".into(),
description: None,
requires_restart: false,
created_at: "2026-01-01T00:00:00Z".into(),
updated_at: "2026-01-01T00:00:00Z".into(),
}
}
#[test]
fn test_diff_identical_values() {
let keys = vec!["server.host".into(), "server.port".into()];
let client = serde_json::json!({"server.host": "127.0.0.1", "server.port": "8080"});
let saas = vec![
make_saas_item("server.host", Some("127.0.0.1")),
make_saas_item("server.port", Some("8080")),
];
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert_eq!(conflicts, 0);
assert_eq!(items.len(), 2);
assert!(!items[0].conflict);
assert!(!items[1].conflict);
}
#[test]
fn test_diff_conflict() {
let keys = vec!["server.host".into()];
let client = serde_json::json!({"server.host": "0.0.0.0"});
let saas = vec![make_saas_item("server.host", Some("127.0.0.1"))];
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert_eq!(conflicts, 1);
assert!(items[0].conflict);
assert_eq!(items[0].client_value.as_deref(), Some("0.0.0.0"));
assert_eq!(items[0].saas_value.as_deref(), Some("127.0.0.1"));
}
#[test]
fn test_diff_client_only_key() {
let keys = vec!["new.key".into()];
let client = serde_json::json!({"new.key": "value1"});
let saas = vec![]; // SaaS 没有这个 key
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert_eq!(conflicts, 0);
assert_eq!(items[0].client_value.as_deref(), Some("value1"));
assert!(items[0].saas_value.is_none());
}
#[test]
fn test_diff_missing_client_value() {
let keys = vec!["server.host".into()];
let client = serde_json::json!({}); // 客户端没有这个 key
let saas = vec![make_saas_item("server.host", Some("127.0.0.1"))];
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert_eq!(conflicts, 0); // 一方为 null 不算冲突
assert!(items[0].client_value.is_none());
assert_eq!(items[0].saas_value.as_deref(), Some("127.0.0.1"));
}
#[test]
fn test_diff_empty_keys() {
let keys: Vec<String> = vec![];
let client = serde_json::json!({"server.host": "127.0.0.1"});
let saas = vec![make_saas_item("server.host", Some("127.0.0.1"))];
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert!(items.is_empty());
assert_eq!(conflicts, 0);
}
#[test]
fn test_diff_mixed() {
let keys = vec!["same".into(), "conflict".into(), "client_only".into(), "saas_only".into()];
let client = serde_json::json!({
"same": "val1",
"conflict": "client-val",
"client_only": "new-val",
});
let saas = vec![
make_saas_item("same", Some("val1")),
make_saas_item("conflict", Some("saas-val")),
make_saas_item("saas_only", Some("only-here")),
];
let (items, conflicts) = compute_diff_items(&keys, &client, &saas);
assert_eq!(items.len(), 4);
assert_eq!(conflicts, 1);
// same: no conflict
assert!(!items[0].conflict);
// conflict: has conflict
assert!(items[1].conflict);
// client_only: SaaS has no such key
assert!(items[2].saas_value.is_none());
assert_eq!(items[2].client_value.as_deref(), Some("new-val"));
// saas_only: client has no such key
assert!(items[3].client_value.is_none());
assert_eq!(items[3].saas_value.as_deref(), Some("only-here"));
}
}

View File

@@ -3,7 +3,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// 配置项信息 /// 配置项信息
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ConfigItemInfo { pub struct ConfigItemInfo {
pub id: String, pub id: String,
pub category: String, pub category: String,
@@ -19,7 +19,7 @@ pub struct ConfigItemInfo {
} }
/// 创建配置项请求 /// 创建配置项请求
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateConfigItemRequest { pub struct CreateConfigItemRequest {
pub category: String, pub category: String,
pub key_path: String, pub key_path: String,
@@ -32,7 +32,7 @@ pub struct CreateConfigItemRequest {
} }
/// 更新配置项请求 /// 更新配置项请求
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateConfigItemRequest { pub struct UpdateConfigItemRequest {
pub current_value: Option<String>, pub current_value: Option<String>,
pub source: Option<String>, pub source: Option<String>,
@@ -40,7 +40,7 @@ pub struct UpdateConfigItemRequest {
} }
/// 配置同步日志 /// 配置同步日志
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
pub struct ConfigSyncLogInfo { pub struct ConfigSyncLogInfo {
pub id: i64, pub id: i64,
pub account_id: String, pub account_id: String,
@@ -54,14 +54,14 @@ pub struct ConfigSyncLogInfo {
} }
/// 配置分析结果 /// 配置分析结果
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ConfigAnalysis { pub struct ConfigAnalysis {
pub total_items: i64, pub total_items: i64,
pub categories: Vec<CategorySummary>, pub categories: Vec<CategorySummary>,
pub items: Vec<ConfigItemInfo>, pub items: Vec<ConfigItemInfo>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct CategorySummary { pub struct CategorySummary {
pub category: String, pub category: String,
pub count: i64, pub count: i64,
@@ -69,10 +69,10 @@ pub struct CategorySummary {
} }
/// 配置同步请求 /// 配置同步请求
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct SyncConfigRequest { pub struct SyncConfigRequest {
pub client_fingerprint: String, pub client_fingerprint: String,
/// 同步方向: "push", "pull", "merge" /// 同步方向: "push", "merge"
#[serde(default = "default_sync_action")] #[serde(default = "default_sync_action")]
pub action: String, pub action: String,
pub config_keys: Vec<String>, pub config_keys: Vec<String>,
@@ -82,7 +82,7 @@ pub struct SyncConfigRequest {
fn default_sync_action() -> String { "push".to_string() } fn default_sync_action() -> String { "push".to_string() }
/// 配置差异项 /// 配置差异项
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
pub struct ConfigDiffItem { pub struct ConfigDiffItem {
pub key_path: String, pub key_path: String,
pub client_value: Option<String>, pub client_value: Option<String>,
@@ -91,7 +91,7 @@ pub struct ConfigDiffItem {
} }
/// 配置差异响应 /// 配置差异响应
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ConfigDiffResponse { pub struct ConfigDiffResponse {
pub items: Vec<ConfigDiffItem>, pub items: Vec<ConfigDiffItem>,
pub total_keys: usize, pub total_keys: usize,
@@ -99,7 +99,7 @@ pub struct ConfigDiffResponse {
} }
/// 配置查询参数 /// 配置查询参数
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
pub struct ConfigQuery { pub struct ConfigQuery {
pub category: Option<String>, pub category: Option<String>,
pub source: Option<String>, pub source: Option<String>,

View File

@@ -36,7 +36,7 @@ pub async fn create_provider(
Json(req): Json<CreateProviderRequest>, Json(req): Json<CreateProviderRequest>,
) -> SaasResult<(StatusCode, Json<ProviderInfo>)> { ) -> SaasResult<(StatusCode, Json<ProviderInfo>)> {
check_permission(&ctx, "provider:manage")?; check_permission(&ctx, "provider:manage")?;
let provider = service::create_provider(&state.db, &req).await?; let provider = service::create_provider(&state.db, &state.field_encryption, &req).await?;
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id, log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?; Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
Ok((StatusCode::CREATED, Json(provider))) Ok((StatusCode::CREATED, Json(provider)))
@@ -50,7 +50,7 @@ pub async fn update_provider(
Json(req): Json<UpdateProviderRequest>, Json(req): Json<UpdateProviderRequest>,
) -> SaasResult<Json<ProviderInfo>> { ) -> SaasResult<Json<ProviderInfo>> {
check_permission(&ctx, "provider:manage")?; check_permission(&ctx, "provider:manage")?;
let provider = service::update_provider(&state.db, &id, &req).await?; let provider = service::update_provider(&state.db, &state.field_encryption, &id, &req).await?;
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?; log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(provider)) Ok(Json(provider))
} }
@@ -135,7 +135,7 @@ pub async fn list_api_keys(
Query(params): Query<std::collections::HashMap<String, String>>, Query(params): Query<std::collections::HashMap<String, String>>,
) -> SaasResult<Json<Vec<AccountApiKeyInfo>>> { ) -> SaasResult<Json<Vec<AccountApiKeyInfo>>> {
let provider_id = params.get("provider_id").map(|s| s.as_str()); let provider_id = params.get("provider_id").map(|s| s.as_str());
service::list_account_api_keys(&state.db, &ctx.account_id, provider_id).await.map(Json) service::list_account_api_keys(&state.db, &state.field_encryption, &ctx.account_id, provider_id).await.map(Json)
} }
/// POST /api/v1/keys /// POST /api/v1/keys
@@ -144,7 +144,7 @@ pub async fn create_api_key(
Extension(ctx): Extension<AuthContext>, Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateAccountApiKeyRequest>, Json(req): Json<CreateAccountApiKeyRequest>,
) -> SaasResult<(StatusCode, Json<AccountApiKeyInfo>)> { ) -> SaasResult<(StatusCode, Json<AccountApiKeyInfo>)> {
let key = service::create_account_api_key(&state.db, &ctx.account_id, &req).await?; let key = service::create_account_api_key(&state.db, &state.field_encryption, &ctx.account_id, &req).await?;
log_operation(&state.db, &ctx.account_id, "api_key.create", "api_key", &key.id, log_operation(&state.db, &ctx.account_id, "api_key.create", "api_key", &key.id,
Some(serde_json::json!({"provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?; Some(serde_json::json!({"provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?;
Ok((StatusCode::CREATED, Json(key))) Ok((StatusCode::CREATED, Json(key)))
@@ -157,7 +157,7 @@ pub async fn rotate_api_key(
Extension(ctx): Extension<AuthContext>, Extension(ctx): Extension<AuthContext>,
Json(req): Json<RotateApiKeyRequest>, Json(req): Json<RotateApiKeyRequest>,
) -> SaasResult<Json<serde_json::Value>> { ) -> SaasResult<Json<serde_json::Value>> {
service::rotate_account_api_key(&state.db, &id, &ctx.account_id, &req.new_key_value).await?; service::rotate_account_api_key(&state.db, &state.field_encryption, &id, &ctx.account_id, &req.new_key_value).await?;
log_operation(&state.db, &ctx.account_id, "api_key.rotate", "api_key", &id, None, ctx.client_ip.as_deref()).await?; log_operation(&state.db, &ctx.account_id, "api_key.rotate", "api_key", &id, None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true}))) Ok(Json(serde_json::json!({"ok": true})))
} }

View File

@@ -1,13 +1,15 @@
//! 模型配置业务逻辑 //! 模型配置业务逻辑
use sqlx::SqlitePool; use sqlx::PgPool;
use std::sync::Arc;
use crate::crypto::FieldEncryption;
use crate::error::{SaasError, SaasResult}; use crate::error::{SaasError, SaasResult};
use super::types::*; use super::types::*;
// ============ Providers ============ // ============ Providers ============
pub async fn list_providers(db: &SqlitePool) -> SaasResult<Vec<ProviderInfo>> { pub async fn list_providers(db: &PgPool) -> SaasResult<Vec<ProviderInfo>> {
let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> = let rows: Vec<(String, String, String, String, String, bool, Option<i64>, Option<i64>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at "SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers ORDER BY name" FROM providers ORDER BY name"
@@ -16,15 +18,15 @@ pub async fn list_providers(db: &SqlitePool) -> SaasResult<Vec<ProviderInfo>> {
.await?; .await?;
Ok(rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| { Ok(rows.into_iter().map(|(id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at)| {
ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at } ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() }
}).collect()) }).collect())
} }
pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<ProviderInfo> { pub async fn get_provider(db: &PgPool, provider_id: &str) -> SaasResult<ProviderInfo> {
let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, String, String)> = let row: Option<(String, String, String, String, String, bool, Option<i64>, Option<i64>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at "SELECT id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at
FROM providers WHERE id = ?1" FROM providers WHERE id = $1"
) )
.bind(provider_id) .bind(provider_id)
.fetch_optional(db) .fetch_optional(db)
@@ -33,25 +35,33 @@ pub async fn get_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<Prov
let (id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at) = let (id, name, display_name, base_url, api_protocol, enabled, rpm, tpm, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?; row.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", provider_id)))?;
Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at, updated_at }) Ok(ProviderInfo { id, name, display_name, base_url, api_protocol, enabled, rate_limit_rpm: rpm, rate_limit_tpm: tpm, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() })
} }
pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> SaasResult<ProviderInfo> { pub async fn create_provider(
db: &PgPool, encryption: &Arc<FieldEncryption>, req: &CreateProviderRequest,
) -> SaasResult<ProviderInfo> {
let id = uuid::Uuid::new_v4().to_string(); let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
// 检查名称唯一性 // 检查名称唯一性
let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = ?1") let existing: Option<(String,)> = sqlx::query_as("SELECT id FROM providers WHERE name = $1")
.bind(&req.name).fetch_optional(db).await?; .bind(&req.name).fetch_optional(db).await?;
if existing.is_some() { if existing.is_some() {
return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name))); return Err(SaasError::AlreadyExists(format!("Provider '{}' 已存在", req.name)));
} }
// 加密 API Key 后存储
let encrypted_api_key: Option<String> = match &req.api_key {
Some(key) => Some(encryption.encrypt(key)?),
None => None,
};
sqlx::query( sqlx::query(
"INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at) "INSERT INTO providers (id, name, display_name, api_key, base_url, api_protocol, enabled, rate_limit_rpm, rate_limit_tpm, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?8, ?9, ?9)" VALUES ($1, $2, $3, $4, $5, $6, true, $7, $8, $9, $9)"
) )
.bind(&id).bind(&req.name).bind(&req.display_name).bind(&req.api_key) .bind(&id).bind(&req.name).bind(&req.display_name).bind(&encrypted_api_key)
.bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now) .bind(&req.base_url).bind(&req.api_protocol).bind(&req.rate_limit_rpm).bind(&req.rate_limit_tpm).bind(&now)
.execute(db).await?; .execute(db).await?;
@@ -59,40 +69,48 @@ pub async fn create_provider(db: &SqlitePool, req: &CreateProviderRequest) -> Sa
} }
pub async fn update_provider( pub async fn update_provider(
db: &SqlitePool, provider_id: &str, req: &UpdateProviderRequest, db: &PgPool, encryption: &Arc<FieldEncryption>, provider_id: &str, req: &UpdateProviderRequest,
) -> SaasResult<ProviderInfo> { ) -> SaasResult<ProviderInfo> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let mut updates = Vec::new(); let mut updates = Vec::new();
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new(); let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
let mut param_idx: i32 = 1;
if let Some(ref v) = req.display_name { updates.push("display_name = ?"); params.push(Box::new(v.clone())); } if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.base_url { updates.push("base_url = ?"); params.push(Box::new(v.clone())); } if let Some(ref v) = req.base_url { updates.push(format!("base_url = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.api_protocol { updates.push("api_protocol = ?"); params.push(Box::new(v.clone())); } if let Some(ref v) = req.api_protocol { updates.push(format!("api_protocol = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(ref v) = req.api_key { updates.push("api_key = ?"); params.push(Box::new(v.clone())); } if let Some(ref v) = req.api_key {
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); } // 加密 API Key 后存储
if let Some(v) = req.rate_limit_rpm { updates.push("rate_limit_rpm = ?"); params.push(Box::new(v)); } let encrypted = encryption.encrypt(v)?;
if let Some(v) = req.rate_limit_tpm { updates.push("rate_limit_tpm = ?"); params.push(Box::new(v)); } updates.push(format!("api_key = ${}", param_idx));
params.push(Box::new(encrypted));
param_idx += 1;
}
if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.rate_limit_rpm { updates.push(format!("rate_limit_rpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.rate_limit_tpm { updates.push(format!("rate_limit_tpm = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if updates.is_empty() { if updates.is_empty() {
return get_provider(db, provider_id).await; return get_provider(db, provider_id).await;
} }
updates.push("updated_at = ?"); updates.push(format!("updated_at = ${}", param_idx));
params.push(Box::new(now.clone())); param_idx += 1;
params.push(Box::new(provider_id.to_string())); params.push(Box::new(provider_id.to_string()));
let sql = format!("UPDATE providers SET {} WHERE id = ?", updates.join(", ")); let sql = format!("UPDATE providers SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql); let mut query = sqlx::query(&sql);
for p in &params { for p in &params {
query = query.bind(format!("{}", p)); query = query.bind(format!("{}", p));
} }
query = query.bind(now);
query.execute(db).await?; query.execute(db).await?;
get_provider(db, provider_id).await get_provider(db, provider_id).await
} }
pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<()> { pub async fn delete_provider(db: &PgPool, provider_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM providers WHERE id = ?1") let result = sqlx::query("DELETE FROM providers WHERE id = $1")
.bind(provider_id).execute(db).await?; .bind(provider_id).execute(db).await?;
if result.rows_affected() == 0 { if result.rows_affected() == 0 {
@@ -103,36 +121,36 @@ pub async fn delete_provider(db: &SqlitePool, provider_id: &str) -> SaasResult<(
// ============ Models ============ // ============ Models ============
pub async fn list_models(db: &SqlitePool, provider_id: Option<&str>) -> SaasResult<Vec<ModelInfo>> { pub async fn list_models(db: &PgPool, provider_id: Option<&str>) -> SaasResult<Vec<ModelInfo>> {
let sql = if provider_id.is_some() { let sql = if provider_id.is_some() {
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE provider_id = ?1 ORDER BY alias" FROM models WHERE provider_id = $1 ORDER BY alias"
} else { } else {
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models ORDER BY provider_id, alias" FROM models ORDER BY provider_id, alias"
}; };
let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)>(sql); let mut query = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool, f64, f64, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)>(sql);
if let Some(pid) = provider_id { if let Some(pid) = provider_id {
query = query.bind(pid); query = query.bind(pid);
} }
let rows = query.fetch_all(db).await?; let rows = query.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| { Ok(rows.into_iter().map(|(id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at)| {
ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at } ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() }
}).collect()) }).collect())
} }
pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResult<ModelInfo> { pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<ModelInfo> {
// 验证 provider 存在 // 验证 provider 存在
let provider = get_provider(db, &req.provider_id).await?; let provider = get_provider(db, &req.provider_id).await?;
let id = uuid::Uuid::new_v4().to_string(); let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
// 检查 model 唯一性 // 检查 model 唯一性
let existing: Option<(String,)> = sqlx::query_as( let existing: Option<(String,)> = sqlx::query_as(
"SELECT id FROM models WHERE provider_id = ?1 AND model_id = ?2" "SELECT id FROM models WHERE provider_id = $1 AND model_id = $2"
) )
.bind(&req.provider_id).bind(&req.model_id) .bind(&req.provider_id).bind(&req.model_id)
.fetch_optional(db).await?; .fetch_optional(db).await?;
@@ -152,7 +170,7 @@ pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResu
sqlx::query( sqlx::query(
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at) "INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, 1, ?9, ?10, ?11, ?11)" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)"
) )
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias) .bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(&req.alias)
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now) .bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
@@ -161,11 +179,11 @@ pub async fn create_model(db: &SqlitePool, req: &CreateModelRequest) -> SaasResu
get_model(db, &id).await get_model(db, &id).await
} }
pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult<ModelInfo> { pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, String, String)> = let row: Option<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at
FROM models WHERE id = ?1" FROM models WHERE id = $1"
) )
.bind(model_id) .bind(model_id)
.fetch_optional(db) .fetch_optional(db)
@@ -174,45 +192,47 @@ pub async fn get_model(db: &SqlitePool, model_id: &str) -> SaasResult<ModelInfo>
let (id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at) = let (id, provider_id, model_id, alias, ctx, max_out, streaming, vision, enabled, pi, po, created_at, updated_at) =
row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?; row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
Ok(ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at, updated_at }) Ok(ModelInfo { id, provider_id, model_id, alias, context_window: ctx, max_output_tokens: max_out, supports_streaming: streaming, supports_vision: vision, enabled, pricing_input: pi, pricing_output: po, created_at: created_at.to_rfc3339(), updated_at: updated_at.to_rfc3339() })
} }
pub async fn update_model( pub async fn update_model(
db: &SqlitePool, model_id: &str, req: &UpdateModelRequest, db: &PgPool, model_id: &str, req: &UpdateModelRequest,
) -> SaasResult<ModelInfo> { ) -> SaasResult<ModelInfo> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let mut updates = Vec::new(); let mut updates = Vec::new();
let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new(); let mut params: Vec<Box<dyn std::fmt::Display + Send + Sync>> = Vec::new();
let mut param_idx: i32 = 1;
if let Some(ref v) = req.alias { updates.push("alias = ?"); params.push(Box::new(v.clone())); } if let Some(ref v) = req.alias { updates.push(format!("alias = ${}", param_idx)); params.push(Box::new(v.clone())); param_idx += 1; }
if let Some(v) = req.context_window { updates.push("context_window = ?"); params.push(Box::new(v)); } if let Some(v) = req.context_window { updates.push(format!("context_window = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.max_output_tokens { updates.push("max_output_tokens = ?"); params.push(Box::new(v)); } if let Some(v) = req.max_output_tokens { updates.push(format!("max_output_tokens = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.supports_streaming { updates.push("supports_streaming = ?"); params.push(Box::new(v)); } if let Some(v) = req.supports_streaming { updates.push(format!("supports_streaming = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.supports_vision { updates.push("supports_vision = ?"); params.push(Box::new(v)); } if let Some(v) = req.supports_vision { updates.push(format!("supports_vision = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.enabled { updates.push("enabled = ?"); params.push(Box::new(v)); } if let Some(v) = req.enabled { updates.push(format!("enabled = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.pricing_input { updates.push("pricing_input = ?"); params.push(Box::new(v)); } if let Some(v) = req.pricing_input { updates.push(format!("pricing_input = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if let Some(v) = req.pricing_output { updates.push("pricing_output = ?"); params.push(Box::new(v)); } if let Some(v) = req.pricing_output { updates.push(format!("pricing_output = ${}", param_idx)); params.push(Box::new(v)); param_idx += 1; }
if updates.is_empty() { if updates.is_empty() {
return get_model(db, model_id).await; return get_model(db, model_id).await;
} }
updates.push("updated_at = ?"); updates.push(format!("updated_at = ${}", param_idx));
params.push(Box::new(now.clone())); param_idx += 1;
params.push(Box::new(model_id.to_string())); params.push(Box::new(model_id.to_string()));
let sql = format!("UPDATE models SET {} WHERE id = ?", updates.join(", ")); let sql = format!("UPDATE models SET {} WHERE id = ${}", updates.join(", "), param_idx);
let mut query = sqlx::query(&sql); let mut query = sqlx::query(&sql);
for p in &params { for p in &params {
query = query.bind(format!("{}", p)); query = query.bind(format!("{}", p));
} }
query = query.bind(now);
query.execute(db).await?; query.execute(db).await?;
get_model(db, model_id).await get_model(db, model_id).await
} }
pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> { pub async fn delete_model(db: &PgPool, model_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM models WHERE id = ?1") let result = sqlx::query("DELETE FROM models WHERE id = $1")
.bind(model_id).execute(db).await?; .bind(model_id).execute(db).await?;
if result.rows_affected() == 0 { if result.rows_affected() == 0 {
@@ -224,17 +244,17 @@ pub async fn delete_model(db: &SqlitePool, model_id: &str) -> SaasResult<()> {
// ============ Account API Keys ============ // ============ Account API Keys ============
pub async fn list_account_api_keys( pub async fn list_account_api_keys(
db: &SqlitePool, account_id: &str, provider_id: Option<&str>, db: &PgPool, encryption: &Arc<FieldEncryption>, account_id: &str, provider_id: Option<&str>,
) -> SaasResult<Vec<AccountApiKeyInfo>> { ) -> SaasResult<Vec<AccountApiKeyInfo>> {
let sql = if provider_id.is_some() { let sql = if provider_id.is_some() {
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value "SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = ?1 AND provider_id = ?2 AND revoked_at IS NULL ORDER BY created_at DESC" FROM account_api_keys WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL ORDER BY created_at DESC"
} else { } else {
"SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value "SELECT id, provider_id, key_label, permissions, enabled, last_used_at, created_at, key_value
FROM account_api_keys WHERE account_id = ?1 AND revoked_at IS NULL ORDER BY created_at DESC" FROM account_api_keys WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC"
}; };
let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<String>, String, String)>(sql) let mut query = sqlx::query_as::<_, (String, String, Option<String>, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>, String)>(sql)
.bind(account_id); .bind(account_id);
if let Some(pid) = provider_id { if let Some(pid) = provider_id {
query = query.bind(pid); query = query.bind(pid);
@@ -243,26 +263,32 @@ pub async fn list_account_api_keys(
let rows = query.fetch_all(db).await?; let rows = query.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| { Ok(rows.into_iter().map(|(id, provider_id, key_label, perms, enabled, last_used, created_at, key_value)| {
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default(); let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
let masked = mask_api_key(&key_value); // 解密 key_value 后再做掩码处理(兼容迁移期间的明文数据)
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used, created_at, masked_key: masked } let decrypted = encryption.decrypt_or_plaintext(&key_value);
let masked = mask_api_key(&decrypted);
AccountApiKeyInfo { id, provider_id, key_label, permissions, enabled, last_used_at: last_used.map(|t| t.to_rfc3339()), created_at: created_at.to_rfc3339(), masked_key: masked }
}).collect()) }).collect())
} }
pub async fn create_account_api_key( pub async fn create_account_api_key(
db: &SqlitePool, account_id: &str, req: &CreateAccountApiKeyRequest, db: &PgPool, encryption: &Arc<FieldEncryption>, account_id: &str, req: &CreateAccountApiKeyRequest,
) -> SaasResult<AccountApiKeyInfo> { ) -> SaasResult<AccountApiKeyInfo> {
// 验证 provider 存在 // 验证 provider 存在
get_provider(db, &req.provider_id).await?; get_provider(db, &req.provider_id).await?;
let id = uuid::Uuid::new_v4().to_string(); let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let now_str = now.to_rfc3339();
let permissions = serde_json::to_string(&req.permissions)?; let permissions = serde_json::to_string(&req.permissions)?;
// 加密 key_value 后存储
let encrypted_key_value = encryption.encrypt(&req.key_value)?;
sqlx::query( sqlx::query(
"INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at) "INSERT INTO account_api_keys (id, account_id, provider_id, key_value, key_label, permissions, enabled, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 1, ?7, ?7)" VALUES ($1, $2, $3, $4, $5, $6, true, $7, $7)"
) )
.bind(&id).bind(account_id).bind(&req.provider_id).bind(&req.key_value) .bind(&id).bind(account_id).bind(&req.provider_id).bind(&encrypted_key_value)
.bind(&req.key_label).bind(&permissions).bind(&now) .bind(&req.key_label).bind(&permissions).bind(&now)
.execute(db).await?; .execute(db).await?;
@@ -270,18 +296,20 @@ pub async fn create_account_api_key(
Ok(AccountApiKeyInfo { Ok(AccountApiKeyInfo {
id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(), id, provider_id: req.provider_id.clone(), key_label: req.key_label.clone(),
permissions: req.permissions.clone(), enabled: true, last_used_at: None, permissions: req.permissions.clone(), enabled: true, last_used_at: None,
created_at: now, masked_key: masked, created_at: now_str, masked_key: masked,
}) })
} }
pub async fn rotate_account_api_key( pub async fn rotate_account_api_key(
db: &SqlitePool, key_id: &str, account_id: &str, new_key_value: &str, db: &PgPool, encryption: &Arc<FieldEncryption>, key_id: &str, account_id: &str, new_key_value: &str,
) -> SaasResult<()> { ) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
// 加密新 key_value 后存储
let encrypted_key = encryption.encrypt(new_key_value)?;
let result = sqlx::query( let result = sqlx::query(
"UPDATE account_api_keys SET key_value = ?1, updated_at = ?2 WHERE id = ?3 AND account_id = ?4 AND revoked_at IS NULL" "UPDATE account_api_keys SET key_value = $1, updated_at = $2 WHERE id = $3 AND account_id = $4 AND revoked_at IS NULL"
) )
.bind(new_key_value).bind(&now).bind(key_id).bind(account_id) .bind(&encrypted_key).bind(&now).bind(key_id).bind(account_id)
.execute(db).await?; .execute(db).await?;
if result.rows_affected() == 0 { if result.rows_affected() == 0 {
@@ -291,11 +319,11 @@ pub async fn rotate_account_api_key(
} }
pub async fn revoke_account_api_key( pub async fn revoke_account_api_key(
db: &SqlitePool, key_id: &str, account_id: &str, db: &PgPool, key_id: &str, account_id: &str,
) -> SaasResult<()> { ) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let result = sqlx::query( let result = sqlx::query(
"UPDATE account_api_keys SET revoked_at = ?1 WHERE id = ?2 AND account_id = ?3 AND revoked_at IS NULL" "UPDATE account_api_keys SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
) )
.bind(&now).bind(key_id).bind(account_id) .bind(&now).bind(key_id).bind(account_id)
.execute(db).await?; .execute(db).await?;
@@ -309,25 +337,30 @@ pub async fn revoke_account_api_key(
// ============ Usage Statistics ============ // ============ Usage Statistics ============
pub async fn get_usage_stats( pub async fn get_usage_stats(
db: &SqlitePool, account_id: &str, query: &UsageQuery, db: &PgPool, account_id: &str, query: &UsageQuery,
) -> SaasResult<UsageStats> { ) -> SaasResult<UsageStats> {
let mut where_clauses = vec!["account_id = ?".to_string()]; let mut param_idx: i32 = 1;
let mut where_clauses = vec![format!("account_id = ${}", param_idx)];
param_idx += 1;
let mut params: Vec<String> = vec![account_id.to_string()]; let mut params: Vec<String> = vec![account_id.to_string()];
if let Some(ref from) = query.from { if let Some(ref from) = query.from {
where_clauses.push("created_at >= ?".to_string()); where_clauses.push(format!("created_at >= ${}", param_idx));
param_idx += 1;
params.push(from.clone()); params.push(from.clone());
} }
if let Some(ref to) = query.to { if let Some(ref to) = query.to {
where_clauses.push("created_at <= ?".to_string()); where_clauses.push(format!("created_at <= ${}", param_idx));
param_idx += 1;
params.push(to.clone()); params.push(to.clone());
} }
if let Some(ref pid) = query.provider_id { if let Some(ref pid) = query.provider_id {
where_clauses.push("provider_id = ?".to_string()); where_clauses.push(format!("provider_id = ${}", param_idx));
param_idx += 1;
params.push(pid.clone()); params.push(pid.clone());
} }
if let Some(ref mid) = query.model_id { if let Some(ref mid) = query.model_id {
where_clauses.push("model_id = ?".to_string()); where_clauses.push(format!("model_id = ${}", param_idx));
params.push(mid.clone()); params.push(mid.clone());
} }
@@ -361,10 +394,10 @@ pub async fn get_usage_stats(
}).collect(); }).collect();
// 按天统计 (最近 30 天) // 按天统计 (最近 30 天)
let from_30d = (chrono::Utc::now() - chrono::Duration::days(30)).to_rfc3339(); let from_30d = chrono::Utc::now() - chrono::Duration::days(30);
let daily_sql = format!( let daily_sql = format!(
"SELECT DATE(created_at) as day, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0) "SELECT DATE(created_at) as day, COUNT(*), COALESCE(SUM(input_tokens), 0), COALESCE(SUM(output_tokens), 0)
FROM usage_records WHERE account_id = ?1 AND created_at >= ?2 FROM usage_records WHERE account_id = $1 AND created_at >= $2
GROUP BY DATE(created_at) ORDER BY day DESC LIMIT 30" GROUP BY DATE(created_at) ORDER BY day DESC LIMIT 30"
); );
let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql) let daily_rows: Vec<(String, i64, i64, i64)> = sqlx::query_as(&daily_sql)
@@ -385,14 +418,14 @@ pub async fn get_usage_stats(
} }
pub async fn record_usage( pub async fn record_usage(
db: &SqlitePool, account_id: &str, provider_id: &str, model_id: &str, db: &PgPool, account_id: &str, provider_id: &str, model_id: &str,
input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>, input_tokens: i64, output_tokens: i64, latency_ms: Option<i64>,
status: &str, error_message: Option<&str>, status: &str, error_message: Option<&str>,
) -> SaasResult<()> { ) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
sqlx::query( sqlx::query(
"INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at) "INSERT INTO usage_records (account_id, provider_id, model_id, input_tokens, output_tokens, latency_ms, status, error_message, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9)" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
) )
.bind(account_id).bind(provider_id).bind(model_id) .bind(account_id).bind(provider_id).bind(model_id)
.bind(input_tokens).bind(output_tokens).bind(latency_ms) .bind(input_tokens).bind(output_tokens).bind(latency_ms)
@@ -409,3 +442,73 @@ fn mask_api_key(key: &str) -> String {
} }
format!("{}...{}", &key[..4], &key[key.len()-4..]) format!("{}...{}", &key[..4], &key[key.len()-4..])
} }
#[cfg(test)]
mod tests {
use super::*;
// ---- mask_api_key ----
#[test]
fn mask_key_long_key() {
let key = "sk-abcdefghijklmnopqrstuvwxyz123456";
let masked = mask_api_key(key);
assert_eq!(masked, "sk-a...3456");
}
#[test]
fn mask_key_exactly_8_chars() {
// keys <= 8 chars are fully masked
let key = "12345678";
let masked = mask_api_key(key);
assert_eq!(masked, "********");
}
#[test]
fn mask_key_7_chars() {
let key = "abcdefg";
let masked = mask_api_key(key);
assert_eq!(masked, "*******");
}
#[test]
fn mask_key_1_char() {
let key = "a";
let masked = mask_api_key(key);
assert_eq!(masked, "*");
}
#[test]
fn mask_key_empty() {
let key = "";
let masked = mask_api_key(key);
assert_eq!(masked, "");
}
#[test]
fn mask_key_9_chars_boundary() {
// 9 chars is the first that uses prefix...suffix format
let key = "abcdefghi";
let masked = mask_api_key(key);
assert_eq!(masked, "abcd...fghi");
}
#[test]
fn mask_key_standard_openai_format() {
let key = "sk-proj-abcdefghijklmnopqrstuvwx";
let masked = mask_api_key(key);
assert_eq!(masked, "sk-p...uvwx");
}
#[test]
fn mask_key_no_ellipsis_for_short() {
let masked = mask_api_key("short");
assert!(!masked.contains("..."));
}
#[test]
fn mask_key_has_ellipsis_for_long() {
let masked = mask_api_key("this_is_a_very_long_key_value");
assert!(masked.contains("..."));
}
}

View File

@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
// --- Provider --- // --- Provider ---
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ProviderInfo { pub struct ProviderInfo {
pub id: String, pub id: String,
pub name: String, pub name: String,
@@ -18,7 +18,7 @@ pub struct ProviderInfo {
pub updated_at: String, pub updated_at: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateProviderRequest { pub struct CreateProviderRequest {
pub name: String, pub name: String,
pub display_name: String, pub display_name: String,
@@ -32,7 +32,7 @@ pub struct CreateProviderRequest {
fn default_protocol() -> String { "openai".into() } fn default_protocol() -> String { "openai".into() }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateProviderRequest { pub struct UpdateProviderRequest {
pub display_name: Option<String>, pub display_name: Option<String>,
pub base_url: Option<String>, pub base_url: Option<String>,
@@ -45,7 +45,7 @@ pub struct UpdateProviderRequest {
// --- Model --- // --- Model ---
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct ModelInfo { pub struct ModelInfo {
pub id: String, pub id: String,
pub provider_id: String, pub provider_id: String,
@@ -62,7 +62,7 @@ pub struct ModelInfo {
pub updated_at: String, pub updated_at: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateModelRequest { pub struct CreateModelRequest {
pub provider_id: String, pub provider_id: String,
pub model_id: String, pub model_id: String,
@@ -75,7 +75,7 @@ pub struct CreateModelRequest {
pub pricing_output: Option<f64>, pub pricing_output: Option<f64>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct UpdateModelRequest { pub struct UpdateModelRequest {
pub alias: Option<String>, pub alias: Option<String>,
pub context_window: Option<i64>, pub context_window: Option<i64>,
@@ -89,7 +89,7 @@ pub struct UpdateModelRequest {
// --- Account API Key --- // --- Account API Key ---
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize, utoipa::ToSchema)]
pub struct AccountApiKeyInfo { pub struct AccountApiKeyInfo {
pub id: String, pub id: String,
pub provider_id: String, pub provider_id: String,
@@ -101,7 +101,7 @@ pub struct AccountApiKeyInfo {
pub masked_key: String, pub masked_key: String,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct CreateAccountApiKeyRequest { pub struct CreateAccountApiKeyRequest {
pub provider_id: String, pub provider_id: String,
pub key_value: String, pub key_value: String,
@@ -110,14 +110,14 @@ pub struct CreateAccountApiKeyRequest {
pub permissions: Vec<String>, pub permissions: Vec<String>,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct RotateApiKeyRequest { pub struct RotateApiKeyRequest {
pub new_key_value: String, pub new_key_value: String,
} }
// --- Usage --- // --- Usage ---
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct UsageStats { pub struct UsageStats {
pub total_requests: i64, pub total_requests: i64,
pub total_input_tokens: i64, pub total_input_tokens: i64,
@@ -126,7 +126,7 @@ pub struct UsageStats {
pub by_day: Vec<DailyUsage>, pub by_day: Vec<DailyUsage>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ModelUsage { pub struct ModelUsage {
pub provider_id: String, pub provider_id: String,
pub model_id: String, pub model_id: String,
@@ -135,7 +135,7 @@ pub struct ModelUsage {
pub output_tokens: i64, pub output_tokens: i64,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct DailyUsage { pub struct DailyUsage {
pub date: String, pub date: String,
pub request_count: i64, pub request_count: i64,
@@ -143,7 +143,7 @@ pub struct DailyUsage {
pub output_tokens: i64, pub output_tokens: i64,
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
pub struct UsageQuery { pub struct UsageQuery {
pub from: Option<String>, pub from: Option<String>,
pub to: Option<String>, pub to: Option<String>,
@@ -151,22 +151,3 @@ pub struct UsageQuery {
pub model_id: Option<String>, pub model_id: Option<String>,
} }
// --- Seed Data ---
#[derive(Debug, Deserialize)]
pub struct SeedProvider {
pub name: String,
pub display_name: String,
pub base_url: String,
pub models: Vec<SeedModel>,
}
#[derive(Debug, Deserialize)]
pub struct SeedModel {
pub id: String,
pub alias: String,
pub context_window: Option<i64>,
pub max_output_tokens: Option<i64>,
pub supports_streaming: Option<bool>,
pub supports_vision: Option<bool>,
}

View File

@@ -0,0 +1,790 @@
//! OpenAPI / Swagger 文档定义
//!
//! 聚合所有模块的 schema并在 build_router 中通过 utoipa-swagger-ui 暴露文档。
use utoipa::OpenApi;
/// ZCLAW SaaS API 根 OpenApi 定义
#[derive(OpenApi)]
#[openapi(
info(
title = "ZCLAW SaaS API",
version = "0.1.0",
description = "ZCLAW SaaS 后端服务 API -- 账号权限管理、模型配置、请求中转和配置迁移",
license(name = "Apache-2.0 OR MIT")
),
tags(
(name = "auth", description = "认证 (登录 / 注册 / TOTP)"),
(name = "accounts", description = "账号管理"),
(name = "providers", description = "模型供应商"),
(name = "models", description = "模型配置"),
(name = "keys", description = "API Key 管理"),
(name = "usage", description = "用量统计"),
(name = "relay", description = "请求中转"),
(name = "config", description = "配置迁移"),
),
paths(
crate::openapi::paths::auth::register,
crate::openapi::paths::auth::login,
crate::openapi::paths::auth::refresh,
crate::openapi::paths::auth::me,
crate::openapi::paths::auth::change_password,
crate::openapi::paths::auth::totp_setup,
crate::openapi::paths::auth::totp_verify,
crate::openapi::paths::auth::totp_disable,
crate::openapi::paths::accounts::list_accounts,
crate::openapi::paths::accounts::get_account,
crate::openapi::paths::accounts::update_account,
crate::openapi::paths::accounts::update_status,
crate::openapi::paths::accounts::list_tokens,
crate::openapi::paths::accounts::create_token,
crate::openapi::paths::accounts::revoke_token,
crate::openapi::paths::accounts::list_devices,
crate::openapi::paths::accounts::register_device,
crate::openapi::paths::accounts::device_heartbeat,
crate::openapi::paths::accounts::list_operation_logs,
crate::openapi::paths::accounts::dashboard_stats,
crate::openapi::paths::providers::list_providers,
crate::openapi::paths::providers::get_provider,
crate::openapi::paths::providers::create_provider,
crate::openapi::paths::providers::update_provider,
crate::openapi::paths::providers::delete_provider,
crate::openapi::paths::providers::list_provider_models,
crate::openapi::paths::models::list_models,
crate::openapi::paths::models::get_model,
crate::openapi::paths::models::create_model,
crate::openapi::paths::models::update_model,
crate::openapi::paths::models::delete_model,
crate::openapi::paths::keys::list_api_keys,
crate::openapi::paths::keys::create_api_key,
crate::openapi::paths::keys::revoke_api_key,
crate::openapi::paths::keys::rotate_api_key,
crate::openapi::paths::usage::get_usage,
crate::openapi::paths::relay::chat_completions,
crate::openapi::paths::relay::list_tasks,
crate::openapi::paths::relay::get_task,
crate::openapi::paths::relay::retry_task,
crate::openapi::paths::relay::list_available_models,
crate::openapi::paths::config::list_config_items,
crate::openapi::paths::config::get_config_item,
crate::openapi::paths::config::create_config_item,
crate::openapi::paths::config::update_config_item,
crate::openapi::paths::config::delete_config_item,
crate::openapi::paths::config::analyze_config,
crate::openapi::paths::config::seed_config,
crate::openapi::paths::config::sync_config,
crate::openapi::paths::config::config_diff,
crate::openapi::paths::config::list_sync_logs,
),
components(schemas(
crate::auth::types::LoginRequest,
crate::auth::types::LoginResponse,
crate::auth::types::RegisterRequest,
crate::auth::types::ChangePasswordRequest,
crate::auth::types::AccountPublic,
crate::account::types::UpdateAccountRequest,
crate::account::types::UpdateStatusRequest,
crate::account::types::ListAccountsQuery,
crate::account::types::AccountPublicPaginatedResponse,
crate::account::types::CreateTokenRequest,
crate::account::types::TokenInfo,
crate::account::types::RegisterDeviceRequest,
crate::account::types::DeviceHeartbeatRequest,
crate::account::types::DeviceInfo,
crate::model_config::types::ProviderInfo,
crate::model_config::types::CreateProviderRequest,
crate::model_config::types::UpdateProviderRequest,
crate::model_config::types::ModelInfo,
crate::model_config::types::CreateModelRequest,
crate::model_config::types::UpdateModelRequest,
crate::model_config::types::AccountApiKeyInfo,
crate::model_config::types::CreateAccountApiKeyRequest,
crate::model_config::types::RotateApiKeyRequest,
crate::model_config::types::UsageStats,
crate::model_config::types::ModelUsage,
crate::model_config::types::DailyUsage,
crate::model_config::types::UsageQuery,
crate::relay::types::RelayTaskInfo,
crate::relay::types::RelayTaskQuery,
crate::migration::types::ConfigItemInfo,
crate::migration::types::CreateConfigItemRequest,
crate::migration::types::UpdateConfigItemRequest,
crate::migration::types::ConfigSyncLogInfo,
crate::migration::types::ConfigAnalysis,
crate::migration::types::CategorySummary,
crate::migration::types::SyncConfigRequest,
crate::migration::types::ConfigDiffItem,
crate::migration::types::ConfigDiffResponse,
crate::migration::types::ConfigQuery,
)),
modifiers(&SecurityAddon)
)]
pub struct ApiDoc;
struct SecurityAddon;
impl utoipa::Modify for SecurityAddon {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
if let Some(components) = openapi.components.as_mut() {
components.add_security_scheme(
"bearer_auth",
utoipa::openapi::security::SecurityScheme::Http(
utoipa::openapi::security::Http::new(
utoipa::openapi::security::HttpAuthScheme::Bearer,
),
),
);
}
}
}
/// Path stubs for OpenAPI documentation generation.
/// These functions are never called at runtime -- they exist solely so that
/// `utoipa::path` can produce the correct OpenAPI spec entries.
pub mod paths {
pub mod auth {
#[utoipa::path(
post,
path = "/api/v1/auth/register",
tag = "auth",
request_body = crate::auth::types::RegisterRequest,
responses(
(status = 201, description = "注册成功", body = crate::auth::types::LoginResponse),
(status = 409, description = "用户已存在"),
)
)]
pub async fn register() {}
#[utoipa::path(
post,
path = "/api/v1/auth/login",
tag = "auth",
request_body = crate::auth::types::LoginRequest,
responses(
(status = 200, description = "登录成功", body = crate::auth::types::LoginResponse),
(status = 401, description = "认证失败"),
)
)]
pub async fn login() {}
#[utoipa::path(
post,
path = "/api/v1/auth/refresh",
tag = "auth",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "刷新 token 成功", body = crate::auth::types::LoginResponse),
(status = 401, description = "认证失败"),
)
)]
pub async fn refresh() {}
#[utoipa::path(
get,
path = "/api/v1/auth/me",
tag = "auth",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "当前用户信息", body = crate::auth::types::AccountPublic),
(status = 401, description = "未认证"),
)
)]
pub async fn me() {}
#[utoipa::path(
put,
path = "/api/v1/auth/password",
tag = "auth",
security(("bearer_auth" = [])),
request_body = crate::auth::types::ChangePasswordRequest,
responses(
(status = 200, description = "密码修改成功"),
(status = 400, description = "旧密码不正确"),
(status = 401, description = "未认证"),
)
)]
pub async fn change_password() {}
#[utoipa::path(
post,
path = "/api/v1/auth/totp/setup",
tag = "auth",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "TOTP 设置信息(含 secret 和 QR URI"),
(status = 401, description = "未认证"),
(status = 409, description = "TOTP 已启用"),
)
)]
pub async fn totp_setup() {}
#[utoipa::path(
post,
path = "/api/v1/auth/totp/verify",
tag = "auth",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "验证成功TOTP 已启用"),
(status = 401, description = "验证码错误"),
)
)]
pub async fn totp_verify() {}
#[utoipa::path(
post,
path = "/api/v1/auth/totp/disable",
tag = "auth",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "TOTP 已禁用"),
(status = 401, description = "密码错误"),
)
)]
pub async fn totp_disable() {}
}
pub mod accounts {
#[utoipa::path(
get,
path = "/api/v1/accounts",
tag = "accounts",
security(("bearer_auth" = [])),
params(crate::account::types::ListAccountsQuery),
responses(
(status = 200, description = "账号列表", body = crate::account::types::AccountPublicPaginatedResponse),
)
)]
pub async fn list_accounts() {}
#[utoipa::path(
get,
path = "/api/v1/accounts/{id}",
tag = "accounts",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "账号 ID")),
responses(
(status = 200, description = "账号详情", body = crate::auth::types::AccountPublic),
(status = 404, description = "账号不存在"),
)
)]
pub async fn get_account() {}
#[utoipa::path(
put,
path = "/api/v1/accounts/{id}",
tag = "accounts",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "账号 ID")),
request_body = crate::account::types::UpdateAccountRequest,
responses(
(status = 200, description = "更新成功", body = crate::auth::types::AccountPublic),
(status = 404, description = "账号不存在"),
)
)]
pub async fn update_account() {}
#[utoipa::path(
patch,
path = "/api/v1/accounts/{id}/status",
tag = "accounts",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "账号 ID")),
request_body = crate::account::types::UpdateStatusRequest,
responses(
(status = 200, description = "状态更新成功"),
)
)]
pub async fn update_status() {}
#[utoipa::path(
get,
path = "/api/v1/tokens",
tag = "accounts",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "Token 列表", body = Vec<crate::account::types::TokenInfo>),
)
)]
pub async fn list_tokens() {}
#[utoipa::path(
post,
path = "/api/v1/tokens",
tag = "accounts",
security(("bearer_auth" = [])),
request_body = crate::account::types::CreateTokenRequest,
responses(
(status = 201, description = "创建成功", body = crate::account::types::TokenInfo),
)
)]
pub async fn create_token() {}
#[utoipa::path(
delete,
path = "/api/v1/tokens/{id}",
tag = "accounts",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "Token ID")),
responses(
(status = 204, description = "撤销成功"),
)
)]
pub async fn revoke_token() {}
#[utoipa::path(
get,
path = "/api/v1/devices",
tag = "accounts",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "设备列表", body = Vec<crate::account::types::DeviceInfo>),
)
)]
pub async fn list_devices() {}
#[utoipa::path(
post,
path = "/api/v1/devices/register",
tag = "accounts",
security(("bearer_auth" = [])),
request_body = crate::account::types::RegisterDeviceRequest,
responses(
(status = 201, description = "注册成功", body = crate::account::types::DeviceInfo),
)
)]
pub async fn register_device() {}
#[utoipa::path(
post,
path = "/api/v1/devices/heartbeat",
tag = "accounts",
security(("bearer_auth" = [])),
request_body = crate::account::types::DeviceHeartbeatRequest,
responses(
(status = 200, description = "心跳更新成功"),
)
)]
pub async fn device_heartbeat() {}
#[utoipa::path(
get,
path = "/api/v1/logs/operations",
tag = "accounts",
security(("bearer_auth" = [])),
params(
("page" = Option<i32>, Query, description = "页码"),
("page_size" = Option<i32>, Query, description = "每页数量"),
("action" = Option<String>, Query, description = "操作类型过滤"),
("account_id" = Option<String>, Query, description = "账号 ID 过滤"),
),
responses(
(status = 200, description = "操作日志列表"),
)
)]
pub async fn list_operation_logs() {}
#[utoipa::path(
get,
path = "/api/v1/stats/dashboard",
tag = "accounts",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "仪表盘统计数据"),
)
)]
pub async fn dashboard_stats() {}
}
pub mod providers {
#[utoipa::path(
get,
path = "/api/v1/providers",
tag = "providers",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "供应商列表", body = Vec<crate::model_config::types::ProviderInfo>),
)
)]
pub async fn list_providers() {}
#[utoipa::path(
get,
path = "/api/v1/providers/{id}",
tag = "providers",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "供应商 ID")),
responses(
(status = 200, description = "供应商详情", body = crate::model_config::types::ProviderInfo),
(status = 404, description = "供应商不存在"),
)
)]
pub async fn get_provider() {}
#[utoipa::path(
post,
path = "/api/v1/providers",
tag = "providers",
security(("bearer_auth" = [])),
request_body = crate::model_config::types::CreateProviderRequest,
responses(
(status = 201, description = "创建成功", body = crate::model_config::types::ProviderInfo),
)
)]
pub async fn create_provider() {}
#[utoipa::path(
put,
path = "/api/v1/providers/{id}",
tag = "providers",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "供应商 ID")),
request_body = crate::model_config::types::UpdateProviderRequest,
responses(
(status = 200, description = "更新成功", body = crate::model_config::types::ProviderInfo),
(status = 404, description = "供应商不存在"),
)
)]
pub async fn update_provider() {}
#[utoipa::path(
delete,
path = "/api/v1/providers/{id}",
tag = "providers",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "供应商 ID")),
responses(
(status = 204, description = "删除成功"),
(status = 404, description = "供应商不存在"),
)
)]
pub async fn delete_provider() {}
#[utoipa::path(
get,
path = "/api/v1/providers/{id}/models",
tag = "providers",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "供应商 ID")),
responses(
(status = 200, description = "供应商下的模型列表", body = Vec<crate::model_config::types::ModelInfo>),
)
)]
pub async fn list_provider_models() {}
}
pub mod models {
#[utoipa::path(
get,
path = "/api/v1/models",
tag = "models",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "模型列表", body = Vec<crate::model_config::types::ModelInfo>),
)
)]
pub async fn list_models() {}
#[utoipa::path(
get,
path = "/api/v1/models/{id}",
tag = "models",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "模型 ID")),
responses(
(status = 200, description = "模型详情", body = crate::model_config::types::ModelInfo),
(status = 404, description = "模型不存在"),
)
)]
pub async fn get_model() {}
#[utoipa::path(
post,
path = "/api/v1/models",
tag = "models",
security(("bearer_auth" = [])),
request_body = crate::model_config::types::CreateModelRequest,
responses(
(status = 201, description = "创建成功", body = crate::model_config::types::ModelInfo),
)
)]
pub async fn create_model() {}
#[utoipa::path(
put,
path = "/api/v1/models/{id}",
tag = "models",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "模型 ID")),
request_body = crate::model_config::types::UpdateModelRequest,
responses(
(status = 200, description = "更新成功", body = crate::model_config::types::ModelInfo),
(status = 404, description = "模型不存在"),
)
)]
pub async fn update_model() {}
#[utoipa::path(
delete,
path = "/api/v1/models/{id}",
tag = "models",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "模型 ID")),
responses(
(status = 204, description = "删除成功"),
(status = 404, description = "模型不存在"),
)
)]
pub async fn delete_model() {}
}
pub mod keys {
#[utoipa::path(
get,
path = "/api/v1/keys",
tag = "keys",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "API Key 列表", body = Vec<crate::model_config::types::AccountApiKeyInfo>),
)
)]
pub async fn list_api_keys() {}
#[utoipa::path(
post,
path = "/api/v1/keys",
tag = "keys",
security(("bearer_auth" = [])),
request_body = crate::model_config::types::CreateAccountApiKeyRequest,
responses(
(status = 201, description = "创建成功", body = crate::model_config::types::AccountApiKeyInfo),
)
)]
pub async fn create_api_key() {}
#[utoipa::path(
delete,
path = "/api/v1/keys/{id}",
tag = "keys",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "Key ID")),
responses(
(status = 204, description = "撤销成功"),
)
)]
pub async fn revoke_api_key() {}
#[utoipa::path(
post,
path = "/api/v1/keys/{id}/rotate",
tag = "keys",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "Key ID")),
request_body = crate::model_config::types::RotateApiKeyRequest,
responses(
(status = 200, description = "轮换成功", body = crate::model_config::types::AccountApiKeyInfo),
)
)]
pub async fn rotate_api_key() {}
}
pub mod usage {
#[utoipa::path(
get,
path = "/api/v1/usage",
tag = "usage",
security(("bearer_auth" = [])),
params(crate::model_config::types::UsageQuery),
responses(
(status = 200, description = "用量统计", body = crate::model_config::types::UsageStats),
)
)]
pub async fn get_usage() {}
}
pub mod relay {
#[utoipa::path(
post,
path = "/api/v1/relay/chat/completions",
tag = "relay",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "聊天补全响应JSON 或 SSE 流)"),
(status = 402, description = "上游服务错误"),
(status = 404, description = "模型不存在或未启用"),
)
)]
pub async fn chat_completions() {}
#[utoipa::path(
get,
path = "/api/v1/relay/tasks",
tag = "relay",
security(("bearer_auth" = [])),
params(crate::relay::types::RelayTaskQuery),
responses(
(status = 200, description = "中转任务列表", body = Vec<crate::relay::types::RelayTaskInfo>),
)
)]
pub async fn list_tasks() {}
#[utoipa::path(
get,
path = "/api/v1/relay/tasks/{id}",
tag = "relay",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "任务 ID")),
responses(
(status = 200, description = "任务详情", body = crate::relay::types::RelayTaskInfo),
(status = 404, description = "任务不存在"),
)
)]
pub async fn get_task() {}
#[utoipa::path(
post,
path = "/api/v1/relay/tasks/{id}/retry",
tag = "relay",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "任务 ID")),
responses(
(status = 200, description = "重试成功", body = crate::relay::types::RelayTaskInfo),
(status = 404, description = "任务不存在"),
)
)]
pub async fn retry_task() {}
#[utoipa::path(
get,
path = "/api/v1/relay/models",
tag = "relay",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "可用模型列表", body = Vec<crate::model_config::types::ModelInfo>),
)
)]
pub async fn list_available_models() {}
}
pub mod config {
#[utoipa::path(
get,
path = "/api/v1/config/items",
tag = "config",
security(("bearer_auth" = [])),
params(crate::migration::types::ConfigQuery),
responses(
(status = 200, description = "配置项列表", body = Vec<crate::migration::types::ConfigItemInfo>),
)
)]
pub async fn list_config_items() {}
#[utoipa::path(
get,
path = "/api/v1/config/items/{id}",
tag = "config",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "配置项 ID")),
responses(
(status = 200, description = "配置项详情", body = crate::migration::types::ConfigItemInfo),
(status = 404, description = "配置项不存在"),
)
)]
pub async fn get_config_item() {}
#[utoipa::path(
post,
path = "/api/v1/config/items",
tag = "config",
security(("bearer_auth" = [])),
request_body = crate::migration::types::CreateConfigItemRequest,
responses(
(status = 201, description = "创建成功", body = crate::migration::types::ConfigItemInfo),
)
)]
pub async fn create_config_item() {}
#[utoipa::path(
put,
path = "/api/v1/config/items/{id}",
tag = "config",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "配置项 ID")),
request_body = crate::migration::types::UpdateConfigItemRequest,
responses(
(status = 200, description = "更新成功", body = crate::migration::types::ConfigItemInfo),
(status = 404, description = "配置项不存在"),
)
)]
pub async fn update_config_item() {}
#[utoipa::path(
delete,
path = "/api/v1/config/items/{id}",
tag = "config",
security(("bearer_auth" = [])),
params(("id" = String, Path, description = "配置项 ID")),
responses(
(status = 204, description = "删除成功"),
(status = 404, description = "配置项不存在"),
)
)]
pub async fn delete_config_item() {}
#[utoipa::path(
get,
path = "/api/v1/config/analysis",
tag = "config",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "配置分析结果", body = crate::migration::types::ConfigAnalysis),
)
)]
pub async fn analyze_config() {}
#[utoipa::path(
post,
path = "/api/v1/config/seed",
tag = "config",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "种子数据初始化成功"),
)
)]
pub async fn seed_config() {}
#[utoipa::path(
post,
path = "/api/v1/config/sync",
tag = "config",
security(("bearer_auth" = [])),
request_body = crate::migration::types::SyncConfigRequest,
responses(
(status = 200, description = "同步成功"),
)
)]
pub async fn sync_config() {}
#[utoipa::path(
post,
path = "/api/v1/config/diff",
tag = "config",
security(("bearer_auth" = [])),
request_body = crate::migration::types::SyncConfigRequest,
responses(
(status = 200, description = "配置差异", body = crate::migration::types::ConfigDiffResponse),
)
)]
pub async fn config_diff() {}
#[utoipa::path(
get,
path = "/api/v1/config/sync-logs",
tag = "config",
security(("bearer_auth" = [])),
responses(
(status = 200, description = "同步日志列表", body = Vec<crate::migration::types::ConfigSyncLogInfo>),
)
)]
pub async fn list_sync_logs() {}
}
}

View File

@@ -1,5 +1,9 @@
//! 中转服务 HTTP 处理器 //! 中转服务 HTTP 处理器
use std::sync::Arc;
use tokio::sync::Mutex;
use axum::body::Bytes;
use axum::{ use axum::{
extract::{Extension, Path, Query, State}, extract::{Extension, Path, Query, State},
http::{HeaderMap, StatusCode}, http::{HeaderMap, StatusCode},
@@ -31,33 +35,70 @@ pub async fn chat_completions(
.and_then(|v| v.as_bool()) .and_then(|v| v.as_bool())
.unwrap_or(false); .unwrap_or(false);
// 查找 model 对应的 provider // 查找 model 对应的 provider (直接 SQL 查询,避免全量加载)
let models = model_service::list_models(&state.db, None).await?; let target_model = sqlx::query_as::<_, (String, String, String, String, i64, i64, bool, bool, bool)>(
let target_model = models.iter().find(|m| m.model_id == model_name && m.enabled) "SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?; FROM models WHERE model_id = $1 AND enabled = true"
)
.bind(model_name)
.fetch_optional(&state.db)
.await?
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
let (_model_id, provider_id, model_name_db, _, _, _, _, _, _) = target_model;
// 获取 provider 信息 // 获取 provider 信息
let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?; let provider = model_service::get_provider(&state.db, &provider_id).await?;
if !provider.enabled { if !provider.enabled {
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name))); return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
} }
// 获取 provider 的 API key (从数据库直接查询) // 优先使用用户级 account_api_key回退到 provider 级 key
let provider_api_key: Option<String> = sqlx::query_scalar( let account_key_encrypted: Option<String> = sqlx::query_scalar(
"SELECT api_key FROM providers WHERE id = ?1" "SELECT key_value FROM account_api_keys
WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL AND enabled = true
ORDER BY created_at DESC LIMIT 1"
) )
.bind(&target_model.provider_id) .bind(&ctx.account_id)
.bind(&provider_id)
.fetch_optional(&state.db) .fetch_optional(&state.db)
.await? .await?
.flatten(); .flatten();
let api_key: Option<String> = if let Some(encrypted) = account_key_encrypted {
// 更新 last_used_at
let _ = sqlx::query(
"UPDATE account_api_keys SET last_used_at = NOW() WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL AND enabled = true"
)
.bind(&ctx.account_id)
.bind(&provider_id)
.execute(&state.db)
.await;
Some(state.field_encryption.decrypt_or_plaintext(&encrypted))
} else {
// 回退到 provider 级 key
let provider_key_encrypted: Option<String> = sqlx::query_scalar(
"SELECT api_key FROM providers WHERE id = $1"
)
.bind(&provider_id)
.fetch_optional(&state.db)
.await?
.flatten();
provider_key_encrypted.map(|k| state.field_encryption.decrypt_or_plaintext(&k))
};
if api_key.is_none() {
return Err(SaasError::Internal(format!(
"Provider {} 没有可用的 API Key", provider.name
)));
}
let request_body = serde_json::to_string(&req)?; let request_body = serde_json::to_string(&req)?;
// 创建中转任务 // 创建中转任务
let config = state.config.read().await; let config = state.config.read().await;
let task = service::create_relay_task( let task = service::create_relay_task(
&state.db, &ctx.account_id, &target_model.provider_id, &state.db, &ctx.account_id, &provider_id,
&target_model.model_id, &request_body, 0, &model_name_db, &request_body, 0,
config.relay.max_attempts, config.relay.max_attempts,
).await?; ).await?;
@@ -66,8 +107,9 @@ pub async fn chat_completions(
// 执行中转 (带重试) // 执行中转 (带重试)
let response = service::execute_relay( let response = service::execute_relay(
&state.db, &task.id, &provider.base_url, &state.db, &task.id, &ctx.account_id, &provider_id, &model_name_db,
provider_api_key.as_deref(), &request_body, stream, &provider.base_url,
api_key.as_deref(), &request_body, stream,
config.relay.max_attempts, config.relay.max_attempts,
config.relay.retry_delay_ms, config.relay.retry_delay_ms,
).await; ).await;
@@ -86,34 +128,35 @@ pub async fn chat_completions(
.unwrap_or(0); .unwrap_or(0);
model_service::record_usage( model_service::record_usage(
&state.db, &ctx.account_id, &target_model.provider_id, &state.db, &ctx.account_id, &provider_id,
&target_model.model_id, input_tokens, output_tokens, &model_name_db, input_tokens, output_tokens,
None, "success", None, None, "success", None,
).await?; ).await?;
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response()) Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
} }
Ok(service::RelayResponse::Sse(body)) => { Ok(service::RelayResponse::SseWithUsage { body, task_id: relay_task_id, account_id: relay_account_id, provider_id: relay_provider_id, model_id: relay_model_id }) => {
model_service::record_usage( // 流式响应: 使用 async_stream 包装器提取 SSE 末尾的 usage
&state.db, &ctx.account_id, &target_model.provider_id, let wrapped = sse_usage_wrapper(
&target_model.model_id, 0, 0, state.db.clone(),
None, "success", None, relay_task_id, relay_account_id, relay_provider_id, relay_model_id,
).await?; body,
);
let wrapped_body = axum::body::Body::from_stream(wrapped);
// 流式响应: 直接转发 axum::body::Body
let response = axum::response::Response::builder() let response = axum::response::Response::builder()
.status(StatusCode::OK) .status(StatusCode::OK)
.header(axum::http::header::CONTENT_TYPE, "text/event-stream") .header(axum::http::header::CONTENT_TYPE, "text/event-stream")
.header("Cache-Control", "no-cache") .header("Cache-Control", "no-cache")
.header("Connection", "keep-alive") .header("Connection", "keep-alive")
.body(body) .body(wrapped_body)
.unwrap(); .map_err(|e| SaasError::Internal(format!("SSE 响应构建失败: {}", e)))?;
Ok(response) Ok(response)
} }
Err(e) => { Err(e) => {
model_service::record_usage( model_service::record_usage(
&state.db, &ctx.account_id, &target_model.provider_id, &state.db, &ctx.account_id, &provider_id,
&target_model.model_id, 0, 0, &model_name_db, 0, 0,
None, "failed", Some(&e.to_string()), None, "failed", Some(&e.to_string()),
).await?; ).await?;
Err(e) Err(e)
@@ -179,7 +222,7 @@ pub async fn retry_task(
State(state): State<AppState>, State(state): State<AppState>,
Path(id): Path<String>, Path(id): Path<String>,
Extension(ctx): Extension<AuthContext>, Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> { ) -> SaasResult<(StatusCode, Json<serde_json::Value>)> {
check_permission(&ctx, "relay:admin")?; check_permission(&ctx, "relay:admin")?;
let task = service::get_relay_task(&state.db, &id).await?; let task = service::get_relay_task(&state.db, &id).await?;
@@ -191,17 +234,35 @@ pub async fn retry_task(
// 获取 provider 信息 // 获取 provider 信息
let provider = model_service::get_provider(&state.db, &task.provider_id).await?; let provider = model_service::get_provider(&state.db, &task.provider_id).await?;
let provider_api_key: Option<String> = sqlx::query_scalar(
"SELECT api_key FROM providers WHERE id = ?1" // 重试时使用原始任务所属用户的 account key回退到 provider key
let account_key_encrypted: Option<String> = sqlx::query_scalar(
"SELECT key_value FROM account_api_keys
WHERE account_id = $1 AND provider_id = $2 AND revoked_at IS NULL AND enabled = true
ORDER BY created_at DESC LIMIT 1"
) )
.bind(&task.account_id)
.bind(&task.provider_id) .bind(&task.provider_id)
.fetch_optional(&state.db) .fetch_optional(&state.db)
.await? .await?
.flatten(); .flatten();
let api_key: Option<String> = if let Some(encrypted) = account_key_encrypted {
Some(state.field_encryption.decrypt_or_plaintext(&encrypted))
} else {
let provider_key_encrypted: Option<String> = sqlx::query_scalar(
"SELECT api_key FROM providers WHERE id = $1"
)
.bind(&task.provider_id)
.fetch_optional(&state.db)
.await?
.flatten();
provider_key_encrypted.map(|k| state.field_encryption.decrypt_or_plaintext(&k))
};
// 读取原始请求体 // 读取原始请求体
let request_body: Option<String> = sqlx::query_scalar( let request_body: Option<String> = sqlx::query_scalar(
"SELECT request_body FROM relay_tasks WHERE id = ?1" "SELECT request_body FROM relay_tasks WHERE id = $1"
) )
.bind(&id) .bind(&id)
.fetch_optional(&state.db) .fetch_optional(&state.db)
@@ -222,7 +283,7 @@ pub async fn retry_task(
// 重置任务状态为 queued 以允许新的 processing // 重置任务状态为 queued 以允许新的 processing
sqlx::query( sqlx::query(
"UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = ?1" "UPDATE relay_tasks SET status = 'queued', error_message = NULL, started_at = NULL, completed_at = NULL WHERE id = $1"
) )
.bind(&id) .bind(&id)
.execute(&state.db) .execute(&state.db)
@@ -231,10 +292,14 @@ pub async fn retry_task(
// 异步执行重试 // 异步执行重试
let db = state.db.clone(); let db = state.db.clone();
let task_id = id.clone(); let task_id = id.clone();
let retry_account_id = ctx.account_id.clone();
let retry_provider_id = task.provider_id.clone();
let retry_model_id = task.model_id.clone();
tokio::spawn(async move { tokio::spawn(async move {
match service::execute_relay( match service::execute_relay(
&db, &task_id, &provider.base_url, &db, &task_id, &retry_account_id, &retry_provider_id, &retry_model_id,
provider_api_key.as_deref(), &body, stream, &provider.base_url,
api_key.as_deref(), &body, stream,
max_attempts, base_delay_ms, max_attempts, base_delay_ms,
).await { ).await {
Ok(_) => tracing::info!("Relay task {} 重试成功", task_id), Ok(_) => tracing::info!("Relay task {} 重试成功", task_id),
@@ -245,5 +310,101 @@ pub async fn retry_task(
log_operation(&state.db, &ctx.account_id, "relay.retry", "relay_task", &id, log_operation(&state.db, &ctx.account_id, "relay.retry", "relay_task", &id,
None, ctx.client_ip.as_deref()).await?; None, ctx.client_ip.as_deref()).await?;
Ok(Json(serde_json::json!({"ok": true, "task_id": id}))) Ok((StatusCode::ACCEPTED, Json(serde_json::json!({"ok": true, "task_id": id}))))
}
/// 包装 SSE 流,提取末尾的 usage 数据并异步记录
///
/// 支持客户端断连检测:当 body stream 返回错误(通常表示客户端提前断开连接),
/// 记录日志并将任务标记为 "cancelled" 而非 "completed"。
fn sse_usage_wrapper(
db: sqlx::PgPool,
task_id: String,
account_id: String,
provider_id: String,
model_id: String,
body: axum::body::Body,
) -> impl futures::Stream<Item = Result<Bytes, std::io::Error>> + Send {
use futures::StreamExt;
let last_usage: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let mut saw_done = false;
async_stream::stream! {
let mut data_stream = std::pin::pin!(body.into_data_stream().map(|r| r.map_err(std::io::Error::other)));
loop {
match StreamExt::next(&mut data_stream).await {
Some(Ok(chunk)) => {
let text = String::from_utf8_lossy(&chunk);
for line in text.lines() {
if let Some(data) = line.strip_prefix("data: ") {
let trimmed = data.trim();
if trimmed == "[DONE]" {
saw_done = true;
let usage_str = last_usage.lock().await.take();
if let Some(s) = usage_str {
let (input, output) = service::extract_token_usage(&s);
if input > 0 || output > 0 {
let db2 = db.clone();
let tid = task_id.clone();
let aid = account_id.clone();
let pid = provider_id.clone();
let mid = model_id.clone();
tokio::spawn(async move {
let _ = service::update_task_status(
&db2, &tid, "completed",
Some(input), Some(output), None
).await;
let _ = model_service::record_usage(
&db2, &aid, &pid, &mid,
input, output, None, "success", None,
).await;
});
}
}
} else if serde_json::from_str::<serde_json::Value>(trimmed)
.ok()
.and_then(|v| if v.get("usage").is_some() { Some(trimmed.to_string()) } else { None })
.is_some()
{
*last_usage.lock().await = Some(trimmed.to_owned());
}
}
}
yield Ok(chunk);
}
Some(Err(e)) => {
// 客户端断连或上游连接中断
if !saw_done {
tracing::warn!(
"SSE stream error for task {} (client disconnected): {}",
task_id, e
);
// 将任务标记为 cancelled区别于 completed 和 failed
let db2 = db.clone();
let tid = task_id.clone();
tokio::spawn(async move {
let _ = service::update_task_status(
&db2, &tid, "cancelled",
None, None, Some("客户端断开连接"),
).await;
});
}
break;
}
None => {
// Stream 正常结束(上游发送完毕)
if !saw_done {
// 上游关闭但未发送 [DONE],仍记录完成
tracing::info!(
"SSE stream ended without [DONE] for task {}",
task_id,
);
}
break;
}
}
}
}
} }

View File

@@ -1,6 +1,6 @@
//! 中转服务核心逻辑 //! 中转服务核心逻辑
use sqlx::SqlitePool; use sqlx::PgPool;
use crate::error::{SaasError, SaasResult}; use crate::error::{SaasError, SaasResult};
use super::types::*; use super::types::*;
use futures::StreamExt; use futures::StreamExt;
@@ -18,35 +18,34 @@ fn is_retryable_error(e: &reqwest::Error) -> bool {
// ============ Relay Task Management ============ // ============ Relay Task Management ============
pub async fn create_relay_task( pub async fn create_relay_task(
db: &SqlitePool, db: &PgPool,
account_id: &str, account_id: &str,
provider_id: &str, provider_id: &str,
model_id: &str, model_id: &str,
request_body: &str, request_body: &str,
priority: i64, _priority: i64,
max_attempts: u32, max_attempts: u32,
) -> SaasResult<RelayTaskInfo> { ) -> SaasResult<RelayTaskInfo> {
let id = uuid::Uuid::new_v4().to_string(); let id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let request_hash = hash_request(request_body);
let max_attempts = max_attempts.max(1).min(5); let max_attempts = max_attempts.max(1).min(5);
sqlx::query( sqlx::query(
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at) "INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, 'queued', ?7, 0, ?8, ?9, ?9)" VALUES ($1, $2, $3, $4, '', $5, 'queued', 0, 0, $6, $7, $7)"
) )
.bind(&id).bind(account_id).bind(provider_id).bind(model_id) .bind(&id).bind(account_id).bind(provider_id).bind(model_id)
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now) .bind(request_body).bind(max_attempts as i64).bind(&now)
.execute(db).await?; .execute(db).await?;
get_relay_task(db, &id).await get_relay_task(db, &id).await
} }
pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayTaskInfo> { pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)> = let row: Option<(String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, chrono::DateTime<chrono::Utc>, Option<chrono::DateTime<chrono::Utc>>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
sqlx::query_as( sqlx::query_as(
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE id = ?1" FROM relay_tasks WHERE id = $1"
) )
.bind(task_id) .bind(task_id)
.fetch_optional(db) .fetch_optional(db)
@@ -58,12 +57,12 @@ pub async fn get_relay_task(db: &SqlitePool, task_id: &str) -> SaasResult<RelayT
Ok(RelayTaskInfo { Ok(RelayTaskInfo {
id, account_id, provider_id, model_id, status, priority, id, account_id, provider_id, model_id, status, priority,
attempt_count, max_attempts, input_tokens, output_tokens, attempt_count, max_attempts, input_tokens, output_tokens,
error_message, queued_at, started_at, completed_at, created_at, error_message, queued_at: queued_at.to_rfc3339(), started_at: started_at.map(|t| t.to_rfc3339()), completed_at: completed_at.map(|t| t.to_rfc3339()), created_at: created_at.to_rfc3339(),
}) })
} }
pub async fn list_relay_tasks( pub async fn list_relay_tasks(
db: &SqlitePool, account_id: &str, query: &RelayTaskQuery, db: &PgPool, account_id: &str, query: &RelayTaskQuery,
) -> SaasResult<Vec<RelayTaskInfo>> { ) -> SaasResult<Vec<RelayTaskInfo>> {
let page = query.page.unwrap_or(1).max(1); let page = query.page.unwrap_or(1).max(1);
let page_size = query.page_size.unwrap_or(20).min(100); let page_size = query.page_size.unwrap_or(20).min(100);
@@ -71,13 +70,13 @@ pub async fn list_relay_tasks(
let sql = if query.status.is_some() { let sql = if query.status.is_some() {
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = ?1 AND status = ?2 ORDER BY created_at DESC LIMIT ?3 OFFSET ?4" FROM relay_tasks WHERE account_id = $1 AND status = $2 ORDER BY created_at DESC LIMIT $3 OFFSET $4"
} else { } else {
"SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at "SELECT id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at
FROM relay_tasks WHERE account_id = ?1 ORDER BY created_at DESC LIMIT ?2 OFFSET ?3" FROM relay_tasks WHERE account_id = $1 ORDER BY created_at DESC LIMIT $2 OFFSET $3"
}; };
let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, String, Option<String>, Option<String>, String)>(sql) let mut query_builder = sqlx::query_as::<_, (String, String, String, String, String, i64, i64, i64, i64, i64, Option<String>, chrono::DateTime<chrono::Utc>, Option<chrono::DateTime<chrono::Utc>>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)>(sql)
.bind(account_id); .bind(account_id);
if let Some(ref status) = query.status { if let Some(ref status) = query.status {
@@ -88,31 +87,32 @@ pub async fn list_relay_tasks(
let rows = query_builder.fetch_all(db).await?; let rows = query_builder.fetch_all(db).await?;
Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| { Ok(rows.into_iter().map(|(id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at)| {
RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at } RelayTaskInfo { id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at: queued_at.to_rfc3339(), started_at: started_at.map(|t| t.to_rfc3339()), completed_at: completed_at.map(|t| t.to_rfc3339()), created_at: created_at.to_rfc3339() }
}).collect()) }).collect())
} }
pub async fn update_task_status( pub async fn update_task_status(
db: &SqlitePool, task_id: &str, status: &str, db: &PgPool, task_id: &str, status: &str,
input_tokens: Option<i64>, output_tokens: Option<i64>, input_tokens: Option<i64>, output_tokens: Option<i64>,
error_message: Option<&str>, error_message: Option<&str>,
) -> SaasResult<()> { ) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339(); let now = chrono::Utc::now();
let update_sql = match status { let update_sql = match status {
"processing" => "started_at = ?1, status = 'processing', attempt_count = attempt_count + 1", "processing" => "started_at = $1, status = 'processing', attempt_count = attempt_count + 1",
"completed" => "completed_at = ?1, status = 'completed', input_tokens = COALESCE(?2, input_tokens), output_tokens = COALESCE(?3, output_tokens)", "completed" => "completed_at = $1, status = 'completed', input_tokens = COALESCE($2, input_tokens), output_tokens = COALESCE($3, output_tokens)",
"failed" => "completed_at = ?1, status = 'failed', error_message = ?2", "failed" => "completed_at = $1, status = 'failed', error_message = $2",
"cancelled" => "completed_at = $1, status = 'cancelled', error_message = $2",
_ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))), _ => return Err(SaasError::InvalidInput(format!("无效任务状态: {}", status))),
}; };
let sql = format!("UPDATE relay_tasks SET {} WHERE id = ?4", update_sql); let sql = format!("UPDATE relay_tasks SET {} WHERE id = $4", update_sql);
let mut query = sqlx::query(&sql).bind(&now); let mut query = sqlx::query(&sql).bind(&now);
if status == "completed" { if status == "completed" {
query = query.bind(input_tokens).bind(output_tokens); query = query.bind(input_tokens).bind(output_tokens);
} }
if status == "failed" { if status == "failed" || status == "cancelled" {
query = query.bind(error_message); query = query.bind(error_message);
} }
query = query.bind(task_id); query = query.bind(task_id);
@@ -124,8 +124,11 @@ pub async fn update_task_status(
// ============ Relay Execution ============ // ============ Relay Execution ============
pub async fn execute_relay( pub async fn execute_relay(
db: &SqlitePool, db: &PgPool,
task_id: &str, task_id: &str,
account_id: &str,
provider_id: &str,
model_id: &str,
provider_base_url: &str, provider_base_url: &str,
provider_api_key: Option<&str>, provider_api_key: Option<&str>,
request_body: &str, request_body: &str,
@@ -135,6 +138,31 @@ pub async fn execute_relay(
) -> SaasResult<RelayResponse> { ) -> SaasResult<RelayResponse> {
validate_provider_url(provider_base_url)?; validate_provider_url(provider_base_url)?;
// DNS Rebinding 防护: 解析 host 并验证所有 resolved IP 非私有
let parsed_url: url::Url = provider_base_url.trim_end_matches('/').parse()
.map_err(|_| SaasError::InvalidInput(format!("无效的 provider URL: {}", provider_base_url)))?;
let host = parsed_url.host_str()
.ok_or_else(|| SaasError::InvalidInput("provider URL 缺少 host".into()))?;
// 仅对非 IP 的 host 做 DNS 解析(纯 IP 已在 validate_provider_url 中检查)
if host.parse::<std::net::IpAddr>().is_err() {
let port = parsed_url.port_or_known_default().unwrap_or(443);
let addr_str = format!("{}:{}", host, port);
let addrs: Vec<std::net::SocketAddr> = std::net::ToSocketAddrs::to_socket_addrs(&addr_str)
.map_err(|e| SaasError::InvalidInput(format!("DNS 解析失败: {}", e)))?
.collect();
if addrs.is_empty() {
return Err(SaasError::InvalidInput(format!("DNS 解析无结果: {}", host)));
}
for addr in &addrs {
if is_private_ip(&addr.ip()) {
return Err(SaasError::InvalidInput(format!(
"provider URL {} 解析到私有 IP: {}", host, addr.ip()
)));
}
}
}
let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/')); let url = format!("{}/chat/completions", provider_base_url.trim_end_matches('/'));
let client = reqwest::Client::builder() let client = reqwest::Client::builder()
@@ -167,8 +195,14 @@ pub async fn execute_relay(
let byte_stream = resp.bytes_stream() let byte_stream = resp.bytes_stream()
.map(|result| result.map_err(std::io::Error::other)); .map(|result| result.map_err(std::io::Error::other));
let body = axum::body::Body::from_stream(byte_stream); let body = axum::body::Body::from_stream(byte_stream);
update_task_status(db, task_id, "completed", None, None, None).await?; update_task_status(db, task_id, "completed", None, None, None).await?;
return Ok(RelayResponse::Sse(body)); return Ok(RelayResponse::SseWithUsage {
body,
task_id: task_id.to_string(),
account_id: account_id.to_string(),
provider_id: provider_id.to_string(),
model_id: model_id.to_string(),
});
} else { } else {
let body = resp.text().await.unwrap_or_default(); let body = resp.text().await.unwrap_or_default();
let (input_tokens, output_tokens) = extract_token_usage(&body); let (input_tokens, output_tokens) = extract_token_usage(&body);
@@ -182,7 +216,12 @@ pub async fn execute_relay(
if !is_retryable_status(status) || attempt + 1 >= max_attempts { if !is_retryable_status(status) || attempt + 1 >= max_attempts {
// 4xx 客户端错误或已达最大重试次数 → 立即失败 // 4xx 客户端错误或已达最大重试次数 → 立即失败
let body = resp.text().await.unwrap_or_default(); let body = resp.text().await.unwrap_or_default();
let err_msg = format!("上游返回 HTTP {}: {}", status, &body[..body.len().min(500)]); // 仅记录日志,不将上游错误体暴露给客户端(可能含敏感信息如 API key
tracing::warn!(
"Relay task {} 上游返回 HTTP {} (body truncated): {}",
task_id, status, &body[..body.len().min(200)]
);
let err_msg = format!("上游服务返回错误 (HTTP {})", status);
update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?; update_task_status(db, task_id, "failed", None, None, Some(&err_msg)).await?;
return Err(SaasError::Relay(err_msg)); return Err(SaasError::Relay(err_msg));
} }
@@ -218,17 +257,19 @@ pub async fn execute_relay(
#[derive(Debug)] #[derive(Debug)]
pub enum RelayResponse { pub enum RelayResponse {
Json(String), Json(String),
Sse(axum::body::Body), /// SSE 流式响应 + 上下文信息
SseWithUsage {
body: axum::body::Body,
task_id: String,
account_id: String,
provider_id: String,
model_id: String,
},
} }
// ============ Helpers ============ // ============ Helpers ============
fn hash_request(body: &str) -> String { pub fn extract_token_usage(body: &str) -> (i64, i64) {
use sha2::{Sha256, Digest};
hex::encode(Sha256::digest(body.as_bytes()))
}
fn extract_token_usage(body: &str) -> (i64, i64) {
let parsed: serde_json::Value = match serde_json::from_str(body) { let parsed: serde_json::Value = match serde_json::from_str(body) {
Ok(v) => v, Ok(v) => v,
Err(_) => return (0, 0), Err(_) => return (0, 0),
@@ -273,6 +314,9 @@ fn validate_provider_url(url: &str) -> SaasResult<()> {
Some(h) => h, Some(h) => h,
None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())), None => return Err(SaasError::InvalidInput("provider URL 缺少 host".into())),
}; };
// url crate 的 host_str() 对 IPv6 地址保留方括号 (如 "[::1]")
// 需要去掉方括号才能与阻止列表匹配和解析为 IpAddr
let host = host.trim_start_matches('[').trim_end_matches(']');
// 精确匹配的阻止列表 // 精确匹配的阻止列表
let blocked_exact = [ let blocked_exact = [
@@ -335,3 +379,302 @@ fn is_private_ip(ip: &std::net::IpAddr) -> bool {
} }
} }
} }
#[cfg(test)]
mod tests {
use super::*;
// ---- is_retryable_status ----
#[test]
fn retryable_status_429() {
assert!(is_retryable_status(429));
}
#[test]
fn retryable_status_5xx_range() {
for code in 500u16..600 {
assert!(is_retryable_status(code), "expected {code} to be retryable");
}
}
#[test]
fn not_retryable_status_200() {
assert!(!is_retryable_status(200));
}
#[test]
fn not_retryable_status_400() {
assert!(!is_retryable_status(400));
}
#[test]
fn not_retryable_status_404() {
assert!(!is_retryable_status(404));
}
#[test]
fn not_retryable_status_422() {
assert!(!is_retryable_status(422));
}
// ---- extract_token_usage ----
#[test]
fn extract_usage_normal() {
let body = r#"{"usage":{"prompt_tokens":100,"completion_tokens":50}}"#;
assert_eq!(extract_token_usage(body), (100, 50));
}
#[test]
fn extract_usage_no_usage_field() {
let body = r#"{"id":"chatcmpl-abc","object":"chat.completion"}"#;
assert_eq!(extract_token_usage(body), (0, 0));
}
#[test]
fn extract_usage_invalid_json() {
assert_eq!(extract_token_usage("not json at all"), (0, 0));
}
#[test]
fn extract_usage_empty_body() {
assert_eq!(extract_token_usage(""), (0, 0));
}
#[test]
fn extract_usage_partial_tokens() {
// only prompt_tokens present, completion_tokens missing
let body = r#"{"usage":{"prompt_tokens":200}}"#;
assert_eq!(extract_token_usage(body), (200, 0));
}
#[test]
fn extract_usage_completion_only() {
let body = r#"{"usage":{"completion_tokens":75}}"#;
assert_eq!(extract_token_usage(body), (0, 75));
}
#[test]
fn extract_usage_zero_tokens() {
let body = r#"{"usage":{"prompt_tokens":0,"completion_tokens":0}}"#;
assert_eq!(extract_token_usage(body), (0, 0));
}
#[test]
fn extract_usage_string_instead_of_int() {
// non-integer token values should fall back to 0
let body = r#"{"usage":{"prompt_tokens":"abc","completion_tokens":null}}"#;
assert_eq!(extract_token_usage(body), (0, 0));
}
// ---- is_private_ip ----
#[test]
fn private_ip_10_range() {
let ip: std::net::IpAddr = "10.0.0.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_172_16_range() {
let ip: std::net::IpAddr = "172.16.0.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_172_31_range() {
let ip: std::net::IpAddr = "172.31.255.255".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_172_15_not_private() {
// 172.15.x.x is NOT in the private range (starts at 172.16)
let ip: std::net::IpAddr = "172.15.255.255".parse().unwrap();
assert!(!is_private_ip(&ip));
}
#[test]
fn private_ip_172_32_not_private() {
// 172.32.x.x is NOT in the private range (ends at 172.31)
let ip: std::net::IpAddr = "172.32.0.0".parse().unwrap();
assert!(!is_private_ip(&ip));
}
#[test]
fn private_ip_192_168_range() {
let ip: std::net::IpAddr = "192.168.1.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_127_loopback() {
let ip: std::net::IpAddr = "127.0.0.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_127_any() {
let ip: std::net::IpAddr = "127.255.255.255".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_169_254_link_local() {
let ip: std::net::IpAddr = "169.254.1.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_0_0_0_0() {
let ip: std::net::IpAddr = "0.0.0.0".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_v6_loopback() {
let ip: std::net::IpAddr = "::1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_v6_link_local() {
let ip: std::net::IpAddr = "fe80::1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_v6_mapped_ipv4_loopback() {
let ip: std::net::IpAddr = "::ffff:127.0.0.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn private_ip_v6_mapped_ipv4_private() {
let ip: std::net::IpAddr = "::ffff:192.168.1.1".parse().unwrap();
assert!(is_private_ip(&ip));
}
#[test]
fn public_ip_8_8_8_8() {
let ip: std::net::IpAddr = "8.8.8.8".parse().unwrap();
assert!(!is_private_ip(&ip));
}
#[test]
fn public_ip_1_1_1_1() {
let ip: std::net::IpAddr = "1.1.1.1".parse().unwrap();
assert!(!is_private_ip(&ip));
}
#[test]
fn public_ip_v6_google() {
let ip: std::net::IpAddr = "2001:4860:4860::8888".parse().unwrap();
assert!(!is_private_ip(&ip));
}
// ---- validate_provider_url ----
#[test]
fn validate_url_https_valid() {
assert!(validate_provider_url("https://api.openai.com").is_ok());
}
#[test]
fn validate_url_https_with_path() {
assert!(validate_provider_url("https://api.openai.com/v1").is_ok());
}
#[test]
fn validate_url_https_with_port() {
assert!(validate_provider_url("https://api.openai.com:443").is_ok());
}
#[test]
fn validate_url_blocks_localhost() {
assert!(validate_provider_url("https://localhost").is_err());
}
#[test]
fn validate_url_blocks_127_0_0_1() {
assert!(validate_provider_url("https://127.0.0.1").is_err());
}
#[test]
fn validate_url_blocks_0_0_0_0() {
assert!(validate_provider_url("https://0.0.0.0").is_err());
}
#[test]
fn validate_url_blocks_169_254_169_254() {
assert!(validate_provider_url("https://169.254.169.254").is_err());
}
#[test]
fn validate_url_blocks_metadata_google_internal() {
assert!(validate_provider_url("https://metadata.google.internal").is_err());
}
#[test]
fn validate_url_blocks_private_ip_10() {
assert!(validate_provider_url("https://10.0.0.1").is_err());
}
#[test]
fn validate_url_blocks_private_ip_172_16() {
assert!(validate_provider_url("https://172.16.0.1").is_err());
}
#[test]
fn validate_url_blocks_private_ip_192_168() {
assert!(validate_provider_url("https://192.168.0.1").is_err());
}
#[test]
fn validate_url_blocks_numeric_host() {
// decimal IP representation (e.g. 2130706433 = 127.0.0.1)
assert!(validate_provider_url("https://2130706433").is_err());
}
#[test]
fn validate_url_blocks_subdomain_localhost() {
assert!(validate_provider_url("https://evil.localhost").is_err());
}
#[test]
fn validate_url_blocks_subdomain_internal() {
assert!(validate_provider_url("https://app.internal").is_err());
}
#[test]
fn validate_url_blocks_subdomain_local() {
assert!(validate_provider_url("https://myapp.local").is_err());
}
#[test]
fn validate_url_blocks_ipv6_loopback() {
assert!(validate_provider_url("https://[::1]").is_err());
}
#[test]
fn validate_url_invalid_format() {
assert!(validate_provider_url("not a url").is_err());
}
#[test]
fn validate_url_missing_host() {
assert!(validate_provider_url("https://").is_err());
}
#[test]
fn validate_url_blocks_ftp_scheme() {
assert!(validate_provider_url("ftp://api.openai.com").is_err());
}
#[test]
fn validate_url_blocks_http_in_production() {
// In CI / default env, ZCLAW_SAAS_DEV is not set, so http is blocked
assert!(validate_provider_url("http://api.openai.com").is_err());
}
}

View File

@@ -2,27 +2,8 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
/// 中转请求 (OpenAI 兼容格式)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelayChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
#[serde(default)]
pub temperature: Option<f64>,
#[serde(default)]
pub max_tokens: Option<u32>,
#[serde(default)]
pub stream: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: serde_json::Value,
}
/// 中转任务信息 /// 中转任务信息
#[derive(Debug, Clone, Serialize)] #[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
pub struct RelayTaskInfo { pub struct RelayTaskInfo {
pub id: String, pub id: String,
pub account_id: String, pub account_id: String,
@@ -42,18 +23,10 @@ pub struct RelayTaskInfo {
} }
/// 中转任务查询 /// 中转任务查询
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
pub struct RelayTaskQuery { pub struct RelayTaskQuery {
pub status: Option<String>, pub status: Option<String>,
pub page: Option<i64>, pub page: Option<i64>,
pub page_size: Option<i64>, pub page_size: Option<i64>,
} }
/// Provider 速率限制状态
#[derive(Debug, Clone)]
pub struct RateLimitState {
pub rpm: i64,
pub tpm: i64,
pub concurrent: usize,
pub max_concurrent: usize,
}

View File

@@ -1,31 +1,36 @@
//! 应用状态 //! 应用状态
use sqlx::SqlitePool; use sqlx::PgPool;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::config::SaaSConfig; use crate::config::SaaSConfig;
use crate::crypto::FieldEncryption;
/// 全局应用状态,通过 Axum State 共享 /// 全局应用状态,通过 Axum State 共享
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
/// 数据库连接池 /// 数据库连接池
pub db: SqlitePool, pub db: PgPool,
/// 服务器配置 (可热更新) /// 服务器配置 (可热更新)
pub config: Arc<RwLock<SaaSConfig>>, pub config: Arc<RwLock<SaaSConfig>>,
/// JWT 密钥 /// JWT 密钥
pub jwt_secret: secrecy::SecretString, pub jwt_secret: secrecy::SecretString,
/// 字段级加密器 (AES-256-GCM)
pub field_encryption: Arc<FieldEncryption>,
/// 速率限制: account_id → 请求时间戳列表 /// 速率限制: account_id → 请求时间戳列表
pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>, pub rate_limit_entries: Arc<dashmap::DashMap<String, Vec<Instant>>>,
} }
impl AppState { impl AppState {
pub fn new(db: SqlitePool, config: SaaSConfig) -> anyhow::Result<Self> { pub fn new(db: PgPool, config: SaaSConfig) -> anyhow::Result<Self> {
let jwt_secret = config.jwt_secret()?; let jwt_secret = config.jwt_secret()?;
let field_encryption = Arc::new(FieldEncryption::new()?);
Ok(Self { Ok(Self {
db, db,
config: Arc::new(RwLock::new(config)), config: Arc::new(RwLock::new(config)),
jwt_secret, jwt_secret,
field_encryption,
rate_limit_entries: Arc::new(dashmap::DashMap::new()), rate_limit_entries: Arc::new(dashmap::DashMap::new()),
}) })
} }

View File

@@ -1,4 +1,6 @@
//! 集成测试 (Phase 1 + Phase 2) //! 集成测试 (Phase 1 + Phase 2)
//!
//! 所有测试通过全局 Mutex 串行执行,避免共享数据库导致的 UNIQUE 约束冲突和数据竞争。
use axum::{ use axum::{
body::Body, body::Body,
@@ -9,8 +11,16 @@ use tower::ServiceExt;
const MAX_BODY_SIZE: usize = 1024 * 1024; // 1MB const MAX_BODY_SIZE: usize = 1024 * 1024; // 1MB
/// 全局 Mutex 用于序列化所有集成测试
/// tokio::test 默认并行执行,但共享数据库要求串行访问
static INTEGRATION_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
async fn build_test_app() -> axum::Router { async fn build_test_app() -> axum::Router {
use zclaw_saas::{config::SaaSConfig, db::init_memory_db, state::AppState}; let _ = tracing_subscriber::fmt()
.with_env_filter("error")
.with_test_writer()
.try_init();
use zclaw_saas::{config::SaaSConfig, db::init_test_db, state::AppState};
use axum::extract::ConnectInfo; use axum::extract::ConnectInfo;
use std::net::SocketAddr; use std::net::SocketAddr;
@@ -18,7 +28,7 @@ async fn build_test_app() -> axum::Router {
std::env::set_var("ZCLAW_SAAS_DEV", "true"); std::env::set_var("ZCLAW_SAAS_DEV", "true");
std::env::set_var("ZCLAW_SAAS_JWT_SECRET", "test-secret-for-integration-tests-only"); std::env::set_var("ZCLAW_SAAS_JWT_SECRET", "test-secret-for-integration-tests-only");
let db = init_memory_db().await.unwrap(); let db = init_test_db().await.unwrap();
let mut config = SaaSConfig::default(); let mut config = SaaSConfig::default();
config.auth.jwt_expiration_hours = 24; config.auth.jwt_expiration_hours = 24;
let state = AppState::new(db, config).expect("测试环境 AppState 初始化失败"); let state = AppState::new(db, config).expect("测试环境 AppState 初始化失败");
@@ -85,6 +95,7 @@ fn auth_header(token: &str) -> String {
#[tokio::test] #[tokio::test]
async fn test_register_and_login() { async fn test_register_and_login() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "testuser", "test@example.com").await; let token = register_and_login(&app, "testuser", "test@example.com").await;
assert!(!token.is_empty()); assert!(!token.is_empty());
@@ -92,6 +103,7 @@ async fn test_register_and_login() {
#[tokio::test] #[tokio::test]
async fn test_register_duplicate_fails() { async fn test_register_duplicate_fails() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let body = json!({ let body = json!({
@@ -123,6 +135,7 @@ async fn test_register_duplicate_fails() {
#[tokio::test] #[tokio::test]
async fn test_unauthorized_access() { async fn test_unauthorized_access() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let req = Request::builder() let req = Request::builder()
@@ -137,6 +150,7 @@ async fn test_unauthorized_access() {
#[tokio::test] #[tokio::test]
async fn test_login_wrong_password() { async fn test_login_wrong_password() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
register_and_login(&app, "wrongpwd", "wrongpwd@example.com").await; register_and_login(&app, "wrongpwd", "wrongpwd@example.com").await;
@@ -156,6 +170,7 @@ async fn test_login_wrong_password() {
#[tokio::test] #[tokio::test]
async fn test_full_authenticated_flow() { async fn test_full_authenticated_flow() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "fulltest", "full@example.com").await; let token = register_and_login(&app, "fulltest", "full@example.com").await;
@@ -204,6 +219,7 @@ async fn test_full_authenticated_flow() {
#[tokio::test] #[tokio::test]
async fn test_providers_crud() { async fn test_providers_crud() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
// 注册 super_admin 角色用户 (通过直接插入角色权限) // 注册 super_admin 角色用户 (通过直接插入角色权限)
let token = register_and_login(&app, "adminprov", "adminprov@example.com").await; let token = register_and_login(&app, "adminprov", "adminprov@example.com").await;
@@ -239,6 +255,7 @@ async fn test_providers_crud() {
#[tokio::test] #[tokio::test]
async fn test_models_list_and_usage() { async fn test_models_list_and_usage() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "modeluser", "modeluser@example.com").await; let token = register_and_login(&app, "modeluser", "modeluser@example.com").await;
@@ -274,6 +291,7 @@ async fn test_models_list_and_usage() {
#[tokio::test] #[tokio::test]
async fn test_api_keys_lifecycle() { async fn test_api_keys_lifecycle() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "keyuser", "keyuser@example.com").await; let token = register_and_login(&app, "keyuser", "keyuser@example.com").await;
@@ -309,6 +327,7 @@ async fn test_api_keys_lifecycle() {
#[tokio::test] #[tokio::test]
async fn test_relay_models_list() { async fn test_relay_models_list() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "relayuser", "relayuser@example.com").await; let token = register_and_login(&app, "relayuser", "relayuser@example.com").await;
@@ -329,6 +348,7 @@ async fn test_relay_models_list() {
#[tokio::test] #[tokio::test]
async fn test_relay_chat_no_model() { async fn test_relay_chat_no_model() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "relayfail", "relayfail@example.com").await; let token = register_and_login(&app, "relayfail", "relayfail@example.com").await;
@@ -351,6 +371,7 @@ async fn test_relay_chat_no_model() {
#[tokio::test] #[tokio::test]
async fn test_relay_tasks_list() { async fn test_relay_tasks_list() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "relaytasks", "relaytasks@example.com").await; let token = register_and_login(&app, "relaytasks", "relaytasks@example.com").await;
@@ -369,6 +390,7 @@ async fn test_relay_tasks_list() {
#[tokio::test] #[tokio::test]
async fn test_config_analysis_empty() { async fn test_config_analysis_empty() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "cfguser", "cfguser@example.com").await; let token = register_and_login(&app, "cfguser", "cfguser@example.com").await;
@@ -389,6 +411,7 @@ async fn test_config_analysis_empty() {
#[tokio::test] #[tokio::test]
async fn test_config_seed_and_list() { async fn test_config_seed_and_list() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "cfgseed", "cfgseed@example.com").await; let token = register_and_login(&app, "cfgseed", "cfgseed@example.com").await;
@@ -423,6 +446,7 @@ async fn test_config_seed_and_list() {
#[tokio::test] #[tokio::test]
async fn test_device_register_and_list() { async fn test_device_register_and_list() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "devuser", "devuser@example.com").await; let token = register_and_login(&app, "devuser", "devuser@example.com").await;
@@ -463,6 +487,7 @@ async fn test_device_register_and_list() {
#[tokio::test] #[tokio::test]
async fn test_device_upsert_on_reregister() { async fn test_device_upsert_on_reregister() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "upsertdev", "upsertdev@example.com").await; let token = register_and_login(&app, "upsertdev", "upsertdev@example.com").await;
@@ -516,6 +541,7 @@ async fn test_device_upsert_on_reregister() {
#[tokio::test] #[tokio::test]
async fn test_device_heartbeat() { async fn test_device_heartbeat() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "hbuser", "hbuser@example.com").await; let token = register_and_login(&app, "hbuser", "hbuser@example.com").await;
@@ -563,6 +589,7 @@ async fn test_device_heartbeat() {
#[tokio::test] #[tokio::test]
async fn test_device_register_missing_id() { async fn test_device_register_missing_id() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "baddev", "baddev@example.com").await; let token = register_and_login(&app, "baddev", "baddev@example.com").await;
@@ -578,11 +605,12 @@ async fn test_device_register_missing_id() {
.unwrap(); .unwrap();
let resp = app.oneshot(req).await.unwrap(); let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST); assert!(resp.status() == StatusCode::BAD_REQUEST || resp.status() == StatusCode::UNPROCESSABLE_ENTITY);
} }
#[tokio::test] #[tokio::test]
async fn test_change_password() { async fn test_change_password() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "pwduser", "pwduser@example.com").await; let token = register_and_login(&app, "pwduser", "pwduser@example.com").await;
@@ -632,6 +660,7 @@ async fn test_change_password() {
#[tokio::test] #[tokio::test]
async fn test_change_password_wrong_old() { async fn test_change_password_wrong_old() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "wrongold", "wrongold@example.com").await; let token = register_and_login(&app, "wrongold", "wrongold@example.com").await;
@@ -655,6 +684,7 @@ async fn test_change_password_wrong_old() {
#[tokio::test] #[tokio::test]
async fn test_e2e_full_lifecycle() { async fn test_e2e_full_lifecycle() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
// 1. 注册 // 1. 注册
@@ -771,6 +801,7 @@ async fn test_e2e_full_lifecycle() {
#[tokio::test] #[tokio::test]
async fn test_config_sync() { async fn test_config_sync() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "cfgsync", "cfgsync@example.com").await; let token = register_and_login(&app, "cfgsync", "cfgsync@example.com").await;
@@ -808,6 +839,7 @@ async fn test_config_sync() {
#[tokio::test] #[tokio::test]
async fn test_totp_setup_and_verify() { async fn test_totp_setup_and_verify() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "totpuser", "totp@example.com").await; let token = register_and_login(&app, "totpuser", "totp@example.com").await;
@@ -825,7 +857,7 @@ async fn test_totp_setup_and_verify() {
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert!(body["otpauth_uri"].is_string()); assert!(body["otpauth_uri"].is_string());
assert!(body["secret"].is_string()); assert!(body["secret"].is_string());
let secret = body["secret"].as_str().unwrap(); let _secret = body["secret"].as_str().unwrap();
// 2. Verify with wrong code → 400 // 2. Verify with wrong code → 400
let bad_verify = Request::builder() let bad_verify = Request::builder()
@@ -868,6 +900,7 @@ async fn test_totp_setup_and_verify() {
#[tokio::test] #[tokio::test]
async fn test_totp_disabled_login_without_code() { async fn test_totp_disabled_login_without_code() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "nototp", "nototp@example.com").await; let token = register_and_login(&app, "nototp", "nototp@example.com").await;
@@ -913,6 +946,7 @@ async fn test_totp_disabled_login_without_code() {
#[tokio::test] #[tokio::test]
async fn test_totp_disable_wrong_password() { async fn test_totp_disable_wrong_password() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "totpwrong", "totpwrong@example.com").await; let token = register_and_login(&app, "totpwrong", "totpwrong@example.com").await;
@@ -932,6 +966,7 @@ async fn test_totp_disable_wrong_password() {
#[tokio::test] #[tokio::test]
async fn test_config_diff() { async fn test_config_diff() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "diffuser", "diffuser@example.com").await; let token = register_and_login(&app, "diffuser", "diffuser@example.com").await;
@@ -959,6 +994,7 @@ async fn test_config_diff() {
#[tokio::test] #[tokio::test]
async fn test_config_sync_push() { async fn test_config_sync_push() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "syncpush", "syncpush@example.com").await; let token = register_and_login(&app, "syncpush", "syncpush@example.com").await;
@@ -987,6 +1023,7 @@ async fn test_config_sync_push() {
#[tokio::test] #[tokio::test]
async fn test_relay_retry_unauthorized() { async fn test_relay_retry_unauthorized() {
let _guard = INTEGRATION_TEST_LOCK.lock().unwrap();
let app = build_test_app().await; let app = build_test_app().await;
let token = register_and_login(&app, "retryuser", "retryuser@example.com").await; let token = register_and_login(&app, "retryuser", "retryuser@example.com").await;

View File

@@ -0,0 +1,206 @@
# Design System Master File
> **LOGIC:** When building a specific page, first check `design-system/pages/[page-name].md`.
> If that file exists, its rules **override** this Master file.
> If not, strictly follow the rules below.
---
**Project:** ZCLAW Admin
**Generated:** 2026-03-27 13:52:31
**Category:** Financial Dashboard
---
## Global Rules
### Color Palette
| Role | Hex | CSS Variable |
|------|-----|--------------|
| Primary | `#0F172A` | `--color-primary` |
| Secondary | `#1E293B` | `--color-secondary` |
| CTA/Accent | `#22C55E` | `--color-cta` |
| Background | `#020617` | `--color-background` |
| Text | `#F8FAFC` | `--color-text` |
**Color Notes:** Dark bg + green positive indicators
### Typography
- **Heading Font:** Fira Code
- **Body Font:** Fira Sans
- **Mood:** dashboard, data, analytics, code, technical, precise
- **Google Fonts:** [Fira Code + Fira Sans](https://fonts.google.com/share?selection.family=Fira+Code:wght@400;500;600;700|Fira+Sans:wght@300;400;500;600;700)
**CSS Import:**
```css
@import url('https://fonts.googleapis.com/css2?family=Fira+Code:wght@400;500;600;700&family=Fira+Sans:wght@300;400;500;600;700&display=swap');
```
### Spacing Variables
| Token | Value | Usage |
|-------|-------|-------|
| `--space-xs` | `4px` / `0.25rem` | Tight gaps |
| `--space-sm` | `8px` / `0.5rem` | Icon gaps, inline spacing |
| `--space-md` | `16px` / `1rem` | Standard padding |
| `--space-lg` | `24px` / `1.5rem` | Section padding |
| `--space-xl` | `32px` / `2rem` | Large gaps |
| `--space-2xl` | `48px` / `3rem` | Section margins |
| `--space-3xl` | `64px` / `4rem` | Hero padding |
### Shadow Depths
| Level | Value | Usage |
|-------|-------|-------|
| `--shadow-sm` | `0 1px 2px rgba(0,0,0,0.05)` | Subtle lift |
| `--shadow-md` | `0 4px 6px rgba(0,0,0,0.1)` | Cards, buttons |
| `--shadow-lg` | `0 10px 15px rgba(0,0,0,0.1)` | Modals, dropdowns |
| `--shadow-xl` | `0 20px 25px rgba(0,0,0,0.15)` | Hero images, featured cards |
---
## Component Specs
### Buttons
```css
/* Primary Button */
.btn-primary {
background: #22C55E;
color: white;
padding: 12px 24px;
border-radius: 8px;
font-weight: 600;
transition: all 200ms ease;
cursor: pointer;
}
.btn-primary:hover {
opacity: 0.9;
transform: translateY(-1px);
}
/* Secondary Button */
.btn-secondary {
background: transparent;
color: #0F172A;
border: 2px solid #0F172A;
padding: 12px 24px;
border-radius: 8px;
font-weight: 600;
transition: all 200ms ease;
cursor: pointer;
}
```
### Cards
```css
.card {
background: #020617;
border-radius: 12px;
padding: 24px;
box-shadow: var(--shadow-md);
transition: all 200ms ease;
cursor: pointer;
}
.card:hover {
box-shadow: var(--shadow-lg);
transform: translateY(-2px);
}
```
### Inputs
```css
.input {
padding: 12px 16px;
border: 1px solid #E2E8F0;
border-radius: 8px;
font-size: 16px;
transition: border-color 200ms ease;
}
.input:focus {
border-color: #0F172A;
outline: none;
box-shadow: 0 0 0 3px #0F172A20;
}
```
### Modals
```css
.modal-overlay {
background: rgba(0, 0, 0, 0.5);
backdrop-filter: blur(4px);
}
.modal {
background: white;
border-radius: 16px;
padding: 32px;
box-shadow: var(--shadow-xl);
max-width: 500px;
width: 90%;
}
```
---
## Style Guidelines
**Style:** Dark Mode (OLED)
**Keywords:** Dark theme, low light, high contrast, deep black, midnight blue, eye-friendly, OLED, night mode, power efficient
**Best For:** Night-mode apps, coding platforms, entertainment, eye-strain prevention, OLED devices, low-light
**Key Effects:** Minimal glow (text-shadow: 0 0 10px), dark-to-light transitions, low white emission, high readability, visible focus
### Page Pattern
**Pattern Name:** Horizontal Scroll Journey
- **Conversion Strategy:** Immersive product discovery. High engagement. Keep navigation visible.
28,Bento Grid Showcase,bento, grid, features, modular, apple-style, showcase", 1. Hero, 2. Bento Grid (Key Features), 3. Detail Cards, 4. Tech Specs, 5. CTA, Floating Action Button or Bottom of Grid, Card backgrounds: #F5F5F7 or Glass. Icons: Vibrant brand colors. Text: Dark., Hover card scale (1.02), video inside cards, tilt effect, staggered reveal, Scannable value props. High information density without clutter. Mobile stack.
29,Interactive 3D Configurator,3d, configurator, customizer, interactive, product", 1. Hero (Configurator), 2. Feature Highlight (synced), 3. Price/Specs, 4. Purchase, Inside Configurator UI + Sticky Bottom Bar, Neutral studio background. Product: Realistic materials. UI: Minimal overlay., Real-time rendering, material swap animation, camera rotate/zoom, light reflection, Increases ownership feeling. 360 view reduces return rates. Direct add-to-cart.
30,AI-Driven Dynamic Landing,ai, dynamic, personalized, adaptive, generative", 1. Prompt/Input Hero, 2. Generated Result Preview, 3. How it Works, 4. Value Prop, Input Field (Hero) + 'Try it' Buttons, Adaptive to user input. Dark mode for compute feel. Neon accents., Typing text effects, shimmering generation loaders, morphing layouts, Immediate value demonstration. 'Show, don't tell'. Low friction start.
- **CTA Placement:** Floating Sticky CTA or End of Horizontal Track
- **Section Order:** 1. Intro (Vertical), 2. The Journey (Horizontal Track), 3. Detail Reveal, 4. Vertical Footer
---
## Anti-Patterns (Do NOT Use)
- ❌ Light mode default
- ❌ Slow rendering
### Additional Forbidden Patterns
-**Emojis as icons** — Use SVG icons (Heroicons, Lucide, Simple Icons)
-**Missing cursor:pointer** — All clickable elements must have cursor:pointer
-**Layout-shifting hovers** — Avoid scale transforms that shift layout
-**Low contrast text** — Maintain 4.5:1 minimum contrast ratio
-**Instant state changes** — Always use transitions (150-300ms)
-**Invisible focus states** — Focus states must be visible for a11y
---
## Pre-Delivery Checklist
Before delivering any UI code, verify:
- [ ] No emojis used as icons (use SVG instead)
- [ ] All icons from consistent icon set (Heroicons/Lucide)
- [ ] `cursor-pointer` on all clickable elements
- [ ] Hover states with smooth transitions (150-300ms)
- [ ] Light mode: text contrast 4.5:1 minimum
- [ ] Focus states visible for keyboard navigation
- [ ] `prefers-reduced-motion` respected
- [ ] Responsive: 375px, 768px, 1024px, 1440px
- [ ] No content hidden behind fixed navbars
- [ ] No horizontal scroll on mobile

View File

@@ -102,8 +102,8 @@ export function ConfigMigrationWizard({ onDone }: { onDone: () => void }) {
if (direction === 'local-to-saas' && localModels.length > 0) { if (direction === 'local-to-saas' && localModels.length > 0) {
// Push local models as config items // Push local models as config items
for (const model of localModels) { for (const model of localModels) {
const exists = saasConfigs.some((c) => c.key_path === `models.${model.id}`); const existingItem = saasConfigs.find((c) => c.key_path === `models.${model.id}`);
if (exists && !selectedKeys.has(model.id)) continue; if (existingItem && !selectedKeys.has(model.id)) continue;
const body = { const body = {
category: 'model', category: 'model',
@@ -114,8 +114,8 @@ export function ConfigMigrationWizard({ onDone }: { onDone: () => void }) {
description: `从桌面端同步: ${model.name}`, description: `从桌面端同步: ${model.name}`,
}; };
if (exists) { if (existingItem) {
await saasClient.request<unknown>('PUT', `/api/v1/config/items/${exists}`, body); await saasClient.request<unknown>('PUT', `/api/v1/config/items/${existingItem.id}`, body);
} else { } else {
await saasClient.request<unknown>('POST', '/api/v1/config/items', body); await saasClient.request<unknown>('POST', '/api/v1/config/items', body);
} }

View File

@@ -1,5 +1,5 @@
import { useEffect, useState } from 'react'; import { useEffect, useState } from 'react';
import type { SaaSAccountInfo, SaaSModelInfo } from '../../lib/saas-client'; import { saasClient, type SaaSAccountInfo, type SaaSModelInfo } from '../../lib/saas-client';
import { Cloud, CloudOff, LogOut, RefreshCw, Cpu, CheckCircle, XCircle, Loader2 } from 'lucide-react'; import { Cloud, CloudOff, LogOut, RefreshCw, Cpu, CheckCircle, XCircle, Loader2 } from 'lucide-react';
import { useSaaSStore } from '../../store/saasStore'; import { useSaaSStore } from '../../store/saasStore';

View File

@@ -6,18 +6,7 @@ import { useConfigStore } from '../../store/configStore';
import { useChatStore } from '../../store/chatStore'; import { useChatStore } from '../../store/chatStore';
import { silentErrorHandler } from '../../lib/error-utils'; import { silentErrorHandler } from '../../lib/error-utils';
import { Plus, Pencil, Trash2, Star, Eye, EyeOff, AlertCircle, X, Zap, Check } from 'lucide-react'; import { Plus, Pencil, Trash2, Star, Eye, EyeOff, AlertCircle, X, Zap, Check } from 'lucide-react';
import type { CustomModel, CustomModelApiProtocol } from '../../types/config';
// 自定义模型数据结构
interface CustomModel {
id: string;
name: string;
provider: string;
apiKey?: string;
apiProtocol: 'openai' | 'anthropic' | 'custom';
baseUrl?: string;
isDefault?: boolean;
createdAt: string;
}
// Embedding 配置数据结构 // Embedding 配置数据结构
interface EmbeddingConfig { interface EmbeddingConfig {
@@ -140,7 +129,7 @@ export function ModelsAPI() {
modelId: 'glm-4-flash', modelId: 'glm-4-flash',
displayName: '', displayName: '',
apiKey: '', apiKey: '',
apiProtocol: 'openai' as 'openai' | 'anthropic' | 'custom', apiProtocol: 'openai' as CustomModelApiProtocol,
baseUrl: '', baseUrl: '',
}); });
@@ -650,7 +639,7 @@ export function ModelsAPI() {
<label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-2">API </label> <label className="block text-sm font-medium text-gray-700 dark:text-gray-300 mb-2">API </label>
<select <select
value={formData.apiProtocol} value={formData.apiProtocol}
onChange={(e) => setFormData({ ...formData, apiProtocol: e.target.value as 'openai' | 'anthropic' | 'custom' })} onChange={(e) => setFormData({ ...formData, apiProtocol: e.target.value as CustomModelApiProtocol })}
className="w-full px-3 py-2 border border-gray-200 dark:border-gray-600 rounded-lg text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-white focus:outline-none focus:ring-2 focus:ring-orange-500" className="w-full px-3 py-2 border border-gray-200 dark:border-gray-600 rounded-lg text-sm bg-white dark:bg-gray-700 text-gray-900 dark:text-white focus:outline-none focus:ring-2 focus:ring-orange-500"
> >
<option value="openai">OpenAI</option> <option value="openai">OpenAI</option>

View File

@@ -455,10 +455,24 @@ export function clearSecurityLog(): void {
} }
/** /**
* Generate a random API key for testing * Generate a random API key for testing.
* WARNING: Only use for testing purposes *
* @internal This function is intended solely for automated tests and
* development tooling. It must never be called in production
* builds because generated keys are not cryptographically secure
* and should never be used to authenticate against real services.
*
* @param type - The API key type to generate a test key for
* @returns A random API key that passes format validation for the given type
* @throws {Error} If called outside of a development or test environment
*/ */
export function generateTestApiKey(type: ApiKeyType): string { export function generateTestApiKey(type: ApiKeyType): string {
if (import.meta.env?.DEV !== true && import.meta.env?.MODE !== 'test') {
throw new Error(
'[Security] generateTestApiKey may only be called in development or test environments'
);
}
const rules = KEY_VALIDATION_RULES[type]; const rules = KEY_VALIDATION_RULES[type];
const length = rules.minLength + 10; const length = rules.minLength + 10;
const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'; const chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789';

View File

@@ -37,13 +37,17 @@ export {
DEFAULT_GATEWAY_URL, DEFAULT_GATEWAY_URL,
REST_API_URL, REST_API_URL,
FALLBACK_GATEWAY_URLS, FALLBACK_GATEWAY_URLS,
ZCLAW_GRPC_PORT,
ZCLAW_LEGACY_PORT,
normalizeGatewayUrl, normalizeGatewayUrl,
isLocalhost, isLocalhost,
getStoredGatewayUrl, getStoredGatewayUrl,
setStoredGatewayUrl, setStoredGatewayUrl,
getStoredGatewayToken, getStoredGatewayToken,
setStoredGatewayToken, setStoredGatewayToken,
detectConnectionMode,
} from './gateway-storage'; } from './gateway-storage';
export type { ConnectionMode } from './gateway-storage';
// === Internal imports === // === Internal imports ===
import type { import type {
@@ -69,6 +73,7 @@ import {
isLocalhost, isLocalhost,
getStoredGatewayUrl, getStoredGatewayUrl,
getStoredGatewayToken, getStoredGatewayToken,
detectConnectionMode,
} from './gateway-storage'; } from './gateway-storage';
import type { GatewayConfigSnapshot, GatewayModelChoice } from './gateway-config'; import type { GatewayConfigSnapshot, GatewayModelChoice } from './gateway-config';
@@ -273,8 +278,8 @@ export class GatewayClient {
return Promise.resolve(); return Promise.resolve();
} }
// Check if URL is for ZCLAW (port 4200 or 50051) - use REST mode // Check if URL is for ZCLAW (known kernel ports) - use REST mode
if (this.url.includes(':4200') || this.url.includes(':50051')) { if (detectConnectionMode(this.url) === 'rest') {
return this.connectRest(); return this.connectRest();
} }

View File

@@ -40,15 +40,47 @@ export function isLocalhost(url: string): boolean {
} }
} }
// === Port Constants ===
/** Default gRPC/HTTP port used by the ZCLAW kernel */
export const ZCLAW_GRPC_PORT = 50051;
/** Legacy/alternative port used in development or older configurations */
export const ZCLAW_LEGACY_PORT = 4200;
// === Connection Mode ===
/**
* Determines how the client connects to the ZCLAW gateway.
* - `rest`: Kernel exposes an HTTP REST API (gRPC-gateway). Used when the
* URL contains a known kernel port.
* - `ws`: Direct WebSocket connection to the kernel.
*/
export type ConnectionMode = 'rest' | 'ws';
/**
* Decide the connection mode based on the gateway URL.
*
* When the URL contains a known kernel port (gRPC or legacy), the client
* routes requests through the REST adapter instead of opening a raw
* WebSocket.
*/
export function detectConnectionMode(url: string): ConnectionMode {
if (url.includes(`:${ZCLAW_GRPC_PORT}`) || url.includes(`:${ZCLAW_LEGACY_PORT}`)) {
return 'rest';
}
return 'ws';
}
// === URL Constants === // === URL Constants ===
// ZCLAW endpoints (port 50051 - actual running port) // ZCLAW endpoints (port 50051 - actual running port)
// Note: REST API uses relative path to leverage Vite proxy for CORS bypass // Note: REST API uses relative path to leverage Vite proxy for CORS bypass
export const DEFAULT_GATEWAY_URL = `${DEFAULT_WS_PROTOCOL}127.0.0.1:50051/ws`; export const DEFAULT_GATEWAY_URL = `${DEFAULT_WS_PROTOCOL}127.0.0.1:${ZCLAW_GRPC_PORT}/ws`;
export const REST_API_URL = ''; // Empty = use relative path (Vite proxy) export const REST_API_URL = ''; // Empty = use relative path (Vite proxy)
export const FALLBACK_GATEWAY_URLS = [ export const FALLBACK_GATEWAY_URLS = [
DEFAULT_GATEWAY_URL, DEFAULT_GATEWAY_URL,
`${DEFAULT_WS_PROTOCOL}127.0.0.1:4200/ws`, `${DEFAULT_WS_PROTOCOL}127.0.0.1:${ZCLAW_LEGACY_PORT}/ws`,
]; ];
const GATEWAY_URL_STORAGE_KEY = 'zclaw_gateway_url'; const GATEWAY_URL_STORAGE_KEY = 'zclaw_gateway_url';

View File

@@ -6,8 +6,13 @@
* *
* API base path: /api/v1/... * API base path: /api/v1/...
* Auth: Bearer token in Authorization header * Auth: Bearer token in Authorization header
*
* Security: JWT token is stored via secureStorage (OS keychain or encrypted localStorage).
* URL, account info, and connection mode remain in plain localStorage (non-sensitive).
*/ */
import { secureStorage } from './secure-storage';
// === Storage Keys === // === Storage Keys ===
const SAASTOKEN_KEY = 'zclaw-saas-token'; const SAASTOKEN_KEY = 'zclaw-saas-token';
@@ -146,6 +151,55 @@ export interface ConfigSyncResult {
skipped: number; skipped: number;
} }
// === JWT Helpers ===
/**
* Decode a JWT payload without verifying the signature.
* Returns the parsed JSON payload, or null if the token is malformed.
*/
export function decodeJwtPayload<T = Record<string, unknown>>(token: string): T | null {
try {
const parts = token.split('.');
if (parts.length !== 3) return null;
// JWT payload is Base64Url-encoded
const base64 = parts[1].replace(/-/g, '+').replace(/_/g, '/');
const json = decodeURIComponent(
atob(base64)
.split('')
.map((c) => '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2))
.join(''),
);
return JSON.parse(json) as T;
} catch {
return null;
}
}
/** JWT payload shape we care about */
interface JwtPayload {
exp?: number;
iat?: number;
sub?: string;
}
/**
* Calculate the delay (ms) until 80% of the token's lifetime has elapsed.
* This is the ideal moment to trigger a proactive refresh.
* Returns null if the token has no exp claim or is already past 80% lifetime.
*/
export function getRefreshDelay(exp: number): number | null {
const now = Math.floor(Date.now() / 1000);
const totalLifetime = exp - now;
if (totalLifetime <= 0) return null; // already expired
// Refresh at 80% of the token's remaining lifetime
const refreshAt = now + Math.floor(totalLifetime * 0.8);
const delayMs = (refreshAt - now) * 1000;
// Minimum 5-second guard to avoid hammering the endpoint
return delayMs > 5000 ? delayMs : 5000;
}
// === Error Class === // === Error Class ===
export class SaaSApiError extends Error { export class SaaSApiError extends Error {
@@ -168,16 +222,35 @@ export interface SaaSSession {
} }
/** /**
* Load a persisted SaaS session from localStorage. * Read a value from localStorage with error handling.
*/
function readLegacyLocalStorage(key: string): string | null {
try {
return localStorage.getItem(key);
} catch {
return null;
}
}
/**
* Load a persisted SaaS session using secure storage for the JWT token.
* Falls back to legacy localStorage if secureStorage has no token (migration).
* Returns null if no valid session exists. * Returns null if no valid session exists.
*/ */
export function loadSaaSSession(): SaaSSession | null { export async function loadSaaSSessionAsync(): Promise<SaaSSession | null> {
try { try {
const token = localStorage.getItem(SAASTOKEN_KEY); // Try secure storage first (keychain or encrypted localStorage)
const saasUrl = localStorage.getItem(SAASURL_KEY); const token = await secureStorage.get(SAASTOKEN_KEY);
const accountRaw = localStorage.getItem(SAASACCOUNT_KEY);
if (!token || !saasUrl) { // Migration: if secureStorage is empty, try legacy localStorage
const legacyToken = !token ? readLegacyLocalStorage(SAASTOKEN_KEY) : null;
const saasUrl = readLegacyLocalStorage(SAASURL_KEY);
const accountRaw = readLegacyLocalStorage(SAASACCOUNT_KEY);
const effectiveToken = token || legacyToken;
if (!effectiveToken || !saasUrl) {
return null; return null;
} }
@@ -185,19 +258,30 @@ export function loadSaaSSession(): SaaSSession | null {
? (JSON.parse(accountRaw) as SaaSAccountInfo) ? (JSON.parse(accountRaw) as SaaSAccountInfo)
: null; : null;
return { token, account, saasUrl }; // If we found a legacy token in localStorage, migrate it to secure storage
if (legacyToken && !token) {
await secureStorage.set(SAASTOKEN_KEY, legacyToken);
// Remove plaintext token from localStorage after migration
try { localStorage.removeItem(SAASTOKEN_KEY); } catch { /* ignore */ }
}
return { token: effectiveToken, account, saasUrl };
} catch { } catch {
// Corrupted data - clear all // Corrupted data - clear all
clearSaaSSession(); await clearSaaSSessionAsync();
return null; return null;
} }
} }
/** /**
* Persist a SaaS session to localStorage. * Persist a SaaS session using secure storage for the JWT token.
* URL and account info remain in localStorage (non-sensitive).
*/ */
export function saveSaaSSession(session: SaaSSession): void { export async function saveSaaSSessionAsync(session: SaaSSession): Promise<void> {
localStorage.setItem(SAASTOKEN_KEY, session.token); await secureStorage.set(SAASTOKEN_KEY, session.token);
// Remove legacy plaintext token from localStorage
try { localStorage.removeItem(SAASTOKEN_KEY); } catch { /* ignore */ }
localStorage.setItem(SAASURL_KEY, session.saasUrl); localStorage.setItem(SAASURL_KEY, session.saasUrl);
if (session.account) { if (session.account) {
localStorage.setItem(SAASACCOUNT_KEY, JSON.stringify(session.account)); localStorage.setItem(SAASACCOUNT_KEY, JSON.stringify(session.account));
@@ -205,16 +289,18 @@ export function saveSaaSSession(session: SaaSSession): void {
} }
/** /**
* Clear the persisted SaaS session from localStorage. * Clear the persisted SaaS session from both secure storage and localStorage.
*/ */
export function clearSaaSSession(): void { export async function clearSaaSSessionAsync(): Promise<void> {
localStorage.removeItem(SAASTOKEN_KEY); await secureStorage.delete(SAASTOKEN_KEY);
localStorage.removeItem(SAASURL_KEY); try { localStorage.removeItem(SAASTOKEN_KEY); } catch { /* ignore */ }
localStorage.removeItem(SAASACCOUNT_KEY); try { localStorage.removeItem(SAASURL_KEY); } catch { /* ignore */ }
try { localStorage.removeItem(SAASACCOUNT_KEY); } catch { /* ignore */ }
} }
/** /**
* Persist the connection mode to localStorage. * Persist the connection mode to localStorage.
* Connection mode is non-sensitive -- no need for secure storage.
*/ */
export function saveConnectionMode(mode: string): void { export function saveConnectionMode(mode: string): void {
localStorage.setItem(SAASMODE_KEY, mode); localStorage.setItem(SAASMODE_KEY, mode);
@@ -230,9 +316,15 @@ export function loadConnectionMode(): string | null {
// === Client Implementation === // === Client Implementation ===
/** Callback invoked when token refresh fails and the session should be terminated. */
export type OnSessionExpired = () => void;
export class SaaSClient { export class SaaSClient {
private baseUrl: string; private baseUrl: string;
private token: string | null = null; private token: string | null = null;
private refreshTimerId: ReturnType<typeof setTimeout> | null = null;
private visibilityHandler: (() => void) | null = null;
private onSessionExpired: OnSessionExpired | null = null;
constructor(baseUrl: string) { constructor(baseUrl: string) {
this.baseUrl = baseUrl.replace(/\/+$/, ''); this.baseUrl = baseUrl.replace(/\/+$/, '');
@@ -248,9 +340,22 @@ export class SaaSClient {
return this.baseUrl; return this.baseUrl;
} }
/** Set or clear the auth token */ /** Set or clear the auth token. Automatically schedules a proactive refresh. */
setToken(token: string | null): void { setToken(token: string | null): void {
this.token = token; this.token = token;
if (token) {
this.scheduleTokenRefresh();
} else {
this.cancelTokenRefresh();
}
}
/**
* Register a callback invoked when the proactive token refresh fails.
* The caller should use this to trigger a logout/redirect flow.
*/
setOnSessionExpired(handler: OnSessionExpired): void {
this.onSessionExpired = handler;
} }
/** Check if the client has an auth token */ /** Check if the client has an auth token */
@@ -258,6 +363,102 @@ export class SaaSClient {
return !!this.token; return !!this.token;
} }
/**
* Schedule a proactive token refresh at 80% of the token's remaining lifetime.
* Also registers a visibilitychange listener to re-check when the tab regains focus.
*/
scheduleTokenRefresh(): void {
this.cancelTokenRefresh();
if (!this.token) return;
const payload = decodeJwtPayload<JwtPayload>(this.token);
if (!payload?.exp) return;
const delay = getRefreshDelay(payload.exp);
if (delay === null) {
// Token already expired or too close -- attempt immediate refresh
this.attemptTokenRefresh();
return;
}
this.refreshTimerId = setTimeout(() => {
this.attemptTokenRefresh();
}, delay);
// When the tab becomes visible again, check if we should refresh sooner
if (typeof document !== 'undefined' && !this.visibilityHandler) {
this.visibilityHandler = () => {
if (document.visibilityState === 'visible') {
this.checkAndRefreshToken();
}
};
document.addEventListener('visibilitychange', this.visibilityHandler);
}
}
/**
* Cancel any pending token refresh timer and remove the visibility listener.
*/
cancelTokenRefresh(): void {
if (this.refreshTimerId !== null) {
clearTimeout(this.refreshTimerId);
this.refreshTimerId = null;
}
if (this.visibilityHandler !== null && typeof document !== 'undefined') {
document.removeEventListener('visibilitychange', this.visibilityHandler);
this.visibilityHandler = null;
}
}
/**
* Check if the current token is close to expiry and refresh if needed.
* Called on visibility change to handle clock skew / long background tabs.
*/
private checkAndRefreshToken(): void {
if (!this.token) return;
const payload = decodeJwtPayload<JwtPayload>(this.token);
if (!payload?.exp) return;
const now = Math.floor(Date.now() / 1000);
const remaining = payload.exp - now;
// If less than 20% of lifetime remains, refresh now
if (remaining <= 0) {
this.attemptTokenRefresh();
return;
}
// If the scheduled refresh is more than 60s away and we're within 80%, do it now
const delay = getRefreshDelay(payload.exp);
if (delay !== null && delay < 60_000) {
this.attemptTokenRefresh();
}
}
/**
* Attempt to refresh the token. On failure, invoke the session-expired callback.
* Persists the new token via secureStorage.
*/
private attemptTokenRefresh(): Promise<void> {
return this.refreshToken()
.then(async (newToken) => {
// Persist the new token to secure storage
const existing = await loadSaaSSessionAsync();
if (existing) {
await saveSaaSSessionAsync({ ...existing, token: newToken });
}
})
.catch(() => {
// Refresh failed -- notify the app to log out
this.cancelTokenRefresh();
if (this.onSessionExpired) {
this.onSessionExpired();
}
});
}
// --- Core HTTP --- // --- Core HTTP ---
/** Track whether the server appears reachable */ /** Track whether the server appears reachable */
@@ -436,7 +637,7 @@ export class SaaSClient {
/** /**
* Register or update this device with the SaaS backend. * Register or update this device with the SaaS backend.
* Uses UPSERT semantics same (account, device_id) updates last_seen_at. * Uses UPSERT semantics -- same (account, device_id) updates last_seen_at.
*/ */
async registerDevice(params: { async registerDevice(params: {
device_id: string; device_id: string;

View File

@@ -37,18 +37,9 @@ const log = createLogger('ConnectionStore');
// === Custom Models Helpers === // === Custom Models Helpers ===
const CUSTOM_MODELS_STORAGE_KEY = 'zclaw-custom-models'; import type { CustomModel } from '../types/config';
interface CustomModel { const CUSTOM_MODELS_STORAGE_KEY = 'zclaw-custom-models';
id: string;
name: string;
provider: string;
apiKey?: string;
apiProtocol: 'openai' | 'anthropic' | 'custom';
baseUrl?: string;
isDefault?: boolean;
createdAt: string;
}
/** /**
* Get custom models from localStorage * Get custom models from localStorage
@@ -218,8 +209,8 @@ export const useConnectionStore = create<ConnectionStore>((set, get) => {
// This takes priority over Tauri/Gateway when the user has selected SaaS mode. // This takes priority over Tauri/Gateway when the user has selected SaaS mode.
const savedMode = localStorage.getItem('zclaw-connection-mode'); const savedMode = localStorage.getItem('zclaw-connection-mode');
if (savedMode === 'saas') { if (savedMode === 'saas') {
const { loadSaaSSession, saasClient } = await import('../lib/saas-client'); const { loadSaaSSessionAsync, saasClient } = await import('../lib/saas-client');
const session = loadSaaSSession(); const session = await loadSaaSSessionAsync();
if (!session || !session.token || !session.saasUrl) { if (!session || !session.token || !session.saasUrl) {
throw new Error('SaaS 模式未登录,请先在设置中登录 SaaS 平台'); throw new Error('SaaS 模式未登录,请先在设置中登录 SaaS 平台');

View File

@@ -2,8 +2,8 @@
* SaaS Store - SaaS Platform Connection State Management * SaaS Store - SaaS Platform Connection State Management
* *
* Manages SaaS login state, account info, connection mode, * Manages SaaS login state, account info, connection mode,
* and available models. Persists auth state to localStorage * and available models. Persists auth token via secureStorage
* via saas-client helpers. * (OS keychain or encrypted localStorage) for security.
* *
* Connection modes: * Connection modes:
* - 'tauri': Local Kernel via Tauri (default) * - 'tauri': Local Kernel via Tauri (default)
@@ -15,9 +15,9 @@ import { create } from 'zustand';
import { import {
saasClient, saasClient,
SaaSApiError, SaaSApiError,
loadSaaSSession, loadSaaSSessionAsync,
saveSaaSSession, saveSaaSSessionAsync,
clearSaaSSession, clearSaaSSessionAsync,
saveConnectionMode, saveConnectionMode,
loadConnectionMode, loadConnectionMode,
type SaaSAccountInfo, type SaaSAccountInfo,
@@ -64,12 +64,12 @@ export interface SaaSActionsSlice {
login: (saasUrl: string, username: string, password: string) => Promise<void>; login: (saasUrl: string, username: string, password: string) => Promise<void>;
loginWithTotp: (saasUrl: string, username: string, password: string, totpCode: string) => Promise<void>; loginWithTotp: (saasUrl: string, username: string, password: string, totpCode: string) => Promise<void>;
register: (saasUrl: string, username: string, email: string, password: string, displayName?: string) => Promise<void>; register: (saasUrl: string, username: string, email: string, password: string, displayName?: string) => Promise<void>;
logout: () => void; logout: () => Promise<void>;
setConnectionMode: (mode: ConnectionMode) => void; setConnectionMode: (mode: ConnectionMode) => void;
fetchAvailableModels: () => Promise<void>; fetchAvailableModels: () => Promise<void>;
registerCurrentDevice: () => Promise<void>; registerCurrentDevice: () => Promise<void>;
clearError: () => void; clearError: () => void;
restoreSession: () => void; restoreSession: () => Promise<void>;
setupTotp: () => Promise<TotpSetupResponse>; setupTotp: () => Promise<TotpSetupResponse>;
verifyTotp: (code: string) => Promise<void>; verifyTotp: (code: string) => Promise<void>;
disableTotp: (password: string) => Promise<void>; disableTotp: (password: string) => Promise<void>;
@@ -85,33 +85,56 @@ const DEFAULT_SAAS_URL = 'https://saas.zclaw.com';
// === Helpers === // === Helpers ===
/** Determine the initial connection mode from persisted state */ /** Determine the initial connection mode from persisted state */
function resolveInitialMode(session: ReturnType<typeof loadSaaSSession>): ConnectionMode { function resolveInitialMode(hasSession: boolean): ConnectionMode {
const persistedMode = loadConnectionMode(); const persistedMode = loadConnectionMode();
if (persistedMode === 'tauri' || persistedMode === 'gateway' || persistedMode === 'saas') { if (persistedMode === 'tauri' || persistedMode === 'gateway' || persistedMode === 'saas') {
return persistedMode; return persistedMode;
} }
return session ? 'saas' : 'tauri'; return hasSession ? 'saas' : 'tauri';
} }
// === Store Implementation === // === Store Implementation ===
export const useSaaSStore = create<SaaSStore>((set, get) => { export const useSaaSStore = create<SaaSStore>((set, get) => {
// Restore session from localStorage on init // Determine initial connection mode synchronously from localStorage.
const session = loadSaaSSession(); // Session token will be loaded asynchronously via restoreSession().
const initialMode = resolveInitialMode(session); const persistedMode = loadConnectionMode();
const hasSession = persistedMode === 'saas';
const initialMode = resolveInitialMode(hasSession);
// If session exists, configure the singleton client // Kick off async session restoration immediately.
if (session) { // The store initializes with a "potentially logged in" state based on
saasClient.setBaseUrl(session.saasUrl); // the connection mode, and restoreSession() will either hydrate the token
saasClient.setToken(session.token); // or clear the session if secure storage has no token.
} loadSaaSSessionAsync().then((session) => {
if (session) {
saasClient.setBaseUrl(session.saasUrl);
saasClient.setToken(session.token);
set({
isLoggedIn: true,
account: session.account,
saasUrl: session.saasUrl,
authToken: session.token,
connectionMode: resolveInitialMode(true),
});
// Fetch models in background after async restore
get().fetchAvailableModels().catch(() => {});
} else if (persistedMode === 'saas') {
// Connection mode was 'saas' but no token found -- reset to tauri
saveConnectionMode('tauri');
set({ connectionMode: 'tauri' });
}
}).catch(() => {
// secureStorage read failed -- keep defaults
});
return { return {
// === Initial State === // === Initial State ===
isLoggedIn: session !== null, // Session data will be hydrated by the async restoreSession above.
account: session?.account ?? null, isLoggedIn: hasSession,
saasUrl: session?.saasUrl ?? DEFAULT_SAAS_URL, account: null,
authToken: session?.token ?? null, saasUrl: DEFAULT_SAAS_URL,
authToken: null,
connectionMode: initialMode, connectionMode: initialMode,
availableModels: [], availableModels: [],
isLoading: false, isLoading: false,
@@ -144,13 +167,13 @@ export const useSaaSStore = create<SaaSStore>((set, get) => {
saasClient.setBaseUrl(normalizedUrl); saasClient.setBaseUrl(normalizedUrl);
const loginData: SaaSLoginResponse = await saasClient.login(trimmedUsername, password); const loginData: SaaSLoginResponse = await saasClient.login(trimmedUsername, password);
// Persist session // Persist session securely
const sessionData = { const sessionData = {
token: loginData.token, token: loginData.token,
account: loginData.account, account: loginData.account,
saasUrl: normalizedUrl, saasUrl: normalizedUrl,
}; };
saveSaaSSession(sessionData); await saveSaaSSessionAsync(sessionData);
saveConnectionMode('saas'); saveConnectionMode('saas');
set({ set({
@@ -212,7 +235,7 @@ export const useSaaSStore = create<SaaSStore>((set, get) => {
account: loginData.account, account: loginData.account,
saasUrl: normalizedUrl, saasUrl: normalizedUrl,
}; };
saveSaaSSession(sessionData); await saveSaaSSessionAsync(sessionData);
saveConnectionMode('saas'); saveConnectionMode('saas');
set({ set({
@@ -273,7 +296,7 @@ export const useSaaSStore = create<SaaSStore>((set, get) => {
account: registerData.account, account: registerData.account,
saasUrl: normalizedUrl, saasUrl: normalizedUrl,
}; };
saveSaaSSession(sessionData); await saveSaaSSessionAsync(sessionData);
saveConnectionMode('saas'); saveConnectionMode('saas');
set({ set({
@@ -305,9 +328,9 @@ export const useSaaSStore = create<SaaSStore>((set, get) => {
} }
}, },
logout: () => { logout: async () => {
saasClient.setToken(null); saasClient.setToken(null);
clearSaaSSession(); await clearSaaSSessionAsync();
saveConnectionMode('tauri'); saveConnectionMode('tauri');
set({ set({
@@ -393,8 +416,8 @@ export const useSaaSStore = create<SaaSStore>((set, get) => {
set({ error: null }); set({ error: null });
}, },
restoreSession: () => { restoreSession: async () => {
const restored = loadSaaSSession(); const restored = await loadSaaSSessionAsync();
if (restored) { if (restored) {
saasClient.setBaseUrl(restored.saasUrl); saasClient.setBaseUrl(restored.saasUrl);
saasClient.setToken(restored.token); saasClient.setToken(restored.token);
@@ -430,7 +453,7 @@ export const useSaaSStore = create<SaaSStore>((set, get) => {
const account = await saasClient.me(); const account = await saasClient.me();
const { saasUrl, authToken } = get(); const { saasUrl, authToken } = get();
if (authToken) { if (authToken) {
saveSaaSSession({ token: authToken, account, saasUrl }); await saveSaaSSessionAsync({ token: authToken, account, saasUrl });
} }
set({ totpSetupData: null, isLoading: false, account }); set({ totpSetupData: null, isLoading: false, account });
} catch (err: unknown) { } catch (err: unknown) {
@@ -448,7 +471,7 @@ export const useSaaSStore = create<SaaSStore>((set, get) => {
const account = await saasClient.me(); const account = await saasClient.me();
const { saasUrl, authToken } = get(); const { saasUrl, authToken } = get();
if (authToken) { if (authToken) {
saveSaaSSession({ token: authToken, account, saasUrl }); await saveSaaSSessionAsync({ token: authToken, account, saasUrl });
} }
set({ isLoading: false, account }); set({ isLoading: false, account });
} catch (err: unknown) { } catch (err: unknown) {

View File

@@ -571,3 +571,35 @@ export interface ConfigFileMetadata {
/** Whether the file has unresolved env vars */ /** Whether the file has unresolved env vars */
hasUnresolvedEnvVars?: boolean; hasUnresolvedEnvVars?: boolean;
} }
// ============================================================
// Custom Model Types
// ============================================================
/**
* API protocol supported by a custom model provider.
*/
export type CustomModelApiProtocol = 'openai' | 'anthropic' | 'custom';
/**
* User-defined custom model configuration.
* Used by the model settings UI and the connection store.
*/
export interface CustomModel {
/** Unique identifier */
id: string;
/** Human-readable model name */
name: string;
/** Provider / vendor name */
provider: string;
/** API key (optional, stored separately in secure storage) */
apiKey?: string;
/** Which API protocol this provider speaks */
apiProtocol: CustomModelApiProtocol;
/** Base URL for the provider API (optional) */
baseUrl?: string;
/** Whether this model is the user's default */
isDefault?: boolean;
/** ISO-8601 timestamp of when this model was added */
createdAt: string;
}

View File

@@ -141,6 +141,12 @@ export type {
AutomationItem, AutomationItem,
} from './automation'; } from './automation';
// Custom Model Types
export type {
CustomModel,
CustomModelApiProtocol,
} from './config';
// Automation Constants and Functions // Automation Constants and Functions
export { export {
HAND_CATEGORY_MAP, HAND_CATEGORY_MAP,

76
docker-compose.yml Normal file
View File

@@ -0,0 +1,76 @@
# ============================================================
# ZCLAW SaaS Backend - Docker Compose
# ============================================================
# Usage:
# cp saas-env.example .env # then edit .env with real values
# docker compose up -d
# docker compose logs -f saas
# ============================================================
services:
# ---- PostgreSQL 16 ----
postgres:
image: postgres:16-alpine
container_name: zclaw-postgres
restart: unless-stopped
environment:
POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-your_secure_password}
POSTGRES_DB: ${POSTGRES_DB:-zclaw}
ports:
- "${POSTGRES_PORT:-5432}:5432"
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER:-postgres} -d ${POSTGRES_DB:-zclaw}"]
interval: 10s
timeout: 5s
retries: 5
start_period: 10s
networks:
- zclaw-saas
# ---- SaaS Backend ----
saas:
build:
context: .
dockerfile: Dockerfile
container_name: zclaw-saas
restart: unless-stopped
ports:
- "${SAAS_PORT:-8080}:8080"
env_file:
- saas-env.example
environment:
DATABASE_URL: postgres://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-your_secure_password}@postgres:5432/${POSTGRES_DB:-zclaw}
depends_on:
postgres:
condition: service_healthy
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:8080/health"]
interval: 30s
timeout: 5s
retries: 3
start_period: 15s
networks:
- zclaw-saas
volumes:
postgres_data:
driver: local
networks:
zclaw-saas:
driver: bridge

View File

@@ -0,0 +1,193 @@
# ZCLAW SaaS 后端部署指南
## 系统要求
| 组件 | 最低要求 | 推荐配置 |
|------|---------|---------|
| CPU | 2 核 | 4 核 |
| 内存 | 2 GB | 4 GB |
| 磁盘 | 10 GB SSD | 20 GB SSD |
| PostgreSQL | 15+ | 16 |
| Docker | 24+ | 最新 |
## 快速部署 (Docker Compose)
### 1. 准备配置
```bash
# 进入项目目录
cd zclaw-saas
# 复制环境变量模板
cp saas-env.example .env
# 编辑 .env填入实际值
# 必须修改: POSTGRES_PASSWORD, ZCLAW_SAAS_JWT_SECRET, ZCLAW_SAAS_FIELD_ENCRYPTION_KEY
```
### 2. 生成密钥
```bash
# JWT 密钥
openssl rand -base64 48
# AES-256-GCM 字段加密密钥
openssl rand -hex 32
```
### 3. 配置 CORS
编辑 `saas-config.toml`,设置允许的来源:
```toml
[server]
cors_origins = ["https://your-admin-domain.com", "https://your-app-domain.com"]
```
### 4. 启动服务
```bash
# 构建并启动
docker compose up -d --build
# 查看日志
docker compose logs -f saas
# 查看状态
docker compose ps
```
### 5. 验证部署
```bash
# 健康检查
curl http://localhost:8080/health
# API 版本
curl http://localhost:8080/api/v1/relay/models
```
## 手动部署 (无 Docker)
### 1. 安装依赖
- Rust 1.75+ (推荐 rustup)
- PostgreSQL 16
- OpenSSL 开发头文件 (`libssl-dev` on Ubuntu)
### 2. 数据库初始化
```bash
# 创建数据库
createdb zclaw
# 启动 SaaS 服务 (首次启动自动创建表结构)
ZCLAW_SAAS_JWT_SECRET=xxx \
ZCLAW_SAAS_FIELD_ENCRYPTION_KEY=xxx \
DATABASE_URL=postgres://user:pass@localhost:5432/zclaw \
cargo run --release --package zclaw-saas
```
### 3. 环境变量
| 变量 | 必需 | 说明 |
|------|------|------|
| `DATABASE_URL` | 是 | PostgreSQL 连接 URL |
| `ZCLAW_SAAS_JWT_SECRET` | 是 | JWT 签名密钥 (>=32 字符) |
| `ZCLAW_SAAS_FIELD_ENCRYPTION_KEY` | 是* | AES-256-GCM 密钥 (64 字符 hex) |
| `ZCLAW_SAAS_CONFIG` | 否 | 配置文件路径 (默认 `./saas-config.toml`) |
| `ZCLAW_SAAS_DEV` | 否 | 开发模式 (`true`/`1`) |
*生产环境必需。开发环境设置 `ZCLAW_SAAS_DEV=true` 可自动生成临时密钥。
### 4. 配置文件 (saas-config.toml)
```toml
[server]
host = "0.0.0.0"
port = 8080
cors_origins = ["https://admin.example.com"]
[database]
url = "postgres://user:pass@localhost:5432/zclaw"
[auth]
jwt_expiration_hours = 24
totp_issuer = "ZCLAW SaaS"
[relay]
max_queue_size = 1000
max_concurrent_per_provider = 5
batch_window_ms = 50
retry_delay_ms = 1000
max_attempts = 3
[rate_limit]
requests_per_minute = 60
burst = 10
```
## 安全加固清单
- [ ] `ZCLAW_SAAS_JWT_SECRET` 使用强随机密钥
- [ ] `ZCLAW_SAAS_FIELD_ENCRYPTION_KEY` 已设置 (数据库 API Key 加密)
- [ ] `ZCLAW_SAAS_DEV` 未设置或为 `false`
- [ ] `cors_origins` 配置为实际域名
- [ ] PostgreSQL 使用独立密码,不使用默认密码
- [ ] 防火墙仅开放 8080 端口
- [ ] HTTPS 反向代理 (Nginx/Caddy) 配置在 SaaS 前面
## Nginx 反向代理示例
```nginx
server {
listen 443 ssl;
server_name saas-api.example.com;
ssl_certificate /path/to/cert.pem;
ssl_certificate_key /path/to/key.pem;
location / {
proxy_pass http://127.0.0.1:8080;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# SSE 超时设置
proxy_read_timeout 300s;
proxy_buffering off;
}
}
```
## 运维命令
```bash
# 使用 Makefile
make saas-build # 编译
make saas-run # 启动
make saas-test # 运行测试
make saas-clippy # 代码检查
make saas-docker-up # Docker 启动
make saas-docker-down # Docker 停止
# 或手动
cargo build --release --package zclaw-saas
cargo run --release --package zclaw-saas
cargo test --package zclaw-saas
cargo clippy --package zclaw-saas
```
## 故障排查
| 问题 | 排查步骤 |
|------|---------|
| 启动失败 "DATABASE_URL 未配置" | 检查 `.env``DATABASE_URL` 是否设置 |
| 启动失败 "ZCLAW_SAAS_JWT_SECRET 未设置" | 设置环境变量或 `ZCLAW_SAAS_DEV=true` |
| 请求 429 Too Many Requests | 调整 `saas-config.toml``rate_limit` 配置 |
| 中转 502 Bad Gateway | 检查 provider URL 是否可达、API Key 是否有效 |
| SSE 流中断 | 检查反向代理超时设置,确保 `proxy_read_timeout >= 300s` |

View File

@@ -14,14 +14,14 @@ ZCLAW SaaS 平台为桌面端用户提供云端能力,包括模型中转、账
└── Mode C: SaaS Cloud ──→ Rust/Axum 后端 ──→ 上游 LLM Provider └── Mode C: SaaS Cloud ──→ Rust/Axum 后端 ──→ 上游 LLM Provider
├── Admin Web (Next.js 管理后台) ├── Admin Web (Next.js 管理后台)
└── SQLite WAL (数据持久化) └── PostgreSQL (数据持久化)
``` ```
## 技术栈 ## 技术栈
| 层级 | 技术 | 说明 | | 层级 | 技术 | 说明 |
|------|------|------| |------|------|------|
| 后端 | Rust + Axum + sqlx + SQLite WAL | JWT + API Token 双认证 | | 后端 | Rust + Axum + sqlx + PostgreSQL | JWT + API Token 双认证 |
| Admin | Next.js 14 + shadcn/ui + Tailwind | 暗色 OLED 主题 | | Admin | Next.js 14 + shadcn/ui + Tailwind | 暗色 OLED 主题 |
| 桌面端 | React 18 + Zustand + TypeScript | saas-client.ts HTTP 通信 | | 桌面端 | React 18 + Zustand + TypeScript | saas-client.ts HTTP 通信 |
| 安全 | argon2 + TOTP 2FA + RBAC | 速率限制 + 操作审计 | | 安全 | argon2 + TOTP 2FA + RBAC | 速率限制 + 操作审计 |

View File

@@ -51,8 +51,8 @@ ZCLAW 当前是纯桌面单用户应用缺少用户账号系统、API 服务
┌───────────────┐ ┌───────────────┐
SQLite (WAL) PostgreSQL
saas-data.db zclaw 数据库
└───────────────┘ └───────────────┘
┌──────────────────────────────────────────────────────┐ ┌──────────────────────────────────────────────────────┐
@@ -136,8 +136,8 @@ saas-admin/ # 独立 React 管理后台
### 3.1 概述 ### 3.1 概述
- 引擎: SQLite WAL 模式 - 引擎: PostgreSQL 16
- 文件: 独立于桌面端 `~/.zclaw/data.db`默认 `./saas-data.db` - 连接: 通过 `DATABASE_URL` 环境变量配置 (推荐) `saas-config.toml` 中指定
- 迁移: 版本化 schema启动时自动迁移 - 迁移: 版本化 schema启动时自动迁移
### 3.2 完整 Schema ### 3.2 完整 Schema
@@ -162,10 +162,10 @@ CREATE TABLE IF NOT EXISTS accounts (
role TEXT NOT NULL DEFAULT 'user', -- 'super_admin' | 'admin' | 'user' role TEXT NOT NULL DEFAULT 'user', -- 'super_admin' | 'admin' | 'user'
status TEXT NOT NULL DEFAULT 'active', -- 'active' | 'disabled' | 'suspended' status TEXT NOT NULL DEFAULT 'active', -- 'active' | 'disabled' | 'suspended'
totp_secret TEXT, -- 加密存储 totp_secret TEXT, -- 加密存储
totp_enabled INTEGER NOT NULL DEFAULT 0, totp_enabled BOOLEAN NOT NULL DEFAULT false,
last_login_at TEXT, last_login_at TIMESTAMPTZ,
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX IF NOT EXISTS idx_accounts_email ON accounts(email); CREATE INDEX IF NOT EXISTS idx_accounts_email ON accounts(email);
CREATE INDEX IF NOT EXISTS idx_accounts_role ON accounts(role); CREATE INDEX IF NOT EXISTS idx_accounts_role ON accounts(role);
@@ -177,10 +177,10 @@ CREATE TABLE IF NOT EXISTS api_tokens (
token_hash TEXT NOT NULL, -- SHA256(token) token_hash TEXT NOT NULL, -- SHA256(token)
token_prefix TEXT NOT NULL, -- 前 8 字符用于展示 token_prefix TEXT NOT NULL, -- 前 8 字符用于展示
permissions TEXT NOT NULL DEFAULT '[]', -- JSON 权限数组 permissions TEXT NOT NULL DEFAULT '[]', -- JSON 权限数组
last_used_at TEXT, last_used_at TIMESTAMPTZ,
expires_at TEXT, -- NULL = 永不过期 expires_at TIMESTAMPTZ, -- NULL = 永不过期
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TEXT, revoked_at TIMESTAMPTZ,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
); );
CREATE INDEX IF NOT EXISTS idx_api_tokens_account ON api_tokens(account_id); CREATE INDEX IF NOT EXISTS idx_api_tokens_account ON api_tokens(account_id);
@@ -191,9 +191,9 @@ CREATE TABLE IF NOT EXISTS roles (
name TEXT NOT NULL, -- 显示名称 (中文) name TEXT NOT NULL, -- 显示名称 (中文)
description TEXT, description TEXT,
permissions TEXT NOT NULL DEFAULT '[]', -- JSON 权限数组 permissions TEXT NOT NULL DEFAULT '[]', -- JSON 权限数组
is_system INTEGER NOT NULL DEFAULT 0, -- 系统角色不可删除 is_system BOOLEAN NOT NULL DEFAULT false, -- 系统角色不可删除
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE TABLE IF NOT EXISTS permission_templates ( CREATE TABLE IF NOT EXISTS permission_templates (
@@ -201,19 +201,19 @@ CREATE TABLE IF NOT EXISTS permission_templates (
name TEXT NOT NULL, -- e.g. "标准用户", "只读用户" name TEXT NOT NULL, -- e.g. "标准用户", "只读用户"
description TEXT, description TEXT,
permissions TEXT NOT NULL DEFAULT '[]', -- JSON 权限数组 permissions TEXT NOT NULL DEFAULT '[]', -- JSON 权限数组
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE TABLE IF NOT EXISTS operation_logs ( CREATE TABLE IF NOT EXISTS operation_logs (
id INTEGER PRIMARY KEY AUTOINCREMENT, id BIGSERIAL PRIMARY KEY,
account_id TEXT, -- NULL = 系统操作 account_id TEXT, -- NULL = 系统操作
action TEXT NOT NULL, -- e.g. "account.create", "model.update" action TEXT NOT NULL, -- e.g. "account.create", "model.update"
target_type TEXT, -- e.g. "account", "api_key", "model" target_type TEXT, -- e.g. "account", "api_key", "model"
target_id TEXT, target_id TEXT,
details TEXT, -- JSON 详情 details TEXT, -- JSON 详情
ip_address TEXT, ip_address TEXT,
created_at TEXT NOT NULL created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX IF NOT EXISTS idx_op_logs_account ON operation_logs(account_id); CREATE INDEX IF NOT EXISTS idx_op_logs_account ON operation_logs(account_id);
CREATE INDEX IF NOT EXISTS idx_op_logs_action ON operation_logs(action); CREATE INDEX IF NOT EXISTS idx_op_logs_action ON operation_logs(action);
@@ -230,12 +230,12 @@ CREATE TABLE IF NOT EXISTS providers (
api_key TEXT, -- 服务端提供商 API 密钥 (加密存储) api_key TEXT, -- 服务端提供商 API 密钥 (加密存储)
base_url TEXT NOT NULL, base_url TEXT NOT NULL,
api_protocol TEXT NOT NULL DEFAULT 'openai', -- 'openai' | 'anthropic' api_protocol TEXT NOT NULL DEFAULT 'openai', -- 'openai' | 'anthropic'
enabled INTEGER NOT NULL DEFAULT 1, enabled BOOLEAN NOT NULL DEFAULT true,
rate_limit_rpm INTEGER, -- 每分钟请求数 rate_limit_rpm INTEGER, -- 每分钟请求数
rate_limit_tpm INTEGER, -- 每分钟 token 数 rate_limit_tpm INTEGER, -- 每分钟 token 数
config_json TEXT DEFAULT '{}', -- 提供商特定配置 config_json TEXT DEFAULT '{}', -- 提供商特定配置
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE TABLE IF NOT EXISTS models ( CREATE TABLE IF NOT EXISTS models (
@@ -245,13 +245,13 @@ CREATE TABLE IF NOT EXISTS models (
alias TEXT NOT NULL, -- 显示名称 alias TEXT NOT NULL, -- 显示名称
context_window INTEGER NOT NULL DEFAULT 8192, context_window INTEGER NOT NULL DEFAULT 8192,
max_output_tokens INTEGER NOT NULL DEFAULT 4096, max_output_tokens INTEGER NOT NULL DEFAULT 4096,
supports_streaming INTEGER NOT NULL DEFAULT 1, supports_streaming BOOLEAN NOT NULL DEFAULT true,
supports_vision INTEGER NOT NULL DEFAULT 0, supports_vision BOOLEAN NOT NULL DEFAULT false,
enabled INTEGER NOT NULL DEFAULT 1, enabled BOOLEAN NOT NULL DEFAULT true,
pricing_input REAL DEFAULT 0, -- 每 1K token 价格 pricing_input DOUBLE PRECISION DEFAULT 0, -- 每 1K token 价格
pricing_output REAL DEFAULT 0, pricing_output DOUBLE PRECISION DEFAULT 0,
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL, updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(provider_id, model_id), UNIQUE(provider_id, model_id),
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
); );
@@ -264,18 +264,18 @@ CREATE TABLE IF NOT EXISTS account_api_keys (
key_value TEXT NOT NULL, -- API 密钥 (加密存储) key_value TEXT NOT NULL, -- API 密钥 (加密存储)
key_label TEXT, -- e.g. "主密钥", "备用密钥" key_label TEXT, -- e.g. "主密钥", "备用密钥"
permissions TEXT NOT NULL DEFAULT '[]', -- JSON: 可访问的模型 ID 列表 permissions TEXT NOT NULL DEFAULT '[]', -- JSON: 可访问的模型 ID 列表
enabled INTEGER NOT NULL DEFAULT 1, enabled BOOLEAN NOT NULL DEFAULT true,
last_used_at TEXT, last_used_at TIMESTAMPTZ,
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL, updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
revoked_at TEXT, revoked_at TIMESTAMPTZ,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE, FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
); );
CREATE INDEX IF NOT EXISTS idx_account_api_keys_account ON account_api_keys(account_id); CREATE INDEX IF NOT EXISTS idx_account_api_keys_account ON account_api_keys(account_id);
CREATE TABLE IF NOT EXISTS usage_records ( CREATE TABLE IF NOT EXISTS usage_records (
id INTEGER PRIMARY KEY AUTOINCREMENT, id BIGSERIAL PRIMARY KEY,
account_id TEXT NOT NULL, account_id TEXT NOT NULL,
provider_id TEXT NOT NULL, provider_id TEXT NOT NULL,
model_id TEXT NOT NULL, model_id TEXT NOT NULL,
@@ -284,7 +284,7 @@ CREATE TABLE IF NOT EXISTS usage_records (
latency_ms INTEGER, latency_ms INTEGER,
status TEXT NOT NULL DEFAULT 'success', -- 'success' | 'error' | 'rate_limited' status TEXT NOT NULL DEFAULT 'success', -- 'success' | 'error' | 'rate_limited'
error_message TEXT, error_message TEXT,
created_at TEXT NOT NULL created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX IF NOT EXISTS idx_usage_account ON usage_records(account_id); CREATE INDEX IF NOT EXISTS idx_usage_account ON usage_records(account_id);
CREATE INDEX IF NOT EXISTS idx_usage_time ON usage_records(created_at); CREATE INDEX IF NOT EXISTS idx_usage_time ON usage_records(created_at);
@@ -308,10 +308,10 @@ CREATE TABLE IF NOT EXISTS relay_tasks (
input_tokens INTEGER DEFAULT 0, input_tokens INTEGER DEFAULT 0,
output_tokens INTEGER DEFAULT 0, output_tokens INTEGER DEFAULT 0,
error_message TEXT, error_message TEXT,
queued_at TEXT NOT NULL, queued_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
started_at TEXT, started_at TIMESTAMPTZ,
completed_at TEXT, completed_at TIMESTAMPTZ,
created_at TEXT NOT NULL created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX IF NOT EXISTS idx_relay_status ON relay_tasks(status); CREATE INDEX IF NOT EXISTS idx_relay_status ON relay_tasks(status);
CREATE INDEX IF NOT EXISTS idx_relay_account ON relay_tasks(account_id); CREATE INDEX IF NOT EXISTS idx_relay_account ON relay_tasks(account_id);
@@ -330,15 +330,15 @@ CREATE TABLE IF NOT EXISTS config_items (
default_value TEXT, -- JSON 编码的默认值 default_value TEXT, -- JSON 编码的默认值
source TEXT NOT NULL DEFAULT 'local', -- 'local' | 'saas' | 'override' source TEXT NOT NULL DEFAULT 'local', -- 'local' | 'saas' | 'override'
description TEXT, -- 中文描述 description TEXT, -- 中文描述
requires_restart INTEGER NOT NULL DEFAULT 0, requires_restart BOOLEAN NOT NULL DEFAULT false,
created_at TEXT NOT NULL, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TEXT NOT NULL, updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(category, key_path) UNIQUE(category, key_path)
); );
CREATE INDEX IF NOT EXISTS idx_config_category ON config_items(category); CREATE INDEX IF NOT EXISTS idx_config_category ON config_items(category);
CREATE TABLE IF NOT EXISTS config_sync_log ( CREATE TABLE IF NOT EXISTS config_sync_log (
id INTEGER PRIMARY KEY AUTOINCREMENT, id BIGSERIAL PRIMARY KEY,
account_id TEXT NOT NULL, account_id TEXT NOT NULL,
client_fingerprint TEXT NOT NULL, client_fingerprint TEXT NOT NULL,
action TEXT NOT NULL, -- 'push' | 'pull' | 'conflict' action TEXT NOT NULL, -- 'push' | 'pull' | 'conflict'
@@ -346,7 +346,7 @@ CREATE TABLE IF NOT EXISTS config_sync_log (
client_values TEXT, -- JSON: 客户端值 client_values TEXT, -- JSON: 客户端值
saas_values TEXT, -- JSON: SaaS 值 saas_values TEXT, -- JSON: SaaS 值
resolution TEXT, -- 'client_wins' | 'saas_wins' | 'manual' resolution TEXT, -- 'client_wins' | 'saas_wins' | 'manual'
created_at TEXT NOT NULL created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
); );
CREATE INDEX IF NOT EXISTS idx_sync_account ON config_sync_log(account_id); CREATE INDEX IF NOT EXISTS idx_sync_account ON config_sync_log(account_id);
``` ```

View File

@@ -1,9 +1,13 @@
[server] [server]
host = "0.0.0.0" host = "0.0.0.0"
port = 8080 port = 8080
# 生产环境必须配置 cors_origins 白名单,开发环境可设置 ZCLAW_SAAS_DEV=true 绕过
# cors_origins = ["https://your-admin-domain.com"]
[database] [database]
url = "sqlite:./saas-data.db" # 使用 DATABASE_URL 环境变量覆盖此配置(推荐)
# 格式: postgres://user:password@localhost:5432/zclaw
url = "postgres://localhost:5432/zclaw"
[auth] [auth]
jwt_expiration_hours = 24 jwt_expiration_hours = 24
@@ -15,3 +19,7 @@ max_concurrent_per_provider = 5
batch_window_ms = 50 batch_window_ms = 50
retry_delay_ms = 1000 retry_delay_ms = 1000
max_attempts = 3 max_attempts = 3
[rate_limit]
requests_per_minute = 60
burst = 10

BIN
saas-data.db-shm Normal file

Binary file not shown.

0
saas-data.db-wal Normal file
View File

37
saas-env.example Normal file
View File

@@ -0,0 +1,37 @@
# ZCLAW SaaS 后端环境变量配置
# 复制此文件为 .env 并填入实际值: cp saas-env.example .env
# ===================== 必需配置 =====================
# PostgreSQL 数据库连接 URL
# 格式: postgres://user:password@host:5432/zclaw
DATABASE_URL=postgres://postgres:your_secure_password@localhost:5432/zclaw
# JWT 签名密钥 (至少 32 字符的随机字符串)
# 生成方式: openssl rand -base64 48
ZCLAW_SAAS_JWT_SECRET=your-secure-jwt-secret-at-least-32-chars
# AES-256-GCM 字段加密密钥 (32 字节 hex 编码64 字符)
# 用于加密数据库中存储的敏感字段 (如 API Key)
# 生产环境必须设置,密钥丢失将导致已加密数据无法恢复
# 生成方式: openssl rand -hex 32
# ZCLAW_SAAS_FIELD_ENCRYPTION_KEY=0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef
# ===================== 可选配置 =====================
# 配置文件路径 (默认: ./saas-config.toml)
# ZCLAW_SAAS_CONFIG=./saas-config.toml
# 开发模式 (绕过部分安全检查,仅限本地开发使用)
# ZCLAW_SAAS_DEV=true
# ===================== 管理员初始化 =====================
# 首次启动时自动创建超级管理员账户 (可选)
# ZCLAW_ADMIN_USERNAME=admin
# ZCLAW_ADMIN_PASSWORD=your-admin-password
# ===================== 测试配置 =====================
# 测试用数据库 URL (仅 cargo test 使用)
# ZCLAW_TEST_DATABASE_URL=postgres://postgres:your_secure_password@localhost:5432/zclaw_test

View File

@@ -1,5 +1,5 @@
# ZCLAW Full Stack Start Script # ZCLAW Full Stack Start Script
# Starts: ChromeDriver (optional) -> Tauri Desktop # Starts: SaaS Backend (optional) -> ChromeDriver (optional) -> Tauri Desktop
# #
# NOTE: ZCLAW now uses internal Kernel (zclaw-kernel) for all operations. # NOTE: ZCLAW now uses internal Kernel (zclaw-kernel) for all operations.
# No external ZCLAW runtime is required. # No external ZCLAW runtime is required.
@@ -9,7 +9,8 @@ param(
[switch]$Dev, [switch]$Dev,
[switch]$Help, [switch]$Help,
[switch]$Stop, [switch]$Stop,
[switch]$DesktopOnly [switch]$DesktopOnly,
[switch]$NoSaas
) )
$ErrorActionPreference = "Continue" $ErrorActionPreference = "Continue"
@@ -30,8 +31,9 @@ ZCLAW Full Stack Start Script
Usage: .\start-all.ps1 [options] Usage: .\start-all.ps1 [options]
Options: Options:
-DesktopOnly Start desktop only (skip ChromeDriver) -DesktopOnly Start desktop only (skip ChromeDriver + SaaS)
-NoBrowser Skip ChromeDriver startup -NoBrowser Skip ChromeDriver startup
-NoSaas Skip SaaS backend startup
-Dev Development mode (hot reload) -Dev Development mode (hot reload)
-Stop Stop all services -Stop Stop all services
-Help Show this help -Help Show this help
@@ -43,7 +45,7 @@ Note:
Quick Commands: Quick Commands:
pnpm start # Start all services pnpm start # Start all services
pnpm start:dev # Start in dev mode pnpm start:dev # Start in dev mode
pnpm start:desktop # Start desktop only (no browser) pnpm start:desktop # Start desktop only (no browser, no SaaS)
"@ "@
exit 0 exit 0
@@ -57,6 +59,17 @@ if ($Stop) {
Get-Process -Name "chromedriver" -ErrorAction SilentlyContinue | Stop-Process -Force Get-Process -Name "chromedriver" -ErrorAction SilentlyContinue | Stop-Process -Force
ok "ChromeDriver stopped" ok "ChromeDriver stopped"
# Stop SaaS backend
Get-Process -Name "zclaw-saas" -ErrorAction SilentlyContinue | Stop-Process -Force
$port8080 = netstat -ano | Select-String ":8080.*LISTENING"
if ($port8080) {
$pid8080 = ($port8080 -split '\s+')[-1]
if ($pid8080 -match '^\d+$') {
Stop-Process -Id $pid8080 -Force -ErrorAction SilentlyContinue
ok "Stopped SaaS backend on port 8080 (PID: $pid8080)"
}
}
# Stop any process on port 4200 (legacy, may still be in use) # Stop any process on port 4200 (legacy, may still be in use)
$port4200 = netstat -ano | Select-String ":4200.*LISTENING" $port4200 = netstat -ano | Select-String ":4200.*LISTENING"
if ($port4200) { if ($port4200) {
@@ -108,12 +121,56 @@ function Cleanup {
trap { Cleanup; break } trap { Cleanup; break }
Register-EngineEvent -SourceIdentifier PowerShell.Exiting -Action { Cleanup } | Out-Null Register-EngineEvent -SourceIdentifier PowerShell.Exiting -Action { Cleanup } | Out-Null
# Skip ChromeDriver if DesktopOnly # Skip ChromeDriver and SaaS if DesktopOnly
if ($DesktopOnly) { if ($DesktopOnly) {
$NoBrowser = $true $NoBrowser = $true
$NoSaas = $true
} }
# 1. ChromeDriver (optional - for Browser Hand automation) # 1. SaaS Backend (for cloud features: account, relay, config sync)
if (-not $NoSaas) {
info "Checking SaaS backend..."
# Check if port 8080 is already in use
$port8080 = netstat -ano | Select-String ":8080.*LISTENING"
if ($port8080) {
$pid8080 = ($port8080 -split '\s+')[-1]
if ($pid8080 -match '^\d+$') {
ok "SaaS backend already running on port 8080 (PID: $pid8080)"
}
} else {
# Check if zclaw-saas binary exists
$saasBin = "$ScriptDir\target\debug\zclaw-saas.exe"
$saasBinRelease = "$ScriptDir\target\release\zclaw-saas.exe"
$saasExe = if (Test-Path $saasBinRelease) { $saasBinRelease } elseif (Test-Path $saasBin) { $saasBin } else { $null }
if ($saasExe) {
ok "SaaS backend binary found: $saasExe"
info "Starting SaaS backend on port 8080..."
$env:ZCLAW_SAAS_DEV = "true"
$proc = Start-Process -FilePath $saasExe -PassThru -WindowStyle Minimized
$Jobs += $proc
Start-Sleep -Seconds 3
if ($proc.HasExited) {
err "SaaS backend exited unexpectedly. Run manually: cd $ScriptDir && ZCLAW_SAAS_DEV=true cargo run --bin zclaw-saas"
} else {
ok "SaaS backend started (PID: $($proc.Id))"
}
} else {
warn "SaaS backend binary not found. Building..."
info "Run: cd $ScriptDir && cargo build --bin zclaw-saas"
warn "SaaS cloud features will be unavailable. Start SaaS manually after build."
}
}
} else {
info "Skipping SaaS backend"
}
Write-Host ""
# 2. ChromeDriver (optional - for Browser Hand automation)
if (-not $NoBrowser) { if (-not $NoBrowser) {
info "Checking ChromeDriver..." info "Checking ChromeDriver..."
@@ -146,7 +203,7 @@ if (-not $NoBrowser) {
Write-Host "" Write-Host ""
# 2. Start Tauri Desktop # 3. Start Tauri Desktop
info "Starting ZCLAW Desktop..." info "Starting ZCLAW Desktop..."
Set-Location "$ScriptDir/desktop" Set-Location "$ScriptDir/desktop"

File diff suppressed because one or more lines are too long

View File

@@ -1 +1 @@
{"rustc_fingerprint":5915500824126575890,"outputs":{"7971740275564407648":{"success":true,"status":"","code":0,"stdout":"___.exe\nlib___.rlib\n___.dll\n___.dll\n___.lib\n___.dll\nC:\\Users\\szend\\.rustup\\toolchains\\stable-x86_64-pc-windows-msvc\npacked\n___\ndebug_assertions\npanic=\"unwind\"\nproc_macro\ntarget_abi=\"\"\ntarget_arch=\"x86_64\"\ntarget_endian=\"little\"\ntarget_env=\"msvc\"\ntarget_family=\"windows\"\ntarget_feature=\"cmpxchg16b\"\ntarget_feature=\"fxsr\"\ntarget_feature=\"sse\"\ntarget_feature=\"sse2\"\ntarget_feature=\"sse3\"\ntarget_has_atomic=\"128\"\ntarget_has_atomic=\"16\"\ntarget_has_atomic=\"32\"\ntarget_has_atomic=\"64\"\ntarget_has_atomic=\"8\"\ntarget_has_atomic=\"ptr\"\ntarget_os=\"windows\"\ntarget_pointer_width=\"64\"\ntarget_vendor=\"pc\"\nwindows\n","stderr":""},"17747080675513052775":{"success":true,"status":"","code":0,"stdout":"rustc 1.93.1 (01f6ddf75 2026-02-11)\nbinary: rustc\ncommit-hash: 01f6ddf7588f42ae2d7eb0a2f21d44e8e96674cf\ncommit-date: 2026-02-11\nhost: x86_64-pc-windows-msvc\nrelease: 1.93.1\nLLVM version: 21.1.8\n","stderr":""}},"successes":{}} {"rustc_fingerprint":5915500824126575890,"outputs":{"17747080675513052775":{"success":true,"status":"","code":0,"stdout":"rustc 1.93.1 (01f6ddf75 2026-02-11)\nbinary: rustc\ncommit-hash: 01f6ddf7588f42ae2d7eb0a2f21d44e8e96674cf\ncommit-date: 2026-02-11\nhost: x86_64-pc-windows-msvc\nrelease: 1.93.1\nLLVM version: 21.1.8\n","stderr":""},"7971740275564407648":{"success":true,"status":"","code":0,"stdout":"___.exe\nlib___.rlib\n___.dll\n___.dll\n___.lib\n___.dll\nC:\\Users\\szend\\.rustup\\toolchains\\stable-x86_64-pc-windows-msvc\npacked\n___\ndebug_assertions\npanic=\"unwind\"\nproc_macro\ntarget_abi=\"\"\ntarget_arch=\"x86_64\"\ntarget_endian=\"little\"\ntarget_env=\"msvc\"\ntarget_family=\"windows\"\ntarget_feature=\"cmpxchg16b\"\ntarget_feature=\"fxsr\"\ntarget_feature=\"sse\"\ntarget_feature=\"sse2\"\ntarget_feature=\"sse3\"\ntarget_has_atomic=\"128\"\ntarget_has_atomic=\"16\"\ntarget_has_atomic=\"32\"\ntarget_has_atomic=\"64\"\ntarget_has_atomic=\"8\"\ntarget_has_atomic=\"ptr\"\ntarget_os=\"windows\"\ntarget_pointer_width=\"64\"\ntarget_vendor=\"pc\"\nwindows\n","stderr":""}},"successes":{}}

View File

@@ -0,0 +1 @@
{"rustc_vv":"rustc 1.93.1 (01f6ddf75 2026-02-11)\nbinary: rustc\ncommit-hash: 01f6ddf7588f42ae2d7eb0a2f21d44e8e96674cf\ncommit-date: 2026-02-11\nhost: x86_64-pc-windows-msvc\nrelease: 1.93.1\nLLVM version: 21.1.8\n"}

0
target/doc/.lock Normal file
View File

2
target/doc/crates.js Normal file
View File

@@ -0,0 +1,2 @@
window.ALL_CRATES = ["totp_rs"];
//{"start":21,"fragment_lengths":[9]}

1
target/doc/help.html Normal file
View File

@@ -0,0 +1 @@
<!DOCTYPE html><html lang="en"><head><meta charset="utf-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><meta name="generator" content="rustdoc"><meta name="description" content="Documentation for Rustdoc"><title>Help</title><script>if(window.location.protocol!=="file:")document.head.insertAdjacentHTML("beforeend","SourceSerif4-Regular-6b053e98.ttf.woff2,FiraSans-Italic-81dc35de.woff2,FiraSans-Regular-0fe48ade.woff2,FiraSans-MediumItalic-ccf7e434.woff2,FiraSans-Medium-e1aa3f0a.woff2,SourceCodePro-Regular-8badfe75.ttf.woff2,SourceCodePro-Semibold-aa29a496.ttf.woff2".split(",").map(f=>`<link rel="preload" as="font" type="font/woff2"href="./static.files/${f}">`).join(""))</script><link rel="stylesheet" href="./static.files/normalize-9960930a.css"><link rel="stylesheet" href="./static.files/rustdoc-ca0dd0c4.css"><meta name="rustdoc-vars" data-root-path="./" data-static-root-path="./static.files/" data-current-crate="totp_rs" data-themes="" data-resource-suffix="" data-rustdoc-version="1.93.1 (01f6ddf75 2026-02-11)" data-channel="1.93.1" data-search-js="search-9e2438ea.js" data-stringdex-js="stringdex-a3946164.js" data-settings-js="settings-c38705f0.js" ><script src="./static.files/storage-e2aeef58.js"></script><script defer src="./static.files/main-a410ff4d.js"></script><noscript><link rel="stylesheet" href="./static.files/noscript-263c88ec.css"></noscript><link rel="alternate icon" type="image/png" href="./static.files/favicon-32x32-eab170b8.png"><link rel="icon" type="image/svg+xml" href="./static.files/favicon-044be391.svg"></head><body class="rustdoc mod sys"><!--[if lte IE 11]><div class="warning">This old browser is unsupported and will most likely display funky things.</div><![endif]--><rustdoc-topbar><h2><a href="#">All</a></h2></rustdoc-topbar><nav class="sidebar"><div class="sidebar-crate"><a class="logo-container" href="./index.html"><img class="rust-logo" src="./static.files/rust-logo-9a9549ea.svg" alt="logo"></a><h2><a href="./index.html">Rustdoc</a><span class="version">1.93.1</span></h2></div><div class="version">(01f6ddf75 2026-02-11)</div><h2 class="location">Help</h2><div class="sidebar-elems"></div></nav><div class="sidebar-resizer" title="Drag to resize sidebar"></div><main><div class="width-limiter"><section id="main-content" class="content"><div class="main-heading"><h1>Rustdoc help</h1><span class="out-of-band"><a id="back" href="javascript:void(0)" onclick="history.back();">Back</a></span></div><noscript><section><p>You need to enable JavaScript to use keyboard commands or search.</p><p>For more information, browse the <a href="https://doc.rust-lang.org/1.93.1/rustdoc/">rustdoc handbook</a>.</p></section></noscript></section></div></main></body></html>

View File

@@ -0,0 +1 @@
rn_("FQRAAACDXwBzAHQAdQB2AHcAeAB5AHoAc2AAewB8AH0AfgB/AIAAgQBzTQBkAGUAZgBnAGgAaQBqAAytAAJtAGZpcHN0CABbAA==")

View File

@@ -0,0 +1 @@
rn_("dQBAAAByhACSAJMAlACVAJYAlwCYAG1CAE4ATwBQAFEAUgBTAFQAQUAAAJAAogCjAK0AswD7AnJ3IUIAAKcAtAC2APsCZW8=")

View File

@@ -0,0 +1 @@
rn_("BQHAAAAKtQACkQBlaVFGAAClAK4ArwCwALEAsgCnAACFoFAAAAC1oJAAAACzoHAAAACrAsAG9EoAAAABAAwAAQA=")

View File

@@ -0,0 +1 @@
rn_("dQBAAAByhACSAJMAlACVAJYAlwCYAG1CAE4ATwBQAFEAUgBTAFQAQUAAAJAAogCjAK0AswABAAGgYAAAAKRyOzAAAAEAABEABAAfAAAAJgAHAF8AAABzAAcA+wNtcnc=")

View File

@@ -0,0 +1 @@
rn_("dQFAAAByhACSAJMAlACVAJYAlwCYAFJeAG4AbwBwAHEAcgBtbkIATgBPAFAAUQBSAFMAVAAFAEQAAAe3AGSfAFFDAAClAK4ArwCwALEAsgDzgQJpdAUBQQAAAWMAB6MAcnUkAAUAQAAAY2sAigCLAIwAjQCOAI8AciAAKwKgMAAAAKttdHUAQgAAcoQAkgCTAJQAlQCWAJcAmABtQgBOAE8AUABRAFIAUwBUACFDAACCAIMAhwCjAISgEAAAAFmgUAAAAIUBEQH1kAAAABIAAQAKAAYABQLAAAAAJQAEqwBTXgBuAG8AcABxAHIAbHN3VQFBAABDmQCaAJsAnACdAFelAK4ArwCwALEAsgBpdDkAPAA9AD4APwBAAAEAAaBgAAAApHI7MAAAAQAAEQAEAB8AAAAmAAcAXwAAAHMABwAFAcAAABFcAJ4AEqoAtwBlaQEAjaBAAAAAbWRBVk6hx6BgAAAAnqBAAAAAYqAAAAAAJWJv9PaXLw0csInIvjkW/rsltmfsxlIz6xdUJGJCF0X6vsIHLBPi/tc7YAxrmNMNQxp5HzswAAABAAAUAAUAIQAAAC4ABgBgAAAAewAGAJkABAA=")

View File

@@ -0,0 +1 @@
rn_("NQBAAAAAqQBzRQBGAFoAYQBVAUIAAEOZAJoAmwCcAJ0AV6UArgCvALAAsQCyAGl0OQA8AD0APgA/AEAAIUYAAIIAgwCHAAUBwAAAAIYAE5AAowBxcmcGAIegEAAAAFugcAAAAKagAAAAAKtRSQTzHQAAAAgAhgA=")

View File

@@ -0,0 +1 @@
rn_("AQAAOzAAAAEAABQABQAhAAAALgAGAGAAAAB7AAYAmQAEAAUAQwAAB7cAZJ8AYwADoBAAAABiZW5v86cAAAANAAIA")

View File

@@ -0,0 +1 @@
rn_("NQNCAAAFtQAJswADkQAHqwBicHN0SgBLAFcAWAAFAUAAAAanAAGoAGNlOwBTAISwQACqAA2gUAAAALQARAkBCgAAAA==")

View File

@@ -0,0 +1 @@
rd_("")

View File

@@ -0,0 +1 @@
rd_("gtotp_rs")

View File

@@ -0,0 +1 @@
rd_("CkWill check that to_bytes() returns the same. One secret \xe2\x80\xa6CjWill not check for issuer and account_name equality As \xe2\x80\xa6AmNon-encoded \xe2\x80\x9craw\xe2\x80\x9d secret.CmWill create a new instance of TOTP with given parameters. \xe2\x80\xa6BnGive the ttl (in seconds) of the current tokenmInvalid host.CcHMAC-SHA1 is the default algorithm of most TOTP \xe2\x80\xa6BeCouldn\xe2\x80\x99t decode step into a number.CnTOTP holds informations as to how to generate an auth code \xe2\x80\xa6AoReturns the argument unchanged.000000BaCalls <code>U::from(self)</code>.000000AmWill sign the given timestampCnNumber of steps allowed as network delay. 1 would mean one \xe2\x80\xa6ClDuration in seconds of a step. The recommended value per \xe2\x80\xa6CmWill check if token is valid given the provided timestamp \xe2\x80\xa6BaCharacters should only be digits.CaIssuer contains invalid character <code>:</code>.CiHMAC-SHA256. Supported in theory according to yubico. \xe2\x80\xa6CiHMAC-SHA512. Supported in theory according to yubico. \xe2\x80\xa6oInvalid scheme.AcWrong base32 input.CmShared secret between client and server to validate token \xe2\x80\xa6AlSet the <code>digits</code>.CmThe number of digits composing the auth code. Per rfc-4226\xe2\x80\xa6ChAs per rfc-4226 the secret should come from a strong \xe2\x80\xa6DoTry to transform a <code>Secret::Encoded</code> into a <code>Secret::Raw</code>AfBase32 encoded secret.Cbrfc-6238 compliant set of options to create a TOTPCmThis library permits the creation of 2FA authentification \xe2\x80\xa6CiWill generate a token given the provided timestamp in \xe2\x80\xa6BmGet the inner String value as a Vec of bytes.BkTry to create a TOTP from a Rfc6238 config.AbUnknown algorithm.CkAlgorithm enum holds the three standards algorithms for \xe2\x80\xa6CmSHA-1 is the most widespread algorithm used, and for totp \xe2\x80\xa6CkReturns the timestamp of the first second for the next stepBaInvalid secret size. (Too short?)EaTry to transforms a <code>Secret::Raw</code> into a <code>Secret::Encoded</code>.DkAccount name contains invalid character <code>:</code> or couldn\xe2\x80\x99t be \xe2\x80\xa6AeInvalid base32 input.BaDigits should be between 6 and 8.CgError returned when input is not compliant to rfc-6238.BiErrors returned mostly upon decoding URL.CmWill create a new instance of TOTP from the given Rfc6238 \xe2\x80\xa6ClImplementations MUST extract a 6-digit code at a minimum \xe2\x80\xa6ChWill check if token is valid by current system time, \xe2\x80\xa6CmWill create a new instance of TOTP with given parameters. \xe2\x80\xa6AiCouldn\xe2\x80\x99t decode issuer.CjThe length of the shared secret MUST be at least 128 bits.AkIssuers should be the same.BeDifferent ways secret parsing failed.BmGenerate a token from the current system timeCnWill return the base32 representation of the secret, which \xe2\x80\xa6CnReturns the timestamp of the first second of the next step \xe2\x80\xa6AnCouldn\xe2\x80\x99t parse account name.")

View File

@@ -0,0 +1 @@
rd_("Ah[99,13,100,28,163,135,0]Ah[99,13,100,88,180,135,0]Ag[99,13,100,88,76,135,0]Ai[99,13,100,145,164,135,0]Ai[99,13,100,100,132,135,0]Ah[99,13,100,100,37,135,0]Ae[99,15,100,88,76,0,0]Cg[99,13,100,28,163,92,0,\"impl-Display-for-Rfc6238Error\"]Ce[99,13,100,28,163,59,0,\"impl-Debug-for-Rfc6238Error\"]Af[99,13,100,28,94,59,0]Ck[99,13,100,88,180,92,0,\"impl-Display-for-SecretParseError\"]Ci[99,13,100,88,180,59,0,\"impl-Debug-for-SecretParseError\"]C`[99,13,100,88,76,92,0,\"impl-Display-for-Secret\"]Bn[99,13,100,88,76,59,0,\"impl-Debug-for-Secret\"]Cf[99,13,100,145,164,59,0,\"impl-Debug-for-TotpUrlError\"]Ch[99,13,100,145,164,92,0,\"impl-Display-for-TotpUrlError\"]Ce[99,13,100,100,132,92,0,\"impl-Display-for-Algorithm\"]Cc[99,13,100,100,132,59,0,\"impl-Debug-for-Algorithm\"]Bm[99,13,100,100,37,59,0,\"impl-Debug-for-TOTP\"]Bo[99,13,100,100,37,92,0,\"impl-Display-for-TOTP\"]Ae[99,13,100,28,94,0,0]Af[99,13,100,100,37,0,0]0Ag[99,15,100,145,164,0,0]Ag[99,15,100,100,132,0,0]1Ab[99,5,100,0,0,0,0]Ag[99,13,100,28,163,32,0]Af[99,13,100,28,94,32,0]Ag[99,13,100,88,180,32,0]Af[99,13,100,88,76,32,0]Ah[99,13,100,145,164,32,0]0Ah[99,13,100,100,132,32,0]Ag[99,13,100,100,37,32,0]Ag[99,13,100,28,163,34,0]Af[99,13,100,28,94,34,0]Ag[99,13,100,88,180,34,0]Af[99,13,100,88,76,34,0]Ah[99,13,100,145,164,34,0]Ah[99,13,100,100,132,34,0]Ag[99,13,100,100,37,34,0]Af[99,13,100,100,37,0,0]Af[99,14,100,100,37,0,0]01Af[99,13,100,28,94,58,0]Ag[99,13,100,88,180,58,0]Af[99,13,100,88,76,58,0]Ah[99,13,100,100,132,58,0]Ag[99,13,100,100,37,58,0]Ag[99,15,100,145,164,0,0]0Ag[99,15,100,100,132,0,0]011Ac[99,6,100,88,0,0,0]Ag[99,13,100,28,163,67,0]Af[99,13,100,28,94,67,0]Ag[99,13,100,88,180,67,0]Af[99,13,100,88,76,67,0]Ah[99,13,100,145,164,67,0]Ah[99,13,100,100,132,67,0]Ag[99,13,100,100,37,67,0]Ae[99,13,100,28,94,0,0]Af[99,14,100,100,37,0,0]0Ae[99,13,100,88,76,0,0]Ae[99,15,100,88,76,0,0]Ac[99,5,100,28,0,0,0]Ah[99,13,100,100,132,91,0]A`[99,3,0,0,0,0,0]Af[99,13,100,28,163,9,0]Ae[99,13,100,28,94,9,0]Af[99,13,100,88,180,9,0]Ae[99,13,100,88,76,9,0]Ag[99,13,100,145,164,9,0]Ag[99,13,100,100,132,9,0]Af[99,13,100,100,37,9,0]Af[99,13,100,100,37,0,0]<Af[99,13,100,28,94,95,0]Ag[99,13,100,88,180,95,0]Af[99,13,100,88,76,95,0]Ah[99,13,100,100,132,95,0]Ag[99,13,100,100,37,95,0]Ag[99,13,100,28,163,96,0]Af[99,13,100,28,94,96,0]Ag[99,13,100,88,180,96,0]Af[99,13,100,88,76,96,0]Ah[99,13,100,145,164,96,0]Ah[99,13,100,100,132,96,0]Ag[99,13,100,100,37,96,0]0Ag[99,13,100,28,163,97,0]Af[99,13,100,28,94,97,0]Ag[99,13,100,88,180,97,0]Af[99,13,100,88,76,97,0]Ah[99,13,100,145,164,97,0]Ah[99,13,100,100,132,97,0]Ag[99,13,100,100,37,97,0]Ag[99,15,100,145,164,0,0]Ab[99,6,100,0,0,0,0]Af[99,14,100,100,37,0,0]Af[99,13,100,100,37,0,0]Ah[99,13,100,28,163,108,0]Ah[99,13,100,88,180,108,0]Ag[99,13,100,88,76,108,0]Ai[99,13,100,145,164,108,0]Ai[99,13,100,100,132,108,0]Ah[99,13,100,100,37,108,0]9Ah[99,13,100,28,163,133,0]Ag[99,13,100,28,94,133,0]Ah[99,13,100,88,180,133,0]Ag[99,13,100,88,76,133,0]Ai[99,13,100,145,164,133,0]Ai[99,13,100,100,132,133,0]Ah[99,13,100,100,37,133,0]Af[99,13,100,28,94,95,0]Ag[99,13,100,88,180,95,0]Af[99,13,100,88,76,95,0]Ah[99,13,100,100,132,95,0]Ag[99,13,100,100,37,95,0]Ae[99,13,100,88,76,0,0]Ag[99,15,100,145,164,0,0]Af[99,15,100,88,180,0,0]1Ac[99,6,100,28,0,0,0]Ad[99,6,100,145,0,0,0]Af[99,13,100,100,37,0,0]Af[99,15,100,28,163,0,0]11Ae[99,13,100,28,94,0,0]616Ag[99,13,100,28,94,166,0]Ah[99,13,100,88,180,166,0]Ag[99,13,100,88,76,166,0]Ai[99,13,100,100,132,166,0]Ah[99,13,100,100,37,166,0]Ac[99,6,100,88,0,0,0]888<")

View File

@@ -0,0 +1 @@
rd_("Ba[\"{{{AAd{ADf}}{AAd{ADf}}}Dl}\",[]]Ba[\"{{{AAd{AFh}}{AAd{AFh}}}Dl}\",[]]Ao[\"{{{AAd{Ih}}{AAd{Ih}}}Dl}\",[]]Ba[\"{{{AAd{ADh}}{AAd{ADh}}}Dl}\",[]]Ba[\"{{{AAd{A@h}}{AAd{A@h}}}Dl}\",[]]Ao[\"{{{AAd{Dj}}{AAd{Dj}}}Dl}\",[]]Bc[\"{{{AAd{ADf}}{AAd{CbA@l}}}Hl}\",[]]0Bb[\"{{{AAd{Kl}}{AAd{CbA@l}}}Hl}\",[]]Bc[\"{{{AAd{AFh}}{AAd{CbA@l}}}Hl}\",[]]0Bb[\"{{{AAd{Ih}}{AAd{CbA@l}}}Hl}\",[]]0Bc[\"{{{AAd{ADh}}{AAd{CbA@l}}}Hl}\",[]]0Bc[\"{{{AAd{A@h}}{AAd{CbA@l}}}Hl}\",[]]0Bb[\"{{{AAd{Dj}}{AAd{CbA@l}}}Hl}\",[]]0Bb[\"{{Hd{Af{A`}}}{{Hn{KlADf}}}}\",[]]Bi[\"{{A@hHdA`Cn{Af{A`}}}{{Hn{DjADh}}}}\",[]]Ba[\"{{{AAd{Dj}}}{{Hn{CnAEl}}}}\",[]]A`[\"{cc{}}\",[\"T\"]]000o[\"{ADfADh}\",[]]111Aa[\"{{}c{}}\",[\"U\"]]000000B`[\"{{{AAd{Dj}}Cn}{{Af{A`}}}}\",[]]m[\"{DjA`}\",[]]m[\"{DjCn}\",[]]Ba[\"{{{AAd{Dj}}{AAd{Cj}}Cn}Dl}\",[]]Af[\"{{{AAd{Kl}}}Kl}\",[]]Ah[\"{{{AAd{AFh}}}AFh}\",[]]Af[\"{{{AAd{Ih}}}Ih}\",[]]Ah[\"{{{AAd{A@h}}}A@h}\",[]]Af[\"{{{AAd{Dj}}}Dj}\",[]]Ak[\"{AAd{{AAd{c}}}{}}\",[\"T\"]]000000Be[\"{{{AAd{CbKl}}Hd}{{Hn{GbADf}}}}\",[]]m[\"{DjHd}\",[]]m[\"{DjAf}\",[]]Ba[\"{{{AAd{Ih}}}{{Hn{IhAFh}}}}\",[]]n[\"{{}A@h}\",[]]n[\"{AAdIl}\",[]]000000Ah[\"{{{AAd{Dj}}Cn}Ij}\",[]]Bg[\"{{{AAd{Ih}}}{{Hn{{Af{A`}}AFh}}}}\",[]]Ab[\"{AAdc{}}\",[\"T\"]]0000An[\"{c{{Hn{e}}}{}{}}\",[\"U\",\"T\"]]000000Bh[\"{Kl{{Hn{Djc}}}{}}\",[\"TryFrom::Error\"]]Ai[\"{{}{{Hn{c}}}{}}\",[\"U\"]]000000n[\"{DjA@h}\",[]]Ah[\"{{{AAd{Dj}}Cn}Cn}\",[]]n[\"{AAdIj}\",[]]00000Be[\"{{{AAd{Cb}}}{{AAd{Cbc}}}{}}\",[\"T\"]]000000Ao[\"{{AAd{AAd{Cbc}}}Gb{}}\",[\"T\"]]0000Af[\"{{{AAd{Ih}}}Ih}\",[]]Ah[\"{Kl{{Hn{DjADh}}}}\",[]]Bj[\"{{{AAd{Dj}}{AAd{Cj}}}{{Hn{DlAEl}}}}\",[]]An[\"{{A@hHdA`Cn{Af{A`}}}Dj}\",[]]B`[\"{{{Af{A`}}}{{Hn{KlADf}}}}\",[]]Aj[\"{{AAd{Lf{CbA`}}}Gb}\",[]]0000Ba[\"{{{AAd{Dj}}}{{Hn{IjAEl}}}}\",[]]Af[\"{{{AAd{Dj}}}Ij}\",[]]Ba[\"{{{AAd{Dj}}}{{Hn{CnAEl}}}}\",[]]")

View File

@@ -0,0 +1 @@
rb_("RWIAOzAAAAEAAAYAAQAuAAYAOzAAAAEAABIABAAmAAMAKwACAG4ABAB7AAYAOzAAAAEAAA0AAgBOAAYAcwAGAAF6AAAAOzAAAAEAAAsAAQCSAAsAQWIAAAA7MAAAAQAABgABAHMABgA=")

View File

@@ -0,0 +1 @@
rd_("b()beq00000bu8cAnycRawcVeccfmt000000000000cmutcnew0crfccstrcttlcu64dFromdHostdIntodSHA1dStepdTOTPdbooldfrom0000000dinto000000dsigndskewdstepduniteCloneeDebugecheckeclone0000eusizefBorrowfDigitsfIssuerfResult0fSHA256fSHA512fSchemefSecret0fStringfTypeIdfborrow000000fdigits0fsecret0fto_rawgDefaultgDisplaygEncodedgRfc6238gToOwnedgTryFromgTryIntogdefaultgpointergtotp_rsgtype_id000000hToStringhgeneratehto_byteshto_owned0000htry_from0000000htry_into000000iAlgorithm0iBorrowMutiFormatteriPartialEqialgorithminext_stepireferenceito_string00000iurl_errorjSecretSizejborrow_mut000000jclone_into0000jto_encodedkAccountNamekParseBase32lDigitsNumberlRfc6238ErrorlTotpUrlErrorlfrom_rfc6238mCloneToUninitmInvalidDigitsmcheck_currentmnew_uncheckedmwith_defaultsnIssuerDecodingnSecretTooSmalloIssuerMistmatchoSystemTimeErroroclone_to_uninit0000A`SecretParseErrorA`generate_currentAaget_secret_base32Aanext_step_currentAcAccountNameDecoding")

View File

@@ -0,0 +1 @@
rd_("b()beq00000bu8canycrawcveccfmt000000000000cmutcnew0crfccstrcttlcu64dfromdhostdintodsha1dstepdtotpdbool666666664444444dsigndskew4dunitecloneedebugecheck22222eusizefborrowfdigitsfissuerfresult0fsha256fsha512fschemefsecret0fstringftypeid99999998822etorawgdefaultgdisplaygencodedgrfc6238gtoownedgtryfromgtryinto6gpointerftotprs:::::::htostringhgenerategtobytes77777666666665555555ialgorithm0iborrowmutiformatteripartialeq3hnextstepireference888888hurlerrorjsecretsize6666666icloneinto0000itoencodedkaccountnamekparsebase32ldigitsnumberlrfc6238errorltotpurlerrorkfromrfc6238mclonetouninitminvaliddigitslcheckcurrentlnewuncheckedlwithdefaultsnissuerdecodingnsecrettoosmalloissuermistmatchosystemtimeerror88888A`secretparseerrorogeneratecurrentogetsecretbase32onextstepcurrentAcaccountnamedecoding")

Some files were not shown because too many files have changed in this diff Show More