Compare commits
1 Commits
13c0b18bbc
...
worktree-s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
44256a511c |
Binary file not shown.
Binary file not shown.
Submodule .claude/worktrees/saas-backend deleted from 4d8d560d1f
93
.dockerignore
Normal file
93
.dockerignore
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# ============================================================
|
||||||
|
# ZCLAW SaaS Backend - Docker Ignore
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Build artifacts
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Frontend applications (not needed for SaaS backend)
|
||||||
|
desktop/
|
||||||
|
admin/
|
||||||
|
design-system/
|
||||||
|
|
||||||
|
# Node.js
|
||||||
|
node_modules/
|
||||||
|
.pnpm-store/
|
||||||
|
bun.lock
|
||||||
|
pnpm-lock.yaml
|
||||||
|
package.json
|
||||||
|
package-lock.json
|
||||||
|
|
||||||
|
# Git
|
||||||
|
.git/
|
||||||
|
.gitignore
|
||||||
|
|
||||||
|
# IDE and editor
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# OS files
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
.docker/
|
||||||
|
docker-compose*.yml
|
||||||
|
Dockerfile
|
||||||
|
.dockerignore
|
||||||
|
|
||||||
|
# Documentation
|
||||||
|
docs/
|
||||||
|
*.md
|
||||||
|
!saas-config.toml
|
||||||
|
CLAUDE.md
|
||||||
|
CLAUDE*.md
|
||||||
|
|
||||||
|
# Environment files (secrets)
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
saas-env.example
|
||||||
|
|
||||||
|
# Data files
|
||||||
|
saas-data/
|
||||||
|
saas-data.db
|
||||||
|
saas-data.db-shm
|
||||||
|
saas-data.db-wal
|
||||||
|
*.db
|
||||||
|
*.db-shm
|
||||||
|
*.db-wal
|
||||||
|
|
||||||
|
# Test artifacts
|
||||||
|
tests/
|
||||||
|
test-results/
|
||||||
|
test.rs
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Temporary files
|
||||||
|
tmp-screenshot.png
|
||||||
|
tmp/
|
||||||
|
temp/
|
||||||
|
*.tmp
|
||||||
|
|
||||||
|
# Claude worktree metadata
|
||||||
|
.claude/
|
||||||
|
plans/
|
||||||
|
pipelines/
|
||||||
|
scripts/
|
||||||
|
hands/
|
||||||
|
skills/
|
||||||
|
plugins/
|
||||||
|
config/
|
||||||
|
extract.js
|
||||||
|
extract_models.js
|
||||||
|
extract_privacy.js
|
||||||
|
start-all.ps1
|
||||||
|
start.ps1
|
||||||
|
start.sh
|
||||||
|
Makefile
|
||||||
|
PROGRESS.md
|
||||||
|
CHANGELOG.md
|
||||||
|
pencil-new.pen
|
||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -12,10 +12,6 @@ build/
|
|||||||
.env.local
|
.env.local
|
||||||
.env.*.local
|
.env.*.local
|
||||||
|
|
||||||
# SaaS config (contains database credentials)
|
|
||||||
saas-config.toml
|
|
||||||
!saas-config.toml.example
|
|
||||||
|
|
||||||
# Logs
|
# Logs
|
||||||
logs/
|
logs/
|
||||||
*.log
|
*.log
|
||||||
|
|||||||
37
CLAUDE.md
37
CLAUDE.md
@@ -36,20 +36,17 @@ ZCLAW/
|
|||||||
│ ├── zclaw-kernel/ # L4: 核心协调 (注册, 调度, 事件, 工作流)
|
│ ├── zclaw-kernel/ # L4: 核心协调 (注册, 调度, 事件, 工作流)
|
||||||
│ ├── zclaw-skills/ # 技能系统 (SKILL.md解析, 执行器)
|
│ ├── zclaw-skills/ # 技能系统 (SKILL.md解析, 执行器)
|
||||||
│ ├── zclaw-hands/ # 自主能力 (Hand/Trigger 注册管理)
|
│ ├── zclaw-hands/ # 自主能力 (Hand/Trigger 注册管理)
|
||||||
│ ├── zclaw-protocols/ # 协议支持 (MCP, A2A)
|
│ ├── zclaw-channels/ # 通道适配器 (仅 ConsoleChannel 测试适配器)
|
||||||
│ └── zclaw-saas/ # SaaS 后端 (账号, 模型配置, 中转, 配置同步)
|
│ └── zclaw-protocols/ # 协议支持 (MCP, A2A)
|
||||||
├── admin/ # Next.js 管理后台
|
|
||||||
├── desktop/ # Tauri 桌面应用
|
├── desktop/ # Tauri 桌面应用
|
||||||
│ ├── src/
|
│ ├── src/
|
||||||
│ │ ├── components/ # React UI 组件 (含 SaaS 集成)
|
│ │ ├── components/ # React UI 组件
|
||||||
│ │ ├── store/ # Zustand 状态管理 (含 saasStore)
|
│ │ ├── store/ # Zustand 状态管理
|
||||||
│ │ └── lib/ # 客户端通信 / 工具函数 (含 saas-client)
|
│ │ └── lib/ # 客户端通信 / 工具函数
|
||||||
│ └── src-tauri/ # Tauri Rust 后端 (集成 Kernel)
|
│ └── src-tauri/ # Tauri Rust 后端 (集成 Kernel)
|
||||||
├── skills/ # SKILL.md 技能定义
|
├── skills/ # SKILL.md 技能定义
|
||||||
├── hands/ # HAND.toml 自主能力配置
|
├── hands/ # HAND.toml 自主能力配置
|
||||||
├── config/ # TOML 配置文件
|
├── config/ # TOML 配置文件
|
||||||
├── saas-config.toml # SaaS 后端配置 (PostgreSQL 连接等)
|
|
||||||
├── docker-compose.yml # PostgreSQL 容器配置
|
|
||||||
├── docs/ # 架构文档和知识库
|
├── docs/ # 架构文档和知识库
|
||||||
└── tests/ # Vitest 回归测试
|
└── tests/ # Vitest 回归测试
|
||||||
```
|
```
|
||||||
@@ -69,9 +66,7 @@ ZCLAW/
|
|||||||
| 桌面框架 | Tauri 2.x |
|
| 桌面框架 | Tauri 2.x |
|
||||||
| 样式方案 | Tailwind CSS |
|
| 样式方案 | Tailwind CSS |
|
||||||
| 配置格式 | TOML |
|
| 配置格式 | TOML |
|
||||||
| 后端核心 | Rust Workspace (9 crates) |
|
| 后端核心 | Rust Workspace (8 crates) |
|
||||||
| SaaS 后端 | Axum + PostgreSQL (zclaw-saas) |
|
|
||||||
| 管理后台 | Next.js (admin/) |
|
|
||||||
|
|
||||||
### 2.3 Crate 依赖关系
|
### 2.3 Crate 依赖关系
|
||||||
|
|
||||||
@@ -84,9 +79,7 @@ zclaw-runtime (→ types, memory)
|
|||||||
↑
|
↑
|
||||||
zclaw-kernel (→ types, memory, runtime)
|
zclaw-kernel (→ types, memory, runtime)
|
||||||
↑
|
↑
|
||||||
zclaw-saas (→ types, 独立运行于 8080 端口)
|
desktop/src-tauri (→ kernel, skills, hands, channels, protocols)
|
||||||
↑
|
|
||||||
desktop/src-tauri (→ kernel, skills, hands, protocols)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
***
|
***
|
||||||
@@ -198,10 +191,10 @@ ZCLAW 提供 11 个自主能力包:
|
|||||||
| Predictor | 预测分析 | ❌ 已禁用 (enabled=false),无 Rust 实现 |
|
| Predictor | 预测分析 | ❌ 已禁用 (enabled=false),无 Rust 实现 |
|
||||||
| Lead | 销售线索发现 | ❌ 已禁用 (enabled=false),无 Rust 实现 |
|
| Lead | 销售线索发现 | ❌ 已禁用 (enabled=false),无 Rust 实现 |
|
||||||
| Clip | 视频处理 | ⚠️ 需 FFmpeg |
|
| Clip | 视频处理 | ⚠️ 需 FFmpeg |
|
||||||
| Twitter | Twitter 自动化 | ✅ 可用(12 个 API v2 真实调用,写操作需 OAuth 1.0a) |
|
| Twitter | Twitter 自动化 | ⚠️ 需 API Key |
|
||||||
| Whiteboard | 白板演示 | ✅ 可用(导出功能开发中,标注 demo) |
|
| Whiteboard | 白板演示 | ✅ 可用(导出功能开发中,标注 demo) |
|
||||||
| Slideshow | 幻灯片生成 | ✅ 可用 |
|
| Slideshow | 幻灯片生成 | ✅ 可用 |
|
||||||
| Speech | 语音合成 | ✅ 可用(Browser TTS 前端集成完成) |
|
| Speech | 语音合成 | ✅ 可用 |
|
||||||
| Quiz | 测验生成 | ✅ 可用 |
|
| Quiz | 测验生成 | ✅ 可用 |
|
||||||
|
|
||||||
**触发 Hand 时:**
|
**触发 Hand 时:**
|
||||||
@@ -267,18 +260,6 @@ docs/
|
|||||||
- **面向未来** - 文档要帮助未来的开发者快速理解
|
- **面向未来** - 文档要帮助未来的开发者快速理解
|
||||||
- **中文优先** - 所有面向用户的文档使用中文
|
- **中文优先** - 所有面向用户的文档使用中文
|
||||||
|
|
||||||
### 8.3 完成工作后的文档同步(强制)
|
|
||||||
|
|
||||||
每次完成功能实现、架构变更、问题修复后,**必须**同步更新以下文档:
|
|
||||||
|
|
||||||
1. **CLAUDE.md** — 如果涉及项目结构、技术栈、工作流程、命令的变化
|
|
||||||
2. **docs/features/** — 如果涉及新功能、功能变更、功能状态更新
|
|
||||||
3. **docs/knowledge-base/** — 如果涉及新知识、故障排查经验、配置说明
|
|
||||||
4. **saas-config.toml 注释** — 如果涉及 SaaS 配置项变更
|
|
||||||
5. **CHANGELOG** — 如果涉及对外可见的行为变化
|
|
||||||
|
|
||||||
**执行时机:** 代码编译通过且验证成功后,在标记任务完成之前,立即执行文档更新。文档更新是任务完成的必要条件,不是可选步骤。
|
|
||||||
|
|
||||||
***
|
***
|
||||||
|
|
||||||
## 9. 常见问题排查
|
## 9. 常见问题排查
|
||||||
|
|||||||
1087
Cargo.lock
generated
1087
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
10
Cargo.toml
10
Cargo.toml
@@ -9,6 +9,7 @@ members = [
|
|||||||
# ZCLAW Extension Crates
|
# ZCLAW Extension Crates
|
||||||
"crates/zclaw-skills",
|
"crates/zclaw-skills",
|
||||||
"crates/zclaw-hands",
|
"crates/zclaw-hands",
|
||||||
|
"crates/zclaw-channels",
|
||||||
"crates/zclaw-protocols",
|
"crates/zclaw-protocols",
|
||||||
"crates/zclaw-pipeline",
|
"crates/zclaw-pipeline",
|
||||||
"crates/zclaw-growth",
|
"crates/zclaw-growth",
|
||||||
@@ -56,7 +57,7 @@ chrono = { version = "0.4", features = ["serde"] }
|
|||||||
uuid = { version = "1", features = ["v4", "v5", "serde"] }
|
uuid = { version = "1", features = ["v4", "v5", "serde"] }
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite", "postgres"] }
|
sqlx = { version = "0.7", features = ["runtime-tokio", "postgres", "chrono"] }
|
||||||
libsqlite3-sys = { version = "0.27", features = ["bundled"] }
|
libsqlite3-sys = { version = "0.27", features = ["bundled"] }
|
||||||
|
|
||||||
# HTTP client (for LLM drivers)
|
# HTTP client (for LLM drivers)
|
||||||
@@ -93,10 +94,6 @@ regex = "1"
|
|||||||
# Shell parsing
|
# Shell parsing
|
||||||
shlex = "1"
|
shlex = "1"
|
||||||
|
|
||||||
# WASM runtime
|
|
||||||
wasmtime = { version = "43", default-features = false, features = ["cranelift"] }
|
|
||||||
wasmtime-wasi = { version = "43" }
|
|
||||||
|
|
||||||
# Testing
|
# Testing
|
||||||
tempfile = "3"
|
tempfile = "3"
|
||||||
|
|
||||||
@@ -104,7 +101,7 @@ tempfile = "3"
|
|||||||
axum = { version = "0.7", features = ["macros"] }
|
axum = { version = "0.7", features = ["macros"] }
|
||||||
axum-extra = { version = "0.9", features = ["typed-header"] }
|
axum-extra = { version = "0.9", features = ["typed-header"] }
|
||||||
tower = { version = "0.4", features = ["util"] }
|
tower = { version = "0.4", features = ["util"] }
|
||||||
tower-http = { version = "0.5", features = ["cors", "trace", "limit", "timeout"] }
|
tower-http = { version = "0.5", features = ["cors", "trace", "limit"] }
|
||||||
jsonwebtoken = "9"
|
jsonwebtoken = "9"
|
||||||
argon2 = "0.5"
|
argon2 = "0.5"
|
||||||
totp-rs = "5"
|
totp-rs = "5"
|
||||||
@@ -117,6 +114,7 @@ zclaw-runtime = { path = "crates/zclaw-runtime" }
|
|||||||
zclaw-kernel = { path = "crates/zclaw-kernel" }
|
zclaw-kernel = { path = "crates/zclaw-kernel" }
|
||||||
zclaw-skills = { path = "crates/zclaw-skills" }
|
zclaw-skills = { path = "crates/zclaw-skills" }
|
||||||
zclaw-hands = { path = "crates/zclaw-hands" }
|
zclaw-hands = { path = "crates/zclaw-hands" }
|
||||||
|
zclaw-channels = { path = "crates/zclaw-channels" }
|
||||||
zclaw-protocols = { path = "crates/zclaw-protocols" }
|
zclaw-protocols = { path = "crates/zclaw-protocols" }
|
||||||
zclaw-pipeline = { path = "crates/zclaw-pipeline" }
|
zclaw-pipeline = { path = "crates/zclaw-pipeline" }
|
||||||
zclaw-growth = { path = "crates/zclaw-growth" }
|
zclaw-growth = { path = "crates/zclaw-growth" }
|
||||||
|
|||||||
83
Dockerfile
Normal file
83
Dockerfile
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
# ============================================================
|
||||||
|
# ZCLAW SaaS Backend - Multi-stage Docker Build
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# ---- Stage 1: Builder ----
|
||||||
|
FROM rust:1.75-bookworm AS builder
|
||||||
|
|
||||||
|
# Install build dependencies for sqlx (postgres) and libsqlite3-sys (bundled)
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
pkg-config \
|
||||||
|
libssl-dev \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy workspace manifests first to leverage Docker layer caching
|
||||||
|
COPY Cargo.toml Cargo.lock ./
|
||||||
|
|
||||||
|
# Create stub source files so cargo can resolve and cache dependencies
|
||||||
|
# This avoids rebuilding dependencies when only application code changes
|
||||||
|
RUN mkdir -p crates/zclaw-saas/src \
|
||||||
|
&& echo 'fn main() {}' > crates/zclaw-saas/src/main.rs \
|
||||||
|
&& for crate in zclaw-types zclaw-memory zclaw-runtime zclaw-kernel \
|
||||||
|
zclaw-skills zclaw-hands zclaw-channels zclaw-protocols \
|
||||||
|
zclaw-pipeline zclaw-growth; do \
|
||||||
|
mkdir -p crates/$crate/src && echo '' > crates/$crate/src/lib.rs; \
|
||||||
|
done \
|
||||||
|
&& mkdir -p desktop/src-tauri/src && echo 'fn main() {}' > desktop/src-tauri/src/main.rs
|
||||||
|
|
||||||
|
# Pre-build dependencies (release profile with caching)
|
||||||
|
RUN cargo build --release --package zclaw-saas 2>/dev/null || true
|
||||||
|
|
||||||
|
# Copy actual source code (invalidates stubs, triggers recompile of app code only)
|
||||||
|
COPY crates/ crates/
|
||||||
|
COPY desktop/ desktop/
|
||||||
|
|
||||||
|
# Touch source files to invalidate the stub timestamps
|
||||||
|
RUN touch crates/zclaw-saas/src/main.rs \
|
||||||
|
&& for crate in zclaw-types zclaw-memory zclaw-runtime zclaw-kernel \
|
||||||
|
zclaw-skills zclaw-hands zclaw-channels zclaw-protocols \
|
||||||
|
zclaw-pipeline zclaw-growth; do \
|
||||||
|
touch crates/$crate/src/lib.rs 2>/dev/null || true; \
|
||||||
|
done \
|
||||||
|
&& touch desktop/src-tauri/src/main.rs 2>/dev/null || true
|
||||||
|
|
||||||
|
# Build the actual binary
|
||||||
|
RUN cargo build --release --package zclaw-saas
|
||||||
|
|
||||||
|
# ---- Stage 2: Runtime ----
|
||||||
|
FROM debian:bookworm-slim AS runtime
|
||||||
|
|
||||||
|
# Install runtime dependencies (ca-certificates for HTTPS, libgcc for Rust runtime)
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
ca-certificates \
|
||||||
|
libgcc-s \
|
||||||
|
&& rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& update-ca-certificates
|
||||||
|
|
||||||
|
# Create non-root user for security
|
||||||
|
RUN groupadd --gid 1000 zclaw \
|
||||||
|
&& useradd --uid 1000 --gid zclaw --shell /bin/false zclaw
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy binary from builder
|
||||||
|
COPY --from=builder /app/target/release/zclaw-saas /app/zclaw-saas
|
||||||
|
|
||||||
|
# Copy configuration file
|
||||||
|
COPY saas-config.toml /app/saas-config.toml
|
||||||
|
|
||||||
|
# Ensure the non-root user owns the application files
|
||||||
|
RUN chown -R zclaw:zclaw /app
|
||||||
|
|
||||||
|
USER zclaw
|
||||||
|
|
||||||
|
# Expose the SaaS API port
|
||||||
|
EXPOSE 8080
|
||||||
|
|
||||||
|
# Health check endpoint (matches the saas-config.toml port)
|
||||||
|
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
||||||
|
CMD ["/app/zclaw-saas", "--healthcheck"] || exit 1
|
||||||
|
|
||||||
|
ENTRYPOINT ["/app/zclaw-saas"]
|
||||||
33
Makefile
33
Makefile
@@ -1,7 +1,9 @@
|
|||||||
# ZCLAW Makefile
|
# ZCLAW Makefile
|
||||||
# Cross-platform task runner
|
# Cross-platform task runner
|
||||||
|
|
||||||
.PHONY: help start start-dev start-no-browser desktop desktop-build setup test clean
|
.PHONY: help start start-dev start-no-browser desktop desktop-build setup test clean \
|
||||||
|
saas-build saas-run saas-test saas-test-integration saas-clippy saas-migrate \
|
||||||
|
saas-docker-up saas-docker-down saas-docker-build
|
||||||
|
|
||||||
help: ## Show this help message
|
help: ## Show this help message
|
||||||
@echo "ZCLAW - AI Agent Desktop Client"
|
@echo "ZCLAW - AI Agent Desktop Client"
|
||||||
@@ -71,3 +73,32 @@ clean-deep: clean ## Deep clean (including pnpm cache)
|
|||||||
@rm -rf desktop/pnpm-lock.yaml
|
@rm -rf desktop/pnpm-lock.yaml
|
||||||
@rm -rf pnpm-lock.yaml
|
@rm -rf pnpm-lock.yaml
|
||||||
@echo "Deep clean complete. Run 'pnpm install' to reinstall."
|
@echo "Deep clean complete. Run 'pnpm install' to reinstall."
|
||||||
|
|
||||||
|
# === SaaS Backend ===
|
||||||
|
|
||||||
|
saas-build: ## Build zclaw-saas crate
|
||||||
|
@cargo build -p zclaw-saas
|
||||||
|
|
||||||
|
saas-run: ## Start SaaS backend (cargo run)
|
||||||
|
@cargo run -p zclaw-saas
|
||||||
|
|
||||||
|
saas-test: ## Run SaaS unit tests
|
||||||
|
@cargo test -p zclaw-saas -- --test-threads=1
|
||||||
|
|
||||||
|
saas-test-integration: ## Run SaaS integration tests (requires PostgreSQL)
|
||||||
|
@cargo test -p zclaw-saas -- --ignored --test-threads=1
|
||||||
|
|
||||||
|
saas-clippy: ## Run clippy on zclaw-saas
|
||||||
|
@cargo clippy -p zclaw-saas -- -D warnings
|
||||||
|
|
||||||
|
saas-migrate: ## Run database migrations
|
||||||
|
@cargo run -p zclaw-saas -- --migrate
|
||||||
|
|
||||||
|
saas-docker-up: ## Start SaaS services (PostgreSQL + backend)
|
||||||
|
@docker compose up -d
|
||||||
|
|
||||||
|
saas-docker-down: ## Stop SaaS services
|
||||||
|
@docker compose down
|
||||||
|
|
||||||
|
saas-docker-build: ## Build SaaS Docker images
|
||||||
|
@docker compose build
|
||||||
|
|||||||
2
admin/.gitignore
vendored
2
admin/.gitignore
vendored
@@ -1,2 +1,4 @@
|
|||||||
.next/
|
.next/
|
||||||
node_modules/
|
node_modules/
|
||||||
|
.env.local
|
||||||
|
.env*.local
|
||||||
|
|||||||
@@ -1,10 +1,41 @@
|
|||||||
/** @type {import('next').NextConfig} */
|
/** @type {import('next').NextConfig} */
|
||||||
const nextConfig = {
|
const nextConfig = {
|
||||||
async rewrites() {
|
async headers() {
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
source: '/api/:path*',
|
source: '/(.*)',
|
||||||
destination: 'http://localhost:8080/api/:path*',
|
headers: [
|
||||||
|
{
|
||||||
|
key: 'X-Frame-Options',
|
||||||
|
value: 'DENY',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'X-Content-Type-Options',
|
||||||
|
value: 'nosniff',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'Referrer-Policy',
|
||||||
|
value: 'strict-origin-when-cross-origin',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'Content-Security-Policy',
|
||||||
|
value: [
|
||||||
|
"default-src 'self'",
|
||||||
|
"script-src 'self' 'unsafe-eval' 'unsafe-inline'",
|
||||||
|
"style-src 'self' 'unsafe-inline' https://fonts.googleapis.com",
|
||||||
|
"font-src 'self' https://fonts.gstatic.com",
|
||||||
|
"img-src 'self' data: blob:",
|
||||||
|
"connect-src 'self'",
|
||||||
|
"frame-ancestors 'none'",
|
||||||
|
"base-uri 'self'",
|
||||||
|
"form-action 'self'",
|
||||||
|
].join('; '),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
key: 'Permissions-Policy',
|
||||||
|
value: 'camera=(), microphone=(), geolocation=()',
|
||||||
|
},
|
||||||
|
],
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -22,7 +22,7 @@
|
|||||||
"react": "^18.3.1",
|
"react": "^18.3.1",
|
||||||
"react-dom": "^18.3.1",
|
"react-dom": "^18.3.1",
|
||||||
"recharts": "^2.15.3",
|
"recharts": "^2.15.3",
|
||||||
"swr": "^2.4.1",
|
"sonner": "^2.0.7",
|
||||||
"tailwind-merge": "^3.0.2"
|
"tailwind-merge": "^3.0.2"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|||||||
43
admin/pnpm-lock.yaml
generated
43
admin/pnpm-lock.yaml
generated
@@ -47,9 +47,9 @@ importers:
|
|||||||
recharts:
|
recharts:
|
||||||
specifier: ^2.15.3
|
specifier: ^2.15.3
|
||||||
version: 2.15.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
version: 2.15.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
swr:
|
sonner:
|
||||||
specifier: ^2.4.1
|
specifier: ^2.0.7
|
||||||
version: 2.4.1(react@18.3.1)
|
version: 2.0.7(react-dom@18.3.1(react@18.3.1))(react@18.3.1)
|
||||||
tailwind-merge:
|
tailwind-merge:
|
||||||
specifier: ^3.0.2
|
specifier: ^3.0.2
|
||||||
version: 3.5.0
|
version: 3.5.0
|
||||||
@@ -722,10 +722,6 @@ packages:
|
|||||||
decimal.js-light@2.5.1:
|
decimal.js-light@2.5.1:
|
||||||
resolution: {integrity: sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==}
|
resolution: {integrity: sha512-qIMFpTMZmny+MMIitAB6D7iVPEorVw6YQRWkvarTkT4tBeSLLiHzcwj6q0MmYSFCiVpiqPJTJEYIrpcPzVEIvg==}
|
||||||
|
|
||||||
dequal@2.0.3:
|
|
||||||
resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==}
|
|
||||||
engines: {node: '>=6'}
|
|
||||||
|
|
||||||
detect-node-es@1.1.0:
|
detect-node-es@1.1.0:
|
||||||
resolution: {integrity: sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==}
|
resolution: {integrity: sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==}
|
||||||
|
|
||||||
@@ -1070,6 +1066,12 @@ packages:
|
|||||||
scheduler@0.23.2:
|
scheduler@0.23.2:
|
||||||
resolution: {integrity: sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==}
|
resolution: {integrity: sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==}
|
||||||
|
|
||||||
|
sonner@2.0.7:
|
||||||
|
resolution: {integrity: sha512-W6ZN4p58k8aDKA4XPcx2hpIQXBRAgyiWVkYhT7CvK6D3iAu7xjvVyhQHg2/iaKJZ1XVJ4r7XuwGL+WGEK37i9w==}
|
||||||
|
peerDependencies:
|
||||||
|
react: ^18.0.0 || ^19.0.0 || ^19.0.0-rc
|
||||||
|
react-dom: ^18.0.0 || ^19.0.0 || ^19.0.0-rc
|
||||||
|
|
||||||
source-map-js@1.2.1:
|
source-map-js@1.2.1:
|
||||||
resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==}
|
resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==}
|
||||||
engines: {node: '>=0.10.0'}
|
engines: {node: '>=0.10.0'}
|
||||||
@@ -1100,11 +1102,6 @@ packages:
|
|||||||
resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==}
|
resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==}
|
||||||
engines: {node: '>= 0.4'}
|
engines: {node: '>= 0.4'}
|
||||||
|
|
||||||
swr@2.4.1:
|
|
||||||
resolution: {integrity: sha512-2CC6CiKQtEwaEeNiqWTAw9PGykW8SR5zZX8MZk6TeAvEAnVS7Visz8WzphqgtQ8v2xz/4Q5K+j+SeMaKXeeQIA==}
|
|
||||||
peerDependencies:
|
|
||||||
react: ^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
|
||||||
|
|
||||||
tailwind-merge@3.5.0:
|
tailwind-merge@3.5.0:
|
||||||
resolution: {integrity: sha512-I8K9wewnVDkL1NTGoqWmVEIlUcB9gFriAEkXkfCjX5ib8ezGxtR3xD7iZIxrfArjEsH7F1CHD4RFUtxefdqV/A==}
|
resolution: {integrity: sha512-I8K9wewnVDkL1NTGoqWmVEIlUcB9gFriAEkXkfCjX5ib8ezGxtR3xD7iZIxrfArjEsH7F1CHD4RFUtxefdqV/A==}
|
||||||
|
|
||||||
@@ -1171,11 +1168,6 @@ packages:
|
|||||||
'@types/react':
|
'@types/react':
|
||||||
optional: true
|
optional: true
|
||||||
|
|
||||||
use-sync-external-store@1.6.0:
|
|
||||||
resolution: {integrity: sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==}
|
|
||||||
peerDependencies:
|
|
||||||
react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0
|
|
||||||
|
|
||||||
util-deprecate@1.0.2:
|
util-deprecate@1.0.2:
|
||||||
resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==}
|
resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==}
|
||||||
|
|
||||||
@@ -1761,8 +1753,6 @@ snapshots:
|
|||||||
|
|
||||||
decimal.js-light@2.5.1: {}
|
decimal.js-light@2.5.1: {}
|
||||||
|
|
||||||
dequal@2.0.3: {}
|
|
||||||
|
|
||||||
detect-node-es@1.1.0: {}
|
detect-node-es@1.1.0: {}
|
||||||
|
|
||||||
didyoumean@1.2.2: {}
|
didyoumean@1.2.2: {}
|
||||||
@@ -2071,6 +2061,11 @@ snapshots:
|
|||||||
dependencies:
|
dependencies:
|
||||||
loose-envify: 1.4.0
|
loose-envify: 1.4.0
|
||||||
|
|
||||||
|
sonner@2.0.7(react-dom@18.3.1(react@18.3.1))(react@18.3.1):
|
||||||
|
dependencies:
|
||||||
|
react: 18.3.1
|
||||||
|
react-dom: 18.3.1(react@18.3.1)
|
||||||
|
|
||||||
source-map-js@1.2.1: {}
|
source-map-js@1.2.1: {}
|
||||||
|
|
||||||
streamsearch@1.1.0: {}
|
streamsearch@1.1.0: {}
|
||||||
@@ -2092,12 +2087,6 @@ snapshots:
|
|||||||
|
|
||||||
supports-preserve-symlinks-flag@1.0.0: {}
|
supports-preserve-symlinks-flag@1.0.0: {}
|
||||||
|
|
||||||
swr@2.4.1(react@18.3.1):
|
|
||||||
dependencies:
|
|
||||||
dequal: 2.0.3
|
|
||||||
react: 18.3.1
|
|
||||||
use-sync-external-store: 1.6.0(react@18.3.1)
|
|
||||||
|
|
||||||
tailwind-merge@3.5.0: {}
|
tailwind-merge@3.5.0: {}
|
||||||
|
|
||||||
tailwindcss@3.4.19:
|
tailwindcss@3.4.19:
|
||||||
@@ -2176,10 +2165,6 @@ snapshots:
|
|||||||
optionalDependencies:
|
optionalDependencies:
|
||||||
'@types/react': 18.3.28
|
'@types/react': 18.3.28
|
||||||
|
|
||||||
use-sync-external-store@1.6.0(react@18.3.1):
|
|
||||||
dependencies:
|
|
||||||
react: 18.3.1
|
|
||||||
|
|
||||||
util-deprecate@1.0.2: {}
|
util-deprecate@1.0.2: {}
|
||||||
|
|
||||||
victory-vendor@36.9.2:
|
victory-vendor@36.9.2:
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useState } from 'react'
|
import { useEffect, useState, useCallback } from 'react'
|
||||||
import useSWR from 'swr'
|
|
||||||
import {
|
import {
|
||||||
Search,
|
Search,
|
||||||
Plus,
|
Plus,
|
||||||
@@ -41,10 +40,7 @@ import {
|
|||||||
} from '@/components/ui/select'
|
} from '@/components/ui/select'
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatDate, getSwrErrorMessage } from '@/lib/utils'
|
import { formatDate } from '@/lib/utils'
|
||||||
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
|
||||||
import { TableSkeleton } from '@/components/ui/skeleton'
|
|
||||||
import { useDebounce } from '@/hooks/use-debounce'
|
|
||||||
import type { AccountPublic } from '@/lib/types'
|
import type { AccountPublic } from '@/lib/types'
|
||||||
|
|
||||||
const PAGE_SIZE = 20
|
const PAGE_SIZE = 20
|
||||||
@@ -68,28 +64,21 @@ const statusLabels: Record<string, string> = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function AccountsPage() {
|
export default function AccountsPage() {
|
||||||
|
const [accounts, setAccounts] = useState<AccountPublic[]>([])
|
||||||
|
const [total, setTotal] = useState(0)
|
||||||
const [page, setPage] = useState(1)
|
const [page, setPage] = useState(1)
|
||||||
const [search, setSearch] = useState('')
|
const [search, setSearch] = useState('')
|
||||||
|
|
||||||
|
// 搜索 debounce: 输入后 300ms 再触发请求
|
||||||
|
const [debouncedSearchState, setDebouncedSearchState] = useState('')
|
||||||
|
useEffect(() => {
|
||||||
|
const timer = setTimeout(() => setDebouncedSearchState(search), 300)
|
||||||
|
return () => clearTimeout(timer)
|
||||||
|
}, [search])
|
||||||
const [roleFilter, setRoleFilter] = useState<string>('all')
|
const [roleFilter, setRoleFilter] = useState<string>('all')
|
||||||
const [statusFilter, setStatusFilter] = useState<string>('all')
|
const [statusFilter, setStatusFilter] = useState<string>('all')
|
||||||
const [mutationError, setMutationError] = useState('')
|
const [loading, setLoading] = useState(true)
|
||||||
|
const [error, setError] = useState('')
|
||||||
const debouncedSearch = useDebounce(search, 300)
|
|
||||||
|
|
||||||
const { data, error: swrError, isLoading, mutate } = useSWR(
|
|
||||||
['accounts', page, debouncedSearch, roleFilter, statusFilter],
|
|
||||||
() => {
|
|
||||||
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
|
||||||
if (debouncedSearch.trim()) params.search = debouncedSearch.trim()
|
|
||||||
if (roleFilter !== 'all') params.role = roleFilter
|
|
||||||
if (statusFilter !== 'all') params.status = statusFilter
|
|
||||||
return api.accounts.list(params)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
const accounts = data?.items ?? []
|
|
||||||
const total = data?.total ?? 0
|
|
||||||
const error = getSwrErrorMessage(swrError) || mutationError
|
|
||||||
|
|
||||||
// 编辑 Dialog
|
// 编辑 Dialog
|
||||||
const [editTarget, setEditTarget] = useState<AccountPublic | null>(null)
|
const [editTarget, setEditTarget] = useState<AccountPublic | null>(null)
|
||||||
@@ -100,6 +89,33 @@ export default function AccountsPage() {
|
|||||||
const [confirmTarget, setConfirmTarget] = useState<{ id: string; action: string; status: string } | null>(null)
|
const [confirmTarget, setConfirmTarget] = useState<{ id: string; action: string; status: string } | null>(null)
|
||||||
const [confirmSaving, setConfirmSaving] = useState(false)
|
const [confirmSaving, setConfirmSaving] = useState(false)
|
||||||
|
|
||||||
|
const fetchAccounts = useCallback(async () => {
|
||||||
|
setLoading(true)
|
||||||
|
setError('')
|
||||||
|
try {
|
||||||
|
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
||||||
|
if (debouncedSearchState.trim()) params.search = debouncedSearchState.trim()
|
||||||
|
if (roleFilter !== 'all') params.role = roleFilter
|
||||||
|
if (statusFilter !== 'all') params.status = statusFilter
|
||||||
|
|
||||||
|
const res = await api.accounts.list(params)
|
||||||
|
setAccounts(res.items)
|
||||||
|
setTotal(res.total)
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) {
|
||||||
|
setError(err.body.message)
|
||||||
|
} else {
|
||||||
|
setError('加载失败')
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}, [page, debouncedSearchState, roleFilter, statusFilter])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchAccounts()
|
||||||
|
}, [fetchAccounts])
|
||||||
|
|
||||||
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
||||||
|
|
||||||
function openEditDialog(account: AccountPublic) {
|
function openEditDialog(account: AccountPublic) {
|
||||||
@@ -121,10 +137,10 @@ export default function AccountsPage() {
|
|||||||
role: editForm.role as AccountPublic['role'],
|
role: editForm.role as AccountPublic['role'],
|
||||||
})
|
})
|
||||||
setEditTarget(null)
|
setEditTarget(null)
|
||||||
mutate()
|
fetchAccounts()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) {
|
if (err instanceof ApiRequestError) {
|
||||||
setMutationError(err.body.message)
|
setError(err.body.message)
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
setEditSaving(false)
|
setEditSaving(false)
|
||||||
@@ -148,10 +164,10 @@ export default function AccountsPage() {
|
|||||||
status: confirmTarget.status as AccountPublic['status'],
|
status: confirmTarget.status as AccountPublic['status'],
|
||||||
})
|
})
|
||||||
setConfirmTarget(null)
|
setConfirmTarget(null)
|
||||||
mutate()
|
fetchAccounts()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) {
|
if (err instanceof ApiRequestError) {
|
||||||
setMutationError(err.body.message)
|
setError(err.body.message)
|
||||||
}
|
}
|
||||||
} finally {
|
} finally {
|
||||||
setConfirmSaving(false)
|
setConfirmSaving(false)
|
||||||
@@ -196,13 +212,24 @@ export default function AccountsPage() {
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* 错误提示 */}
|
{/* 错误提示 */}
|
||||||
{error && <ErrorBanner message={error} onDismiss={() => { setMutationError('') }} />}
|
{error && (
|
||||||
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
||||||
|
{error}
|
||||||
|
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">
|
||||||
|
关闭
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* 表格 */}
|
{/* 表格 */}
|
||||||
{isLoading ? (
|
{loading ? (
|
||||||
<TableSkeleton rows={6} cols={7} />
|
<div className="flex h-64 items-center justify-center">
|
||||||
) : error ? null : accounts.length === 0 ? (
|
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
||||||
<EmptyState />
|
</div>
|
||||||
|
) : accounts.length === 0 ? (
|
||||||
|
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
||||||
|
暂无数据
|
||||||
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Table>
|
<Table>
|
||||||
|
|||||||
@@ -1,290 +0,0 @@
|
|||||||
'use client'
|
|
||||||
|
|
||||||
import { useState } from 'react'
|
|
||||||
import useSWR from 'swr'
|
|
||||||
import { api } from '@/lib/api-client'
|
|
||||||
import type { AgentTemplate } from '@/lib/types'
|
|
||||||
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
|
||||||
import { TableSkeleton } from '@/components/ui/skeleton'
|
|
||||||
|
|
||||||
export default function AgentTemplatesPage() {
|
|
||||||
const [page, setPage] = useState(1)
|
|
||||||
const [error, setError] = useState('')
|
|
||||||
const [showCreate, setShowCreate] = useState(false)
|
|
||||||
const [editingId, setEditingId] = useState<string | null>(null)
|
|
||||||
|
|
||||||
const { data, isLoading, mutate } = useSWR(
|
|
||||||
['agentTemplates.list', page],
|
|
||||||
() => api.agentTemplates.list({ page, page_size: 50 }),
|
|
||||||
)
|
|
||||||
|
|
||||||
const templates = data?.items ?? []
|
|
||||||
const total = data?.total ?? 0
|
|
||||||
|
|
||||||
const handleCreate = async (e: React.FormEvent<HTMLFormElement>) => {
|
|
||||||
e.preventDefault()
|
|
||||||
const fd = new FormData(e.currentTarget)
|
|
||||||
try {
|
|
||||||
const tools = (fd.get('tools') as string || '').split(',').map(s => s.trim()).filter(Boolean)
|
|
||||||
const capabilities = (fd.get('capabilities') as string || '').split(',').map(s => s.trim()).filter(Boolean)
|
|
||||||
await api.agentTemplates.create({
|
|
||||||
name: fd.get('name') as string,
|
|
||||||
description: (fd.get('description') as string) || undefined,
|
|
||||||
category: (fd.get('category') as string) || 'general',
|
|
||||||
model: (fd.get('model') as string) || undefined,
|
|
||||||
system_prompt: (fd.get('system_prompt') as string) || undefined,
|
|
||||||
tools: tools.length > 0 ? tools : undefined,
|
|
||||||
capabilities: capabilities.length > 0 ? capabilities : undefined,
|
|
||||||
temperature: (fd.get('temperature') as string) ? parseFloat(fd.get('temperature') as string) : undefined,
|
|
||||||
max_tokens: (fd.get('max_tokens') as string) ? parseInt(fd.get('max_tokens') as string, 10) : undefined,
|
|
||||||
visibility: (fd.get('visibility') as string) || 'public',
|
|
||||||
})
|
|
||||||
setShowCreate(false)
|
|
||||||
mutate()
|
|
||||||
} catch {
|
|
||||||
setError('创建失败')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleArchive = async (id: string, name: string) => {
|
|
||||||
if (!confirm(`确认归档模板 "${name}"?`)) return
|
|
||||||
try {
|
|
||||||
await api.agentTemplates.archive(id)
|
|
||||||
mutate()
|
|
||||||
} catch {
|
|
||||||
setError('归档失败')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const statusBadge = (status: string) => {
|
|
||||||
const colors: Record<string, string> = {
|
|
||||||
active: 'bg-emerald-500/20 text-emerald-400',
|
|
||||||
archived: 'bg-zinc-500/20 text-zinc-400',
|
|
||||||
}
|
|
||||||
return <span className={`px-2 py-0.5 text-xs rounded-full ${colors[status] || colors.archived}`}>{status}</span>
|
|
||||||
}
|
|
||||||
|
|
||||||
const sourceBadge = (source: string) => {
|
|
||||||
const colors: Record<string, string> = {
|
|
||||||
builtin: 'bg-blue-500/20 text-blue-400',
|
|
||||||
custom: 'bg-purple-500/20 text-purple-400',
|
|
||||||
}
|
|
||||||
return (
|
|
||||||
<span className={`px-2 py-0.5 text-xs rounded-full ${colors[source] || ''}`}>
|
|
||||||
{source === 'builtin' ? '内置' : '自定义'}
|
|
||||||
</span>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="space-y-6">
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<div>
|
|
||||||
<h1 className="text-2xl font-bold text-white">Agent 配置模板</h1>
|
|
||||||
<p className="text-sm text-zinc-400 mt-1">管理 Agent 配置模板,支持团队共享和一键复用</p>
|
|
||||||
</div>
|
|
||||||
<button
|
|
||||||
onClick={() => setShowCreate(true)}
|
|
||||||
className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors text-sm"
|
|
||||||
>
|
|
||||||
+ 新建模板
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{error && <ErrorBanner message={error} onDismiss={() => setError('')} />}
|
|
||||||
|
|
||||||
<div className="bg-zinc-900 rounded-xl border border-zinc-800 overflow-hidden">
|
|
||||||
<table className="w-full text-sm">
|
|
||||||
<thead>
|
|
||||||
<tr className="border-b border-zinc-800">
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">名称</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">分类</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">来源</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">模型</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">工具数</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">可见性</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">状态</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">更新时间</th>
|
|
||||||
<th className="text-right px-4 py-3 text-zinc-400 font-medium">操作</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>
|
|
||||||
{isLoading ? (
|
|
||||||
<tr>
|
|
||||||
<td colSpan={9}>
|
|
||||||
<TableSkeleton rows={5} cols={9} hasToolbar={false} />
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
) : templates.length === 0 ? (
|
|
||||||
<tr><td colSpan={9}><EmptyState message="暂无 Agent 模板" /></td></tr>
|
|
||||||
) : (
|
|
||||||
templates.map(t => (
|
|
||||||
<tr key={t.id} className="border-b border-zinc-800/50 hover:bg-zinc-800/30">
|
|
||||||
<td className="px-4 py-3">
|
|
||||||
<div>
|
|
||||||
<span className="text-white font-medium">{t.name}</span>
|
|
||||||
{t.description && (
|
|
||||||
<p className="text-xs text-zinc-500 mt-0.5 truncate max-w-[200px]">{t.description}</p>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</td>
|
|
||||||
<td className="px-4 py-3 text-zinc-400">{t.category}</td>
|
|
||||||
<td className="px-4 py-3">{sourceBadge(t.source)}</td>
|
|
||||||
<td className="px-4 py-3 text-zinc-300 font-mono text-xs">{t.model || '-'}</td>
|
|
||||||
<td className="px-4 py-3 text-zinc-400">{t.tools.length}</td>
|
|
||||||
<td className="px-4 py-3 text-zinc-400">{t.visibility}</td>
|
|
||||||
<td className="px-4 py-3">{statusBadge(t.status)}</td>
|
|
||||||
<td className="px-4 py-3 text-zinc-500 text-xs">
|
|
||||||
{new Date(t.updated_at).toLocaleString('zh-CN')}
|
|
||||||
</td>
|
|
||||||
<td className="px-4 py-3 text-right">
|
|
||||||
<button
|
|
||||||
onClick={() => setEditingId(editingId === t.id ? null : t.id)}
|
|
||||||
className="text-zinc-400 hover:text-white mr-2"
|
|
||||||
>
|
|
||||||
详情
|
|
||||||
</button>
|
|
||||||
{t.source === 'custom' && (
|
|
||||||
<button
|
|
||||||
onClick={() => handleArchive(t.id, t.name)}
|
|
||||||
className="text-red-400 hover:text-red-300"
|
|
||||||
>
|
|
||||||
归档
|
|
||||||
</button>
|
|
||||||
)}
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
))
|
|
||||||
)}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
<div className="px-4 py-2 text-xs text-zinc-500 border-t border-zinc-800">
|
|
||||||
共 {total} 个模板
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* 展开详情 */}
|
|
||||||
{editingId && (() => {
|
|
||||||
const t = templates.find(t => t.id === editingId)
|
|
||||||
if (!t) return null
|
|
||||||
return (
|
|
||||||
<div className="bg-zinc-900 rounded-xl border border-zinc-800 p-4">
|
|
||||||
<div className="flex items-center justify-between mb-3">
|
|
||||||
<h2 className="text-lg font-semibold text-white">{t.name} — 详情</h2>
|
|
||||||
<button onClick={() => setEditingId(null)} className="text-zinc-400 hover:text-white text-sm">关闭</button>
|
|
||||||
</div>
|
|
||||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
|
||||||
<div>
|
|
||||||
<span className="text-zinc-500">分类:</span>
|
|
||||||
<span className="text-zinc-300">{t.category}</span>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<span className="text-zinc-500">模型:</span>
|
|
||||||
<span className="text-zinc-300 font-mono">{t.model || '未指定'}</span>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<span className="text-zinc-500">温度:</span>
|
|
||||||
<span className="text-zinc-300">{t.temperature?.toFixed(2) || '默认'}</span>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<span className="text-zinc-500">最大 Token:</span>
|
|
||||||
<span className="text-zinc-300">{t.max_tokens || '未限制'}</span>
|
|
||||||
</div>
|
|
||||||
<div className="col-span-2">
|
|
||||||
<span className="text-zinc-500">工具:</span>
|
|
||||||
<div className="flex flex-wrap gap-1 mt-1">
|
|
||||||
{t.tools.length > 0 ? t.tools.map(tool => (
|
|
||||||
<span key={tool} className="px-2 py-0.5 bg-zinc-800 rounded text-xs text-zinc-300">{tool}</span>
|
|
||||||
)) : <span className="text-zinc-600">无</span>}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div className="col-span-2">
|
|
||||||
<span className="text-zinc-500">能力:</span>
|
|
||||||
<div className="flex flex-wrap gap-1 mt-1">
|
|
||||||
{t.capabilities.length > 0 ? t.capabilities.map(cap => (
|
|
||||||
<span key={cap} className="px-2 py-0.5 bg-blue-500/10 rounded text-xs text-blue-400">{cap}</span>
|
|
||||||
)) : <span className="text-zinc-600">无</span>}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{t.system_prompt && (
|
|
||||||
<div className="col-span-2">
|
|
||||||
<span className="text-zinc-500">系统提示词:</span>
|
|
||||||
<pre className="text-xs text-zinc-400 bg-zinc-800/50 rounded p-2 mt-1 overflow-x-auto max-h-32">
|
|
||||||
{t.system_prompt.substring(0, 500)}{t.system_prompt.length > 500 ? '...' : ''}
|
|
||||||
</pre>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
})()}
|
|
||||||
|
|
||||||
{/* Create Modal */}
|
|
||||||
{showCreate && (
|
|
||||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
|
||||||
<form onSubmit={handleCreate} className="bg-zinc-900 rounded-xl border border-zinc-700 p-6 w-full max-w-lg space-y-4 max-h-[80vh] overflow-y-auto">
|
|
||||||
<h2 className="text-lg font-semibold text-white">新建 Agent 模板</h2>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">名称 *</label>
|
|
||||||
<input name="name" required className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="my_agent" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">描述</label>
|
|
||||||
<input name="description" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="可选" />
|
|
||||||
</div>
|
|
||||||
<div className="grid grid-cols-2 gap-4">
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">分类</label>
|
|
||||||
<select name="category" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm">
|
|
||||||
<option value="general">通用</option>
|
|
||||||
<option value="coding">编程</option>
|
|
||||||
<option value="research">研究</option>
|
|
||||||
<option value="creative">创意</option>
|
|
||||||
<option value="assistant">助手</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">模型</label>
|
|
||||||
<input name="model" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="如 glm-4-plus" />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">系统提示词</label>
|
|
||||||
<textarea name="system_prompt" rows={4} className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm font-mono" placeholder="Agent 系统提示词" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">工具(逗号分隔)</label>
|
|
||||||
<input name="tools" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="browser, file_system, code_execute" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">能力(逗号分隔)</label>
|
|
||||||
<input name="capabilities" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="streaming, vision, function_calling" />
|
|
||||||
</div>
|
|
||||||
<div className="grid grid-cols-3 gap-4">
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">温度</label>
|
|
||||||
<input name="temperature" type="number" step="0.1" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="默认" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">最大 Token</label>
|
|
||||||
<input name="max_tokens" type="number" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="不限" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">可见性</label>
|
|
||||||
<select name="visibility" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm">
|
|
||||||
<option value="public">公开</option>
|
|
||||||
<option value="team">团队</option>
|
|
||||||
<option value="private">私有</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div className="flex gap-2 justify-end">
|
|
||||||
<button type="button" onClick={() => setShowCreate(false)} className="px-4 py-2 bg-zinc-700 text-white rounded-lg hover:bg-zinc-600 text-sm">取消</button>
|
|
||||||
<button type="submit" className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 text-sm">创建</button>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useState } from 'react'
|
import { useEffect, useState, useCallback } from 'react'
|
||||||
import useSWR from 'swr'
|
|
||||||
import {
|
import {
|
||||||
Plus,
|
Plus,
|
||||||
Loader2,
|
Loader2,
|
||||||
@@ -33,10 +32,8 @@ import {
|
|||||||
DialogDescription,
|
DialogDescription,
|
||||||
} from '@/components/ui/dialog'
|
} from '@/components/ui/dialog'
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatDate, getSwrErrorMessage } from '@/lib/utils'
|
import { formatDate } from '@/lib/utils'
|
||||||
import { TableSkeleton } from '@/components/ui/skeleton'
|
|
||||||
import type { TokenInfo } from '@/lib/types'
|
import type { TokenInfo } from '@/lib/types'
|
||||||
|
|
||||||
const PAGE_SIZE = 20
|
const PAGE_SIZE = 20
|
||||||
@@ -48,17 +45,11 @@ const allPermissions = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
export default function ApiKeysPage() {
|
export default function ApiKeysPage() {
|
||||||
|
const [tokens, setTokens] = useState<TokenInfo[]>([])
|
||||||
|
const [total, setTotal] = useState(0)
|
||||||
const [page, setPage] = useState(1)
|
const [page, setPage] = useState(1)
|
||||||
const [mutationError, setMutationError] = useState('')
|
const [loading, setLoading] = useState(true)
|
||||||
|
const [error, setError] = useState('')
|
||||||
const { data, error: swrError, isLoading, mutate } = useSWR(
|
|
||||||
['tokens', page],
|
|
||||||
() => api.tokens.list({ page, page_size: PAGE_SIZE }),
|
|
||||||
)
|
|
||||||
|
|
||||||
const tokens = data?.items ?? []
|
|
||||||
const total = data?.total ?? 0
|
|
||||||
const error = getSwrErrorMessage(swrError) || mutationError
|
|
||||||
|
|
||||||
// 创建 Dialog
|
// 创建 Dialog
|
||||||
const [createOpen, setCreateOpen] = useState(false)
|
const [createOpen, setCreateOpen] = useState(false)
|
||||||
@@ -73,6 +64,25 @@ export default function ApiKeysPage() {
|
|||||||
const [revokeTarget, setRevokeTarget] = useState<TokenInfo | null>(null)
|
const [revokeTarget, setRevokeTarget] = useState<TokenInfo | null>(null)
|
||||||
const [revoking, setRevoking] = useState(false)
|
const [revoking, setRevoking] = useState(false)
|
||||||
|
|
||||||
|
const fetchTokens = useCallback(async () => {
|
||||||
|
setLoading(true)
|
||||||
|
setError('')
|
||||||
|
try {
|
||||||
|
const res = await api.tokens.list({ page, page_size: PAGE_SIZE })
|
||||||
|
setTokens(res.items)
|
||||||
|
setTotal(res.total)
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
|
else setError('加载失败')
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}, [page])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchTokens()
|
||||||
|
}, [fetchTokens])
|
||||||
|
|
||||||
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
||||||
|
|
||||||
function togglePermission(perm: string) {
|
function togglePermission(perm: string) {
|
||||||
@@ -97,9 +107,9 @@ export default function ApiKeysPage() {
|
|||||||
setCreateOpen(false)
|
setCreateOpen(false)
|
||||||
setCreatedToken(res)
|
setCreatedToken(res)
|
||||||
setCreateForm({ name: '', expires_days: '', permissions: ['chat'] })
|
setCreateForm({ name: '', expires_days: '', permissions: ['chat'] })
|
||||||
mutate()
|
fetchTokens()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setMutationError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
setCreating(false)
|
setCreating(false)
|
||||||
}
|
}
|
||||||
@@ -111,9 +121,9 @@ export default function ApiKeysPage() {
|
|||||||
try {
|
try {
|
||||||
await api.tokens.revoke(revokeTarget.id)
|
await api.tokens.revoke(revokeTarget.id)
|
||||||
setRevokeTarget(null)
|
setRevokeTarget(null)
|
||||||
mutate()
|
fetchTokens()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setMutationError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
setRevoking(false)
|
setRevoking(false)
|
||||||
}
|
}
|
||||||
@@ -148,12 +158,21 @@ export default function ApiKeysPage() {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{error && <ErrorBanner message={error} onDismiss={() => setMutationError('')} />}
|
{error && (
|
||||||
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
||||||
|
{error}
|
||||||
|
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{isLoading ? (
|
{loading ? (
|
||||||
<TableSkeleton rows={6} cols={7} />
|
<div className="flex h-64 items-center justify-center">
|
||||||
) : error ? null : tokens.length === 0 ? (
|
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
||||||
<EmptyState />
|
</div>
|
||||||
|
) : tokens.length === 0 ? (
|
||||||
|
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
||||||
|
暂无数据
|
||||||
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Table>
|
<Table>
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useState } from 'react'
|
import { useEffect, useState, useCallback } from 'react'
|
||||||
import useSWR from 'swr'
|
|
||||||
import {
|
import {
|
||||||
Loader2,
|
Loader2,
|
||||||
Pencil,
|
Pencil,
|
||||||
@@ -36,8 +35,6 @@ import {
|
|||||||
} from '@/components/ui/dialog'
|
} from '@/components/ui/dialog'
|
||||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { TableSkeleton } from '@/components/ui/skeleton'
|
|
||||||
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import type { ConfigItem } from '@/lib/types'
|
import type { ConfigItem } from '@/lib/types'
|
||||||
|
|
||||||
@@ -54,31 +51,56 @@ const sourceVariants: Record<string, 'secondary' | 'info' | 'default'> = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function ConfigPage() {
|
export default function ConfigPage() {
|
||||||
|
const [configs, setConfigs] = useState<ConfigItem[]>([])
|
||||||
|
const [loading, setLoading] = useState(true)
|
||||||
const [error, setError] = useState('')
|
const [error, setError] = useState('')
|
||||||
const [activeTab, setActiveTab] = useState('all')
|
const [activeTab, setActiveTab] = useState('all')
|
||||||
|
|
||||||
// SWR for config list
|
|
||||||
const { data: configs = [], isLoading, mutate } = useSWR(
|
|
||||||
['config', activeTab],
|
|
||||||
() => {
|
|
||||||
const params: Record<string, unknown> = {}
|
|
||||||
if (activeTab !== 'all') params.category = activeTab
|
|
||||||
return api.config.list(params)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// 编辑 Dialog
|
// 编辑 Dialog
|
||||||
const [editTarget, setEditTarget] = useState<ConfigItem | null>(null)
|
const [editTarget, setEditTarget] = useState<ConfigItem | null>(null)
|
||||||
const [editValue, setEditValue] = useState('')
|
const [editValue, setEditValue] = useState('')
|
||||||
const [saving, setSaving] = useState(false)
|
const [saving, setSaving] = useState(false)
|
||||||
|
|
||||||
|
const fetchConfigs = useCallback(async (category?: string) => {
|
||||||
|
setLoading(true)
|
||||||
|
setError('')
|
||||||
|
try {
|
||||||
|
const params: Record<string, unknown> = {}
|
||||||
|
if (category && category !== 'all') params.category = category
|
||||||
|
const res = await api.config.list(params)
|
||||||
|
setConfigs(res)
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
|
else setError('加载失败')
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchConfigs(activeTab)
|
||||||
|
}, [fetchConfigs, activeTab])
|
||||||
|
|
||||||
function openEditDialog(config: ConfigItem) {
|
function openEditDialog(config: ConfigItem) {
|
||||||
setEditTarget(config)
|
setEditTarget(config)
|
||||||
setEditValue(config.current_value ?? '')
|
setEditValue(config.current_value !== undefined ? String(config.current_value) : '')
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleSave() {
|
async function handleSave() {
|
||||||
if (!editTarget) return
|
if (!editTarget) return
|
||||||
|
// 表单验证
|
||||||
|
if (editValue.trim() === '') {
|
||||||
|
setError('配置值不能为空')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (editTarget.value_type === 'number' && isNaN(Number(editValue))) {
|
||||||
|
setError('请输入有效的数字')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (editTarget.value_type === 'boolean' && editValue !== 'true' && editValue !== 'false') {
|
||||||
|
setError('布尔值只能为 true 或 false')
|
||||||
|
return
|
||||||
|
}
|
||||||
setSaving(true)
|
setSaving(true)
|
||||||
try {
|
try {
|
||||||
let parsedValue: string | number | boolean = editValue
|
let parsedValue: string | number | boolean = editValue
|
||||||
@@ -87,9 +109,9 @@ export default function ConfigPage() {
|
|||||||
} else if (editTarget.value_type === 'boolean') {
|
} else if (editTarget.value_type === 'boolean') {
|
||||||
parsedValue = editValue === 'true'
|
parsedValue = editValue === 'true'
|
||||||
}
|
}
|
||||||
await api.config.update(editTarget.id, { value: parsedValue })
|
await api.config.update(editTarget.id, { current_value: parsedValue })
|
||||||
setEditTarget(null)
|
setEditTarget(null)
|
||||||
mutate()
|
fetchConfigs(activeTab)
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
@@ -103,15 +125,7 @@ export default function ConfigPage() {
|
|||||||
return String(value)
|
return String(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
const categoryLabels: Record<string, string> = {
|
const categories = ['all', 'auth', 'relay', 'model', 'system']
|
||||||
all: '全部',
|
|
||||||
server: '服务器',
|
|
||||||
agent: 'Agent',
|
|
||||||
memory: '记忆',
|
|
||||||
llm: 'LLM',
|
|
||||||
security: '安全策略',
|
|
||||||
}
|
|
||||||
const categories = Object.keys(categoryLabels)
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-4">
|
<div className="space-y-4">
|
||||||
@@ -120,18 +134,27 @@ export default function ConfigPage() {
|
|||||||
<TabsList>
|
<TabsList>
|
||||||
{categories.map((cat) => (
|
{categories.map((cat) => (
|
||||||
<TabsTrigger key={cat} value={cat}>
|
<TabsTrigger key={cat} value={cat}>
|
||||||
{categoryLabels[cat] || cat}
|
{cat === 'all' ? '全部' : cat}
|
||||||
</TabsTrigger>
|
</TabsTrigger>
|
||||||
))}
|
))}
|
||||||
</TabsList>
|
</TabsList>
|
||||||
</Tabs>
|
</Tabs>
|
||||||
|
|
||||||
{error && <ErrorBanner message={error} onDismiss={() => setError('')} />}
|
{error && (
|
||||||
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
||||||
|
{error}
|
||||||
|
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{isLoading ? (
|
{loading ? (
|
||||||
<TableSkeleton rows={8} cols={8} hasToolbar={false} />
|
<div className="flex h-64 items-center justify-center">
|
||||||
) : error ? null : configs.length === 0 ? (
|
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
||||||
<EmptyState message="暂无配置项" />
|
</div>
|
||||||
|
) : configs.length === 0 ? (
|
||||||
|
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
||||||
|
暂无配置项
|
||||||
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<Table>
|
<Table>
|
||||||
<TableHeader>
|
<TableHeader>
|
||||||
@@ -210,7 +233,7 @@ export default function ConfigPage() {
|
|||||||
</div>
|
</div>
|
||||||
<div className="space-y-2">
|
<div className="space-y-2">
|
||||||
<Label>
|
<Label>
|
||||||
新值 {editTarget?.default_value != null && (
|
新值 {editTarget?.default_value !== undefined && (
|
||||||
<span className="text-xs text-muted-foreground ml-2">
|
<span className="text-xs text-muted-foreground ml-2">
|
||||||
(默认: {formatValue(editTarget.default_value)})
|
(默认: {formatValue(editTarget.default_value)})
|
||||||
</span>
|
</span>
|
||||||
@@ -239,7 +262,7 @@ export default function ConfigPage() {
|
|||||||
<Button
|
<Button
|
||||||
variant="outline"
|
variant="outline"
|
||||||
onClick={() => {
|
onClick={() => {
|
||||||
if (editTarget?.default_value != null) {
|
if (editTarget?.default_value !== undefined) {
|
||||||
setEditValue(String(editTarget.default_value))
|
setEditValue(String(editTarget.default_value))
|
||||||
}
|
}
|
||||||
}}
|
}}
|
||||||
|
|||||||
125
admin/src/app/(dashboard)/devices/page.tsx
Normal file
125
admin/src/app/(dashboard)/devices/page.tsx
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
'use client'
|
||||||
|
|
||||||
|
import { useEffect, useState } from 'react'
|
||||||
|
import { Monitor, Loader2, RefreshCw } from 'lucide-react'
|
||||||
|
import { Badge } from '@/components/ui/badge'
|
||||||
|
import {
|
||||||
|
Table, TableBody, TableCell, TableHead, TableHeader, TableRow,
|
||||||
|
} from '@/components/ui/table'
|
||||||
|
import { api } from '@/lib/api-client'
|
||||||
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
|
import type { DeviceInfo } from '@/lib/types'
|
||||||
|
|
||||||
|
function formatRelativeTime(dateStr: string): string {
|
||||||
|
const now = Date.now()
|
||||||
|
const then = new Date(dateStr).getTime()
|
||||||
|
const diffMs = now - then
|
||||||
|
const diffMin = Math.floor(diffMs / 60000)
|
||||||
|
const diffHour = Math.floor(diffMs / 3600000)
|
||||||
|
const diffDay = Math.floor(diffMs / 86400000)
|
||||||
|
|
||||||
|
if (diffMin < 1) return '刚刚'
|
||||||
|
if (diffMin < 60) return `${diffMin} 分钟前`
|
||||||
|
if (diffHour < 24) return `${diffHour} 小时前`
|
||||||
|
return `${diffDay} 天前`
|
||||||
|
}
|
||||||
|
|
||||||
|
function isOnline(lastSeen: string): boolean {
|
||||||
|
return Date.now() - new Date(lastSeen).getTime() < 5 * 60 * 1000
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function DevicesPage() {
|
||||||
|
const [devices, setDevices] = useState<DeviceInfo[]>([])
|
||||||
|
const [loading, setLoading] = useState(true)
|
||||||
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
|
async function fetchDevices() {
|
||||||
|
setLoading(true)
|
||||||
|
setError('')
|
||||||
|
try {
|
||||||
|
const res = await api.devices.list()
|
||||||
|
setDevices(res)
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
|
else setError('加载失败')
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
useEffect(() => { fetchDevices() }, [])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="flex items-center justify-between">
|
||||||
|
<h2 className="text-lg font-semibold text-foreground">设备管理</h2>
|
||||||
|
<button
|
||||||
|
onClick={fetchDevices}
|
||||||
|
disabled={loading}
|
||||||
|
className="flex items-center gap-2 rounded-md border border-border px-3 py-1.5 text-sm text-muted-foreground hover:bg-muted hover:text-foreground transition-colors cursor-pointer disabled:opacity-50"
|
||||||
|
>
|
||||||
|
<RefreshCw className={`h-4 w-4 ${loading ? 'animate-spin' : ''}`} />
|
||||||
|
刷新
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{error && (
|
||||||
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
||||||
|
{error}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{loading && !devices.length ? (
|
||||||
|
<div className="flex items-center justify-center py-12">
|
||||||
|
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
||||||
|
</div>
|
||||||
|
) : devices.length === 0 ? (
|
||||||
|
<div className="flex flex-col items-center justify-center py-12 text-muted-foreground">
|
||||||
|
<Monitor className="h-10 w-10 mb-3" />
|
||||||
|
<p className="text-sm">暂无已注册设备</p>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="rounded-md border border-border">
|
||||||
|
<Table>
|
||||||
|
<TableHeader>
|
||||||
|
<TableRow>
|
||||||
|
<TableHead>设备名称</TableHead>
|
||||||
|
<TableHead>平台</TableHead>
|
||||||
|
<TableHead>版本</TableHead>
|
||||||
|
<TableHead>状态</TableHead>
|
||||||
|
<TableHead>最后活跃</TableHead>
|
||||||
|
<TableHead>注册时间</TableHead>
|
||||||
|
</TableRow>
|
||||||
|
</TableHeader>
|
||||||
|
<TableBody>
|
||||||
|
{devices.map((d) => (
|
||||||
|
<TableRow key={d.id}>
|
||||||
|
<TableCell className="font-medium">
|
||||||
|
{d.device_name || d.device_id}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
<Badge variant="secondary">{d.platform || 'unknown'}</Badge>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="text-muted-foreground">
|
||||||
|
{d.app_version || '-'}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell>
|
||||||
|
<Badge variant={isOnline(d.last_seen_at) ? 'success' : 'outline'}>
|
||||||
|
{isOnline(d.last_seen_at) ? '在线' : '离线'}
|
||||||
|
</Badge>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="text-muted-foreground">
|
||||||
|
{formatRelativeTime(d.last_seen_at)}
|
||||||
|
</TableCell>
|
||||||
|
<TableCell className="text-muted-foreground text-xs">
|
||||||
|
{new Date(d.created_at).toLocaleString('zh-CN')}
|
||||||
|
</TableCell>
|
||||||
|
</TableRow>
|
||||||
|
))}
|
||||||
|
</TableBody>
|
||||||
|
</Table>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useState, type ReactNode } from 'react'
|
import { useState, useEffect, type ReactNode } from 'react'
|
||||||
import Link from 'next/link'
|
import Link from 'next/link'
|
||||||
import { usePathname, useRouter } from 'next/navigation'
|
import { usePathname, useRouter } from 'next/navigation'
|
||||||
import {
|
import {
|
||||||
@@ -13,73 +13,75 @@ import {
|
|||||||
ArrowLeftRight,
|
ArrowLeftRight,
|
||||||
Settings,
|
Settings,
|
||||||
FileText,
|
FileText,
|
||||||
MessageSquare,
|
|
||||||
Bot,
|
|
||||||
LogOut,
|
LogOut,
|
||||||
ChevronLeft,
|
ChevronLeft,
|
||||||
Menu,
|
Menu,
|
||||||
Bell,
|
Bell,
|
||||||
|
UserCog,
|
||||||
|
ShieldCheck,
|
||||||
|
Monitor,
|
||||||
} from 'lucide-react'
|
} from 'lucide-react'
|
||||||
import { AuthGuard, useAuth } from '@/components/auth-guard'
|
import { AuthGuard, useAuth } from '@/components/auth-guard'
|
||||||
import { logout } from '@/lib/auth'
|
import { logout } from '@/lib/auth'
|
||||||
import { cn } from '@/lib/utils'
|
import { cn } from '@/lib/utils'
|
||||||
|
|
||||||
/** 权限常量 — 与后端 db.rs SEED_ROLES 保持同步 */
|
|
||||||
const ROLE_PERMISSIONS: Record<string, string[]> = {
|
|
||||||
super_admin: ['admin:full', 'account:admin', 'provider:manage', 'model:manage', 'relay:admin', 'config:write', 'prompt:read', 'prompt:write', 'prompt:publish', 'prompt:admin'],
|
|
||||||
admin: ['account:read', 'account:admin', 'provider:manage', 'model:read', 'model:manage', 'relay:use', 'relay:admin', 'config:read', 'config:write', 'prompt:read', 'prompt:write', 'prompt:publish'],
|
|
||||||
user: ['model:read', 'relay:use', 'config:read', 'prompt:read'],
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 根据 role 获取权限列表 */
|
|
||||||
function getPermissionsForRole(role: string): string[] {
|
|
||||||
return ROLE_PERMISSIONS[role] ?? []
|
|
||||||
}
|
|
||||||
|
|
||||||
const navItems = [
|
const navItems = [
|
||||||
{ href: '/', label: '仪表盘', icon: LayoutDashboard },
|
{ href: '/', label: '仪表盘', icon: LayoutDashboard, permission: null },
|
||||||
{ href: '/accounts', label: '账号管理', icon: Users, permission: 'account:admin' },
|
{ href: '/accounts', label: '账号管理', icon: Users, permission: 'account:admin' },
|
||||||
{ href: '/providers', label: '服务商', icon: Server, permission: 'provider:manage' },
|
{ href: '/providers', label: '服务商', icon: Server, permission: 'model:admin' },
|
||||||
{ href: '/models', label: '模型管理', icon: Cpu, permission: 'model:read' },
|
{ href: '/models', label: '模型管理', icon: Cpu, permission: 'model:admin' },
|
||||||
{ href: '/agent-templates', label: 'Agent 模板', icon: Bot, permission: 'model:read' },
|
{ href: '/api-keys', label: 'API 密钥', icon: Key, permission: null },
|
||||||
{ href: '/api-keys', label: 'API 密钥', icon: Key, permission: 'admin:full' },
|
{ href: '/usage', label: '用量统计', icon: BarChart3, permission: null },
|
||||||
{ href: '/usage', label: '用量统计', icon: BarChart3, permission: 'admin:full' },
|
{ href: '/relay', label: '中转任务', icon: ArrowLeftRight, permission: 'relay:admin' },
|
||||||
{ href: '/relay', label: '中转任务', icon: ArrowLeftRight, permission: 'relay:use' },
|
{ href: '/config', label: '系统配置', icon: Settings, permission: 'admin:full' },
|
||||||
{ href: '/config', label: '系统配置', icon: Settings, permission: 'config:read' },
|
|
||||||
{ href: '/prompts', label: '提示词管理', icon: MessageSquare, permission: 'prompt:read' },
|
|
||||||
{ href: '/logs', label: '操作日志', icon: FileText, permission: 'admin:full' },
|
{ href: '/logs', label: '操作日志', icon: FileText, permission: 'admin:full' },
|
||||||
|
{ href: '/profile', label: '个人设置', icon: UserCog, permission: null },
|
||||||
|
{ href: '/security', label: '安全设置', icon: ShieldCheck, permission: null },
|
||||||
|
{ href: '/devices', label: '设备管理', icon: Monitor, permission: null },
|
||||||
]
|
]
|
||||||
|
|
||||||
function Sidebar({
|
function Sidebar({
|
||||||
collapsed,
|
collapsed,
|
||||||
onToggle,
|
onToggle,
|
||||||
|
mobileOpen,
|
||||||
|
onMobileClose,
|
||||||
}: {
|
}: {
|
||||||
collapsed: boolean
|
collapsed: boolean
|
||||||
onToggle: () => void
|
onToggle: () => void
|
||||||
|
mobileOpen: boolean
|
||||||
|
onMobileClose: () => void
|
||||||
}) {
|
}) {
|
||||||
const pathname = usePathname()
|
const pathname = usePathname()
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
const { account } = useAuth()
|
const { account } = useAuth()
|
||||||
|
|
||||||
const permissions = account ? getPermissionsForRole(account.role) : []
|
// 路由变化时关闭移动端菜单
|
||||||
|
useEffect(() => {
|
||||||
|
onMobileClose()
|
||||||
|
}, [pathname, onMobileClose])
|
||||||
|
|
||||||
function handleLogout() {
|
function handleLogout() {
|
||||||
logout()
|
logout()
|
||||||
router.replace('/login')
|
router.replace('/login')
|
||||||
}
|
}
|
||||||
|
|
||||||
const filteredNavItems = navItems.filter((item) => {
|
|
||||||
if (!item.permission) return true
|
|
||||||
return permissions.includes(item.permission) || permissions.includes('admin:full')
|
|
||||||
})
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<aside
|
<>
|
||||||
className={cn(
|
{/* 移动端 overlay */}
|
||||||
'fixed left-0 top-0 z-40 flex h-screen flex-col border-r border-border bg-card transition-all duration-300',
|
{mobileOpen && (
|
||||||
collapsed ? 'w-16' : 'w-64',
|
<div
|
||||||
|
className="fixed inset-0 z-40 bg-black/50 lg:hidden"
|
||||||
|
onClick={onMobileClose}
|
||||||
|
/>
|
||||||
)}
|
)}
|
||||||
>
|
<aside
|
||||||
|
className={cn(
|
||||||
|
'fixed left-0 top-0 z-50 flex h-screen flex-col border-r border-border bg-card transition-all duration-300',
|
||||||
|
collapsed ? 'w-16' : 'w-64',
|
||||||
|
'lg:z-40',
|
||||||
|
mobileOpen ? 'translate-x-0' : '-translate-x-full lg:translate-x-0',
|
||||||
|
)}
|
||||||
|
>
|
||||||
{/* Logo */}
|
{/* Logo */}
|
||||||
<div className="flex h-14 items-center border-b border-border px-4">
|
<div className="flex h-14 items-center border-b border-border px-4">
|
||||||
<Link href="/" className="flex items-center gap-2 cursor-pointer">
|
<Link href="/" className="flex items-center gap-2 cursor-pointer">
|
||||||
@@ -98,7 +100,15 @@ function Sidebar({
|
|||||||
{/* 导航 */}
|
{/* 导航 */}
|
||||||
<nav className="flex-1 overflow-y-auto scrollbar-thin py-2 px-2">
|
<nav className="flex-1 overflow-y-auto scrollbar-thin py-2 px-2">
|
||||||
<ul className="space-y-1">
|
<ul className="space-y-1">
|
||||||
{filteredNavItems.map((item) => {
|
{navItems
|
||||||
|
.filter((item) => {
|
||||||
|
if (!item.permission) return true
|
||||||
|
if (!account) return false
|
||||||
|
// super_admin 拥有所有权限
|
||||||
|
if (account.role === 'super_admin') return true
|
||||||
|
return account.permissions?.includes(item.permission) ?? false
|
||||||
|
})
|
||||||
|
.map((item) => {
|
||||||
const isActive =
|
const isActive =
|
||||||
item.href === '/'
|
item.href === '/'
|
||||||
? pathname === '/'
|
? pathname === '/'
|
||||||
@@ -142,6 +152,19 @@ function Sidebar({
|
|||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
{/* 折叠时显示退出按钮 */}
|
||||||
|
{collapsed && (
|
||||||
|
<div className="border-t border-border p-2">
|
||||||
|
<button
|
||||||
|
onClick={handleLogout}
|
||||||
|
className="flex w-full items-center justify-center rounded-md px-3 py-2 text-muted-foreground hover:bg-muted hover:text-destructive transition-colors duration-200 cursor-pointer"
|
||||||
|
title="退出登录"
|
||||||
|
>
|
||||||
|
<LogOut className="h-4 w-4" />
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{/* 用户信息 */}
|
{/* 用户信息 */}
|
||||||
{!collapsed && (
|
{!collapsed && (
|
||||||
<div className="border-t border-border p-3">
|
<div className="border-t border-border p-3">
|
||||||
@@ -168,10 +191,11 @@ function Sidebar({
|
|||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</aside>
|
</aside>
|
||||||
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
function Header() {
|
function Header({ children }: { children?: ReactNode }) {
|
||||||
const pathname = usePathname()
|
const pathname = usePathname()
|
||||||
const currentNav = navItems.find(
|
const currentNav = navItems.find(
|
||||||
(item) =>
|
(item) =>
|
||||||
@@ -183,7 +207,7 @@ function Header() {
|
|||||||
return (
|
return (
|
||||||
<header className="sticky top-0 z-30 flex h-14 items-center border-b border-border bg-background/80 backdrop-blur-sm px-6">
|
<header className="sticky top-0 z-30 flex h-14 items-center border-b border-border bg-background/80 backdrop-blur-sm px-6">
|
||||||
{/* 移动端菜单按钮 */}
|
{/* 移动端菜单按钮 */}
|
||||||
<MobileMenuButton />
|
{children}
|
||||||
|
|
||||||
{/* 页面标题 */}
|
{/* 页面标题 */}
|
||||||
<h1 className="text-lg font-semibold text-foreground">
|
<h1 className="text-lg font-semibold text-foreground">
|
||||||
@@ -203,10 +227,10 @@ function Header() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
function MobileMenuButton() {
|
function MobileMenuButton({ onClick }: { onClick: () => void }) {
|
||||||
// Placeholder for mobile menu toggle
|
|
||||||
return (
|
return (
|
||||||
<button
|
<button
|
||||||
|
onClick={onClick}
|
||||||
className="mr-3 rounded-md p-2 text-muted-foreground hover:bg-muted hover:text-foreground transition-colors duration-200 lg:hidden cursor-pointer"
|
className="mr-3 rounded-md p-2 text-muted-foreground hover:bg-muted hover:text-foreground transition-colors duration-200 lg:hidden cursor-pointer"
|
||||||
>
|
>
|
||||||
<Menu className="h-5 w-5" />
|
<Menu className="h-5 w-5" />
|
||||||
@@ -214,28 +238,68 @@ function MobileMenuButton() {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** 路由级权限守卫:隐藏导航项但用户直接访问 URL 时拦截 */
|
||||||
|
function PageGuard({ children }: { children: ReactNode }) {
|
||||||
|
const pathname = usePathname()
|
||||||
|
const router = useRouter()
|
||||||
|
const { account } = useAuth()
|
||||||
|
|
||||||
|
const matchedNav = navItems.find((item) =>
|
||||||
|
item.href === '/' ? pathname === '/' : pathname.startsWith(item.href),
|
||||||
|
)
|
||||||
|
|
||||||
|
if (matchedNav?.permission && account) {
|
||||||
|
if (account.role !== 'super_admin' && !(account.permissions?.includes(matchedNav.permission) ?? false)) {
|
||||||
|
return (
|
||||||
|
<div className="flex flex-1 items-center justify-center">
|
||||||
|
<div className="text-center space-y-3">
|
||||||
|
<p className="text-lg font-medium text-muted-foreground">权限不足</p>
|
||||||
|
<p className="text-sm text-muted-foreground">您没有访问「{matchedNav.label}」的权限</p>
|
||||||
|
<button
|
||||||
|
onClick={() => router.replace('/')}
|
||||||
|
className="text-sm text-primary hover:underline cursor-pointer"
|
||||||
|
>
|
||||||
|
返回仪表盘
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return <>{children}</>
|
||||||
|
}
|
||||||
|
|
||||||
export default function DashboardLayout({ children }: { children: ReactNode }) {
|
export default function DashboardLayout({ children }: { children: ReactNode }) {
|
||||||
const [sidebarCollapsed, setSidebarCollapsed] = useState(false)
|
const [sidebarCollapsed, setSidebarCollapsed] = useState(false)
|
||||||
|
const [mobileOpen, setMobileOpen] = useState(false)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<AuthGuard>
|
<AuthGuard>
|
||||||
<div className="flex min-h-screen">
|
<PageGuard>
|
||||||
|
<div className="flex min-h-screen">
|
||||||
<Sidebar
|
<Sidebar
|
||||||
collapsed={sidebarCollapsed}
|
collapsed={sidebarCollapsed}
|
||||||
onToggle={() => setSidebarCollapsed(!sidebarCollapsed)}
|
onToggle={() => setSidebarCollapsed(!sidebarCollapsed)}
|
||||||
|
mobileOpen={mobileOpen}
|
||||||
|
onMobileClose={() => setMobileOpen(false)}
|
||||||
/>
|
/>
|
||||||
<div
|
<div
|
||||||
className={cn(
|
className={cn(
|
||||||
'flex flex-1 flex-col transition-all duration-300',
|
'flex flex-1 flex-col transition-all duration-300',
|
||||||
sidebarCollapsed ? 'ml-16' : 'ml-64',
|
'ml-0 lg:transition-all',
|
||||||
|
sidebarCollapsed ? 'lg:ml-16' : 'lg:ml-64',
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
<Header />
|
<Header>
|
||||||
|
<MobileMenuButton onClick={() => setMobileOpen(true)} />
|
||||||
|
</Header>
|
||||||
<main className="flex-1 overflow-auto p-6 scrollbar-thin">
|
<main className="flex-1 overflow-auto p-6 scrollbar-thin">
|
||||||
{children}
|
{children}
|
||||||
</main>
|
</main>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
</PageGuard>
|
||||||
</AuthGuard>
|
</AuthGuard>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useState } from 'react'
|
import { useEffect, useState, useCallback } from 'react'
|
||||||
import useSWR from 'swr'
|
|
||||||
import {
|
import {
|
||||||
Plus,
|
Plus,
|
||||||
Loader2,
|
Loader2,
|
||||||
@@ -38,8 +37,6 @@ import {
|
|||||||
SelectTrigger,
|
SelectTrigger,
|
||||||
SelectValue,
|
SelectValue,
|
||||||
} from '@/components/ui/select'
|
} from '@/components/ui/select'
|
||||||
import { TableSkeleton } from '@/components/ui/skeleton'
|
|
||||||
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatNumber } from '@/lib/utils'
|
import { formatNumber } from '@/lib/utils'
|
||||||
@@ -74,29 +71,14 @@ const emptyForm: ModelForm = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function ModelsPage() {
|
export default function ModelsPage() {
|
||||||
|
const [models, setModels] = useState<Model[]>([])
|
||||||
|
const [providers, setProviders] = useState<Provider[]>([])
|
||||||
|
const [total, setTotal] = useState(0)
|
||||||
const [page, setPage] = useState(1)
|
const [page, setPage] = useState(1)
|
||||||
const [providerFilter, setProviderFilter] = useState<string>('all')
|
const [providerFilter, setProviderFilter] = useState<string>('all')
|
||||||
|
const [loading, setLoading] = useState(true)
|
||||||
const [error, setError] = useState('')
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
// SWR for models list
|
|
||||||
const { data, isLoading, mutate } = useSWR(
|
|
||||||
['models', page, providerFilter],
|
|
||||||
() => {
|
|
||||||
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
|
||||||
if (providerFilter !== 'all') params.provider_id = providerFilter
|
|
||||||
return api.models.list(params)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
const models = data?.items ?? []
|
|
||||||
const total = data?.total ?? 0
|
|
||||||
|
|
||||||
// SWR for providers list (dropdown)
|
|
||||||
const { data: providersData } = useSWR(
|
|
||||||
['providers.all'],
|
|
||||||
() => api.providers.list({ page: 1, page_size: 100 })
|
|
||||||
)
|
|
||||||
const providers = providersData?.items ?? []
|
|
||||||
|
|
||||||
// Dialog
|
// Dialog
|
||||||
const [dialogOpen, setDialogOpen] = useState(false)
|
const [dialogOpen, setDialogOpen] = useState(false)
|
||||||
const [editTarget, setEditTarget] = useState<Model | null>(null)
|
const [editTarget, setEditTarget] = useState<Model | null>(null)
|
||||||
@@ -107,6 +89,37 @@ export default function ModelsPage() {
|
|||||||
const [deleteTarget, setDeleteTarget] = useState<Model | null>(null)
|
const [deleteTarget, setDeleteTarget] = useState<Model | null>(null)
|
||||||
const [deleting, setDeleting] = useState(false)
|
const [deleting, setDeleting] = useState(false)
|
||||||
|
|
||||||
|
const fetchModels = useCallback(async () => {
|
||||||
|
setLoading(true)
|
||||||
|
setError('')
|
||||||
|
try {
|
||||||
|
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
||||||
|
if (providerFilter !== 'all') params.provider_id = providerFilter
|
||||||
|
const res = await api.models.list(params)
|
||||||
|
setModels(res.items)
|
||||||
|
setTotal(res.total)
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
|
else setError('加载失败')
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}, [page, providerFilter])
|
||||||
|
|
||||||
|
const fetchProviders = useCallback(async () => {
|
||||||
|
try {
|
||||||
|
const res = await api.providers.list()
|
||||||
|
setProviders(res)
|
||||||
|
} catch {
|
||||||
|
// ignore
|
||||||
|
}
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
fetchModels()
|
||||||
|
fetchProviders()
|
||||||
|
}, [fetchModels, fetchProviders])
|
||||||
|
|
||||||
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
||||||
|
|
||||||
const providerMap = new Map(providers.map((p) => [p.id, p.display_name || p.name]))
|
const providerMap = new Map(providers.map((p) => [p.id, p.display_name || p.name]))
|
||||||
@@ -156,7 +169,7 @@ export default function ModelsPage() {
|
|||||||
await api.models.create(payload)
|
await api.models.create(payload)
|
||||||
}
|
}
|
||||||
setDialogOpen(false)
|
setDialogOpen(false)
|
||||||
mutate()
|
fetchModels()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
@@ -170,7 +183,7 @@ export default function ModelsPage() {
|
|||||||
try {
|
try {
|
||||||
await api.models.delete(deleteTarget.id)
|
await api.models.delete(deleteTarget.id)
|
||||||
setDeleteTarget(null)
|
setDeleteTarget(null)
|
||||||
mutate()
|
fetchModels()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
@@ -200,12 +213,21 @@ export default function ModelsPage() {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{error && <ErrorBanner message={error} onDismiss={() => setError('')} />}
|
{error && (
|
||||||
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
||||||
|
{error}
|
||||||
|
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{isLoading ? (
|
{loading ? (
|
||||||
<TableSkeleton rows={8} cols={9} hasToolbar={false} />
|
<div className="flex h-64 items-center justify-center">
|
||||||
) : error ? null : models.length === 0 ? (
|
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
||||||
<EmptyState />
|
</div>
|
||||||
|
) : models.length === 0 ? (
|
||||||
|
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
||||||
|
暂无数据
|
||||||
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Table>
|
<Table>
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
|
import { useEffect, useState } from 'react'
|
||||||
import {
|
import {
|
||||||
Users,
|
Users,
|
||||||
Server,
|
Server,
|
||||||
ArrowLeftRight,
|
ArrowLeftRight,
|
||||||
Zap,
|
Zap,
|
||||||
|
Loader2,
|
||||||
TrendingUp,
|
TrendingUp,
|
||||||
} from 'lucide-react'
|
} from 'lucide-react'
|
||||||
import {
|
import {
|
||||||
@@ -19,12 +21,8 @@ import {
|
|||||||
Bar,
|
Bar,
|
||||||
Legend,
|
Legend,
|
||||||
} from 'recharts'
|
} from 'recharts'
|
||||||
import useSWR from 'swr'
|
|
||||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
|
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
|
||||||
import { Badge } from '@/components/ui/badge'
|
import { Badge } from '@/components/ui/badge'
|
||||||
import { StatsSkeleton } from '@/components/ui/skeleton'
|
|
||||||
import { ChartSkeleton } from '@/components/ui/skeleton'
|
|
||||||
import { TableSkeleton } from '@/components/ui/skeleton'
|
|
||||||
import {
|
import {
|
||||||
Table,
|
Table,
|
||||||
TableBody,
|
TableBody,
|
||||||
@@ -37,7 +35,7 @@ import { api } from '@/lib/api-client'
|
|||||||
import { formatNumber, formatDate } from '@/lib/utils'
|
import { formatNumber, formatDate } from '@/lib/utils'
|
||||||
import type {
|
import type {
|
||||||
DashboardStats,
|
DashboardStats,
|
||||||
UsageRecord,
|
UsageStats,
|
||||||
OperationLog,
|
OperationLog,
|
||||||
} from '@/lib/types'
|
} from '@/lib/types'
|
||||||
|
|
||||||
@@ -88,26 +86,65 @@ function StatusBadge({ status }: { status: string }) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function DashboardPage() {
|
export default function DashboardPage() {
|
||||||
const { data: stats, isLoading: statsLoading } = useSWR(
|
const [stats, setStats] = useState<DashboardStats | null>(null)
|
||||||
['stats.dashboard'],
|
const [usageStats, setUsageStats] = useState<UsageStats | null>(null)
|
||||||
() => api.stats.dashboard(),
|
const [recentLogs, setRecentLogs] = useState<OperationLog[]>([])
|
||||||
)
|
const [loading, setLoading] = useState(true)
|
||||||
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
const { data: usageData = [], isLoading: usageLoading } = useSWR(
|
useEffect(() => {
|
||||||
['usage.daily.30'],
|
async function fetchData() {
|
||||||
() => api.usage.daily({ days: 30 }),
|
try {
|
||||||
)
|
const [statsRes, usageRes, logsRes] = await Promise.allSettled([
|
||||||
|
api.stats.dashboard(),
|
||||||
|
api.usage.get(),
|
||||||
|
api.logs.list({ page: 1, page_size: 5 }),
|
||||||
|
])
|
||||||
|
|
||||||
const { data: logsData, isLoading: logsLoading } = useSWR(
|
if (statsRes.status === 'fulfilled') setStats(statsRes.value)
|
||||||
['logs.recent'],
|
if (usageRes.status === 'fulfilled') setUsageStats(usageRes.value)
|
||||||
() => api.logs.list({ page: 1, page_size: 5 }),
|
if (logsRes.status === 'fulfilled') setRecentLogs(logsRes.value)
|
||||||
)
|
|
||||||
|
|
||||||
const recentLogs: OperationLog[] = logsData?.items ?? []
|
if (statsRes.status === 'rejected' && usageRes.status === 'rejected' && logsRes.status === 'rejected') {
|
||||||
|
setError('加载数据失败,请检查后端服务是否启动')
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fetchData()
|
||||||
|
}, [])
|
||||||
|
|
||||||
const chartData = usageData.map((r: UsageRecord) => ({
|
if (loading) {
|
||||||
day: r.day.slice(5), // MM-DD
|
return (
|
||||||
请求量: r.count,
|
<div className="flex h-[60vh] items-center justify-center">
|
||||||
|
<div className="flex flex-col items-center gap-3">
|
||||||
|
<Loader2 className="h-8 w-8 animate-spin text-primary" />
|
||||||
|
<p className="text-sm text-muted-foreground">加载中...</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return (
|
||||||
|
<div className="flex h-[60vh] items-center justify-center">
|
||||||
|
<div className="text-center">
|
||||||
|
<p className="text-destructive">{error}</p>
|
||||||
|
<button
|
||||||
|
onClick={() => window.location.reload()}
|
||||||
|
className="mt-4 text-sm text-primary hover:underline cursor-pointer"
|
||||||
|
>
|
||||||
|
重新加载
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const chartData = (usageStats?.by_day ?? []).map((r) => ({
|
||||||
|
day: r.date.slice(5), // MM-DD
|
||||||
|
请求量: r.request_count,
|
||||||
Input: r.input_tokens,
|
Input: r.input_tokens,
|
||||||
Output: r.output_tokens,
|
Output: r.output_tokens,
|
||||||
}))
|
}))
|
||||||
@@ -115,151 +152,139 @@ export default function DashboardPage() {
|
|||||||
return (
|
return (
|
||||||
<div className="space-y-6">
|
<div className="space-y-6">
|
||||||
{/* 统计卡片 */}
|
{/* 统计卡片 */}
|
||||||
{statsLoading ? (
|
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-4">
|
||||||
<StatsSkeleton count={4} />
|
<StatCard
|
||||||
) : (
|
title="总账号数"
|
||||||
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-4">
|
value={stats?.total_accounts ?? '-'}
|
||||||
<StatCard
|
icon={<Users className="h-5 w-5 text-blue-400" />}
|
||||||
title="总账号数"
|
color="bg-blue-500/10"
|
||||||
value={stats?.total_accounts ?? '-'}
|
subtitle={`活跃 ${stats?.active_accounts ?? 0}`}
|
||||||
icon={<Users className="h-5 w-5 text-blue-400" />}
|
/>
|
||||||
color="bg-blue-500/10"
|
<StatCard
|
||||||
subtitle={`活跃 ${stats?.active_accounts ?? 0}`}
|
title="活跃服务商"
|
||||||
/>
|
value={stats?.active_providers ?? '-'}
|
||||||
<StatCard
|
icon={<Server className="h-5 w-5 text-green-400" />}
|
||||||
title="活跃服务商"
|
color="bg-green-500/10"
|
||||||
value={stats?.active_providers ?? '-'}
|
subtitle={`模型 ${stats?.active_models ?? 0}`}
|
||||||
icon={<Server className="h-5 w-5 text-green-400" />}
|
/>
|
||||||
color="bg-green-500/10"
|
<StatCard
|
||||||
subtitle={`模型 ${stats?.active_models ?? 0}`}
|
title="今日请求"
|
||||||
/>
|
value={stats?.tasks_today ?? '-'}
|
||||||
<StatCard
|
icon={<ArrowLeftRight className="h-5 w-5 text-purple-400" />}
|
||||||
title="今日请求"
|
color="bg-purple-500/10"
|
||||||
value={stats?.tasks_today ?? '-'}
|
subtitle="中转任务"
|
||||||
icon={<ArrowLeftRight className="h-5 w-5 text-purple-400" />}
|
/>
|
||||||
color="bg-purple-500/10"
|
<StatCard
|
||||||
subtitle="中转任务"
|
title="今日 Token"
|
||||||
/>
|
value={formatNumber((stats?.tokens_today_input ?? 0) + (stats?.tokens_today_output ?? 0))}
|
||||||
<StatCard
|
icon={<Zap className="h-5 w-5 text-orange-400" />}
|
||||||
title="今日 Token"
|
color="bg-orange-500/10"
|
||||||
value={formatNumber((stats?.tokens_today_input ?? 0) + (stats?.tokens_today_output ?? 0))}
|
subtitle={`In: ${formatNumber(stats?.tokens_today_input ?? 0)} / Out: ${formatNumber(stats?.tokens_today_output ?? 0)}`}
|
||||||
icon={<Zap className="h-5 w-5 text-orange-400" />}
|
/>
|
||||||
color="bg-orange-500/10"
|
</div>
|
||||||
subtitle={`In: ${formatNumber(stats?.tokens_today_input ?? 0)} / Out: ${formatNumber(stats?.tokens_today_output ?? 0)}`}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* 图表 */}
|
{/* 图表 */}
|
||||||
<div className="grid grid-cols-1 gap-4 lg:grid-cols-2">
|
<div className="grid grid-cols-1 gap-4 lg:grid-cols-2">
|
||||||
{/* 请求趋势 */}
|
{/* 请求趋势 */}
|
||||||
{usageLoading ? (
|
<Card>
|
||||||
<ChartSkeleton height={280} />
|
<CardHeader>
|
||||||
) : (
|
<CardTitle className="flex items-center gap-2 text-base">
|
||||||
<Card>
|
<TrendingUp className="h-4 w-4 text-primary" />
|
||||||
<CardHeader>
|
请求趋势 (30 天)
|
||||||
<CardTitle className="flex items-center gap-2 text-base">
|
</CardTitle>
|
||||||
<TrendingUp className="h-4 w-4 text-primary" />
|
</CardHeader>
|
||||||
请求趋势 (30 天)
|
<CardContent>
|
||||||
</CardTitle>
|
{chartData.length > 0 ? (
|
||||||
</CardHeader>
|
<ResponsiveContainer width="100%" height={280}>
|
||||||
<CardContent>
|
<AreaChart data={chartData}>
|
||||||
{chartData.length > 0 ? (
|
<defs>
|
||||||
<ResponsiveContainer width="100%" height={280}>
|
<linearGradient id="colorRequests" x1="0" y1="0" x2="0" y2="1">
|
||||||
<AreaChart data={chartData}>
|
<stop offset="5%" stopColor="#22C55E" stopOpacity={0.3} />
|
||||||
<defs>
|
<stop offset="95%" stopColor="#22C55E" stopOpacity={0} />
|
||||||
<linearGradient id="colorRequests" x1="0" y1="0" x2="0" y2="1">
|
</linearGradient>
|
||||||
<stop offset="5%" stopColor="#22C55E" stopOpacity={0.3} />
|
</defs>
|
||||||
<stop offset="95%" stopColor="#22C55E" stopOpacity={0} />
|
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
||||||
</linearGradient>
|
<XAxis
|
||||||
</defs>
|
dataKey="day"
|
||||||
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
||||||
<XAxis
|
axisLine={{ stroke: '#1E293B' }}
|
||||||
dataKey="day"
|
/>
|
||||||
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
<YAxis
|
||||||
axisLine={{ stroke: '#1E293B' }}
|
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
||||||
/>
|
axisLine={{ stroke: '#1E293B' }}
|
||||||
<YAxis
|
/>
|
||||||
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
<Tooltip
|
||||||
axisLine={{ stroke: '#1E293B' }}
|
contentStyle={{
|
||||||
/>
|
backgroundColor: '#0F172A',
|
||||||
<Tooltip
|
border: '1px solid #1E293B',
|
||||||
contentStyle={{
|
borderRadius: '8px',
|
||||||
backgroundColor: '#0F172A',
|
color: '#F8FAFC',
|
||||||
border: '1px solid #1E293B',
|
fontSize: '12px',
|
||||||
borderRadius: '8px',
|
}}
|
||||||
color: '#F8FAFC',
|
/>
|
||||||
fontSize: '12px',
|
<Area
|
||||||
}}
|
type="monotone"
|
||||||
/>
|
dataKey="请求量"
|
||||||
<Area
|
stroke="#22C55E"
|
||||||
type="monotone"
|
fillOpacity={1}
|
||||||
dataKey="请求量"
|
fill="url(#colorRequests)"
|
||||||
stroke="#22C55E"
|
strokeWidth={2}
|
||||||
fillOpacity={1}
|
/>
|
||||||
fill="url(#colorRequests)"
|
</AreaChart>
|
||||||
strokeWidth={2}
|
</ResponsiveContainer>
|
||||||
/>
|
) : (
|
||||||
</AreaChart>
|
<div className="flex h-[280px] items-center justify-center text-muted-foreground text-sm">
|
||||||
</ResponsiveContainer>
|
暂无数据
|
||||||
) : (
|
</div>
|
||||||
<div className="flex h-[280px] items-center justify-center text-muted-foreground text-sm">
|
)}
|
||||||
暂无数据
|
</CardContent>
|
||||||
</div>
|
</Card>
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Token 用量 */}
|
{/* Token 用量 */}
|
||||||
{usageLoading ? (
|
<Card>
|
||||||
<ChartSkeleton height={280} />
|
<CardHeader>
|
||||||
) : (
|
<CardTitle className="flex items-center gap-2 text-base">
|
||||||
<Card>
|
<Zap className="h-4 w-4 text-orange-400" />
|
||||||
<CardHeader>
|
Token 用量 (30 天)
|
||||||
<CardTitle className="flex items-center gap-2 text-base">
|
</CardTitle>
|
||||||
<Zap className="h-4 w-4 text-orange-400" />
|
</CardHeader>
|
||||||
Token 用量 (30 天)
|
<CardContent>
|
||||||
</CardTitle>
|
{chartData.length > 0 ? (
|
||||||
</CardHeader>
|
<ResponsiveContainer width="100%" height={280}>
|
||||||
<CardContent>
|
<BarChart data={chartData}>
|
||||||
{chartData.length > 0 ? (
|
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
||||||
<ResponsiveContainer width="100%" height={280}>
|
<XAxis
|
||||||
<BarChart data={chartData}>
|
dataKey="day"
|
||||||
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
||||||
<XAxis
|
axisLine={{ stroke: '#1E293B' }}
|
||||||
dataKey="day"
|
/>
|
||||||
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
<YAxis
|
||||||
axisLine={{ stroke: '#1E293B' }}
|
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
||||||
/>
|
axisLine={{ stroke: '#1E293B' }}
|
||||||
<YAxis
|
/>
|
||||||
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
<Tooltip
|
||||||
axisLine={{ stroke: '#1E293B' }}
|
contentStyle={{
|
||||||
/>
|
backgroundColor: '#0F172A',
|
||||||
<Tooltip
|
border: '1px solid #1E293B',
|
||||||
contentStyle={{
|
borderRadius: '8px',
|
||||||
backgroundColor: '#0F172A',
|
color: '#F8FAFC',
|
||||||
border: '1px solid #1E293B',
|
fontSize: '12px',
|
||||||
borderRadius: '8px',
|
}}
|
||||||
color: '#F8FAFC',
|
/>
|
||||||
fontSize: '12px',
|
<Legend
|
||||||
}}
|
wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }}
|
||||||
/>
|
/>
|
||||||
<Legend
|
<Bar dataKey="Input" fill="#3B82F6" radius={[2, 2, 0, 0]} />
|
||||||
wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }}
|
<Bar dataKey="Output" fill="#F97316" radius={[2, 2, 0, 0]} />
|
||||||
/>
|
</BarChart>
|
||||||
<Bar dataKey="Input" fill="#3B82F6" radius={[2, 2, 0, 0]} />
|
</ResponsiveContainer>
|
||||||
<Bar dataKey="Output" fill="#F97316" radius={[2, 2, 0, 0]} />
|
) : (
|
||||||
</BarChart>
|
<div className="flex h-[280px] items-center justify-center text-muted-foreground text-sm">
|
||||||
</ResponsiveContainer>
|
暂无数据
|
||||||
) : (
|
</div>
|
||||||
<div className="flex h-[280px] items-center justify-center text-muted-foreground text-sm">
|
)}
|
||||||
暂无数据
|
</CardContent>
|
||||||
</div>
|
</Card>
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
)}
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* 最近操作日志 */}
|
{/* 最近操作日志 */}
|
||||||
@@ -268,9 +293,7 @@ export default function DashboardPage() {
|
|||||||
<CardTitle className="text-base">最近操作</CardTitle>
|
<CardTitle className="text-base">最近操作</CardTitle>
|
||||||
</CardHeader>
|
</CardHeader>
|
||||||
<CardContent>
|
<CardContent>
|
||||||
{logsLoading ? (
|
{recentLogs.length > 0 ? (
|
||||||
<TableSkeleton rows={5} cols={5} hasToolbar={false} />
|
|
||||||
) : recentLogs.length > 0 ? (
|
|
||||||
<Table>
|
<Table>
|
||||||
<TableHeader>
|
<TableHeader>
|
||||||
<TableRow>
|
<TableRow>
|
||||||
|
|||||||
154
admin/src/app/(dashboard)/profile/page.tsx
Normal file
154
admin/src/app/(dashboard)/profile/page.tsx
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
'use client'
|
||||||
|
|
||||||
|
import { useState } from 'react'
|
||||||
|
import { Lock, Loader2, Eye, EyeOff, Check } from 'lucide-react'
|
||||||
|
import { Button } from '@/components/ui/button'
|
||||||
|
import { Input } from '@/components/ui/input'
|
||||||
|
import { Label } from '@/components/ui/label'
|
||||||
|
import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card'
|
||||||
|
import { api } from '@/lib/api-client'
|
||||||
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
|
|
||||||
|
export default function ProfilePage() {
|
||||||
|
const [oldPassword, setOldPassword] = useState('')
|
||||||
|
const [newPassword, setNewPassword] = useState('')
|
||||||
|
const [confirmPassword, setConfirmPassword] = useState('')
|
||||||
|
const [showOld, setShowOld] = useState(false)
|
||||||
|
const [showNew, setShowNew] = useState(false)
|
||||||
|
const [showConfirm, setShowConfirm] = useState(false)
|
||||||
|
const [saving, setSaving] = useState(false)
|
||||||
|
const [error, setError] = useState('')
|
||||||
|
const [success, setSuccess] = useState('')
|
||||||
|
|
||||||
|
async function handleSubmit(e: React.FormEvent) {
|
||||||
|
e.preventDefault()
|
||||||
|
setError('')
|
||||||
|
setSuccess('')
|
||||||
|
|
||||||
|
if (newPassword.length < 8) {
|
||||||
|
setError('新密码至少 8 个字符')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (newPassword !== confirmPassword) {
|
||||||
|
setError('两次输入的新密码不一致')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
setSaving(true)
|
||||||
|
try {
|
||||||
|
await api.auth.changePassword({ old_password: oldPassword, new_password: newPassword })
|
||||||
|
setSuccess('密码修改成功')
|
||||||
|
setOldPassword('')
|
||||||
|
setNewPassword('')
|
||||||
|
setConfirmPassword('')
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) {
|
||||||
|
setError(err.body.message || '修改失败')
|
||||||
|
} else {
|
||||||
|
setError('网络错误,请稍后重试')
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
setSaving(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-lg">
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle className="flex items-center gap-2">
|
||||||
|
<Lock className="h-5 w-5" />
|
||||||
|
修改密码
|
||||||
|
</CardTitle>
|
||||||
|
<CardDescription>修改您的登录密码。修改后需要重新登录。</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<form onSubmit={handleSubmit} className="space-y-4">
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label htmlFor="old-password">当前密码</Label>
|
||||||
|
<div className="relative">
|
||||||
|
<Input
|
||||||
|
id="old-password"
|
||||||
|
type={showOld ? 'text' : 'password'}
|
||||||
|
value={oldPassword}
|
||||||
|
onChange={(e) => setOldPassword(e.target.value)}
|
||||||
|
placeholder="请输入当前密码"
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setShowOld(!showOld)}
|
||||||
|
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
|
||||||
|
>
|
||||||
|
{showOld ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label htmlFor="new-password">新密码</Label>
|
||||||
|
<div className="relative">
|
||||||
|
<Input
|
||||||
|
id="new-password"
|
||||||
|
type={showNew ? 'text' : 'password'}
|
||||||
|
value={newPassword}
|
||||||
|
onChange={(e) => setNewPassword(e.target.value)}
|
||||||
|
placeholder="至少 8 个字符"
|
||||||
|
required
|
||||||
|
minLength={8}
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setShowNew(!showNew)}
|
||||||
|
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
|
||||||
|
>
|
||||||
|
{showNew ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label htmlFor="confirm-password">确认新密码</Label>
|
||||||
|
<div className="relative">
|
||||||
|
<Input
|
||||||
|
id="confirm-password"
|
||||||
|
type={showConfirm ? 'text' : 'password'}
|
||||||
|
value={confirmPassword}
|
||||||
|
onChange={(e) => setConfirmPassword(e.target.value)}
|
||||||
|
placeholder="再次输入新密码"
|
||||||
|
required
|
||||||
|
minLength={8}
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setShowConfirm(!showConfirm)}
|
||||||
|
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
|
||||||
|
>
|
||||||
|
{showConfirm ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{error && (
|
||||||
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
||||||
|
{error}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{success && (
|
||||||
|
<div className="rounded-md bg-emerald-500/10 border border-emerald-500/20 px-4 py-3 text-sm text-emerald-500 flex items-center gap-2">
|
||||||
|
<Check className="h-4 w-4" />
|
||||||
|
{success}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
<Button type="submit" disabled={saving || !oldPassword || !newPassword || !confirmPassword}>
|
||||||
|
{saving && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||||
|
修改密码
|
||||||
|
</Button>
|
||||||
|
</form>
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -1,341 +0,0 @@
|
|||||||
'use client'
|
|
||||||
|
|
||||||
import { useState } from 'react'
|
|
||||||
import useSWR from 'swr'
|
|
||||||
import { api } from '@/lib/api-client'
|
|
||||||
import type { PromptTemplate, PromptVersion } from '@/lib/types'
|
|
||||||
import { EmptyState } from '@/components/ui/state'
|
|
||||||
import { TableSkeleton } from '@/components/ui/skeleton'
|
|
||||||
|
|
||||||
export default function PromptsPage() {
|
|
||||||
const [page, setPage] = useState(1)
|
|
||||||
const [selectedName, setSelectedName] = useState<string | null>(null)
|
|
||||||
const [versions, setVersions] = useState<PromptVersion[]>([])
|
|
||||||
const [showCreate, setShowCreate] = useState(false)
|
|
||||||
const [showNewVersion, setShowNewVersion] = useState(false)
|
|
||||||
const [filter, setFilter] = useState<{ source?: string; status?: string }>({})
|
|
||||||
|
|
||||||
const { data, error, isLoading, mutate } = useSWR(
|
|
||||||
['prompts.list', page, filter.source, filter.status],
|
|
||||||
() => api.prompts.list({ page, page_size: 50, ...filter }),
|
|
||||||
)
|
|
||||||
|
|
||||||
const templates = data?.items ?? []
|
|
||||||
const total = data?.total ?? 0
|
|
||||||
|
|
||||||
const fetchVersions = async (name: string) => {
|
|
||||||
try {
|
|
||||||
const res = await api.prompts.listVersions(name)
|
|
||||||
setVersions(res)
|
|
||||||
setSelectedName(name)
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Failed to fetch versions:', err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleCreate = async (e: React.FormEvent<HTMLFormElement>) => {
|
|
||||||
e.preventDefault()
|
|
||||||
const fd = new FormData(e.currentTarget)
|
|
||||||
try {
|
|
||||||
await api.prompts.create({
|
|
||||||
name: fd.get('name') as string,
|
|
||||||
category: fd.get('category') as string,
|
|
||||||
description: (fd.get('description') as string) || undefined,
|
|
||||||
source: 'custom',
|
|
||||||
system_prompt: fd.get('system_prompt') as string,
|
|
||||||
})
|
|
||||||
setShowCreate(false)
|
|
||||||
mutate()
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Failed to create prompt:', err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleNewVersion = async (e: React.FormEvent<HTMLFormElement>) => {
|
|
||||||
e.preventDefault()
|
|
||||||
if (!selectedName) return
|
|
||||||
const fd = new FormData(e.currentTarget)
|
|
||||||
try {
|
|
||||||
await api.prompts.createVersion(selectedName, {
|
|
||||||
system_prompt: fd.get('system_prompt') as string,
|
|
||||||
changelog: (fd.get('changelog') as string) || undefined,
|
|
||||||
})
|
|
||||||
setShowNewVersion(false)
|
|
||||||
fetchVersions(selectedName)
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Failed to create version:', err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleRollback = async (name: string, version: number) => {
|
|
||||||
if (!confirm(`确认回退到版本 ${version}?`)) return
|
|
||||||
try {
|
|
||||||
await api.prompts.rollback(name, version)
|
|
||||||
fetchVersions(name)
|
|
||||||
mutate()
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Failed to rollback:', err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const handleArchive = async (name: string) => {
|
|
||||||
if (!confirm(`确认归档 ${name}?`)) return
|
|
||||||
try {
|
|
||||||
await api.prompts.archive(name)
|
|
||||||
mutate()
|
|
||||||
} catch (err) {
|
|
||||||
console.error('Failed to archive:', err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const statusBadge = (status: string) => {
|
|
||||||
const colors: Record<string, string> = {
|
|
||||||
active: 'bg-emerald-500/20 text-emerald-400',
|
|
||||||
deprecated: 'bg-amber-500/20 text-amber-400',
|
|
||||||
archived: 'bg-zinc-500/20 text-zinc-400',
|
|
||||||
}
|
|
||||||
return (
|
|
||||||
<span className={`px-2 py-0.5 text-xs rounded-full ${colors[status] || colors.archived}`}>
|
|
||||||
{status}
|
|
||||||
</span>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const sourceBadge = (source: string) => {
|
|
||||||
const colors: Record<string, string> = {
|
|
||||||
builtin: 'bg-blue-500/20 text-blue-400',
|
|
||||||
custom: 'bg-purple-500/20 text-purple-400',
|
|
||||||
}
|
|
||||||
return (
|
|
||||||
<span className={`px-2 py-0.5 text-xs rounded-full ${colors[source] || ''}`}>
|
|
||||||
{source === 'builtin' ? '内置' : '自定义'}
|
|
||||||
</span>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="space-y-6">
|
|
||||||
{/* Header */}
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<div>
|
|
||||||
<h1 className="text-2xl font-bold text-white">提示词管理</h1>
|
|
||||||
<p className="text-sm text-zinc-400 mt-1">管理内置和自定义提示词模板,支持版本控制和 OTA 分发</p>
|
|
||||||
</div>
|
|
||||||
<button
|
|
||||||
onClick={() => setShowCreate(true)}
|
|
||||||
className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 transition-colors text-sm"
|
|
||||||
>
|
|
||||||
+ 新建模板
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Filters */}
|
|
||||||
<div className="flex gap-2">
|
|
||||||
{(['all', 'builtin', 'custom'] as const).map(s => (
|
|
||||||
<button
|
|
||||||
key={s}
|
|
||||||
onClick={() => setFilter(s === 'all' ? {} : { source: s })}
|
|
||||||
className={`px-3 py-1 text-sm rounded-lg transition-colors ${
|
|
||||||
(filter.source || 'all') === s
|
|
||||||
? 'bg-zinc-700 text-white'
|
|
||||||
: 'bg-zinc-800 text-zinc-400 hover:text-white'
|
|
||||||
}`}
|
|
||||||
>
|
|
||||||
{s === 'all' ? '全部' : s === 'builtin' ? '内置' : '自定义'}
|
|
||||||
</button>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Template List */}
|
|
||||||
<div className="bg-zinc-900 rounded-xl border border-zinc-800 overflow-hidden">
|
|
||||||
<table className="w-full text-sm">
|
|
||||||
<thead>
|
|
||||||
<tr className="border-b border-zinc-800">
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">名称</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">分类</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">来源</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">版本</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">状态</th>
|
|
||||||
<th className="text-left px-4 py-3 text-zinc-400 font-medium">更新时间</th>
|
|
||||||
<th className="text-right px-4 py-3 text-zinc-400 font-medium">操作</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>
|
|
||||||
{isLoading ? (
|
|
||||||
<tr>
|
|
||||||
<td colSpan={7}>
|
|
||||||
<TableSkeleton rows={5} cols={7} hasToolbar={false} />
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
) : error ? (
|
|
||||||
<tr><td colSpan={7} className="px-4 py-8 text-center text-red-400">加载失败</td></tr>
|
|
||||||
) : templates.length === 0 ? (
|
|
||||||
<tr><td colSpan={7}><EmptyState message="暂无提示词模板" /></td></tr>
|
|
||||||
) : (
|
|
||||||
templates.map(t => (
|
|
||||||
<tr key={t.id} className="border-b border-zinc-800/50 hover:bg-zinc-800/30">
|
|
||||||
<td className="px-4 py-3">
|
|
||||||
<button
|
|
||||||
onClick={() => fetchVersions(t.name)}
|
|
||||||
className="text-blue-400 hover:text-blue-300 font-mono"
|
|
||||||
>
|
|
||||||
{t.name}
|
|
||||||
</button>
|
|
||||||
</td>
|
|
||||||
<td className="px-4 py-3 text-zinc-400">{t.category}</td>
|
|
||||||
<td className="px-4 py-3">{sourceBadge(t.source)}</td>
|
|
||||||
<td className="px-4 py-3 text-zinc-300">v{t.current_version}</td>
|
|
||||||
<td className="px-4 py-3">{statusBadge(t.status)}</td>
|
|
||||||
<td className="px-4 py-3 text-zinc-500 text-xs">
|
|
||||||
{new Date(t.updated_at).toLocaleString('zh-CN')}
|
|
||||||
</td>
|
|
||||||
<td className="px-4 py-3 text-right">
|
|
||||||
<button
|
|
||||||
onClick={() => fetchVersions(t.name)}
|
|
||||||
className="text-zinc-400 hover:text-white mr-2"
|
|
||||||
>
|
|
||||||
历史
|
|
||||||
</button>
|
|
||||||
{t.source === 'custom' && (
|
|
||||||
<button
|
|
||||||
onClick={() => handleArchive(t.name)}
|
|
||||||
className="text-red-400 hover:text-red-300"
|
|
||||||
>
|
|
||||||
归档
|
|
||||||
</button>
|
|
||||||
)}
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
))
|
|
||||||
)}
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
<div className="px-4 py-2 text-xs text-zinc-500 border-t border-zinc-800">
|
|
||||||
共 {total} 个模板
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* Version History Panel */}
|
|
||||||
{selectedName && (
|
|
||||||
<div className="bg-zinc-900 rounded-xl border border-zinc-800 p-4">
|
|
||||||
<div className="flex items-center justify-between mb-4">
|
|
||||||
<h2 className="text-lg font-semibold text-white">
|
|
||||||
{selectedName} — 版本历史
|
|
||||||
</h2>
|
|
||||||
<div className="flex gap-2">
|
|
||||||
<button
|
|
||||||
onClick={() => setShowNewVersion(true)}
|
|
||||||
className="px-3 py-1.5 bg-blue-600 text-white rounded-lg hover:bg-blue-700 text-xs"
|
|
||||||
>
|
|
||||||
发布新版本
|
|
||||||
</button>
|
|
||||||
<button
|
|
||||||
onClick={() => { setSelectedName(null); setVersions([]) }}
|
|
||||||
className="px-3 py-1.5 bg-zinc-700 text-white rounded-lg hover:bg-zinc-600 text-xs"
|
|
||||||
>
|
|
||||||
关闭
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<div className="space-y-3">
|
|
||||||
{versions.map(v => (
|
|
||||||
<div key={v.id} className="bg-zinc-800/50 rounded-lg p-3">
|
|
||||||
<div className="flex items-center justify-between mb-2">
|
|
||||||
<span className="text-sm font-mono text-zinc-300">v{v.version}</span>
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<span className="text-xs text-zinc-500">
|
|
||||||
{new Date(v.created_at).toLocaleString('zh-CN')}
|
|
||||||
</span>
|
|
||||||
{v.changelog && (
|
|
||||||
<span className="text-xs text-zinc-400">— {v.changelog}</span>
|
|
||||||
)}
|
|
||||||
{v.min_app_version && (
|
|
||||||
<span className="text-xs text-amber-400">最低版本: {v.min_app_version}</span>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<pre className="text-xs text-zinc-400 bg-zinc-900 rounded p-2 overflow-x-auto max-h-32">
|
|
||||||
{v.system_prompt.substring(0, 300)}{v.system_prompt.length > 300 ? '...' : ''}
|
|
||||||
</pre>
|
|
||||||
<div className="mt-2 flex gap-2">
|
|
||||||
<button
|
|
||||||
onClick={() => {
|
|
||||||
navigator.clipboard.writeText(v.system_prompt)
|
|
||||||
}}
|
|
||||||
className="text-xs text-zinc-500 hover:text-white"
|
|
||||||
>
|
|
||||||
复制
|
|
||||||
</button>
|
|
||||||
<button
|
|
||||||
onClick={() => handleRollback(selectedName, v.version)}
|
|
||||||
className="text-xs text-amber-500 hover:text-amber-400"
|
|
||||||
>
|
|
||||||
回退到此版本
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
{versions.length === 0 && (
|
|
||||||
<EmptyState message="暂无版本历史" />
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* Create Modal */}
|
|
||||||
{showCreate && (
|
|
||||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
|
||||||
<form onSubmit={handleCreate} className="bg-zinc-900 rounded-xl border border-zinc-700 p-6 w-full max-w-lg space-y-4">
|
|
||||||
<h2 className="text-lg font-semibold text-white">新建提示词模板</h2>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">名称</label>
|
|
||||||
<input name="name" required className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="my_prompt" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">分类</label>
|
|
||||||
<select name="category" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm">
|
|
||||||
<option value="custom_system">系统提示词</option>
|
|
||||||
<option value="custom_extraction">提取提示词</option>
|
|
||||||
<option value="custom_compaction">压缩提示词</option>
|
|
||||||
<option value="custom_other">其他</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">描述</label>
|
|
||||||
<input name="description" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="可选" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">系统提示词</label>
|
|
||||||
<textarea name="system_prompt" required rows={6} className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm font-mono" />
|
|
||||||
</div>
|
|
||||||
<div className="flex gap-2 justify-end">
|
|
||||||
<button type="button" onClick={() => setShowCreate(false)} className="px-4 py-2 bg-zinc-700 text-white rounded-lg hover:bg-zinc-600 text-sm">取消</button>
|
|
||||||
<button type="submit" className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 text-sm">创建</button>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
|
|
||||||
{/* New Version Modal */}
|
|
||||||
{showNewVersion && selectedName && (
|
|
||||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
|
||||||
<form onSubmit={handleNewVersion} className="bg-zinc-900 rounded-xl border border-zinc-700 p-6 w-full max-w-lg space-y-4">
|
|
||||||
<h2 className="text-lg font-semibold text-white">发布 {selectedName} 新版本</h2>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">系统提示词</label>
|
|
||||||
<textarea name="system_prompt" required rows={6} className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm font-mono" />
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<label className="block text-sm text-zinc-400 mb-1">变更说明</label>
|
|
||||||
<input name="changelog" className="w-full px-3 py-2 bg-zinc-800 border border-zinc-700 rounded-lg text-white text-sm" placeholder="描述本次变更" />
|
|
||||||
</div>
|
|
||||||
<div className="flex gap-2 justify-end">
|
|
||||||
<button type="button" onClick={() => setShowNewVersion(false)} className="px-4 py-2 bg-zinc-700 text-white rounded-lg hover:bg-zinc-600 text-sm">取消</button>
|
|
||||||
<button type="submit" className="px-4 py-2 bg-blue-600 text-white rounded-lg hover:bg-blue-700 text-sm">发布</button>
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useState } from 'react'
|
import { useEffect, useState, useCallback } from 'react'
|
||||||
import useSWR from 'swr'
|
|
||||||
import {
|
import {
|
||||||
Plus,
|
Plus,
|
||||||
Loader2,
|
Loader2,
|
||||||
@@ -9,9 +8,6 @@ import {
|
|||||||
ChevronRight,
|
ChevronRight,
|
||||||
Pencil,
|
Pencil,
|
||||||
Trash2,
|
Trash2,
|
||||||
KeyRound,
|
|
||||||
Power,
|
|
||||||
PowerOff,
|
|
||||||
} from 'lucide-react'
|
} from 'lucide-react'
|
||||||
import { Button } from '@/components/ui/button'
|
import { Button } from '@/components/ui/button'
|
||||||
import { Input } from '@/components/ui/input'
|
import { Input } from '@/components/ui/input'
|
||||||
@@ -41,18 +37,10 @@ import {
|
|||||||
SelectTrigger,
|
SelectTrigger,
|
||||||
SelectValue,
|
SelectValue,
|
||||||
} from '@/components/ui/select'
|
} from '@/components/ui/select'
|
||||||
import { TableSkeleton } from '@/components/ui/skeleton'
|
|
||||||
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatDate, maskApiKey } from '@/lib/utils'
|
import { formatDate } from '@/lib/utils'
|
||||||
|
import type { Provider } from '@/lib/types'
|
||||||
function formatTokens(tokens: number): string {
|
|
||||||
if (tokens >= 1_000_000) return `${(tokens / 1_000_000).toFixed(1)}M`
|
|
||||||
if (tokens >= 1_000) return `${(tokens / 1_000).toFixed(1)}K`
|
|
||||||
return String(tokens)
|
|
||||||
}
|
|
||||||
import type { Provider, ProviderKey } from '@/lib/types'
|
|
||||||
|
|
||||||
const PAGE_SIZE = 20
|
const PAGE_SIZE = 20
|
||||||
|
|
||||||
@@ -61,7 +49,6 @@ interface ProviderForm {
|
|||||||
display_name: string
|
display_name: string
|
||||||
base_url: string
|
base_url: string
|
||||||
api_protocol: 'openai' | 'anthropic'
|
api_protocol: 'openai' | 'anthropic'
|
||||||
api_key: string
|
|
||||||
enabled: boolean
|
enabled: boolean
|
||||||
rate_limit_rpm: string
|
rate_limit_rpm: string
|
||||||
rate_limit_tpm: string
|
rate_limit_tpm: string
|
||||||
@@ -72,24 +59,18 @@ const emptyForm: ProviderForm = {
|
|||||||
display_name: '',
|
display_name: '',
|
||||||
base_url: '',
|
base_url: '',
|
||||||
api_protocol: 'openai',
|
api_protocol: 'openai',
|
||||||
api_key: '',
|
|
||||||
enabled: true,
|
enabled: true,
|
||||||
rate_limit_rpm: '',
|
rate_limit_rpm: '',
|
||||||
rate_limit_tpm: '',
|
rate_limit_tpm: '',
|
||||||
}
|
}
|
||||||
|
|
||||||
export default function ProvidersPage() {
|
export default function ProvidersPage() {
|
||||||
|
const [providers, setProviders] = useState<Provider[]>([])
|
||||||
|
const [total, setTotal] = useState(0)
|
||||||
const [page, setPage] = useState(1)
|
const [page, setPage] = useState(1)
|
||||||
|
const [loading, setLoading] = useState(true)
|
||||||
const [error, setError] = useState('')
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
// SWR for providers list
|
|
||||||
const { data, isLoading, mutate } = useSWR(
|
|
||||||
['providers', page],
|
|
||||||
() => api.providers.list({ page, page_size: PAGE_SIZE })
|
|
||||||
)
|
|
||||||
const providers = data?.items ?? []
|
|
||||||
const total = data?.total ?? 0
|
|
||||||
|
|
||||||
// 创建/编辑 Dialog
|
// 创建/编辑 Dialog
|
||||||
const [dialogOpen, setDialogOpen] = useState(false)
|
const [dialogOpen, setDialogOpen] = useState(false)
|
||||||
const [editTarget, setEditTarget] = useState<Provider | null>(null)
|
const [editTarget, setEditTarget] = useState<Provider | null>(null)
|
||||||
@@ -100,24 +81,24 @@ export default function ProvidersPage() {
|
|||||||
const [deleteTarget, setDeleteTarget] = useState<Provider | null>(null)
|
const [deleteTarget, setDeleteTarget] = useState<Provider | null>(null)
|
||||||
const [deleting, setDeleting] = useState(false)
|
const [deleting, setDeleting] = useState(false)
|
||||||
|
|
||||||
// Key Pool 管理
|
const fetchProviders = useCallback(async () => {
|
||||||
const [keyPoolProvider, setKeyPoolProvider] = useState<Provider | null>(null)
|
setLoading(true)
|
||||||
const [showAddKey, setShowAddKey] = useState(false)
|
setError('')
|
||||||
const [addKeyForm, setAddKeyForm] = useState({
|
try {
|
||||||
key_label: '',
|
const res = await api.providers.list({ page, page_size: PAGE_SIZE })
|
||||||
key_value: '',
|
setProviders(res.items)
|
||||||
priority: 0,
|
setTotal(res.total)
|
||||||
max_rpm: '',
|
} catch (err) {
|
||||||
max_tpm: '',
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
quota_reset_interval: '',
|
else setError('加载失败')
|
||||||
})
|
} finally {
|
||||||
const [addingKey, setAddingKey] = useState(false)
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}, [page])
|
||||||
|
|
||||||
// SWR for key pool — only fetches when dialog is open
|
useEffect(() => {
|
||||||
const { data: providerKeys = [], isLoading: keysLoading, mutate: mutateKeys } = useSWR(
|
fetchProviders()
|
||||||
keyPoolProvider ? ['provider.keys', keyPoolProvider.id] : null,
|
}, [fetchProviders])
|
||||||
() => api.providers.listKeys(keyPoolProvider!.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
||||||
|
|
||||||
@@ -134,7 +115,6 @@ export default function ProvidersPage() {
|
|||||||
display_name: provider.display_name,
|
display_name: provider.display_name,
|
||||||
base_url: provider.base_url,
|
base_url: provider.base_url,
|
||||||
api_protocol: provider.api_protocol,
|
api_protocol: provider.api_protocol,
|
||||||
api_key: provider.api_key || '',
|
|
||||||
enabled: provider.enabled,
|
enabled: provider.enabled,
|
||||||
rate_limit_rpm: provider.rate_limit_rpm?.toString() || '',
|
rate_limit_rpm: provider.rate_limit_rpm?.toString() || '',
|
||||||
rate_limit_tpm: provider.rate_limit_tpm?.toString() || '',
|
rate_limit_tpm: provider.rate_limit_tpm?.toString() || '',
|
||||||
@@ -151,7 +131,6 @@ export default function ProvidersPage() {
|
|||||||
display_name: form.display_name.trim(),
|
display_name: form.display_name.trim(),
|
||||||
base_url: form.base_url.trim(),
|
base_url: form.base_url.trim(),
|
||||||
api_protocol: form.api_protocol,
|
api_protocol: form.api_protocol,
|
||||||
api_key: form.api_key.trim() || undefined,
|
|
||||||
enabled: form.enabled,
|
enabled: form.enabled,
|
||||||
rate_limit_rpm: form.rate_limit_rpm ? parseInt(form.rate_limit_rpm, 10) : undefined,
|
rate_limit_rpm: form.rate_limit_rpm ? parseInt(form.rate_limit_rpm, 10) : undefined,
|
||||||
rate_limit_tpm: form.rate_limit_tpm ? parseInt(form.rate_limit_tpm, 10) : undefined,
|
rate_limit_tpm: form.rate_limit_tpm ? parseInt(form.rate_limit_tpm, 10) : undefined,
|
||||||
@@ -162,7 +141,7 @@ export default function ProvidersPage() {
|
|||||||
await api.providers.create(payload)
|
await api.providers.create(payload)
|
||||||
}
|
}
|
||||||
setDialogOpen(false)
|
setDialogOpen(false)
|
||||||
mutate()
|
fetchProviders()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
@@ -176,7 +155,7 @@ export default function ProvidersPage() {
|
|||||||
try {
|
try {
|
||||||
await api.providers.delete(deleteTarget.id)
|
await api.providers.delete(deleteTarget.id)
|
||||||
setDeleteTarget(null)
|
setDeleteTarget(null)
|
||||||
mutate()
|
fetchProviders()
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
} finally {
|
} finally {
|
||||||
@@ -184,55 +163,6 @@ export default function ProvidersPage() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Key Pool 管理 ─────────────────────────────────────
|
|
||||||
|
|
||||||
function openKeyPool(provider: Provider) {
|
|
||||||
setKeyPoolProvider(provider)
|
|
||||||
setShowAddKey(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
async function handleAddKey() {
|
|
||||||
if (!keyPoolProvider || !addKeyForm.key_label.trim() || !addKeyForm.key_value.trim()) return
|
|
||||||
setAddingKey(true)
|
|
||||||
try {
|
|
||||||
await api.providers.addKey(keyPoolProvider.id, {
|
|
||||||
key_label: addKeyForm.key_label.trim(),
|
|
||||||
key_value: addKeyForm.key_value.trim(),
|
|
||||||
priority: addKeyForm.priority,
|
|
||||||
max_rpm: addKeyForm.max_rpm ? parseInt(addKeyForm.max_rpm, 10) : undefined,
|
|
||||||
max_tpm: addKeyForm.max_tpm ? parseInt(addKeyForm.max_tpm, 10) : undefined,
|
|
||||||
quota_reset_interval: addKeyForm.quota_reset_interval.trim() || undefined,
|
|
||||||
})
|
|
||||||
setAddKeyForm({ key_label: '', key_value: '', priority: 0, max_rpm: '', max_tpm: '', quota_reset_interval: '' })
|
|
||||||
setShowAddKey(false)
|
|
||||||
mutateKeys()
|
|
||||||
} catch (err) {
|
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
|
||||||
} finally {
|
|
||||||
setAddingKey(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function handleToggleKey(keyId: string, active: boolean) {
|
|
||||||
if (!keyPoolProvider) return
|
|
||||||
try {
|
|
||||||
await api.providers.toggleKey(keyPoolProvider.id, keyId, active)
|
|
||||||
mutateKeys()
|
|
||||||
} catch (err) {
|
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function handleDeleteKey(keyId: string) {
|
|
||||||
if (!keyPoolProvider || !confirm('确认删除此 Key?')) return
|
|
||||||
try {
|
|
||||||
await api.providers.deleteKey(keyPoolProvider.id, keyId)
|
|
||||||
mutateKeys()
|
|
||||||
} catch (err) {
|
|
||||||
if (err instanceof ApiRequestError) setError(err.body.message)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-4">
|
<div className="space-y-4">
|
||||||
{/* 工具栏 */}
|
{/* 工具栏 */}
|
||||||
@@ -244,12 +174,21 @@ export default function ProvidersPage() {
|
|||||||
</Button>
|
</Button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{error && <ErrorBanner message={error} onDismiss={() => setError('')} />}
|
{error && (
|
||||||
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
||||||
|
{error}
|
||||||
|
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{isLoading ? (
|
{loading ? (
|
||||||
<TableSkeleton rows={6} cols={9} hasToolbar={false} />
|
<div className="flex h-64 items-center justify-center">
|
||||||
) : error ? null : providers.length === 0 ? (
|
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
||||||
<EmptyState />
|
</div>
|
||||||
|
) : providers.length === 0 ? (
|
||||||
|
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
||||||
|
暂无数据
|
||||||
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Table>
|
<Table>
|
||||||
@@ -259,7 +198,6 @@ export default function ProvidersPage() {
|
|||||||
<TableHead>显示名</TableHead>
|
<TableHead>显示名</TableHead>
|
||||||
<TableHead>Base URL</TableHead>
|
<TableHead>Base URL</TableHead>
|
||||||
<TableHead>协议</TableHead>
|
<TableHead>协议</TableHead>
|
||||||
<TableHead>API Key</TableHead>
|
|
||||||
<TableHead>启用</TableHead>
|
<TableHead>启用</TableHead>
|
||||||
<TableHead>RPM 限制</TableHead>
|
<TableHead>RPM 限制</TableHead>
|
||||||
<TableHead>创建时间</TableHead>
|
<TableHead>创建时间</TableHead>
|
||||||
@@ -279,9 +217,6 @@ export default function ProvidersPage() {
|
|||||||
{p.api_protocol}
|
{p.api_protocol}
|
||||||
</Badge>
|
</Badge>
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell className="font-mono text-xs text-muted-foreground">
|
|
||||||
{maskApiKey(p.api_key)}
|
|
||||||
</TableCell>
|
|
||||||
<TableCell>
|
<TableCell>
|
||||||
<Badge variant={p.enabled ? 'success' : 'secondary'}>
|
<Badge variant={p.enabled ? 'success' : 'secondary'}>
|
||||||
{p.enabled ? '是' : '否'}
|
{p.enabled ? '是' : '否'}
|
||||||
@@ -295,9 +230,6 @@ export default function ProvidersPage() {
|
|||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell className="text-right">
|
<TableCell className="text-right">
|
||||||
<div className="flex items-center justify-end gap-1">
|
<div className="flex items-center justify-end gap-1">
|
||||||
<Button variant="ghost" size="icon" onClick={() => openKeyPool(p)} title="Key Pool">
|
|
||||||
<KeyRound className="h-4 w-4" />
|
|
||||||
</Button>
|
|
||||||
<Button variant="ghost" size="icon" onClick={() => openEditDialog(p)} title="编辑">
|
<Button variant="ghost" size="icon" onClick={() => openEditDialog(p)} title="编辑">
|
||||||
<Pencil className="h-4 w-4" />
|
<Pencil className="h-4 w-4" />
|
||||||
</Button>
|
</Button>
|
||||||
@@ -376,15 +308,6 @@ export default function ProvidersPage() {
|
|||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
<div className="space-y-2">
|
|
||||||
<Label>API Key</Label>
|
|
||||||
<Input
|
|
||||||
type="password"
|
|
||||||
value={form.api_key}
|
|
||||||
onChange={(e) => setForm({ ...form, api_key: e.target.value })}
|
|
||||||
placeholder={editTarget ? '留空则不修改' : 'sk-...'}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="flex items-center gap-3">
|
<div className="flex items-center gap-3">
|
||||||
<Switch
|
<Switch
|
||||||
checked={form.enabled}
|
checked={form.enabled}
|
||||||
@@ -441,165 +364,6 @@ export default function ProvidersPage() {
|
|||||||
</DialogFooter>
|
</DialogFooter>
|
||||||
</DialogContent>
|
</DialogContent>
|
||||||
</Dialog>
|
</Dialog>
|
||||||
|
|
||||||
{/* Key Pool 管理 Dialog */}
|
|
||||||
<Dialog open={!!keyPoolProvider} onOpenChange={() => setKeyPoolProvider(null)}>
|
|
||||||
<DialogContent className="max-w-2xl">
|
|
||||||
<DialogHeader>
|
|
||||||
<DialogTitle>Key Pool 管理 — {keyPoolProvider?.display_name || keyPoolProvider?.name}</DialogTitle>
|
|
||||||
<DialogDescription>
|
|
||||||
管理此服务商的多个 API Key,实现智能轮转绕过限额。优先级数字越小越优先。
|
|
||||||
</DialogDescription>
|
|
||||||
</DialogHeader>
|
|
||||||
|
|
||||||
<div className="max-h-[50vh] overflow-y-auto scrollbar-thin">
|
|
||||||
{keysLoading ? (
|
|
||||||
<TableSkeleton rows={4} cols={8} hasToolbar={false} />
|
|
||||||
) : providerKeys.length === 0 && !showAddKey ? (
|
|
||||||
<div className="text-center py-8 text-muted-foreground text-sm">
|
|
||||||
<p>尚未配置 Key Pool</p>
|
|
||||||
<p className="mt-1 text-xs">将使用服务商主 API Key 作为回退</p>
|
|
||||||
</div>
|
|
||||||
) : (
|
|
||||||
<Table>
|
|
||||||
<TableHeader>
|
|
||||||
<TableRow>
|
|
||||||
<TableHead>标签</TableHead>
|
|
||||||
<TableHead>优先级</TableHead>
|
|
||||||
<TableHead>RPM</TableHead>
|
|
||||||
<TableHead>TPM</TableHead>
|
|
||||||
<TableHead>状态</TableHead>
|
|
||||||
<TableHead>请求/Token</TableHead>
|
|
||||||
<TableHead>最后 429</TableHead>
|
|
||||||
<TableHead className="text-right">操作</TableHead>
|
|
||||||
</TableRow>
|
|
||||||
</TableHeader>
|
|
||||||
<TableBody>
|
|
||||||
{providerKeys.map((k) => {
|
|
||||||
const isCooling = k.cooldown_until && new Date(k.cooldown_until) > new Date()
|
|
||||||
return (
|
|
||||||
<TableRow key={k.id} className={isCooling ? 'opacity-60' : ''}>
|
|
||||||
<TableCell className="font-medium">{k.key_label}</TableCell>
|
|
||||||
<TableCell>{k.priority}</TableCell>
|
|
||||||
<TableCell className="text-muted-foreground">{k.max_rpm ?? '-'}</TableCell>
|
|
||||||
<TableCell className="text-muted-foreground">{k.max_tpm ?? '-'}</TableCell>
|
|
||||||
<TableCell>
|
|
||||||
<Badge variant={k.is_active ? 'success' : 'secondary'}>
|
|
||||||
{isCooling ? '冷却中' : k.is_active ? '活跃' : '禁用'}
|
|
||||||
</Badge>
|
|
||||||
</TableCell>
|
|
||||||
<TableCell className="text-xs text-muted-foreground">
|
|
||||||
{k.total_requests} / {formatTokens(k.total_tokens)}
|
|
||||||
</TableCell>
|
|
||||||
<TableCell className="text-xs text-muted-foreground">
|
|
||||||
{k.last_429_at ? formatDate(k.last_429_at) : '-'}
|
|
||||||
</TableCell>
|
|
||||||
<TableCell className="text-right">
|
|
||||||
<div className="flex items-center justify-end gap-1">
|
|
||||||
<Button
|
|
||||||
variant="ghost"
|
|
||||||
size="icon"
|
|
||||||
onClick={() => handleToggleKey(k.id, !k.is_active)}
|
|
||||||
title={k.is_active ? '禁用' : '启用'}
|
|
||||||
>
|
|
||||||
{k.is_active ? <PowerOff className="h-3.5 w-3.5 text-amber-500" /> : <Power className="h-3.5 w-3.5 text-green-500" />}
|
|
||||||
</Button>
|
|
||||||
<Button
|
|
||||||
variant="ghost"
|
|
||||||
size="icon"
|
|
||||||
onClick={() => handleDeleteKey(k.id)}
|
|
||||||
title="删除"
|
|
||||||
>
|
|
||||||
<Trash2 className="h-3.5 w-3.5 text-destructive" />
|
|
||||||
</Button>
|
|
||||||
</div>
|
|
||||||
</TableCell>
|
|
||||||
</TableRow>
|
|
||||||
)
|
|
||||||
})}
|
|
||||||
</TableBody>
|
|
||||||
</Table>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{!showAddKey ? (
|
|
||||||
<DialogFooter>
|
|
||||||
<Button variant="outline" onClick={() => setKeyPoolProvider(null)}>关闭</Button>
|
|
||||||
<Button onClick={() => setShowAddKey(true)}>
|
|
||||||
<Plus className="h-4 w-4 mr-2" />
|
|
||||||
添加 Key
|
|
||||||
</Button>
|
|
||||||
</DialogFooter>
|
|
||||||
) : (
|
|
||||||
<div className="space-y-3 border-t pt-4">
|
|
||||||
<p className="text-sm font-medium">添加新 Key</p>
|
|
||||||
<div className="grid grid-cols-2 gap-3">
|
|
||||||
<div className="space-y-1">
|
|
||||||
<Label className="text-xs">标签 *</Label>
|
|
||||||
<Input
|
|
||||||
value={addKeyForm.key_label}
|
|
||||||
onChange={(e) => setAddKeyForm({ ...addKeyForm, key_label: e.target.value })}
|
|
||||||
placeholder="如 zhipu-coding-1"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="space-y-1">
|
|
||||||
<Label className="text-xs">优先级</Label>
|
|
||||||
<Input
|
|
||||||
type="number"
|
|
||||||
value={addKeyForm.priority}
|
|
||||||
onChange={(e) => setAddKeyForm({ ...addKeyForm, priority: parseInt(e.target.value, 10) || 0 })}
|
|
||||||
placeholder="0"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="col-span-2 space-y-1">
|
|
||||||
<Label className="text-xs">API Key *</Label>
|
|
||||||
<Input
|
|
||||||
type="password"
|
|
||||||
value={addKeyForm.key_value}
|
|
||||||
onChange={(e) => setAddKeyForm({ ...addKeyForm, key_value: e.target.value })}
|
|
||||||
placeholder="输入 API Key"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="space-y-1">
|
|
||||||
<Label className="text-xs">RPM 限额</Label>
|
|
||||||
<Input
|
|
||||||
type="number"
|
|
||||||
value={addKeyForm.max_rpm}
|
|
||||||
onChange={(e) => setAddKeyForm({ ...addKeyForm, max_rpm: e.target.value })}
|
|
||||||
placeholder="不限"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="space-y-1">
|
|
||||||
<Label className="text-xs">TPM 限额</Label>
|
|
||||||
<Input
|
|
||||||
type="number"
|
|
||||||
value={addKeyForm.max_tpm}
|
|
||||||
onChange={(e) => setAddKeyForm({ ...addKeyForm, max_tpm: e.target.value })}
|
|
||||||
placeholder="不限"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="col-span-2 space-y-1">
|
|
||||||
<Label className="text-xs">限额重置周期</Label>
|
|
||||||
<Input
|
|
||||||
value={addKeyForm.quota_reset_interval}
|
|
||||||
onChange={(e) => setAddKeyForm({ ...addKeyForm, quota_reset_interval: e.target.value })}
|
|
||||||
placeholder="如 5h, 1d(可选)"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<DialogFooter>
|
|
||||||
<Button variant="outline" onClick={() => { setShowAddKey(false); setAddKeyForm({ key_label: '', key_value: '', priority: 0, max_rpm: '', max_tpm: '', quota_reset_interval: '' }) }}>
|
|
||||||
取消
|
|
||||||
</Button>
|
|
||||||
<Button onClick={handleAddKey} disabled={addingKey || !addKeyForm.key_label.trim() || !addKeyForm.key_value.trim()}>
|
|
||||||
{addingKey && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
|
||||||
添加
|
|
||||||
</Button>
|
|
||||||
</DialogFooter>
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</DialogContent>
|
|
||||||
</Dialog>
|
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,14 +1,13 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useState } from 'react'
|
import { useEffect, useState, useCallback } from 'react'
|
||||||
import useSWR from 'swr'
|
|
||||||
import {
|
import {
|
||||||
Search,
|
|
||||||
Loader2,
|
Loader2,
|
||||||
ChevronLeft,
|
ChevronLeft,
|
||||||
ChevronRight,
|
ChevronRight,
|
||||||
ChevronDown,
|
ChevronDown,
|
||||||
ChevronUp,
|
ChevronUp,
|
||||||
|
RotateCcw,
|
||||||
} from 'lucide-react'
|
} from 'lucide-react'
|
||||||
import { Button } from '@/components/ui/button'
|
import { Button } from '@/components/ui/button'
|
||||||
import { Badge } from '@/components/ui/badge'
|
import { Badge } from '@/components/ui/badge'
|
||||||
@@ -29,9 +28,7 @@ import {
|
|||||||
} from '@/components/ui/table'
|
} from '@/components/ui/table'
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import { ApiRequestError } from '@/lib/api-client'
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatDate, formatNumber, getSwrErrorMessage } from '@/lib/utils'
|
import { formatDate, formatNumber } from '@/lib/utils'
|
||||||
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
|
||||||
import { TableSkeleton } from '@/components/ui/skeleton'
|
|
||||||
import type { RelayTask } from '@/lib/types'
|
import type { RelayTask } from '@/lib/types'
|
||||||
|
|
||||||
const PAGE_SIZE = 20
|
const PAGE_SIZE = 20
|
||||||
@@ -51,22 +48,35 @@ const statusLabels: Record<string, string> = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function RelayPage() {
|
export default function RelayPage() {
|
||||||
|
const [tasks, setTasks] = useState<RelayTask[]>([])
|
||||||
|
const [total, setTotal] = useState(0)
|
||||||
const [page, setPage] = useState(1)
|
const [page, setPage] = useState(1)
|
||||||
const [statusFilter, setStatusFilter] = useState<string>('all')
|
const [statusFilter, setStatusFilter] = useState<string>('all')
|
||||||
|
const [loading, setLoading] = useState(true)
|
||||||
|
const [error, setError] = useState('')
|
||||||
const [expandedId, setExpandedId] = useState<string | null>(null)
|
const [expandedId, setExpandedId] = useState<string | null>(null)
|
||||||
|
const [retryingId, setRetryingId] = useState<string | null>(null)
|
||||||
|
|
||||||
const { data, error: swrError, isLoading } = useSWR(
|
const fetchTasks = useCallback(async () => {
|
||||||
['relay', page, statusFilter],
|
setLoading(true)
|
||||||
() => {
|
setError('')
|
||||||
|
try {
|
||||||
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
const params: Record<string, unknown> = { page, page_size: PAGE_SIZE }
|
||||||
if (statusFilter !== 'all') params.status = statusFilter
|
if (statusFilter !== 'all') params.status = statusFilter
|
||||||
return api.relay.list(params)
|
const res = await api.relay.list(params)
|
||||||
},
|
setTasks(res.items)
|
||||||
)
|
setTotal(res.total)
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
|
else setError('加载失败')
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}, [page, statusFilter])
|
||||||
|
|
||||||
const tasks = data?.items ?? []
|
useEffect(() => {
|
||||||
const total = data?.total ?? 0
|
fetchTasks()
|
||||||
const error = getSwrErrorMessage(swrError)
|
}, [fetchTasks])
|
||||||
|
|
||||||
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
const totalPages = Math.max(1, Math.ceil(total / PAGE_SIZE))
|
||||||
|
|
||||||
@@ -74,6 +84,20 @@ export default function RelayPage() {
|
|||||||
setExpandedId((prev) => (prev === id ? null : id))
|
setExpandedId((prev) => (prev === id ? null : id))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async function handleRetry(taskId: string, e: React.MouseEvent) {
|
||||||
|
e.stopPropagation()
|
||||||
|
setRetryingId(taskId)
|
||||||
|
try {
|
||||||
|
await api.relay.retry(taskId)
|
||||||
|
fetchTasks()
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
|
else setError('重试失败')
|
||||||
|
} finally {
|
||||||
|
setRetryingId(null)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-4">
|
<div className="space-y-4">
|
||||||
{/* 筛选 */}
|
{/* 筛选 */}
|
||||||
@@ -92,12 +116,21 @@ export default function RelayPage() {
|
|||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{error && <ErrorBanner message={error} onDismiss={() => {}} />}
|
{error && (
|
||||||
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
||||||
|
{error}
|
||||||
|
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
{isLoading ? (
|
{loading ? (
|
||||||
<TableSkeleton rows={6} cols={10} />
|
<div className="flex h-64 items-center justify-center">
|
||||||
) : error ? null : tasks.length === 0 ? (
|
<Loader2 className="h-6 w-6 animate-spin text-muted-foreground" />
|
||||||
<EmptyState />
|
</div>
|
||||||
|
) : tasks.length === 0 ? (
|
||||||
|
<div className="flex h-64 items-center justify-center text-muted-foreground text-sm">
|
||||||
|
暂无数据
|
||||||
|
</div>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<>
|
||||||
<Table>
|
<Table>
|
||||||
@@ -113,6 +146,7 @@ export default function RelayPage() {
|
|||||||
<TableHead>Output Tokens</TableHead>
|
<TableHead>Output Tokens</TableHead>
|
||||||
<TableHead>错误信息</TableHead>
|
<TableHead>错误信息</TableHead>
|
||||||
<TableHead>创建时间</TableHead>
|
<TableHead>创建时间</TableHead>
|
||||||
|
<TableHead className="text-right">操作</TableHead>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
</TableHeader>
|
</TableHeader>
|
||||||
<TableBody>
|
<TableBody>
|
||||||
@@ -151,10 +185,27 @@ export default function RelayPage() {
|
|||||||
<TableCell className="font-mono text-xs text-muted-foreground">
|
<TableCell className="font-mono text-xs text-muted-foreground">
|
||||||
{formatDate(task.created_at)}
|
{formatDate(task.created_at)}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
|
<TableCell className="text-right">
|
||||||
|
{task.status === 'failed' && (
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="icon"
|
||||||
|
onClick={(e) => handleRetry(task.id, e)}
|
||||||
|
disabled={retryingId === task.id}
|
||||||
|
title="重试"
|
||||||
|
>
|
||||||
|
{retryingId === task.id ? (
|
||||||
|
<Loader2 className="h-4 w-4 animate-spin" />
|
||||||
|
) : (
|
||||||
|
<RotateCcw className="h-4 w-4" />
|
||||||
|
)}
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</TableCell>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
{expandedId === task.id && (
|
{expandedId === task.id && (
|
||||||
<TableRow key={`${task.id}-detail`}>
|
<TableRow key={`${task.id}-detail`}>
|
||||||
<TableCell colSpan={10} className="bg-muted/20 px-8 py-4">
|
<TableCell colSpan={11} className="bg-muted/20 px-8 py-4">
|
||||||
<div className="grid grid-cols-2 gap-4 text-sm">
|
<div className="grid grid-cols-2 gap-4 text-sm">
|
||||||
<div>
|
<div>
|
||||||
<p className="text-muted-foreground">任务 ID</p>
|
<p className="text-muted-foreground">任务 ID</p>
|
||||||
|
|||||||
203
admin/src/app/(dashboard)/security/page.tsx
Normal file
203
admin/src/app/(dashboard)/security/page.tsx
Normal file
@@ -0,0 +1,203 @@
|
|||||||
|
'use client'
|
||||||
|
|
||||||
|
import { useState } from 'react'
|
||||||
|
import { ShieldCheck, Loader2, Eye, EyeOff, QrCode, Key, AlertTriangle } from 'lucide-react'
|
||||||
|
import { Button } from '@/components/ui/button'
|
||||||
|
import { Input } from '@/components/ui/input'
|
||||||
|
import { Label } from '@/components/ui/label'
|
||||||
|
import { Card, CardContent, CardHeader, CardTitle, CardDescription } from '@/components/ui/card'
|
||||||
|
import { Badge } from '@/components/ui/badge'
|
||||||
|
import { api } from '@/lib/api-client'
|
||||||
|
import { useAuth } from '@/components/auth-guard'
|
||||||
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
|
|
||||||
|
export default function SecurityPage() {
|
||||||
|
const { account } = useAuth()
|
||||||
|
const totpEnabled = account?.totp_enabled ?? false
|
||||||
|
|
||||||
|
// Setup state
|
||||||
|
const [step, setStep] = useState<'idle' | 'verify' | 'done'>('idle')
|
||||||
|
const [otpauthUri, setOtpauthUri] = useState('')
|
||||||
|
const [secret, setSecret] = useState('')
|
||||||
|
const [verifyCode, setVerifyCode] = useState('')
|
||||||
|
const [loading, setLoading] = useState(false)
|
||||||
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
|
// Disable state
|
||||||
|
const [disablePassword, setDisablePassword] = useState('')
|
||||||
|
const [showDisablePassword, setShowDisablePassword] = useState(false)
|
||||||
|
const [disabling, setDisabling] = useState(false)
|
||||||
|
|
||||||
|
async function handleSetup() {
|
||||||
|
setLoading(true)
|
||||||
|
setError('')
|
||||||
|
try {
|
||||||
|
const res = await api.auth.totpSetup()
|
||||||
|
setOtpauthUri(res.otpauth_uri)
|
||||||
|
setSecret(res.secret)
|
||||||
|
setStep('verify')
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message || '获取密钥失败')
|
||||||
|
else setError('网络错误')
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleVerify() {
|
||||||
|
if (verifyCode.length !== 6) {
|
||||||
|
setError('请输入 6 位验证码')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setLoading(true)
|
||||||
|
setError('')
|
||||||
|
try {
|
||||||
|
await api.auth.totpVerify({ code: verifyCode })
|
||||||
|
setStep('done')
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message || '验证失败')
|
||||||
|
else setError('网络错误')
|
||||||
|
} finally {
|
||||||
|
setLoading(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async function handleDisable() {
|
||||||
|
if (!disablePassword) {
|
||||||
|
setError('请输入密码以确认禁用')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setDisabling(true)
|
||||||
|
setError('')
|
||||||
|
try {
|
||||||
|
await api.auth.totpDisable({ password: disablePassword })
|
||||||
|
setDisablePassword('')
|
||||||
|
window.location.reload()
|
||||||
|
} catch (err) {
|
||||||
|
if (err instanceof ApiRequestError) setError(err.body.message || '禁用失败')
|
||||||
|
else setError('网络错误')
|
||||||
|
} finally {
|
||||||
|
setDisabling(false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="max-w-lg space-y-6">
|
||||||
|
{/* TOTP 状态 */}
|
||||||
|
<Card>
|
||||||
|
<CardHeader>
|
||||||
|
<CardTitle className="flex items-center gap-2">
|
||||||
|
<ShieldCheck className="h-5 w-5" />
|
||||||
|
双因素认证 (TOTP)
|
||||||
|
</CardTitle>
|
||||||
|
<CardDescription>
|
||||||
|
使用 Google Authenticator 等应用生成一次性验证码,增强账号安全。
|
||||||
|
</CardDescription>
|
||||||
|
</CardHeader>
|
||||||
|
<CardContent>
|
||||||
|
<div className="flex items-center gap-3 mb-4">
|
||||||
|
<span className="text-sm text-muted-foreground">当前状态:</span>
|
||||||
|
<Badge variant={totpEnabled ? 'success' : 'secondary'}>
|
||||||
|
{totpEnabled ? '已启用' : '未启用'}
|
||||||
|
</Badge>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{error && (
|
||||||
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive mb-4">
|
||||||
|
{error}
|
||||||
|
<button onClick={() => setError('')} className="ml-2 underline cursor-pointer">关闭</button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 未启用: 设置流程 */}
|
||||||
|
{!totpEnabled && step === 'idle' && (
|
||||||
|
<Button onClick={handleSetup} disabled={loading}>
|
||||||
|
{loading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||||
|
<Key className="mr-2 h-4 w-4" />
|
||||||
|
启用双因素认证
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!totpEnabled && step === 'verify' && (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="rounded-md border border-border p-4 space-y-3">
|
||||||
|
<div className="flex items-center gap-2 text-sm font-medium">
|
||||||
|
<QrCode className="h-4 w-4" />
|
||||||
|
步骤 1: 扫描二维码或手动输入密钥
|
||||||
|
</div>
|
||||||
|
<div className="bg-muted rounded-md p-3 font-mono text-xs break-all">
|
||||||
|
{otpauthUri}
|
||||||
|
</div>
|
||||||
|
<div className="space-y-1">
|
||||||
|
<p className="text-xs text-muted-foreground">手动输入密钥:</p>
|
||||||
|
<p className="font-mono text-sm font-medium select-all">{secret}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label>
|
||||||
|
步骤 2: 输入 6 位验证码
|
||||||
|
</Label>
|
||||||
|
<Input
|
||||||
|
value={verifyCode}
|
||||||
|
onChange={(e) => setVerifyCode(e.target.value.replace(/\D/g, '').slice(0, 6))}
|
||||||
|
placeholder="请输入应用中显示的 6 位数字"
|
||||||
|
maxLength={6}
|
||||||
|
className="font-mono tracking-widest text-center"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="flex gap-2">
|
||||||
|
<Button variant="outline" onClick={() => { setStep('idle'); setVerifyCode('') }}>
|
||||||
|
取消
|
||||||
|
</Button>
|
||||||
|
<Button onClick={handleVerify} disabled={loading || verifyCode.length !== 6}>
|
||||||
|
{loading && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||||
|
验证并启用
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{!totpEnabled && step === 'done' && (
|
||||||
|
<div className="rounded-md bg-emerald-500/10 border border-emerald-500/20 p-4 text-sm text-emerald-500">
|
||||||
|
双因素认证已成功启用。下次登录时需要输入验证码。
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{/* 已启用: 禁用流程 */}
|
||||||
|
{totpEnabled && (
|
||||||
|
<div className="space-y-4">
|
||||||
|
<div className="rounded-md bg-amber-500/10 border border-amber-500/20 p-3 flex items-start gap-2 text-sm text-amber-600">
|
||||||
|
<AlertTriangle className="h-4 w-4 mt-0.5 shrink-0" />
|
||||||
|
<span>禁用双因素认证会降低账号安全性,建议仅在必要时操作。</span>
|
||||||
|
</div>
|
||||||
|
<div className="space-y-2">
|
||||||
|
<Label>输入当前密码以确认禁用</Label>
|
||||||
|
<div className="relative">
|
||||||
|
<Input
|
||||||
|
type={showDisablePassword ? 'text' : 'password'}
|
||||||
|
value={disablePassword}
|
||||||
|
onChange={(e) => setDisablePassword(e.target.value)}
|
||||||
|
placeholder="请输入当前密码"
|
||||||
|
/>
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => setShowDisablePassword(!showDisablePassword)}
|
||||||
|
className="absolute right-3 top-1/2 -translate-y-1/2 text-muted-foreground hover:text-foreground cursor-pointer"
|
||||||
|
>
|
||||||
|
{showDisablePassword ? <EyeOff className="h-4 w-4" /> : <Eye className="h-4 w-4" />}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<Button variant="destructive" onClick={handleDisable} disabled={disabling || !disablePassword}>
|
||||||
|
{disabling && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||||
|
禁用双因素认证
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useState } from 'react'
|
import { useEffect, useState, useCallback } from 'react'
|
||||||
import useSWR from 'swr'
|
import { Loader2, Zap } from 'lucide-react'
|
||||||
import { Zap, Monitor, Smartphone } from 'lucide-react'
|
|
||||||
import {
|
import {
|
||||||
LineChart,
|
LineChart,
|
||||||
Line,
|
Line,
|
||||||
@@ -16,8 +15,6 @@ import {
|
|||||||
Legend,
|
Legend,
|
||||||
} from 'recharts'
|
} from 'recharts'
|
||||||
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
|
import { Card, CardContent, CardHeader, CardTitle } from '@/components/ui/card'
|
||||||
import { ErrorBanner, EmptyState } from '@/components/ui/state'
|
|
||||||
import { TableSkeleton, ChartSkeleton } from '@/components/ui/skeleton'
|
|
||||||
import {
|
import {
|
||||||
Select,
|
Select,
|
||||||
SelectContent,
|
SelectContent,
|
||||||
@@ -25,87 +22,83 @@ import {
|
|||||||
SelectTrigger,
|
SelectTrigger,
|
||||||
SelectValue,
|
SelectValue,
|
||||||
} from '@/components/ui/select'
|
} from '@/components/ui/select'
|
||||||
import {
|
|
||||||
Table,
|
|
||||||
TableBody,
|
|
||||||
TableCell,
|
|
||||||
TableHead,
|
|
||||||
TableHeader,
|
|
||||||
TableRow,
|
|
||||||
} from '@/components/ui/table'
|
|
||||||
import { Badge } from '@/components/ui/badge'
|
|
||||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'
|
|
||||||
import { api } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
|
import { ApiRequestError } from '@/lib/api-client'
|
||||||
import { formatNumber } from '@/lib/utils'
|
import { formatNumber } from '@/lib/utils'
|
||||||
import type { UsageRecord, UsageByModel, ModelUsageStat, DailyUsageStat } from '@/lib/types'
|
import type { UsageStats } from '@/lib/types'
|
||||||
|
|
||||||
export default function UsagePage() {
|
export default function UsagePage() {
|
||||||
const [days, setDays] = useState(7)
|
const [days, setDays] = useState(7)
|
||||||
const [activeTab, setActiveTab] = useState('relay')
|
const [usageStats, setUsageStats] = useState<UsageStats | null>(null)
|
||||||
|
const [loading, setLoading] = useState(true)
|
||||||
const [error, setError] = useState('')
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
// 4 parallel SWR calls — each loads independently
|
const fetchData = useCallback(async () => {
|
||||||
const { data: dailyData = [], isLoading: dailyLoading } = useSWR(
|
setLoading(true)
|
||||||
['usage.daily', days],
|
setError('')
|
||||||
() => api.usage.daily({ days })
|
try {
|
||||||
)
|
const from = new Date()
|
||||||
const { data: modelData = [], isLoading: modelLoading } = useSWR(
|
from.setDate(from.getDate() - days)
|
||||||
['usage.byModel', days],
|
const fromStr = from.toISOString().slice(0, 10)
|
||||||
() => api.usage.byModel({ days })
|
const res = await api.usage.get({ from: fromStr })
|
||||||
)
|
setUsageStats(res)
|
||||||
const { data: telemetryModels = [] } = useSWR(
|
} catch (err) {
|
||||||
['telemetry.modelStats'],
|
if (err instanceof ApiRequestError) setError(err.body.message)
|
||||||
() => api.telemetry.modelStats()
|
else setError('加载数据失败')
|
||||||
)
|
} finally {
|
||||||
const { data: telemetryDaily = [] } = useSWR(
|
setLoading(false)
|
||||||
['telemetry.dailyStats', days],
|
}
|
||||||
() => api.telemetry.dailyStats({ days })
|
}, [days])
|
||||||
)
|
|
||||||
|
|
||||||
const relayLoading = dailyLoading || modelLoading
|
useEffect(() => {
|
||||||
const telemetryLoading = !telemetryModels.length && !telemetryDaily.length && (dailyLoading || modelLoading)
|
fetchData()
|
||||||
|
}, [fetchData])
|
||||||
|
|
||||||
// === Relay 用量图表数据 ===
|
const byDay = usageStats?.by_day ?? []
|
||||||
|
|
||||||
const relayLineData = dailyData.map((r) => ({
|
const lineChartData = byDay.map((r) => ({
|
||||||
day: r.day.slice(5),
|
day: r.date.slice(5),
|
||||||
Input: r.input_tokens,
|
Input: r.input_tokens,
|
||||||
Output: r.output_tokens,
|
Output: r.output_tokens,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
const relayBarData = modelData.map((r) => ({
|
const barChartData = (usageStats?.by_model ?? []).map((r) => ({
|
||||||
model: r.model_id,
|
model: r.model_id,
|
||||||
请求量: r.count,
|
请求量: r.request_count,
|
||||||
Input: r.input_tokens,
|
Input: r.input_tokens,
|
||||||
Output: r.output_tokens,
|
Output: r.output_tokens,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
const relayTotalInput = dailyData.reduce((s, r) => s + r.input_tokens, 0)
|
const totalInput = byDay.reduce((s, r) => s + r.input_tokens, 0)
|
||||||
const relayTotalOutput = dailyData.reduce((s, r) => s + r.output_tokens, 0)
|
const totalOutput = byDay.reduce((s, r) => s + r.output_tokens, 0)
|
||||||
const relayTotalRequests = dailyData.reduce((s, r) => s + r.count, 0)
|
const totalRequests = byDay.reduce((s, r) => s + r.request_count, 0)
|
||||||
|
|
||||||
// === 遥测图表数据 ===
|
if (loading) {
|
||||||
|
return (
|
||||||
|
<div className="flex h-[60vh] items-center justify-center">
|
||||||
|
<div className="flex flex-col items-center gap-3">
|
||||||
|
<Loader2 className="h-8 w-8 animate-spin text-primary" />
|
||||||
|
<p className="text-sm text-muted-foreground">加载中...</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const telemetryLineData = telemetryDaily.map((r) => ({
|
if (error) {
|
||||||
day: r.day.slice(5),
|
return (
|
||||||
Input: r.input_tokens,
|
<div className="flex h-[60vh] items-center justify-center">
|
||||||
Output: r.output_tokens,
|
<div className="text-center">
|
||||||
设备数: r.unique_devices,
|
<p className="text-destructive">{error}</p>
|
||||||
}))
|
<button onClick={() => fetchData()} className="mt-4 text-sm text-primary hover:underline cursor-pointer">
|
||||||
|
重新加载
|
||||||
const telemetryTotalInput = telemetryDaily.reduce((s, r) => s + r.input_tokens, 0)
|
</button>
|
||||||
const telemetryTotalOutput = telemetryDaily.reduce((s, r) => s + r.output_tokens, 0)
|
</div>
|
||||||
const telemetryTotalRequests = telemetryDaily.reduce((s, r) => s + r.request_count, 0)
|
</div>
|
||||||
|
)
|
||||||
// === 合计 ===
|
}
|
||||||
|
|
||||||
const totalInput = relayTotalInput + telemetryTotalInput
|
|
||||||
const totalOutput = relayTotalOutput + telemetryTotalOutput
|
|
||||||
const totalRequests = relayTotalRequests + telemetryTotalRequests
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="space-y-6">
|
<div className="space-y-6">
|
||||||
{error && <ErrorBanner message={error} onDismiss={() => setError('')} />}
|
|
||||||
{/* 时间范围 */}
|
{/* 时间范围 */}
|
||||||
<div className="flex items-center gap-3">
|
<div className="flex items-center gap-3">
|
||||||
<span className="text-sm text-muted-foreground">时间范围:</span>
|
<span className="text-sm text-muted-foreground">时间范围:</span>
|
||||||
@@ -121,8 +114,8 @@ export default function UsagePage() {
|
|||||||
</Select>
|
</Select>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* 汇总统计 — render immediately, use 0 while loading */}
|
{/* 汇总统计 */}
|
||||||
<div className="grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-5">
|
<div className="grid grid-cols-1 gap-4 sm:grid-cols-3">
|
||||||
<Card>
|
<Card>
|
||||||
<CardContent className="p-6">
|
<CardContent className="p-6">
|
||||||
<p className="text-sm text-muted-foreground">总请求数</p>
|
<p className="text-sm text-muted-foreground">总请求数</p>
|
||||||
@@ -133,7 +126,7 @@ export default function UsagePage() {
|
|||||||
</Card>
|
</Card>
|
||||||
<Card>
|
<Card>
|
||||||
<CardContent className="p-6">
|
<CardContent className="p-6">
|
||||||
<p className="text-sm text-muted-foreground">总 Input Tokens</p>
|
<p className="text-sm text-muted-foreground">Input Tokens</p>
|
||||||
<p className="mt-1 text-2xl font-bold text-blue-400">
|
<p className="mt-1 text-2xl font-bold text-blue-400">
|
||||||
{formatNumber(totalInput)}
|
{formatNumber(totalInput)}
|
||||||
</p>
|
</p>
|
||||||
@@ -141,190 +134,101 @@ export default function UsagePage() {
|
|||||||
</Card>
|
</Card>
|
||||||
<Card>
|
<Card>
|
||||||
<CardContent className="p-6">
|
<CardContent className="p-6">
|
||||||
<p className="text-sm text-muted-foreground">总 Output Tokens</p>
|
<p className="text-sm text-muted-foreground">Output Tokens</p>
|
||||||
<p className="mt-1 text-2xl font-bold text-orange-400">
|
<p className="mt-1 text-2xl font-bold text-orange-400">
|
||||||
{formatNumber(totalOutput)}
|
{formatNumber(totalOutput)}
|
||||||
</p>
|
</p>
|
||||||
</CardContent>
|
</CardContent>
|
||||||
</Card>
|
</Card>
|
||||||
<Card>
|
|
||||||
<CardContent className="p-6">
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<Monitor className="h-4 w-4 text-green-400" />
|
|
||||||
<p className="text-sm text-muted-foreground">中转请求</p>
|
|
||||||
</div>
|
|
||||||
<p className="mt-1 text-2xl font-bold text-green-400">
|
|
||||||
{formatNumber(relayTotalRequests)}
|
|
||||||
</p>
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
<Card>
|
|
||||||
<CardContent className="p-6">
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<Smartphone className="h-4 w-4 text-purple-400" />
|
|
||||||
<p className="text-sm text-muted-foreground">桌面端调用</p>
|
|
||||||
</div>
|
|
||||||
<p className="mt-1 text-2xl font-bold text-purple-400">
|
|
||||||
{formatNumber(telemetryTotalRequests)}
|
|
||||||
</p>
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* Tab 切换 */}
|
{/* Token 用量趋势 */}
|
||||||
<Tabs value={activeTab} onValueChange={setActiveTab}>
|
<Card>
|
||||||
<TabsList>
|
<CardHeader>
|
||||||
<TabsTrigger value="relay">
|
<CardTitle className="flex items-center gap-2 text-base">
|
||||||
<Monitor className="h-4 w-4 mr-1" />
|
<Zap className="h-4 w-4 text-primary" />
|
||||||
中转用量
|
Token 用量趋势
|
||||||
</TabsTrigger>
|
</CardTitle>
|
||||||
<TabsTrigger value="telemetry">
|
</CardHeader>
|
||||||
<Smartphone className="h-4 w-4 mr-1" />
|
<CardContent>
|
||||||
桌面端遥测
|
{lineChartData.length > 0 ? (
|
||||||
</TabsTrigger>
|
<ResponsiveContainer width="100%" height={320}>
|
||||||
</TabsList>
|
<LineChart data={lineChartData}>
|
||||||
|
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
||||||
{/* Relay 用量 Tab */}
|
<XAxis
|
||||||
<TabsContent value="relay" className="space-y-6">
|
dataKey="day"
|
||||||
{relayLoading ? (
|
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
||||||
<>
|
axisLine={{ stroke: '#1E293B' }}
|
||||||
<ChartSkeleton height={320} />
|
/>
|
||||||
<ChartSkeleton height={280} />
|
<YAxis
|
||||||
</>
|
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
||||||
|
axisLine={{ stroke: '#1E293B' }}
|
||||||
|
/>
|
||||||
|
<Tooltip
|
||||||
|
contentStyle={{
|
||||||
|
backgroundColor: '#0F172A',
|
||||||
|
border: '1px solid #1E293B',
|
||||||
|
borderRadius: '8px',
|
||||||
|
color: '#F8FAFC',
|
||||||
|
fontSize: '12px',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
|
||||||
|
<Line type="monotone" dataKey="Input" stroke="#3B82F6" strokeWidth={2} dot={false} />
|
||||||
|
<Line type="monotone" dataKey="Output" stroke="#F97316" strokeWidth={2} dot={false} />
|
||||||
|
</LineChart>
|
||||||
|
</ResponsiveContainer>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<div className="flex h-[320px] items-center justify-center text-muted-foreground text-sm">
|
||||||
<Card>
|
暂无数据
|
||||||
<CardHeader>
|
</div>
|
||||||
<CardTitle className="flex items-center gap-2 text-base">
|
|
||||||
<Zap className="h-4 w-4 text-primary" />
|
|
||||||
中转 Token 用量趋势
|
|
||||||
</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent>
|
|
||||||
{relayLineData.length > 0 ? (
|
|
||||||
<ResponsiveContainer width="100%" height={320}>
|
|
||||||
<LineChart data={relayLineData}>
|
|
||||||
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
|
||||||
<XAxis dataKey="day" tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} />
|
|
||||||
<YAxis tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} />
|
|
||||||
<Tooltip contentStyle={{ backgroundColor: '#0F172A', border: '1px solid #1E293B', borderRadius: '8px', color: '#F8FAFC', fontSize: '12px' }} />
|
|
||||||
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
|
|
||||||
<Line type="monotone" dataKey="Input" stroke="#3B82F6" strokeWidth={2} dot={false} />
|
|
||||||
<Line type="monotone" dataKey="Output" stroke="#F97316" strokeWidth={2} dot={false} />
|
|
||||||
</LineChart>
|
|
||||||
</ResponsiveContainer>
|
|
||||||
) : (
|
|
||||||
<EmptyState message="暂无中转数据" />
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
|
|
||||||
<Card>
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="text-base">中转按模型分布</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent>
|
|
||||||
{relayBarData.length > 0 ? (
|
|
||||||
<ResponsiveContainer width="100%" height={Math.max(200, relayBarData.length * 40)}>
|
|
||||||
<BarChart data={relayBarData} layout="vertical">
|
|
||||||
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
|
||||||
<XAxis type="number" tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} />
|
|
||||||
<YAxis type="category" dataKey="model" tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} width={120} />
|
|
||||||
<Tooltip contentStyle={{ backgroundColor: '#0F172A', border: '1px solid #1E293B', borderRadius: '8px', color: '#F8FAFC', fontSize: '12px' }} />
|
|
||||||
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
|
|
||||||
<Bar dataKey="Input" fill="#3B82F6" radius={[0, 2, 2, 0]} />
|
|
||||||
<Bar dataKey="Output" fill="#F97316" radius={[0, 2, 2, 0]} />
|
|
||||||
</BarChart>
|
|
||||||
</ResponsiveContainer>
|
|
||||||
) : (
|
|
||||||
<EmptyState />
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
</>
|
|
||||||
)}
|
)}
|
||||||
</TabsContent>
|
</CardContent>
|
||||||
|
</Card>
|
||||||
|
|
||||||
{/* 遥测 Tab */}
|
{/* 按模型分布 */}
|
||||||
<TabsContent value="telemetry" className="space-y-6">
|
<Card>
|
||||||
{telemetryLoading ? (
|
<CardHeader>
|
||||||
<>
|
<CardTitle className="text-base">按模型分布</CardTitle>
|
||||||
<ChartSkeleton height={320} />
|
</CardHeader>
|
||||||
<TableSkeleton rows={5} cols={6} hasToolbar={false} />
|
<CardContent>
|
||||||
</>
|
{barChartData.length > 0 ? (
|
||||||
|
<ResponsiveContainer width="100%" height={320}>
|
||||||
|
<BarChart data={barChartData} layout="vertical">
|
||||||
|
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
||||||
|
<XAxis
|
||||||
|
type="number"
|
||||||
|
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
||||||
|
axisLine={{ stroke: '#1E293B' }}
|
||||||
|
/>
|
||||||
|
<YAxis
|
||||||
|
type="category"
|
||||||
|
dataKey="model"
|
||||||
|
tick={{ fontSize: 12, fill: '#94A3B8' }}
|
||||||
|
axisLine={{ stroke: '#1E293B' }}
|
||||||
|
width={120}
|
||||||
|
/>
|
||||||
|
<Tooltip
|
||||||
|
contentStyle={{
|
||||||
|
backgroundColor: '#0F172A',
|
||||||
|
border: '1px solid #1E293B',
|
||||||
|
borderRadius: '8px',
|
||||||
|
color: '#F8FAFC',
|
||||||
|
fontSize: '12px',
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
|
||||||
|
<Bar dataKey="Input" fill="#3B82F6" radius={[0, 2, 2, 0]} />
|
||||||
|
<Bar dataKey="Output" fill="#F97316" radius={[0, 2, 2, 0]} />
|
||||||
|
</BarChart>
|
||||||
|
</ResponsiveContainer>
|
||||||
) : (
|
) : (
|
||||||
<>
|
<div className="flex h-[320px] items-center justify-center text-muted-foreground text-sm">
|
||||||
<Card>
|
暂无数据
|
||||||
<CardHeader>
|
</div>
|
||||||
<CardTitle className="flex items-center gap-2 text-base">
|
|
||||||
<Smartphone className="h-4 w-4 text-purple-400" />
|
|
||||||
桌面端 Token 用量趋势
|
|
||||||
</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent>
|
|
||||||
{telemetryLineData.length > 0 ? (
|
|
||||||
<ResponsiveContainer width="100%" height={320}>
|
|
||||||
<LineChart data={telemetryLineData}>
|
|
||||||
<CartesianGrid strokeDasharray="3 3" stroke="#1E293B" />
|
|
||||||
<XAxis dataKey="day" tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} />
|
|
||||||
<YAxis tick={{ fontSize: 12, fill: '#94A3B8' }} axisLine={{ stroke: '#1E293B' }} />
|
|
||||||
<Tooltip contentStyle={{ backgroundColor: '#0F172A', border: '1px solid #1E293B', borderRadius: '8px', color: '#F8FAFC', fontSize: '12px' }} />
|
|
||||||
<Legend wrapperStyle={{ fontSize: '12px', color: '#94A3B8' }} />
|
|
||||||
<Line type="monotone" dataKey="Input" stroke="#3B82F6" strokeWidth={2} dot={false} />
|
|
||||||
<Line type="monotone" dataKey="Output" stroke="#F97316" strokeWidth={2} dot={false} />
|
|
||||||
</LineChart>
|
|
||||||
</ResponsiveContainer>
|
|
||||||
) : (
|
|
||||||
<EmptyState message="暂无桌面端遥测数据(需要桌面端上报)" />
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
|
|
||||||
<Card>
|
|
||||||
<CardHeader>
|
|
||||||
<CardTitle className="text-base">桌面端按模型统计</CardTitle>
|
|
||||||
</CardHeader>
|
|
||||||
<CardContent>
|
|
||||||
{telemetryModels.length > 0 ? (
|
|
||||||
<Table>
|
|
||||||
<TableHeader>
|
|
||||||
<TableRow>
|
|
||||||
<TableHead>模型</TableHead>
|
|
||||||
<TableHead className="text-right">请求数</TableHead>
|
|
||||||
<TableHead className="text-right">Input Tokens</TableHead>
|
|
||||||
<TableHead className="text-right">Output Tokens</TableHead>
|
|
||||||
<TableHead className="text-right">平均延迟</TableHead>
|
|
||||||
<TableHead className="text-right">成功率</TableHead>
|
|
||||||
</TableRow>
|
|
||||||
</TableHeader>
|
|
||||||
<TableBody>
|
|
||||||
{telemetryModels.map((stat) => (
|
|
||||||
<TableRow key={stat.model_id}>
|
|
||||||
<TableCell className="font-mono text-sm">{stat.model_id}</TableCell>
|
|
||||||
<TableCell className="text-right">{formatNumber(stat.request_count)}</TableCell>
|
|
||||||
<TableCell className="text-right text-blue-400">{formatNumber(stat.input_tokens)}</TableCell>
|
|
||||||
<TableCell className="text-right text-orange-400">{formatNumber(stat.output_tokens)}</TableCell>
|
|
||||||
<TableCell className="text-right">
|
|
||||||
{stat.avg_latency_ms !== null ? `${Math.round(stat.avg_latency_ms)}ms` : '-'}
|
|
||||||
</TableCell>
|
|
||||||
<TableCell className="text-right">
|
|
||||||
<Badge variant={stat.success_rate >= 0.95 ? 'default' : 'destructive'}>
|
|
||||||
{(stat.success_rate * 100).toFixed(1)}%
|
|
||||||
</Badge>
|
|
||||||
</TableCell>
|
|
||||||
</TableRow>
|
|
||||||
))}
|
|
||||||
</TableBody>
|
|
||||||
</Table>
|
|
||||||
) : (
|
|
||||||
<EmptyState />
|
|
||||||
)}
|
|
||||||
</CardContent>
|
|
||||||
</Card>
|
|
||||||
</>
|
|
||||||
)}
|
)}
|
||||||
</TabsContent>
|
</CardContent>
|
||||||
</Tabs>
|
</Card>
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +0,0 @@
|
|||||||
<svg xmlns="http://www.w3.org/2000/svg" width="32" height="32" viewBox="0 0 32 32">
|
|
||||||
<rect width="32" height="32" rx="6" fill="#0f172a"/>
|
|
||||||
<text x="16" y="22" font-family="system-ui, sans-serif" font-size="16" font-weight="700" fill="#60a5fa" text-anchor="middle">Z</text>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 282 B |
@@ -1,5 +1,5 @@
|
|||||||
import type { Metadata } from 'next'
|
import type { Metadata } from 'next'
|
||||||
import { SWRProvider } from '@/lib/swr-provider'
|
import { Toaster } from 'sonner'
|
||||||
import './globals.css'
|
import './globals.css'
|
||||||
|
|
||||||
export const metadata: Metadata = {
|
export const metadata: Metadata = {
|
||||||
@@ -21,9 +21,8 @@ export default function RootLayout({
|
|||||||
/>
|
/>
|
||||||
</head>
|
</head>
|
||||||
<body className="min-h-screen bg-background font-sans antialiased">
|
<body className="min-h-screen bg-background font-sans antialiased">
|
||||||
<SWRProvider>
|
{children}
|
||||||
{children}
|
<Toaster richColors position="top-right" />
|
||||||
</SWRProvider>
|
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,10 +11,9 @@ export default function LoginPage() {
|
|||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
const [username, setUsername] = useState('')
|
const [username, setUsername] = useState('')
|
||||||
const [password, setPassword] = useState('')
|
const [password, setPassword] = useState('')
|
||||||
const [totpCode, setTotpCode] = useState('')
|
|
||||||
const [showPassword, setShowPassword] = useState(false)
|
const [showPassword, setShowPassword] = useState(false)
|
||||||
const [needTotp, setNeedTotp] = useState(false)
|
const [totpCode, setTotpCode] = useState('')
|
||||||
const [remember, setRemember] = useState(false)
|
const [showTotp, setShowTotp] = useState(false)
|
||||||
const [loading, setLoading] = useState(false)
|
const [loading, setLoading] = useState(false)
|
||||||
const [error, setError] = useState('')
|
const [error, setError] = useState('')
|
||||||
|
|
||||||
@@ -36,19 +35,18 @@ export default function LoginPage() {
|
|||||||
const res = await api.auth.login({
|
const res = await api.auth.login({
|
||||||
username: username.trim(),
|
username: username.trim(),
|
||||||
password,
|
password,
|
||||||
totp_code: totpCode.trim() || undefined,
|
totp_code: showTotp ? totpCode.trim() || undefined : undefined,
|
||||||
})
|
})
|
||||||
login(res.token, res.account)
|
login(res.token, res.account)
|
||||||
router.replace('/')
|
router.replace('/')
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
if (err instanceof ApiRequestError) {
|
if (err instanceof ApiRequestError) {
|
||||||
const msg = err.body.message || ''
|
// 检测 TOTP 错误码,自动显示验证码输入框
|
||||||
// 后端返回 "需要 TOTP" 时显示 TOTP 输入框
|
if (err.body.error === 'totp_required' || err.body.message?.includes('双因素认证') || err.body.message?.includes('TOTP')) {
|
||||||
if (msg.includes('TOTP') || msg.includes('totp') || msg.includes('2FA') || msg.includes('验证码') || err.status === 403) {
|
setShowTotp(true)
|
||||||
setNeedTotp(true)
|
setError(err.body.message || '此账号已启用双因素认证,请输入验证码')
|
||||||
setError(msg || '请输入两步验证码')
|
|
||||||
} else {
|
} else {
|
||||||
setError(msg || '登录失败,请检查用户名和密码')
|
setError(err.body.message || '登录失败,请检查用户名和密码')
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
setError('网络错误,请稍后重试')
|
setError('网络错误,请稍后重试')
|
||||||
@@ -165,52 +163,31 @@ export default function LoginPage() {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
{/* TOTP 验证码 */}
|
{/* TOTP 验证码 (仅账号启用 2FA 时显示) */}
|
||||||
{needTotp && (
|
{showTotp && (
|
||||||
<div className="space-y-2">
|
<div className="space-y-2">
|
||||||
<label
|
<label
|
||||||
htmlFor="totp"
|
htmlFor="totp_code"
|
||||||
className="text-sm font-medium text-foreground"
|
className="text-sm font-medium text-foreground"
|
||||||
>
|
>
|
||||||
两步验证码
|
<span className="inline-flex items-center gap-1">
|
||||||
|
<ShieldCheck className="h-3.5 w-3.5" />
|
||||||
|
双因素验证码
|
||||||
|
</span>
|
||||||
</label>
|
</label>
|
||||||
<div className="relative">
|
<input
|
||||||
<ShieldCheck className="absolute left-3 top-1/2 -translate-y-1/2 h-4 w-4 text-muted-foreground" />
|
id="totp_code"
|
||||||
<input
|
type="text"
|
||||||
id="totp"
|
placeholder="请输入 6 位验证码"
|
||||||
type="text"
|
value={totpCode}
|
||||||
placeholder="请输入 6 位验证码"
|
onChange={(e) => setTotpCode(e.target.value.replace(/\D/g, '').slice(0, 6))}
|
||||||
value={totpCode}
|
className="flex h-10 w-full rounded-md border border-input bg-transparent px-3 py-2 text-sm tracking-widest text-center font-mono shadow-sm transition-colors duration-200 placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
|
||||||
onChange={(e) => setTotpCode(e.target.value.replace(/\D/g, '').slice(0, 6))}
|
maxLength={6}
|
||||||
maxLength={6}
|
autoFocus
|
||||||
className="flex h-10 w-full rounded-md border border-input bg-transparent pl-10 pr-3 py-2 text-sm shadow-sm transition-colors duration-200 placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring tracking-widest"
|
/>
|
||||||
autoComplete="one-time-code"
|
|
||||||
inputMode="numeric"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<p className="text-xs text-muted-foreground">
|
|
||||||
请使用身份验证器 App(如 Google Authenticator)扫描二维码后生成的验证码
|
|
||||||
</p>
|
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
{/* 记住我 */}
|
|
||||||
<div className="flex items-center gap-2">
|
|
||||||
<input
|
|
||||||
id="remember"
|
|
||||||
type="checkbox"
|
|
||||||
checked={remember}
|
|
||||||
onChange={(e) => setRemember(e.target.checked)}
|
|
||||||
className="h-4 w-4 rounded border-input bg-transparent accent-primary cursor-pointer"
|
|
||||||
/>
|
|
||||||
<label
|
|
||||||
htmlFor="remember"
|
|
||||||
className="text-sm text-muted-foreground cursor-pointer select-none"
|
|
||||||
>
|
|
||||||
记住我
|
|
||||||
</label>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
{/* 错误信息 */}
|
{/* 错误信息 */}
|
||||||
{error && (
|
{error && (
|
||||||
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive">
|
||||||
|
|||||||
@@ -1,11 +1,26 @@
|
|||||||
'use client'
|
'use client'
|
||||||
|
|
||||||
import { useEffect, useState, useRef, useCallback, type ReactNode } from 'react'
|
import { createContext, useContext, useEffect, useState, useCallback, type ReactNode } from 'react'
|
||||||
import { useRouter } from 'next/navigation'
|
import { useRouter } from 'next/navigation'
|
||||||
import { isAuthenticated, getAccount, clearAuth } from '@/lib/auth'
|
import { isAuthenticated, getAccount, logout as clearCredentials, scheduleTokenRefresh, cancelTokenRefresh, setOnSessionExpired } from '@/lib/auth'
|
||||||
import { api, ApiRequestError } from '@/lib/api-client'
|
import { api } from '@/lib/api-client'
|
||||||
import type { AccountPublic } from '@/lib/types'
|
import type { AccountPublic } from '@/lib/types'
|
||||||
import { AlertTriangle, RefreshCw } from 'lucide-react'
|
|
||||||
|
interface AuthContextValue {
|
||||||
|
account: AccountPublic | null
|
||||||
|
loading: boolean
|
||||||
|
refresh: () => Promise<void>
|
||||||
|
}
|
||||||
|
|
||||||
|
const AuthContext = createContext<AuthContextValue>({
|
||||||
|
account: null,
|
||||||
|
loading: true,
|
||||||
|
refresh: async () => {},
|
||||||
|
})
|
||||||
|
|
||||||
|
export function useAuth() {
|
||||||
|
return useContext(AuthContext)
|
||||||
|
}
|
||||||
|
|
||||||
interface AuthGuardProps {
|
interface AuthGuardProps {
|
||||||
children: ReactNode
|
children: ReactNode
|
||||||
@@ -13,72 +28,44 @@ interface AuthGuardProps {
|
|||||||
|
|
||||||
export function AuthGuard({ children }: AuthGuardProps) {
|
export function AuthGuard({ children }: AuthGuardProps) {
|
||||||
const router = useRouter()
|
const router = useRouter()
|
||||||
const [authorized, setAuthorized] = useState(false)
|
|
||||||
const [account, setAccount] = useState<AccountPublic | null>(null)
|
const [account, setAccount] = useState<AccountPublic | null>(null)
|
||||||
const [verifying, setVerifying] = useState(true)
|
const [loading, setLoading] = useState(true)
|
||||||
const [connectionError, setConnectionError] = useState(false)
|
|
||||||
|
|
||||||
// Ref 跟踪授权状态,避免 useCallback 闭包捕获过时的 state
|
|
||||||
const authorizedRef = useRef(false)
|
|
||||||
// 防止并发验证(RSC 导航可能触发多次 effect)
|
|
||||||
const verifyingRef = useRef(false)
|
|
||||||
|
|
||||||
const verifyAuth = useCallback(async () => {
|
|
||||||
// 防止并发验证
|
|
||||||
if (verifyingRef.current) return
|
|
||||||
verifyingRef.current = true
|
|
||||||
setVerifying(true)
|
|
||||||
setConnectionError(false)
|
|
||||||
|
|
||||||
if (!isAuthenticated()) {
|
|
||||||
setVerifying(false)
|
|
||||||
verifyingRef.current = false
|
|
||||||
router.replace('/login')
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
|
const refresh = useCallback(async () => {
|
||||||
try {
|
try {
|
||||||
const serverAccount = await api.auth.me()
|
const me = await api.auth.me()
|
||||||
setAccount(serverAccount)
|
setAccount(me)
|
||||||
setAuthorized(true)
|
} catch {
|
||||||
authorizedRef.current = true
|
clearCredentials()
|
||||||
} catch (err) {
|
router.replace('/login')
|
||||||
// AbortError: 导航/SWR 取消了请求,忽略
|
|
||||||
// 如果已有授权(ref 跟踪),保持不变;否则尝试 localStorage 缓存
|
|
||||||
if (err instanceof DOMException && err.name === 'AbortError') {
|
|
||||||
if (!authorizedRef.current) {
|
|
||||||
const cachedAccount = getAccount()
|
|
||||||
if (cachedAccount) {
|
|
||||||
setAccount(cachedAccount)
|
|
||||||
setAuthorized(true)
|
|
||||||
authorizedRef.current = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// 401/403: 真正的认证失败,清除 token
|
|
||||||
if (err instanceof ApiRequestError && (err.status === 401 || err.status === 403)) {
|
|
||||||
clearAuth()
|
|
||||||
authorizedRef.current = false
|
|
||||||
router.replace('/login')
|
|
||||||
} else {
|
|
||||||
// 网络错误/超时 — 仅在未授权时显示连接错误
|
|
||||||
// 已授权的情况下忽略瞬态错误,保持当前状态
|
|
||||||
if (!authorizedRef.current) {
|
|
||||||
setConnectionError(true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
setVerifying(false)
|
|
||||||
verifyingRef.current = false
|
|
||||||
}
|
}
|
||||||
}, [router])
|
}, [router])
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
verifyAuth()
|
if (!isAuthenticated()) {
|
||||||
}, [verifyAuth])
|
router.replace('/login')
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 验证 token 有效性并获取最新账号信息
|
||||||
|
refresh().finally(() => setLoading(false))
|
||||||
|
}, [router, refresh])
|
||||||
|
|
||||||
if (verifying) {
|
// Set up proactive token refresh with session-expired handler
|
||||||
|
useEffect(() => {
|
||||||
|
const handleSessionExpired = () => {
|
||||||
|
clearCredentials()
|
||||||
|
router.replace('/login')
|
||||||
|
}
|
||||||
|
setOnSessionExpired(handleSessionExpired)
|
||||||
|
scheduleTokenRefresh()
|
||||||
|
|
||||||
|
return () => {
|
||||||
|
cancelTokenRefresh()
|
||||||
|
setOnSessionExpired(null)
|
||||||
|
}
|
||||||
|
}, [router])
|
||||||
|
|
||||||
|
if (loading) {
|
||||||
return (
|
return (
|
||||||
<div className="flex h-screen w-screen items-center justify-center bg-background">
|
<div className="flex h-screen w-screen items-center justify-center bg-background">
|
||||||
<div className="h-8 w-8 animate-spin rounded-full border-2 border-primary border-t-transparent" />
|
<div className="h-8 w-8 animate-spin rounded-full border-2 border-primary border-t-transparent" />
|
||||||
@@ -86,39 +73,13 @@ export function AuthGuard({ children }: AuthGuardProps) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (connectionError) {
|
if (!account) {
|
||||||
return (
|
|
||||||
<div className="flex h-screen w-screen flex-col items-center justify-center gap-4 bg-background">
|
|
||||||
<AlertTriangle className="h-12 w-12 text-yellow-500" />
|
|
||||||
<h2 className="text-lg font-semibold text-foreground">连接中断</h2>
|
|
||||||
<p className="text-sm text-muted-foreground">无法连接到服务器,请检查网络后重试</p>
|
|
||||||
<button
|
|
||||||
onClick={verifyAuth}
|
|
||||||
className="mt-2 inline-flex items-center gap-2 rounded-md bg-primary px-4 py-2 text-sm font-medium text-primary-foreground hover:bg-primary/90 transition-colors cursor-pointer"
|
|
||||||
>
|
|
||||||
<RefreshCw className="h-4 w-4" />
|
|
||||||
重新连接
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!authorized) {
|
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
return <>{children}</>
|
return (
|
||||||
}
|
<AuthContext.Provider value={{ account, loading, refresh }}>
|
||||||
|
{children}
|
||||||
export function useAuth() {
|
</AuthContext.Provider>
|
||||||
const [account, setAccount] = useState<AccountPublic | null>(null)
|
)
|
||||||
const [loading, setLoading] = useState(true)
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const acc = getAccount()
|
|
||||||
setAccount(acc)
|
|
||||||
setLoading(false)
|
|
||||||
}, [])
|
|
||||||
|
|
||||||
return { account, loading, isAuthenticated: isAuthenticated() }
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,115 +0,0 @@
|
|||||||
// ============================================================
|
|
||||||
// Skeleton 组件 — 替代全屏 spinner 的骨架屏
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
import { cn } from '@/lib/utils'
|
|
||||||
|
|
||||||
function SkeletonBase({ className }: { className?: string }) {
|
|
||||||
return (
|
|
||||||
<div
|
|
||||||
className={cn(
|
|
||||||
'animate-pulse rounded-md bg-muted',
|
|
||||||
className,
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 表格骨架屏 */
|
|
||||||
export function TableSkeleton({
|
|
||||||
rows = 5,
|
|
||||||
cols = 5,
|
|
||||||
hasToolbar = true,
|
|
||||||
}: {
|
|
||||||
rows?: number
|
|
||||||
cols?: number
|
|
||||||
hasToolbar?: boolean
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<div className="space-y-4">
|
|
||||||
{hasToolbar && (
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<SkeletonBase className="h-9 w-[200px]" />
|
|
||||||
<SkeletonBase className="h-9 w-[120px]" />
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
<div className="rounded-md border border-border overflow-hidden">
|
|
||||||
{/* Header */}
|
|
||||||
<div className="border-b border-border bg-muted/30 px-4 py-3">
|
|
||||||
<div className="flex gap-4">
|
|
||||||
{Array.from({ length: cols }).map((_, i) => (
|
|
||||||
<SkeletonBase
|
|
||||||
key={i}
|
|
||||||
className={cn(
|
|
||||||
'h-4',
|
|
||||||
i === 0 ? 'w-[120px]' : i === cols - 1 ? 'w-[80px]' : 'w-[100px]',
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{/* Rows */}
|
|
||||||
{Array.from({ length: rows }).map((_, rowIdx) => (
|
|
||||||
<div
|
|
||||||
key={rowIdx}
|
|
||||||
className={cn(
|
|
||||||
'px-4 py-3',
|
|
||||||
rowIdx < rows - 1 && 'border-b border-border',
|
|
||||||
)}
|
|
||||||
>
|
|
||||||
<div className="flex gap-4">
|
|
||||||
{Array.from({ length: cols }).map((_, colIdx) => (
|
|
||||||
<SkeletonBase
|
|
||||||
key={colIdx}
|
|
||||||
className={cn(
|
|
||||||
'h-4',
|
|
||||||
colIdx === 0 ? 'w-[120px]' : colIdx === cols - 1 ? 'w-[80px]' : 'w-[100px]',
|
|
||||||
)}
|
|
||||||
/>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
{/* Pagination */}
|
|
||||||
<div className="flex items-center justify-between">
|
|
||||||
<SkeletonBase className="h-4 w-[140px]" />
|
|
||||||
<div className="flex gap-2">
|
|
||||||
<SkeletonBase className="h-8 w-[80px]" />
|
|
||||||
<SkeletonBase className="h-8 w-[80px]" />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 统计卡片骨架屏 */
|
|
||||||
export function StatsSkeleton({ count = 4 }: { count?: number }) {
|
|
||||||
return (
|
|
||||||
<div className={`grid grid-cols-1 gap-4 sm:grid-cols-2 lg:grid-cols-${count}`}>
|
|
||||||
{Array.from({ length: count }).map((_, i) => (
|
|
||||||
<div key={i} className="rounded-lg border border-border p-6">
|
|
||||||
<SkeletonBase className="h-4 w-[80px]" />
|
|
||||||
<SkeletonBase className="mt-2 h-8 w-[100px]" />
|
|
||||||
<SkeletonBase className="mt-1 h-3 w-[120px]" />
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 图表骨架屏 */
|
|
||||||
export function ChartSkeleton({ height }: { height?: number }) {
|
|
||||||
return (
|
|
||||||
<div className="rounded-lg border border-border">
|
|
||||||
<div className="border-b border-border px-6 py-4">
|
|
||||||
<SkeletonBase className="h-5 w-[140px]" />
|
|
||||||
</div>
|
|
||||||
<div className="p-6">
|
|
||||||
<SkeletonBase className="w-full" />
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
export { SkeletonBase as Skeleton }
|
|
||||||
@@ -1,63 +0,0 @@
|
|||||||
'use client'
|
|
||||||
|
|
||||||
import { AlertCircle, Inbox } from 'lucide-react'
|
|
||||||
|
|
||||||
/** 统一的错误提示横幅 */
|
|
||||||
export function ErrorBanner({
|
|
||||||
message,
|
|
||||||
onDismiss,
|
|
||||||
}: {
|
|
||||||
message: string
|
|
||||||
onDismiss?: () => void
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<div className="rounded-md bg-destructive/10 border border-destructive/20 px-4 py-3 text-sm text-destructive flex items-center gap-2">
|
|
||||||
<AlertCircle className="h-4 w-4 shrink-0" />
|
|
||||||
<span className="flex-1">{message}</span>
|
|
||||||
{onDismiss && (
|
|
||||||
<button
|
|
||||||
onClick={onDismiss}
|
|
||||||
className="underline cursor-pointer shrink-0"
|
|
||||||
>
|
|
||||||
关闭
|
|
||||||
</button>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 统一的空状态占位 */
|
|
||||||
export function EmptyState({
|
|
||||||
message = '暂无数据',
|
|
||||||
}: {
|
|
||||||
message?: string
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<div className="flex h-64 flex-col items-center justify-center gap-2 text-muted-foreground">
|
|
||||||
<Inbox className="h-8 w-8" />
|
|
||||||
<span className="text-sm">{message}</span>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 统一的加载失败提示 + 重试 */
|
|
||||||
export function ErrorRetry({
|
|
||||||
message = '请求失败,请重试',
|
|
||||||
onRetry,
|
|
||||||
}: {
|
|
||||||
message?: string
|
|
||||||
onRetry: () => void
|
|
||||||
}) {
|
|
||||||
return (
|
|
||||||
<div className="flex h-64 flex-col items-center justify-center gap-3 text-muted-foreground">
|
|
||||||
<AlertCircle className="h-8 w-8 text-destructive" />
|
|
||||||
<span className="text-sm">{message}</span>
|
|
||||||
<button
|
|
||||||
onClick={onRetry}
|
|
||||||
className="rounded-md bg-primary px-4 py-2 text-sm text-primary-foreground hover:bg-primary/90 transition-colors cursor-pointer"
|
|
||||||
>
|
|
||||||
重新加载
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
// ============================================================
|
|
||||||
// useDebounce — 防抖 hook
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
import { useState, useEffect } from 'react'
|
|
||||||
|
|
||||||
export function useDebounce<T>(value: T, delay = 300): T {
|
|
||||||
const [debouncedValue, setDebouncedValue] = useState<T>(value)
|
|
||||||
|
|
||||||
useEffect(() => {
|
|
||||||
const handler = setTimeout(() => setDebouncedValue(value), delay)
|
|
||||||
return () => clearTimeout(handler)
|
|
||||||
}, [value, delay])
|
|
||||||
|
|
||||||
return debouncedValue
|
|
||||||
}
|
|
||||||
@@ -2,29 +2,25 @@
|
|||||||
// ZCLAW SaaS Admin — 类型化 HTTP 客户端
|
// ZCLAW SaaS Admin — 类型化 HTTP 客户端
|
||||||
// ============================================================
|
// ============================================================
|
||||||
|
|
||||||
import { getToken, login as saveToken, logout, getAccount } from './auth'
|
import { getToken, logout, refreshToken } from './auth'
|
||||||
|
import { toast } from 'sonner'
|
||||||
import type {
|
import type {
|
||||||
AccountPublic,
|
AccountPublic,
|
||||||
AgentTemplate,
|
|
||||||
ApiError,
|
ApiError,
|
||||||
ConfigItem,
|
ConfigItem,
|
||||||
CreateTokenRequest,
|
CreateTokenRequest,
|
||||||
DashboardStats,
|
DashboardStats,
|
||||||
DailyUsageStat,
|
DeviceInfo,
|
||||||
LoginRequest,
|
LoginRequest,
|
||||||
LoginResponse,
|
LoginResponse,
|
||||||
Model,
|
Model,
|
||||||
ModelUsageStat,
|
|
||||||
OperationLog,
|
OperationLog,
|
||||||
PaginatedResponse,
|
PaginatedResponse,
|
||||||
PromptTemplate,
|
|
||||||
PromptVersion,
|
|
||||||
Provider,
|
Provider,
|
||||||
ProviderKey,
|
|
||||||
RelayTask,
|
RelayTask,
|
||||||
TokenInfo,
|
TokenInfo,
|
||||||
UsageByModel,
|
UsageByModel,
|
||||||
UsageRecord,
|
UsageStats,
|
||||||
} from './types'
|
} from './types'
|
||||||
|
|
||||||
// ── 错误类 ────────────────────────────────────────────────
|
// ── 错误类 ────────────────────────────────────────────────
|
||||||
@@ -41,132 +37,76 @@ export class ApiRequestError extends Error {
|
|||||||
|
|
||||||
// ── 基础请求 ──────────────────────────────────────────────
|
// ── 基础请求 ──────────────────────────────────────────────
|
||||||
|
|
||||||
const BASE_URL = process.env.NEXT_PUBLIC_SAAS_API_URL || '/api/v1'
|
const BASE_URL = process.env.NEXT_PUBLIC_SAAS_API_URL || 'http://localhost:8080'
|
||||||
|
const API_PREFIX = '/api/v1'
|
||||||
const DEFAULT_TIMEOUT_MS = 10_000
|
|
||||||
const MAX_RETRIES = 2
|
|
||||||
|
|
||||||
function sleep(ms: number): Promise<void> {
|
|
||||||
return new Promise(resolve => setTimeout(resolve, ms))
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 判断是否为可重试的网络错误(不含 AbortError) */
|
|
||||||
function isRetryableNetworkError(err: unknown): boolean {
|
|
||||||
// AbortError 不重试:可能是组件卸载或路由切换导致的外部取消
|
|
||||||
if (err instanceof DOMException && err.name === 'AbortError') return false
|
|
||||||
if (err instanceof TypeError) {
|
|
||||||
const msg = (err as TypeError).message
|
|
||||||
return msg.includes('Failed to fetch') || msg.includes('NetworkError') || msg.includes('ECONNREFUSED')
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 尝试刷新 Token,成功返回新 token,失败返回 null */
|
|
||||||
async function tryRefreshToken(): Promise<string | null> {
|
|
||||||
try {
|
|
||||||
const token = getToken()
|
|
||||||
if (!token) return null
|
|
||||||
|
|
||||||
const res = await fetch(`${BASE_URL}/auth/refresh`, {
|
|
||||||
method: 'POST',
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
Authorization: `Bearer ${token}`,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!res.ok) return null
|
|
||||||
|
|
||||||
const data = await res.json()
|
|
||||||
const newToken = data.token as string
|
|
||||||
const account = getAccount()
|
|
||||||
if (account && newToken) {
|
|
||||||
saveToken(newToken, account)
|
|
||||||
}
|
|
||||||
return newToken
|
|
||||||
} catch {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async function request<T>(
|
async function request<T>(
|
||||||
method: string,
|
method: string,
|
||||||
path: string,
|
path: string,
|
||||||
body?: unknown,
|
body?: unknown,
|
||||||
_isRetry = false,
|
|
||||||
externalSignal?: AbortSignal,
|
|
||||||
): Promise<T> {
|
): Promise<T> {
|
||||||
let lastError: unknown
|
const token = getToken()
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
}
|
||||||
|
if (token) {
|
||||||
|
headers['Authorization'] = `Bearer ${token}`
|
||||||
|
}
|
||||||
|
|
||||||
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
|
const res = await fetch(`${BASE_URL}${API_PREFIX}${path}`, {
|
||||||
// Merge external signal (e.g. from SWR) with a timeout signal
|
method,
|
||||||
const signals: AbortSignal[] = [AbortSignal.timeout(DEFAULT_TIMEOUT_MS)]
|
headers,
|
||||||
if (externalSignal) signals.push(externalSignal)
|
body: body ? JSON.stringify(body) : undefined,
|
||||||
const signal = signals.length === 1 ? signals[0] : AbortSignal.any(signals)
|
})
|
||||||
|
|
||||||
|
if (res.status === 401) {
|
||||||
|
// 尝试刷新 token 后重试
|
||||||
try {
|
try {
|
||||||
const token = getToken()
|
const newToken = await refreshToken()
|
||||||
const headers: Record<string, string> = {
|
headers['Authorization'] = `Bearer ${newToken}`
|
||||||
'Content-Type': 'application/json',
|
const retryRes = await fetch(`${BASE_URL}${API_PREFIX}${path}`, {
|
||||||
}
|
|
||||||
if (token) {
|
|
||||||
headers['Authorization'] = `Bearer ${token}`
|
|
||||||
}
|
|
||||||
|
|
||||||
const res = await fetch(`${BASE_URL}${path}`, {
|
|
||||||
method,
|
method,
|
||||||
headers,
|
headers,
|
||||||
body: body ? JSON.stringify(body) : undefined,
|
body: body ? JSON.stringify(body) : undefined,
|
||||||
signal,
|
|
||||||
})
|
})
|
||||||
|
if (retryRes.ok || retryRes.status === 204) {
|
||||||
// 401: 尝试刷新 Token 后重试
|
return retryRes.status === 204 ? (undefined as T) : retryRes.json()
|
||||||
if (res.status === 401 && !_isRetry) {
|
|
||||||
const newToken = await tryRefreshToken()
|
|
||||||
if (newToken) {
|
|
||||||
return request<T>(method, path, body, true)
|
|
||||||
}
|
|
||||||
logout()
|
|
||||||
if (typeof window !== 'undefined') {
|
|
||||||
window.location.href = '/login'
|
|
||||||
}
|
|
||||||
throw new ApiRequestError(401, { error: 'unauthorized', message: '登录已过期,请重新登录' })
|
|
||||||
}
|
}
|
||||||
|
// 刷新成功但重试仍失败,走正常错误处理
|
||||||
if (!res.ok) {
|
if (!retryRes.ok) {
|
||||||
let errorBody: ApiError
|
let errorBody: ApiError
|
||||||
try {
|
try { errorBody = await retryRes.json() } catch { errorBody = { error: 'unknown', message: `请求失败 (${retryRes.status})` } }
|
||||||
errorBody = await res.json()
|
throw new ApiRequestError(retryRes.status, errorBody)
|
||||||
} catch {
|
|
||||||
errorBody = { error: 'unknown', message: `请求失败 (${res.status})` }
|
|
||||||
}
|
|
||||||
throw new ApiRequestError(res.status, errorBody)
|
|
||||||
}
|
}
|
||||||
|
} catch {
|
||||||
// 204 No Content
|
// 刷新失败,执行登出
|
||||||
if (res.status === 204) {
|
|
||||||
return undefined as T
|
|
||||||
}
|
|
||||||
|
|
||||||
return res.json() as Promise<T>
|
|
||||||
} catch (err) {
|
|
||||||
// API 错误和外部取消的 AbortError 直接抛出,不重试
|
|
||||||
if (err instanceof ApiRequestError) throw err
|
|
||||||
if (err instanceof DOMException && err.name === 'AbortError') throw err
|
|
||||||
|
|
||||||
lastError = err
|
|
||||||
|
|
||||||
// 仅对可重试的网络错误重试
|
|
||||||
if (attempt < MAX_RETRIES && isRetryableNetworkError(err)) {
|
|
||||||
await sleep(1000 * Math.pow(2, attempt))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
throw err
|
|
||||||
}
|
}
|
||||||
|
logout()
|
||||||
|
if (typeof window !== 'undefined') {
|
||||||
|
window.location.href = '/login'
|
||||||
|
}
|
||||||
|
throw new ApiRequestError(401, { error: 'unauthorized', message: '登录已过期,请重新登录' })
|
||||||
}
|
}
|
||||||
|
|
||||||
throw lastError
|
if (!res.ok) {
|
||||||
|
let errorBody: ApiError
|
||||||
|
try {
|
||||||
|
errorBody = await res.json()
|
||||||
|
} catch {
|
||||||
|
errorBody = { error: 'unknown', message: `请求失败 (${res.status})` }
|
||||||
|
}
|
||||||
|
if (typeof window !== 'undefined') {
|
||||||
|
toast.error(errorBody.message || `请求失败 (${res.status})`)
|
||||||
|
}
|
||||||
|
throw new ApiRequestError(res.status, errorBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 204 No Content
|
||||||
|
if (res.status === 204) {
|
||||||
|
return undefined as T
|
||||||
|
}
|
||||||
|
|
||||||
|
return res.json() as Promise<T>
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── API 客户端 ────────────────────────────────────────────
|
// ── API 客户端 ────────────────────────────────────────────
|
||||||
@@ -190,6 +130,22 @@ export const api = {
|
|||||||
async me(): Promise<AccountPublic> {
|
async me(): Promise<AccountPublic> {
|
||||||
return request<AccountPublic>('GET', '/auth/me')
|
return request<AccountPublic>('GET', '/auth/me')
|
||||||
},
|
},
|
||||||
|
|
||||||
|
async changePassword(data: { old_password: string; new_password: string }): Promise<void> {
|
||||||
|
return request<void>('PUT', '/auth/password', data)
|
||||||
|
},
|
||||||
|
|
||||||
|
async totpSetup(): Promise<{ otpauth_uri: string; secret: string; issuer: string }> {
|
||||||
|
return request<{ otpauth_uri: string; secret: string; issuer: string }>('POST', '/auth/totp/setup')
|
||||||
|
},
|
||||||
|
|
||||||
|
async totpVerify(data: { code: string }): Promise<void> {
|
||||||
|
return request<void>('POST', '/auth/totp/verify', data)
|
||||||
|
},
|
||||||
|
|
||||||
|
async totpDisable(data: { password: string }): Promise<void> {
|
||||||
|
return request<void>('POST', '/auth/totp/disable', data)
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// ── 账号管理 ──────────────────────────────────────────
|
// ── 账号管理 ──────────────────────────────────────────
|
||||||
@@ -213,7 +169,7 @@ export const api = {
|
|||||||
id: string,
|
id: string,
|
||||||
data: Partial<Pick<AccountPublic, 'display_name' | 'email' | 'role'>>,
|
data: Partial<Pick<AccountPublic, 'display_name' | 'email' | 'role'>>,
|
||||||
): Promise<AccountPublic> {
|
): Promise<AccountPublic> {
|
||||||
return request<AccountPublic>('PATCH', `/accounts/${id}`, data)
|
return request<AccountPublic>('PUT', `/accounts/${id}`, data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async updateStatus(
|
async updateStatus(
|
||||||
@@ -234,6 +190,10 @@ export const api = {
|
|||||||
return request<PaginatedResponse<Provider>>('GET', `/providers${qs}`)
|
return request<PaginatedResponse<Provider>>('GET', `/providers${qs}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
|
async get(id: string): Promise<Provider> {
|
||||||
|
return request<Provider>('GET', `/providers/${id}`)
|
||||||
|
},
|
||||||
|
|
||||||
async create(data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>): Promise<Provider> {
|
async create(data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>): Promise<Provider> {
|
||||||
return request<Provider>('POST', '/providers', data)
|
return request<Provider>('POST', '/providers', data)
|
||||||
},
|
},
|
||||||
@@ -242,36 +202,12 @@ export const api = {
|
|||||||
id: string,
|
id: string,
|
||||||
data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>,
|
data: Partial<Omit<Provider, 'id' | 'created_at' | 'updated_at'>>,
|
||||||
): Promise<Provider> {
|
): Promise<Provider> {
|
||||||
return request<Provider>('PATCH', `/providers/${id}`, data)
|
return request<Provider>('PUT', `/providers/${id}`, data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async delete(id: string): Promise<void> {
|
async delete(id: string): Promise<void> {
|
||||||
return request<void>('DELETE', `/providers/${id}`)
|
return request<void>('DELETE', `/providers/${id}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
// Key Pool 管理
|
|
||||||
async listKeys(providerId: string): Promise<ProviderKey[]> {
|
|
||||||
return request<ProviderKey[]>('GET', `/providers/${providerId}/keys`)
|
|
||||||
},
|
|
||||||
|
|
||||||
async addKey(providerId: string, data: {
|
|
||||||
key_label: string
|
|
||||||
key_value: string
|
|
||||||
priority?: number
|
|
||||||
max_rpm?: number
|
|
||||||
max_tpm?: number
|
|
||||||
quota_reset_interval?: string
|
|
||||||
}): Promise<{ ok: boolean; key_id: string }> {
|
|
||||||
return request<{ ok: boolean; key_id: string }>('POST', `/providers/${providerId}/keys`, data)
|
|
||||||
},
|
|
||||||
|
|
||||||
async toggleKey(providerId: string, keyId: string, active: boolean): Promise<{ ok: boolean }> {
|
|
||||||
return request<{ ok: boolean }>('PUT', `/providers/${providerId}/keys/${keyId}/toggle`, { active })
|
|
||||||
},
|
|
||||||
|
|
||||||
async deleteKey(providerId: string, keyId: string): Promise<{ ok: boolean }> {
|
|
||||||
return request<{ ok: boolean }>('DELETE', `/providers/${providerId}/keys/${keyId}`)
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
|
|
||||||
// ── 模型管理 ──────────────────────────────────────────
|
// ── 模型管理 ──────────────────────────────────────────
|
||||||
@@ -285,12 +221,16 @@ export const api = {
|
|||||||
return request<PaginatedResponse<Model>>('GET', `/models${qs}`)
|
return request<PaginatedResponse<Model>>('GET', `/models${qs}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
|
async get(id: string): Promise<Model> {
|
||||||
|
return request<Model>('GET', `/models/${id}`)
|
||||||
|
},
|
||||||
|
|
||||||
async create(data: Partial<Omit<Model, 'id'>>): Promise<Model> {
|
async create(data: Partial<Omit<Model, 'id'>>): Promise<Model> {
|
||||||
return request<Model>('POST', '/models', data)
|
return request<Model>('POST', '/models', data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async update(id: string, data: Partial<Omit<Model, 'id'>>): Promise<Model> {
|
async update(id: string, data: Partial<Omit<Model, 'id'>>): Promise<Model> {
|
||||||
return request<Model>('PATCH', `/models/${id}`, data)
|
return request<Model>('PUT', `/models/${id}`, data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async delete(id: string): Promise<void> {
|
async delete(id: string): Promise<void> {
|
||||||
@@ -305,30 +245,23 @@ export const api = {
|
|||||||
page_size?: number
|
page_size?: number
|
||||||
}): Promise<PaginatedResponse<TokenInfo>> {
|
}): Promise<PaginatedResponse<TokenInfo>> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString(params)
|
||||||
return request<PaginatedResponse<TokenInfo>>('GET', `/keys${qs}`)
|
return request<PaginatedResponse<TokenInfo>>('GET', `/tokens${qs}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
async create(data: CreateTokenRequest): Promise<TokenInfo> {
|
async create(data: CreateTokenRequest): Promise<TokenInfo> {
|
||||||
return request<TokenInfo>('POST', '/keys', data)
|
return request<TokenInfo>('POST', '/tokens', data)
|
||||||
},
|
},
|
||||||
|
|
||||||
async revoke(id: string): Promise<void> {
|
async revoke(id: string): Promise<void> {
|
||||||
return request<void>('DELETE', `/keys/${id}`)
|
return request<void>('DELETE', `/tokens/${id}`)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// ── 用量统计 ──────────────────────────────────────────
|
// ── 用量统计 ──────────────────────────────────────────
|
||||||
usage: {
|
usage: {
|
||||||
async daily(params?: { days?: number }): Promise<UsageRecord[]> {
|
async get(params?: { from?: string; to?: string; provider_id?: string; model_id?: string }): Promise<UsageStats> {
|
||||||
const qs = buildQueryString({ ...params, group_by: 'day' })
|
const qs = buildQueryString(params)
|
||||||
const result = await request<{ by_day: UsageRecord[] }>('GET', `/usage${qs}`)
|
return request<UsageStats>('GET', `/usage${qs}`)
|
||||||
return result.by_day || []
|
|
||||||
},
|
|
||||||
|
|
||||||
async byModel(params?: { days?: number }): Promise<UsageByModel[]> {
|
|
||||||
const qs = buildQueryString({ ...params, group_by: 'model' })
|
|
||||||
const result = await request<{ by_model: UsageByModel[] }>('GET', `/usage${qs}`)
|
|
||||||
return result.by_model || []
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -346,22 +279,23 @@ export const api = {
|
|||||||
async get(id: string): Promise<RelayTask> {
|
async get(id: string): Promise<RelayTask> {
|
||||||
return request<RelayTask>('GET', `/relay/tasks/${id}`)
|
return request<RelayTask>('GET', `/relay/tasks/${id}`)
|
||||||
},
|
},
|
||||||
|
|
||||||
|
async retry(id: string): Promise<void> {
|
||||||
|
return request<void>('POST', `/relay/tasks/${id}/retry`)
|
||||||
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// ── 系统配置 ──────────────────────────────────────────
|
// ── 系统配置 ──────────────────────────────────────────
|
||||||
config: {
|
config: {
|
||||||
async list(params?: {
|
async list(params?: {
|
||||||
category?: string
|
category?: string
|
||||||
page?: number
|
|
||||||
page_size?: number
|
|
||||||
}): Promise<ConfigItem[]> {
|
}): Promise<ConfigItem[]> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString(params)
|
||||||
const result = await request<PaginatedResponse<ConfigItem>>('GET', `/config/items${qs}`)
|
return request<ConfigItem[]>('GET', `/config/items${qs}`)
|
||||||
return result.items
|
|
||||||
},
|
},
|
||||||
|
|
||||||
async update(id: string, data: { value: string | number | boolean }): Promise<ConfigItem> {
|
async update(id: string, data: { current_value: string | number | boolean }): Promise<ConfigItem> {
|
||||||
return request<ConfigItem>('PATCH', `/config/items/${id}`, data)
|
return request<ConfigItem>('PUT', `/config/items/${id}`, data)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -371,9 +305,9 @@ export const api = {
|
|||||||
page?: number
|
page?: number
|
||||||
page_size?: number
|
page_size?: number
|
||||||
action?: string
|
action?: string
|
||||||
}): Promise<PaginatedResponse<OperationLog>> {
|
}): Promise<OperationLog[]> {
|
||||||
const qs = buildQueryString(params)
|
const qs = buildQueryString(params)
|
||||||
return request<PaginatedResponse<OperationLog>>('GET', `/logs/operations${qs}`)
|
return request<OperationLog[]>('GET', `/logs/operations${qs}`)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
@@ -384,138 +318,16 @@ export const api = {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|
||||||
// ── 提示词管理 ────────────────────────────────────────
|
// ── 设备管理 ──────────────────────────────────────────
|
||||||
prompts: {
|
devices: {
|
||||||
async list(params?: {
|
async list(): Promise<DeviceInfo[]> {
|
||||||
category?: string
|
return request<DeviceInfo[]>('GET', '/devices')
|
||||||
source?: string
|
|
||||||
status?: string
|
|
||||||
page?: number
|
|
||||||
page_size?: number
|
|
||||||
}): Promise<PaginatedResponse<PromptTemplate>> {
|
|
||||||
const qs = buildQueryString(params)
|
|
||||||
return request<PaginatedResponse<PromptTemplate>>('GET', `/prompts${qs}`)
|
|
||||||
},
|
},
|
||||||
|
async register(data: { device_id: string; device_name?: string; platform?: string; app_version?: string }) {
|
||||||
async get(name: string): Promise<PromptTemplate> {
|
return request<{ ok: boolean; device_id: string }>('POST', '/devices/register', data)
|
||||||
return request<PromptTemplate>('GET', `/prompts/${encodeURIComponent(name)}`)
|
|
||||||
},
|
},
|
||||||
|
async heartbeat(data: { device_id: string }) {
|
||||||
async create(data: {
|
return request<{ ok: boolean }>('POST', '/devices/heartbeat', data)
|
||||||
name: string
|
|
||||||
category: string
|
|
||||||
description?: string
|
|
||||||
source?: string
|
|
||||||
system_prompt: string
|
|
||||||
user_prompt_template?: string
|
|
||||||
variables?: unknown[]
|
|
||||||
min_app_version?: string
|
|
||||||
}): Promise<PromptTemplate> {
|
|
||||||
return request<PromptTemplate>('POST', '/prompts', data)
|
|
||||||
},
|
|
||||||
|
|
||||||
async update(name: string, data: {
|
|
||||||
description?: string
|
|
||||||
status?: string
|
|
||||||
}): Promise<PromptTemplate> {
|
|
||||||
return request<PromptTemplate>('PUT', `/prompts/${encodeURIComponent(name)}`, data)
|
|
||||||
},
|
|
||||||
|
|
||||||
async archive(name: string): Promise<PromptTemplate> {
|
|
||||||
return request<PromptTemplate>('DELETE', `/prompts/${encodeURIComponent(name)}`)
|
|
||||||
},
|
|
||||||
|
|
||||||
async listVersions(name: string): Promise<PromptVersion[]> {
|
|
||||||
return request<PromptVersion[]>('GET', `/prompts/${encodeURIComponent(name)}/versions`)
|
|
||||||
},
|
|
||||||
|
|
||||||
async createVersion(name: string, data: {
|
|
||||||
system_prompt: string
|
|
||||||
user_prompt_template?: string
|
|
||||||
variables?: unknown[]
|
|
||||||
changelog?: string
|
|
||||||
min_app_version?: string
|
|
||||||
}): Promise<PromptVersion> {
|
|
||||||
return request<PromptVersion>('POST', `/prompts/${encodeURIComponent(name)}/versions`, data)
|
|
||||||
},
|
|
||||||
|
|
||||||
async rollback(name: string, version: number): Promise<PromptTemplate> {
|
|
||||||
return request<PromptTemplate>('POST', `/prompts/${encodeURIComponent(name)}/rollback/${version}`)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
// ── Agent 配置模板 ──────────────────────────────────
|
|
||||||
agentTemplates: {
|
|
||||||
async list(params?: {
|
|
||||||
category?: string
|
|
||||||
source?: string
|
|
||||||
visibility?: string
|
|
||||||
status?: string
|
|
||||||
page?: number
|
|
||||||
page_size?: number
|
|
||||||
}): Promise<PaginatedResponse<AgentTemplate>> {
|
|
||||||
const qs = buildQueryString(params)
|
|
||||||
return request<PaginatedResponse<AgentTemplate>>('GET', `/agent-templates${qs}`)
|
|
||||||
},
|
|
||||||
|
|
||||||
async get(id: string): Promise<AgentTemplate> {
|
|
||||||
return request<AgentTemplate>('GET', `/agent-templates/${id}`)
|
|
||||||
},
|
|
||||||
|
|
||||||
async create(data: {
|
|
||||||
name: string
|
|
||||||
description?: string
|
|
||||||
category?: string
|
|
||||||
source?: string
|
|
||||||
model?: string
|
|
||||||
system_prompt?: string
|
|
||||||
tools?: string[]
|
|
||||||
capabilities?: string[]
|
|
||||||
temperature?: number
|
|
||||||
max_tokens?: number
|
|
||||||
visibility?: string
|
|
||||||
}): Promise<AgentTemplate> {
|
|
||||||
return request<AgentTemplate>('POST', '/agent-templates', data)
|
|
||||||
},
|
|
||||||
|
|
||||||
async update(id: string, data: {
|
|
||||||
description?: string
|
|
||||||
model?: string
|
|
||||||
system_prompt?: string
|
|
||||||
tools?: string[]
|
|
||||||
capabilities?: string[]
|
|
||||||
temperature?: number
|
|
||||||
max_tokens?: number
|
|
||||||
visibility?: string
|
|
||||||
status?: string
|
|
||||||
}): Promise<AgentTemplate> {
|
|
||||||
return request<AgentTemplate>('POST', `/agent-templates/${id}`, data)
|
|
||||||
},
|
|
||||||
|
|
||||||
async archive(id: string): Promise<AgentTemplate> {
|
|
||||||
return request<AgentTemplate>('DELETE', `/agent-templates/${id}`)
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
// ── 遥测统计 ──────────────────────────────────────────
|
|
||||||
telemetry: {
|
|
||||||
/** 按模型聚合用量统计 */
|
|
||||||
async modelStats(params?: {
|
|
||||||
from?: string
|
|
||||||
to?: string
|
|
||||||
model_id?: string
|
|
||||||
connection_mode?: string
|
|
||||||
}): Promise<ModelUsageStat[]> {
|
|
||||||
const qs = buildQueryString(params)
|
|
||||||
return request<ModelUsageStat[]>('GET', `/telemetry/stats${qs}`)
|
|
||||||
},
|
|
||||||
|
|
||||||
/** 按天聚合用量统计 */
|
|
||||||
async dailyStats(params?: {
|
|
||||||
days?: number
|
|
||||||
}): Promise<DailyUsageStat[]> {
|
|
||||||
const qs = buildQueryString(params)
|
|
||||||
return request<DailyUsageStat[]>('GET', `/telemetry/daily${qs}`)
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
// ============================================================
|
|
||||||
// API Error 类 — 与 swr-fetcher 共享
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
export class ApiRequestError extends Error {
|
|
||||||
constructor(
|
|
||||||
public status: number,
|
|
||||||
public body: { error?: string; message?: string },
|
|
||||||
) {
|
|
||||||
super(body.message || `Request failed with status ${status}`)
|
|
||||||
this.name = 'ApiRequestError'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,28 +2,74 @@
|
|||||||
// ZCLAW SaaS Admin — JWT Token 管理
|
// ZCLAW SaaS Admin — JWT Token 管理
|
||||||
// ============================================================
|
// ============================================================
|
||||||
|
|
||||||
import type { AccountPublic } from './types'
|
import type { AccountPublic, LoginResponse } from './types'
|
||||||
|
|
||||||
const TOKEN_KEY = 'zclaw_admin_token'
|
const TOKEN_KEY = 'zclaw_admin_token'
|
||||||
const ACCOUNT_KEY = 'zclaw_admin_account'
|
const ACCOUNT_KEY = 'zclaw_admin_account'
|
||||||
|
|
||||||
/** 保存登录凭证 */
|
// ── JWT 辅助函数 ────────────────────────────────────────────
|
||||||
|
|
||||||
|
interface JwtPayload {
|
||||||
|
exp?: number
|
||||||
|
iat?: number
|
||||||
|
sub?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decode a JWT payload without verifying the signature.
|
||||||
|
* Returns the parsed JSON payload, or null if the token is malformed.
|
||||||
|
*/
|
||||||
|
function decodeJwtPayload<T = Record<string, unknown>>(token: string): T | null {
|
||||||
|
try {
|
||||||
|
const parts = token.split('.')
|
||||||
|
if (parts.length !== 3) return null
|
||||||
|
const base64 = parts[1].replace(/-/g, '+').replace(/_/g, '/')
|
||||||
|
const json = decodeURIComponent(
|
||||||
|
atob(base64)
|
||||||
|
.split('')
|
||||||
|
.map((c) => '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2))
|
||||||
|
.join(''),
|
||||||
|
)
|
||||||
|
return JSON.parse(json) as T
|
||||||
|
} catch {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the delay (ms) until 80% of the token's remaining lifetime
|
||||||
|
* has elapsed. Returns null if the token is already past that point.
|
||||||
|
*/
|
||||||
|
function getRefreshDelay(exp: number): number | null {
|
||||||
|
const now = Math.floor(Date.now() / 1000)
|
||||||
|
const totalLifetime = exp - now
|
||||||
|
if (totalLifetime <= 0) return null
|
||||||
|
|
||||||
|
const refreshAt = now + Math.floor(totalLifetime * 0.8)
|
||||||
|
const delayMs = (refreshAt - now) * 1000
|
||||||
|
return delayMs > 5000 ? delayMs : 5000
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 定时刷新状态 ────────────────────────────────────────────
|
||||||
|
|
||||||
|
let refreshTimerId: ReturnType<typeof setTimeout> | null = null
|
||||||
|
let visibilityHandler: (() => void) | null = null
|
||||||
|
let sessionExpiredCallback: (() => void) | null = null
|
||||||
|
|
||||||
|
// ── 凭证操作 ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/** 保存登录凭证并启动自动刷新 */
|
||||||
export function login(token: string, account: AccountPublic): void {
|
export function login(token: string, account: AccountPublic): void {
|
||||||
if (typeof window === 'undefined') return
|
if (typeof window === 'undefined') return
|
||||||
localStorage.setItem(TOKEN_KEY, token)
|
localStorage.setItem(TOKEN_KEY, token)
|
||||||
localStorage.setItem(ACCOUNT_KEY, JSON.stringify(account))
|
localStorage.setItem(ACCOUNT_KEY, JSON.stringify(account))
|
||||||
|
scheduleTokenRefresh()
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 清除登录凭证 */
|
/** 清除登录凭证并停止自动刷新 */
|
||||||
export function logout(): void {
|
export function logout(): void {
|
||||||
if (typeof window === 'undefined') return
|
if (typeof window === 'undefined') return
|
||||||
localStorage.removeItem(TOKEN_KEY)
|
cancelTokenRefresh()
|
||||||
localStorage.removeItem(ACCOUNT_KEY)
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 清除认证状态(用于 Token 验证失败时) */
|
|
||||||
export function clearAuth(): void {
|
|
||||||
if (typeof window === 'undefined') return
|
|
||||||
localStorage.removeItem(TOKEN_KEY)
|
localStorage.removeItem(TOKEN_KEY)
|
||||||
localStorage.removeItem(ACCOUNT_KEY)
|
localStorage.removeItem(ACCOUNT_KEY)
|
||||||
}
|
}
|
||||||
@@ -50,3 +96,121 @@ export function getAccount(): AccountPublic | null {
|
|||||||
export function isAuthenticated(): boolean {
|
export function isAuthenticated(): boolean {
|
||||||
return !!getToken()
|
return !!getToken()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** 尝试刷新 token,成功则更新 localStorage 并返回新 token */
|
||||||
|
export async function refreshToken(): Promise<string> {
|
||||||
|
const res = await fetch(
|
||||||
|
`${process.env.NEXT_PUBLIC_SAAS_API_URL || 'http://localhost:8080'}/api/v1/auth/refresh`,
|
||||||
|
{
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
Authorization: `Bearer ${getToken()}`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if (!res.ok) {
|
||||||
|
throw new Error('Token 刷新失败')
|
||||||
|
}
|
||||||
|
const data: LoginResponse = await res.json()
|
||||||
|
login(data.token, data.account)
|
||||||
|
return data.token
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── 自动刷新调度 ────────────────────────────────────────────
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Register a callback invoked when the proactive token refresh fails.
|
||||||
|
* The caller should use this to trigger a logout/redirect flow.
|
||||||
|
*/
|
||||||
|
export function setOnSessionExpired(handler: (() => void) | null): void {
|
||||||
|
sessionExpiredCallback = handler
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Schedule a proactive token refresh at 80% of the token's remaining lifetime.
|
||||||
|
* Also registers a visibilitychange listener to re-check when the tab regains focus.
|
||||||
|
*/
|
||||||
|
export function scheduleTokenRefresh(): void {
|
||||||
|
cancelTokenRefresh()
|
||||||
|
|
||||||
|
const token = getToken()
|
||||||
|
if (!token) return
|
||||||
|
|
||||||
|
const payload = decodeJwtPayload<JwtPayload>(token)
|
||||||
|
if (!payload?.exp) return
|
||||||
|
|
||||||
|
const delay = getRefreshDelay(payload.exp)
|
||||||
|
if (delay === null) {
|
||||||
|
attemptTokenRefresh()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
refreshTimerId = setTimeout(() => {
|
||||||
|
attemptTokenRefresh()
|
||||||
|
}, delay)
|
||||||
|
|
||||||
|
if (typeof document !== 'undefined' && !visibilityHandler) {
|
||||||
|
visibilityHandler = () => {
|
||||||
|
if (document.visibilityState === 'visible') {
|
||||||
|
checkAndRefreshToken()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
document.addEventListener('visibilitychange', visibilityHandler)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cancel any pending token refresh timer and remove the visibility listener.
|
||||||
|
*/
|
||||||
|
export function cancelTokenRefresh(): void {
|
||||||
|
if (refreshTimerId !== null) {
|
||||||
|
clearTimeout(refreshTimerId)
|
||||||
|
refreshTimerId = null
|
||||||
|
}
|
||||||
|
if (visibilityHandler !== null && typeof document !== 'undefined') {
|
||||||
|
document.removeEventListener('visibilitychange', visibilityHandler)
|
||||||
|
visibilityHandler = null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if the current token is close to expiry and refresh if needed.
|
||||||
|
* Called on visibility change to handle clock skew / long background tabs.
|
||||||
|
*/
|
||||||
|
function checkAndRefreshToken(): void {
|
||||||
|
const token = getToken()
|
||||||
|
if (!token) return
|
||||||
|
|
||||||
|
const payload = decodeJwtPayload<JwtPayload>(token)
|
||||||
|
if (!payload?.exp) return
|
||||||
|
|
||||||
|
const now = Math.floor(Date.now() / 1000)
|
||||||
|
const remaining = payload.exp - now
|
||||||
|
|
||||||
|
if (remaining <= 0) {
|
||||||
|
attemptTokenRefresh()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const delay = getRefreshDelay(payload.exp)
|
||||||
|
if (delay !== null && delay < 60_000) {
|
||||||
|
attemptTokenRefresh()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attempt to refresh the token. On success, the new token is persisted via
|
||||||
|
* login() which also reschedules the next refresh. On failure, invoke the
|
||||||
|
* session-expired callback.
|
||||||
|
*/
|
||||||
|
async function attemptTokenRefresh(): Promise<void> {
|
||||||
|
try {
|
||||||
|
await refreshToken()
|
||||||
|
} catch {
|
||||||
|
cancelTokenRefresh()
|
||||||
|
if (sessionExpiredCallback) {
|
||||||
|
sessionExpiredCallback()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,75 +0,0 @@
|
|||||||
// ============================================================
|
|
||||||
// SWR fetcher — 将 SWR key 映射到 api-client 调用
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
import { api } from './api-client'
|
|
||||||
import { ApiRequestError } from './api-client'
|
|
||||||
|
|
||||||
type ApiMethod = typeof api
|
|
||||||
|
|
||||||
/** SWR fetcher: key 可以是字符串或 [method-path, params] 元组 */
|
|
||||||
type SwrKey =
|
|
||||||
| string
|
|
||||||
| [string, ...unknown[]]
|
|
||||||
|
|
||||||
/** SWR fetcher 支持 AbortSignal 传递 */
|
|
||||||
type SwrFetcherArgs = { signal?: AbortSignal } | null
|
|
||||||
|
|
||||||
async function resolveApiCall(key: SwrKey, args: SwrFetcherArgs): Promise<unknown> {
|
|
||||||
if (typeof key === 'string') {
|
|
||||||
// 简单字符串 key,直接 fetch
|
|
||||||
return fetchGeneric(key, args?.signal)
|
|
||||||
}
|
|
||||||
|
|
||||||
const [path, ...rest] = key
|
|
||||||
return callByPath(path, rest, args?.signal)
|
|
||||||
}
|
|
||||||
|
|
||||||
async function fetchGeneric(path: string, signal?: AbortSignal): Promise<unknown> {
|
|
||||||
const res = await fetch(path, {
|
|
||||||
headers: {
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
},
|
|
||||||
signal,
|
|
||||||
})
|
|
||||||
if (!res.ok) {
|
|
||||||
const body = await res.json().catch(() => ({ error: 'unknown', message: `请求失败 (${res.status})` }))
|
|
||||||
throw new ApiRequestError(res.status, body)
|
|
||||||
}
|
|
||||||
if (res.status === 204) return null
|
|
||||||
return res.json()
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 根据 path 调用对应的 api 方法 */
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
async function callByPath(path: string, callArgs: unknown[], signal?: AbortSignal): Promise<unknown> {
|
|
||||||
const parts = path.split('.')
|
|
||||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
|
||||||
let target: any = api
|
|
||||||
for (const part of parts) {
|
|
||||||
target = target[part]
|
|
||||||
if (!target) throw new Error(`API method not found: ${path}`)
|
|
||||||
}
|
|
||||||
// Append signal as last argument if the target is the request function
|
|
||||||
// For api.xxx() calls that ultimately use request(), we pass signal through
|
|
||||||
// The simplest approach: pass signal as part of an options bag
|
|
||||||
return target(...callArgs, signal ? { signal } : undefined)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* SWR fetcher — 接受 SWR 自动传入的 AbortSignal
|
|
||||||
*
|
|
||||||
* 用法: useSWR(key, swrFetcher)
|
|
||||||
* SWR 会自动在组件卸载或 key 变化时 abort 请求
|
|
||||||
*/
|
|
||||||
export function swrFetcher<T = unknown>(key: SwrKey, args: SwrFetcherArgs): Promise<T> {
|
|
||||||
return resolveApiCall(key, args) as Promise<T>
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 创建 SWR key helper — 类型安全 */
|
|
||||||
export function createKey<TMethod extends string>(
|
|
||||||
method: TMethod,
|
|
||||||
...args: unknown[]
|
|
||||||
): [TMethod, ...unknown[]] {
|
|
||||||
return [method, ...args]
|
|
||||||
}
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
'use client'
|
|
||||||
|
|
||||||
import { SWRConfig } from 'swr'
|
|
||||||
import type { ReactNode } from 'react'
|
|
||||||
|
|
||||||
/** 判断是否为请求被中断(页面导航等场景) */
|
|
||||||
function isAbortError(err: unknown): boolean {
|
|
||||||
if (err instanceof DOMException && err.name === 'AbortError') return true
|
|
||||||
if (err instanceof Error && err.message?.includes('aborted')) return true
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
export function SWRProvider({ children }: { children: ReactNode }) {
|
|
||||||
return (
|
|
||||||
<SWRConfig
|
|
||||||
value={{
|
|
||||||
revalidateOnFocus: false,
|
|
||||||
dedupingInterval: 5000,
|
|
||||||
errorRetryCount: 2,
|
|
||||||
errorRetryInterval: 3000,
|
|
||||||
shouldRetryOnError: (err: unknown) => {
|
|
||||||
if (isAbortError(err)) return false
|
|
||||||
if (err && typeof err === 'object' && 'status' in err) {
|
|
||||||
const status = (err as { status: number }).status
|
|
||||||
return status !== 401 && status !== 403
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
},
|
|
||||||
onError: (err: unknown) => {
|
|
||||||
// 中断错误静默忽略,不展示给用户
|
|
||||||
if (isAbortError(err)) return
|
|
||||||
},
|
|
||||||
}}
|
|
||||||
>
|
|
||||||
{children}
|
|
||||||
</SWRConfig>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
@@ -9,9 +9,9 @@ export interface AccountPublic {
|
|||||||
email: string
|
email: string
|
||||||
display_name: string
|
display_name: string
|
||||||
role: 'super_admin' | 'admin' | 'user'
|
role: 'super_admin' | 'admin' | 'user'
|
||||||
|
permissions: string[]
|
||||||
status: 'active' | 'disabled' | 'suspended'
|
status: 'active' | 'disabled' | 'suspended'
|
||||||
totp_enabled: boolean
|
totp_enabled: boolean
|
||||||
last_login_at: string | null
|
|
||||||
created_at: string
|
created_at: string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -25,7 +25,6 @@ export interface LoginRequest {
|
|||||||
/** 登录响应 */
|
/** 登录响应 */
|
||||||
export interface LoginResponse {
|
export interface LoginResponse {
|
||||||
token: string
|
token: string
|
||||||
refresh_token: string
|
|
||||||
account: AccountPublic
|
account: AccountPublic
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -50,12 +49,11 @@ export interface Provider {
|
|||||||
id: string
|
id: string
|
||||||
name: string
|
name: string
|
||||||
display_name: string
|
display_name: string
|
||||||
api_key?: string
|
|
||||||
base_url: string
|
base_url: string
|
||||||
api_protocol: string
|
api_protocol: 'openai' | 'anthropic'
|
||||||
enabled: boolean
|
enabled: boolean
|
||||||
rate_limit_rpm: number | null
|
rate_limit_rpm?: number
|
||||||
rate_limit_tpm: number | null
|
rate_limit_tpm?: number
|
||||||
created_at: string
|
created_at: string
|
||||||
updated_at: string
|
updated_at: string
|
||||||
}
|
}
|
||||||
@@ -100,31 +98,40 @@ export interface RelayTask {
|
|||||||
account_id: string
|
account_id: string
|
||||||
provider_id: string
|
provider_id: string
|
||||||
model_id: string
|
model_id: string
|
||||||
status: string
|
status: 'queued' | 'processing' | 'completed' | 'failed'
|
||||||
priority: number
|
priority: number
|
||||||
attempt_count: number
|
attempt_count: number
|
||||||
max_attempts: number
|
|
||||||
input_tokens: number
|
input_tokens: number
|
||||||
output_tokens: number
|
output_tokens: number
|
||||||
error_message: string | null
|
error_message?: string
|
||||||
queued_at: string
|
queued_at?: string
|
||||||
started_at: string | null
|
started_at?: string
|
||||||
completed_at: string | null
|
completed_at?: string
|
||||||
created_at: string
|
created_at: string
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 用量记录 */
|
/** 用量统计 — 后端返回的完整结构 */
|
||||||
export interface UsageRecord {
|
export interface UsageStats {
|
||||||
day: string
|
total_requests: number
|
||||||
count: number
|
total_input_tokens: number
|
||||||
|
total_output_tokens: number
|
||||||
|
by_model: UsageByModel[]
|
||||||
|
by_day: DailyUsage[]
|
||||||
|
}
|
||||||
|
|
||||||
|
/** 每日用量 */
|
||||||
|
export interface DailyUsage {
|
||||||
|
date: string
|
||||||
|
request_count: number
|
||||||
input_tokens: number
|
input_tokens: number
|
||||||
output_tokens: number
|
output_tokens: number
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 按模型用量 */
|
/** 按模型用量 */
|
||||||
export interface UsageByModel {
|
export interface UsageByModel {
|
||||||
|
provider_id: string
|
||||||
model_id: string
|
model_id: string
|
||||||
count: number
|
request_count: number
|
||||||
input_tokens: number
|
input_tokens: number
|
||||||
output_tokens: number
|
output_tokens: number
|
||||||
}
|
}
|
||||||
@@ -134,11 +141,11 @@ export interface ConfigItem {
|
|||||||
id: string
|
id: string
|
||||||
category: string
|
category: string
|
||||||
key_path: string
|
key_path: string
|
||||||
value_type: string
|
value_type: 'string' | 'number' | 'boolean'
|
||||||
current_value: string | null
|
current_value?: string
|
||||||
default_value: string | null
|
default_value?: string
|
||||||
source: string
|
source: 'default' | 'env' | 'db'
|
||||||
description: string | null
|
description?: string
|
||||||
requires_restart: boolean
|
requires_restart: boolean
|
||||||
created_at: string
|
created_at: string
|
||||||
updated_at: string
|
updated_at: string
|
||||||
@@ -147,12 +154,12 @@ export interface ConfigItem {
|
|||||||
/** 操作日志 */
|
/** 操作日志 */
|
||||||
export interface OperationLog {
|
export interface OperationLog {
|
||||||
id: number
|
id: number
|
||||||
account_id: string | null
|
account_id: string
|
||||||
action: string
|
action: string
|
||||||
target_type: string | null
|
target_type: string
|
||||||
target_id: string | null
|
target_id: string
|
||||||
details: Record<string, unknown> | null
|
details?: Record<string, unknown>
|
||||||
ip_address: string | null
|
ip_address?: string
|
||||||
created_at: string
|
created_at: string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,133 +174,20 @@ export interface DashboardStats {
|
|||||||
tokens_today_output: number
|
tokens_today_output: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** 设备信息 */
|
||||||
|
export interface DeviceInfo {
|
||||||
|
id: string
|
||||||
|
device_id: string
|
||||||
|
device_name?: string
|
||||||
|
platform?: string
|
||||||
|
app_version?: string
|
||||||
|
last_seen_at: string
|
||||||
|
created_at: string
|
||||||
|
}
|
||||||
|
|
||||||
/** API 错误响应 */
|
/** API 错误响应 */
|
||||||
export interface ApiError {
|
export interface ApiError {
|
||||||
error: string
|
error: string
|
||||||
message: string
|
message: string
|
||||||
status?: number
|
status?: number
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── 提示词模板 ────────────────────────────────────────────
|
|
||||||
|
|
||||||
/** 提示词模板 */
|
|
||||||
export interface PromptTemplate {
|
|
||||||
id: string
|
|
||||||
name: string
|
|
||||||
category: string
|
|
||||||
description?: string
|
|
||||||
source: 'builtin' | 'custom'
|
|
||||||
current_version: number
|
|
||||||
status: 'active' | 'deprecated' | 'archived'
|
|
||||||
created_at: string
|
|
||||||
updated_at: string
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 提示词版本 */
|
|
||||||
export interface PromptVersion {
|
|
||||||
id: string
|
|
||||||
template_id: string
|
|
||||||
version: number
|
|
||||||
system_prompt: string
|
|
||||||
user_prompt_template?: string
|
|
||||||
variables: PromptVariable[]
|
|
||||||
changelog?: string
|
|
||||||
min_app_version?: string
|
|
||||||
created_at: string
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 提示词变量定义 */
|
|
||||||
export interface PromptVariable {
|
|
||||||
name: string
|
|
||||||
type: 'string' | 'number' | 'select' | 'boolean'
|
|
||||||
default_value?: string
|
|
||||||
description?: string
|
|
||||||
required?: boolean
|
|
||||||
}
|
|
||||||
|
|
||||||
/** OTA 更新检查请求 */
|
|
||||||
export interface PromptCheckRequest {
|
|
||||||
device_id: string
|
|
||||||
versions: Record<string, number>
|
|
||||||
}
|
|
||||||
|
|
||||||
/** OTA 更新响应 */
|
|
||||||
export interface PromptCheckResponse {
|
|
||||||
updates: PromptUpdatePayload[]
|
|
||||||
server_time: string
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 单个更新载荷 */
|
|
||||||
export interface PromptUpdatePayload {
|
|
||||||
name: string
|
|
||||||
version: number
|
|
||||||
system_prompt: string
|
|
||||||
user_prompt_template?: string
|
|
||||||
variables: PromptVariable[]
|
|
||||||
source: string
|
|
||||||
min_app_version?: string
|
|
||||||
changelog?: string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Agent 配置模板 ────────────────────────────────────────
|
|
||||||
|
|
||||||
/** Agent 模板 */
|
|
||||||
export interface AgentTemplate {
|
|
||||||
id: string
|
|
||||||
name: string
|
|
||||||
description?: string
|
|
||||||
category: string
|
|
||||||
source: 'builtin' | 'custom'
|
|
||||||
model?: string
|
|
||||||
system_prompt?: string
|
|
||||||
tools: string[]
|
|
||||||
capabilities: string[]
|
|
||||||
temperature?: number
|
|
||||||
max_tokens?: number
|
|
||||||
visibility: 'public' | 'team' | 'private'
|
|
||||||
status: 'active' | 'archived'
|
|
||||||
current_version: number
|
|
||||||
created_at: string
|
|
||||||
updated_at: string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Provider Key Pool ─────────────────────────────────────
|
|
||||||
|
|
||||||
/** Provider Key */
|
|
||||||
export interface ProviderKey {
|
|
||||||
id: string
|
|
||||||
provider_id: string
|
|
||||||
key_label: string
|
|
||||||
priority: number
|
|
||||||
max_rpm?: number
|
|
||||||
max_tpm?: number
|
|
||||||
quota_reset_interval?: string
|
|
||||||
is_active: boolean
|
|
||||||
last_429_at?: string
|
|
||||||
cooldown_until?: string
|
|
||||||
total_requests: number
|
|
||||||
total_tokens: number
|
|
||||||
created_at: string
|
|
||||||
updated_at: string
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── 遥测统计 ────────────────────────────────────────────
|
|
||||||
|
|
||||||
/** 按模型聚合的用量统计 */
|
|
||||||
export interface ModelUsageStat {
|
|
||||||
model_id: string
|
|
||||||
request_count: number
|
|
||||||
input_tokens: number
|
|
||||||
output_tokens: number
|
|
||||||
avg_latency_ms: number | null
|
|
||||||
success_rate: number
|
|
||||||
}
|
|
||||||
|
|
||||||
/** 按天的用量统计 */
|
|
||||||
export interface DailyUsageStat {
|
|
||||||
day: string
|
|
||||||
request_count: number
|
|
||||||
input_tokens: number
|
|
||||||
output_tokens: number
|
|
||||||
unique_devices: number
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -32,14 +32,3 @@ export function maskApiKey(key?: string): string {
|
|||||||
export function sleep(ms: number): Promise<void> {
|
export function sleep(ms: number): Promise<void> {
|
||||||
return new Promise(resolve => setTimeout(resolve, ms))
|
return new Promise(resolve => setTimeout(resolve, ms))
|
||||||
}
|
}
|
||||||
|
|
||||||
/** 从 SWR error 中提取用户可见消息,过滤 abort 错误 */
|
|
||||||
export function getSwrErrorMessage(err: unknown): string | undefined {
|
|
||||||
if (!err) return undefined
|
|
||||||
if (err instanceof DOMException && err.name === 'AbortError') return undefined
|
|
||||||
if (err instanceof Error) {
|
|
||||||
if (err.name === 'AbortError' || err.message?.includes('aborted')) return undefined
|
|
||||||
return err.message
|
|
||||||
}
|
|
||||||
return String(err)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,33 +0,0 @@
|
|||||||
# ZCLAW SaaS 开发环境配置
|
|
||||||
# 通过 ZCLAW_ENV=development 或默认使用此配置
|
|
||||||
|
|
||||||
[server]
|
|
||||||
host = "0.0.0.0"
|
|
||||||
port = 8080
|
|
||||||
cors_origins = [] # 空 = 开发模式允许所有来源
|
|
||||||
|
|
||||||
[database]
|
|
||||||
url = "postgres://postgres:123123@localhost:5432/zclaw"
|
|
||||||
|
|
||||||
[auth]
|
|
||||||
jwt_expiration_hours = 24
|
|
||||||
totp_issuer = "ZCLAW SaaS (dev)"
|
|
||||||
refresh_token_hours = 168
|
|
||||||
|
|
||||||
[relay]
|
|
||||||
max_queue_size = 1000
|
|
||||||
max_concurrent_per_provider = 5
|
|
||||||
batch_window_ms = 50
|
|
||||||
retry_delay_ms = 1000
|
|
||||||
max_attempts = 3
|
|
||||||
|
|
||||||
[rate_limit]
|
|
||||||
requests_per_minute = 120
|
|
||||||
burst = 20
|
|
||||||
|
|
||||||
[scheduler]
|
|
||||||
jobs = [
|
|
||||||
{ name = "cleanup_rate_limit", interval = "5m", task = "cleanup_rate_limit", run_on_start = false },
|
|
||||||
{ name = "cleanup_refresh_tokens", interval = "1h", task = "cleanup_refresh_tokens", run_on_start = false },
|
|
||||||
{ name = "cleanup_devices", interval = "24h", task = "cleanup_devices", run_on_start = false },
|
|
||||||
]
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
# ZCLAW SaaS 生产环境配置
|
|
||||||
# 通过 ZCLAW_ENV=production 使用此配置
|
|
||||||
|
|
||||||
[server]
|
|
||||||
host = "0.0.0.0"
|
|
||||||
port = 8080
|
|
||||||
# 生产环境必须配置 CORS 白名单
|
|
||||||
cors_origins = ["https://admin.zclaw.ai", "https://zclaw.ai"]
|
|
||||||
|
|
||||||
[database]
|
|
||||||
# 生产环境通过 ZCLAW_DATABASE_URL 环境变量覆盖,此处为占位
|
|
||||||
url = "postgres://zclaw:CHANGE_ME@db:5432/zclaw"
|
|
||||||
|
|
||||||
[auth]
|
|
||||||
jwt_expiration_hours = 12
|
|
||||||
totp_issuer = "ZCLAW SaaS"
|
|
||||||
refresh_token_hours = 168
|
|
||||||
|
|
||||||
[relay]
|
|
||||||
max_queue_size = 5000
|
|
||||||
max_concurrent_per_provider = 10
|
|
||||||
batch_window_ms = 50
|
|
||||||
retry_delay_ms = 2000
|
|
||||||
max_attempts = 3
|
|
||||||
|
|
||||||
[rate_limit]
|
|
||||||
requests_per_minute = 60
|
|
||||||
burst = 10
|
|
||||||
|
|
||||||
[scheduler]
|
|
||||||
jobs = [
|
|
||||||
{ name = "cleanup_rate_limit", interval = "5m", task = "cleanup_rate_limit", run_on_start = false },
|
|
||||||
{ name = "cleanup_refresh_tokens", interval = "1h", task = "cleanup_refresh_tokens", run_on_start = false },
|
|
||||||
{ name = "cleanup_devices", interval = "24h", task = "cleanup_devices", run_on_start = true },
|
|
||||||
]
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
# ZCLAW SaaS 测试环境配置
|
|
||||||
# 通过 ZCLAW_ENV=test 使用此配置
|
|
||||||
|
|
||||||
[server]
|
|
||||||
host = "127.0.0.1"
|
|
||||||
port = 8090
|
|
||||||
cors_origins = []
|
|
||||||
|
|
||||||
[database]
|
|
||||||
# 测试环境使用独立数据库
|
|
||||||
url = "postgres://postgres:123123@localhost:5432/zclaw_test"
|
|
||||||
|
|
||||||
[auth]
|
|
||||||
jwt_expiration_hours = 1
|
|
||||||
totp_issuer = "ZCLAW SaaS (test)"
|
|
||||||
refresh_token_hours = 24
|
|
||||||
|
|
||||||
[relay]
|
|
||||||
max_queue_size = 100
|
|
||||||
max_concurrent_per_provider = 2
|
|
||||||
batch_window_ms = 10
|
|
||||||
retry_delay_ms = 100
|
|
||||||
max_attempts = 2
|
|
||||||
|
|
||||||
[rate_limit]
|
|
||||||
requests_per_minute = 200
|
|
||||||
burst = 50
|
|
||||||
|
|
||||||
[scheduler]
|
|
||||||
# 测试环境不启动定时任务
|
|
||||||
jobs = []
|
|
||||||
21
crates/zclaw-channels/Cargo.toml
Normal file
21
crates/zclaw-channels/Cargo.toml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
[package]
|
||||||
|
name = "zclaw-channels"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
rust-version.workspace = true
|
||||||
|
description = "ZCLAW Channels - external platform adapters"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
zclaw-types = { workspace = true }
|
||||||
|
|
||||||
|
tokio = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
thiserror = { workspace = true }
|
||||||
|
tracing = { workspace = true }
|
||||||
|
async-trait = { workspace = true }
|
||||||
|
|
||||||
|
reqwest = { workspace = true }
|
||||||
|
chrono = { workspace = true }
|
||||||
71
crates/zclaw-channels/src/adapters/console.rs
Normal file
71
crates/zclaw-channels/src/adapters/console.rs
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
//! Console channel adapter for testing
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
use zclaw_types::Result;
|
||||||
|
|
||||||
|
use crate::{Channel, ChannelConfig, ChannelStatus, IncomingMessage, OutgoingMessage};
|
||||||
|
|
||||||
|
/// Console channel adapter (for testing)
|
||||||
|
pub struct ConsoleChannel {
|
||||||
|
config: ChannelConfig,
|
||||||
|
status: Arc<tokio::sync::RwLock<ChannelStatus>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConsoleChannel {
|
||||||
|
pub fn new(config: ChannelConfig) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
status: Arc::new(tokio::sync::RwLock::new(ChannelStatus::Disconnected)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Channel for ConsoleChannel {
|
||||||
|
fn config(&self) -> &ChannelConfig {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect(&self) -> Result<()> {
|
||||||
|
let mut status = self.status.write().await;
|
||||||
|
*status = ChannelStatus::Connected;
|
||||||
|
tracing::info!("Console channel connected");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn disconnect(&self) -> Result<()> {
|
||||||
|
let mut status = self.status.write().await;
|
||||||
|
*status = ChannelStatus::Disconnected;
|
||||||
|
tracing::info!("Console channel disconnected");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn status(&self) -> ChannelStatus {
|
||||||
|
self.status.read().await.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send(&self, message: OutgoingMessage) -> Result<String> {
|
||||||
|
// Print to console for testing
|
||||||
|
let msg_id = format!("console_{}", chrono::Utc::now().timestamp());
|
||||||
|
|
||||||
|
match &message.content {
|
||||||
|
crate::MessageContent::Text { text } => {
|
||||||
|
tracing::info!("[Console] To {}: {}", message.conversation_id, text);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
tracing::info!("[Console] To {}: {:?}", message.conversation_id, message.content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(msg_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn receive(&self) -> Result<mpsc::Receiver<IncomingMessage>> {
|
||||||
|
let (_tx, rx) = mpsc::channel(100);
|
||||||
|
// Console channel doesn't receive messages automatically
|
||||||
|
// Messages would need to be injected via a separate method
|
||||||
|
Ok(rx)
|
||||||
|
}
|
||||||
|
}
|
||||||
5
crates/zclaw-channels/src/adapters/mod.rs
Normal file
5
crates/zclaw-channels/src/adapters/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
//! Channel adapters
|
||||||
|
|
||||||
|
mod console;
|
||||||
|
|
||||||
|
pub use console::ConsoleChannel;
|
||||||
94
crates/zclaw-channels/src/bridge.rs
Normal file
94
crates/zclaw-channels/src/bridge.rs
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
//! Channel bridge manager
|
||||||
|
//!
|
||||||
|
//! Coordinates multiple channel adapters and routes messages.
|
||||||
|
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use zclaw_types::Result;
|
||||||
|
|
||||||
|
use super::{Channel, ChannelConfig, OutgoingMessage};
|
||||||
|
|
||||||
|
/// Channel bridge manager
|
||||||
|
pub struct ChannelBridge {
|
||||||
|
channels: RwLock<HashMap<String, Arc<dyn Channel>>>,
|
||||||
|
configs: RwLock<HashMap<String, ChannelConfig>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChannelBridge {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
channels: RwLock::new(HashMap::new()),
|
||||||
|
configs: RwLock::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Register a channel adapter
|
||||||
|
pub async fn register(&self, channel: Arc<dyn Channel>) {
|
||||||
|
let config = channel.config().clone();
|
||||||
|
let mut channels = self.channels.write().await;
|
||||||
|
let mut configs = self.configs.write().await;
|
||||||
|
|
||||||
|
channels.insert(config.id.clone(), channel);
|
||||||
|
configs.insert(config.id.clone(), config);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a channel by ID
|
||||||
|
pub async fn get(&self, id: &str) -> Option<Arc<dyn Channel>> {
|
||||||
|
let channels = self.channels.read().await;
|
||||||
|
channels.get(id).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get channel configuration
|
||||||
|
pub async fn get_config(&self, id: &str) -> Option<ChannelConfig> {
|
||||||
|
let configs = self.configs.read().await;
|
||||||
|
configs.get(id).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all channels
|
||||||
|
pub async fn list(&self) -> Vec<ChannelConfig> {
|
||||||
|
let configs = self.configs.read().await;
|
||||||
|
configs.values().cloned().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Connect all channels
|
||||||
|
pub async fn connect_all(&self) -> Result<()> {
|
||||||
|
let channels = self.channels.read().await;
|
||||||
|
for channel in channels.values() {
|
||||||
|
channel.connect().await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Disconnect all channels
|
||||||
|
pub async fn disconnect_all(&self) -> Result<()> {
|
||||||
|
let channels = self.channels.read().await;
|
||||||
|
for channel in channels.values() {
|
||||||
|
channel.disconnect().await?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send message through a specific channel
|
||||||
|
pub async fn send(&self, channel_id: &str, message: OutgoingMessage) -> Result<String> {
|
||||||
|
let channel = self.get(channel_id).await
|
||||||
|
.ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Channel not found: {}", channel_id)))?;
|
||||||
|
|
||||||
|
channel.send(message).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a channel
|
||||||
|
pub async fn remove(&self, id: &str) {
|
||||||
|
let mut channels = self.channels.write().await;
|
||||||
|
let mut configs = self.configs.write().await;
|
||||||
|
|
||||||
|
channels.remove(id);
|
||||||
|
configs.remove(id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ChannelBridge {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
109
crates/zclaw-channels/src/channel.rs
Normal file
109
crates/zclaw-channels/src/channel.rs
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
//! Channel trait and types
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use zclaw_types::{Result, AgentId};
|
||||||
|
|
||||||
|
/// Channel configuration
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChannelConfig {
|
||||||
|
/// Unique channel identifier
|
||||||
|
pub id: String,
|
||||||
|
/// Channel type (telegram, discord, slack, etc.)
|
||||||
|
pub channel_type: String,
|
||||||
|
/// Human-readable name
|
||||||
|
pub name: String,
|
||||||
|
/// Whether the channel is enabled
|
||||||
|
#[serde(default = "default_enabled")]
|
||||||
|
pub enabled: bool,
|
||||||
|
/// Channel-specific configuration
|
||||||
|
#[serde(default)]
|
||||||
|
pub config: serde_json::Value,
|
||||||
|
/// Associated agent for this channel
|
||||||
|
pub agent_id: Option<AgentId>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_enabled() -> bool { true }
|
||||||
|
|
||||||
|
/// Incoming message from a channel
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct IncomingMessage {
|
||||||
|
/// Message ID from the platform
|
||||||
|
pub platform_id: String,
|
||||||
|
/// Channel/conversation ID
|
||||||
|
pub conversation_id: String,
|
||||||
|
/// Sender information
|
||||||
|
pub sender: MessageSender,
|
||||||
|
/// Message content
|
||||||
|
pub content: MessageContent,
|
||||||
|
/// Timestamp
|
||||||
|
pub timestamp: i64,
|
||||||
|
/// Reply-to message ID if any
|
||||||
|
pub reply_to: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Message sender information
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct MessageSender {
|
||||||
|
pub id: String,
|
||||||
|
pub name: Option<String>,
|
||||||
|
pub username: Option<String>,
|
||||||
|
pub is_bot: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Message content types
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum MessageContent {
|
||||||
|
Text { text: String },
|
||||||
|
Image { url: String, caption: Option<String> },
|
||||||
|
File { url: String, filename: String },
|
||||||
|
Audio { url: String },
|
||||||
|
Video { url: String },
|
||||||
|
Location { latitude: f64, longitude: f64 },
|
||||||
|
Sticker { emoji: Option<String>, url: Option<String> },
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Outgoing message to a channel
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OutgoingMessage {
|
||||||
|
/// Conversation/channel ID to send to
|
||||||
|
pub conversation_id: String,
|
||||||
|
/// Message content
|
||||||
|
pub content: MessageContent,
|
||||||
|
/// Reply-to message ID if any
|
||||||
|
pub reply_to: Option<String>,
|
||||||
|
/// Whether to send silently (no notification)
|
||||||
|
pub silent: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Channel connection status
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
pub enum ChannelStatus {
|
||||||
|
Disconnected,
|
||||||
|
Connecting,
|
||||||
|
Connected,
|
||||||
|
Error(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Channel trait for platform adapters
|
||||||
|
#[async_trait]
|
||||||
|
pub trait Channel: Send + Sync {
|
||||||
|
/// Get channel configuration
|
||||||
|
fn config(&self) -> &ChannelConfig;
|
||||||
|
|
||||||
|
/// Connect to the platform
|
||||||
|
async fn connect(&self) -> Result<()>;
|
||||||
|
|
||||||
|
/// Disconnect from the platform
|
||||||
|
async fn disconnect(&self) -> Result<()>;
|
||||||
|
|
||||||
|
/// Get current connection status
|
||||||
|
async fn status(&self) -> ChannelStatus;
|
||||||
|
|
||||||
|
/// Send a message
|
||||||
|
async fn send(&self, message: OutgoingMessage) -> Result<String>;
|
||||||
|
|
||||||
|
/// Receive incoming messages (streaming)
|
||||||
|
async fn receive(&self) -> Result<tokio::sync::mpsc::Receiver<IncomingMessage>>;
|
||||||
|
}
|
||||||
11
crates/zclaw-channels/src/lib.rs
Normal file
11
crates/zclaw-channels/src/lib.rs
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
//! ZCLAW Channels
|
||||||
|
//!
|
||||||
|
//! External platform adapters for unified message handling.
|
||||||
|
|
||||||
|
mod channel;
|
||||||
|
mod bridge;
|
||||||
|
mod adapters;
|
||||||
|
|
||||||
|
pub use channel::*;
|
||||||
|
pub use bridge::*;
|
||||||
|
pub use adapters::*;
|
||||||
@@ -27,7 +27,7 @@ pub struct SqliteStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Database row structure for memory entry
|
/// Database row structure for memory entry
|
||||||
pub(crate) struct MemoryRow {
|
struct MemoryRow {
|
||||||
uri: String,
|
uri: String,
|
||||||
memory_type: String,
|
memory_type: String,
|
||||||
content: String,
|
content: String,
|
||||||
@@ -289,44 +289,6 @@ impl sqlx::FromRow<'_, SqliteRow> for MemoryRow {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Private helper methods on SqliteStorage (NOT in impl VikingStorage block)
|
|
||||||
impl SqliteStorage {
|
|
||||||
/// Fetch memories by scope with importance-based ordering.
|
|
||||||
/// Used internally by find() for scope-based queries.
|
|
||||||
pub(crate) async fn fetch_by_scope_priv(&self, scope: Option<&str>, limit: usize) -> Result<Vec<MemoryRow>> {
|
|
||||||
let rows = if let Some(scope) = scope {
|
|
||||||
sqlx::query_as::<_, MemoryRow>(
|
|
||||||
r#"
|
|
||||||
SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed, overview, abstract_summary
|
|
||||||
FROM memories
|
|
||||||
WHERE uri LIKE ?
|
|
||||||
ORDER BY importance DESC, access_count DESC
|
|
||||||
LIMIT ?
|
|
||||||
"#
|
|
||||||
)
|
|
||||||
.bind(format!("{}%", scope))
|
|
||||||
.bind(limit as i64)
|
|
||||||
.fetch_all(&self.pool)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ZclawError::StorageError(format!("Failed to fetch by scope: {}", e)))?
|
|
||||||
} else {
|
|
||||||
sqlx::query_as::<_, MemoryRow>(
|
|
||||||
r#"
|
|
||||||
SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed, overview, abstract_summary
|
|
||||||
FROM memories
|
|
||||||
ORDER BY importance DESC
|
|
||||||
LIMIT ?
|
|
||||||
"#
|
|
||||||
)
|
|
||||||
.bind(limit as i64)
|
|
||||||
.fetch_all(&self.pool)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ZclawError::StorageError(format!("Failed to fetch by scope: {}", e)))?
|
|
||||||
};
|
|
||||||
Ok(rows)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl VikingStorage for SqliteStorage {
|
impl VikingStorage for SqliteStorage {
|
||||||
async fn store(&self, entry: &MemoryEntry) -> Result<()> {
|
async fn store(&self, entry: &MemoryEntry) -> Result<()> {
|
||||||
@@ -412,61 +374,22 @@ impl VikingStorage for SqliteStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
|
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
|
||||||
let limit = options.limit.unwrap_or(50).max(20); // Fetch more candidates for reranking
|
// Get all matching entries
|
||||||
|
let rows = if let Some(ref scope) = options.scope {
|
||||||
// Strategy: use FTS5 for initial filtering when query is non-empty,
|
sqlx::query_as::<_, MemoryRow>(
|
||||||
// then score candidates with TF-IDF / embedding for precise ranking.
|
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed, overview, abstract_summary FROM memories WHERE uri LIKE ?"
|
||||||
// Fallback to scope-only scan when query is empty (e.g., "list all").
|
)
|
||||||
let rows = if !query.is_empty() {
|
.bind(format!("{}%", scope))
|
||||||
// FTS5-powered candidate retrieval (fast, index-based)
|
.fetch_all(&self.pool)
|
||||||
let fts_candidates = if let Some(ref scope) = options.scope {
|
.await
|
||||||
sqlx::query_as::<_, MemoryRow>(
|
.map_err(|e| ZclawError::StorageError(format!("Failed to find memories: {}", e)))?
|
||||||
r#"
|
|
||||||
SELECT m.uri, m.memory_type, m.content, m.keywords, m.importance,
|
|
||||||
m.access_count, m.created_at, m.last_accessed, m.overview, m.abstract_summary
|
|
||||||
FROM memories m
|
|
||||||
INNER JOIN memories_fts f ON m.uri = f.uri
|
|
||||||
WHERE f.memories_fts MATCH ?
|
|
||||||
AND m.uri LIKE ?
|
|
||||||
ORDER BY f.rank
|
|
||||||
LIMIT ?
|
|
||||||
"#
|
|
||||||
)
|
|
||||||
.bind(query)
|
|
||||||
.bind(format!("{}%", scope))
|
|
||||||
.bind(limit as i64)
|
|
||||||
.fetch_all(&self.pool)
|
|
||||||
.await
|
|
||||||
} else {
|
|
||||||
sqlx::query_as::<_, MemoryRow>(
|
|
||||||
r#"
|
|
||||||
SELECT m.uri, m.memory_type, m.content, m.keywords, m.importance,
|
|
||||||
m.access_count, m.created_at, m.last_accessed, m.overview, m.abstract_summary
|
|
||||||
FROM memories m
|
|
||||||
INNER JOIN memories_fts f ON m.uri = f.uri
|
|
||||||
WHERE f.memories_fts MATCH ?
|
|
||||||
ORDER BY f.rank
|
|
||||||
LIMIT ?
|
|
||||||
"#
|
|
||||||
)
|
|
||||||
.bind(query)
|
|
||||||
.bind(limit as i64)
|
|
||||||
.fetch_all(&self.pool)
|
|
||||||
.await
|
|
||||||
};
|
|
||||||
|
|
||||||
match fts_candidates {
|
|
||||||
Ok(rows) if !rows.is_empty() => rows,
|
|
||||||
Ok(_) | Err(_) => {
|
|
||||||
// FTS5 returned nothing or query syntax was invalid —
|
|
||||||
// fallback to scope-based scan (no full table scan unless no scope)
|
|
||||||
tracing::debug!("[SqliteStorage] FTS5 returned no results, falling back to scope scan");
|
|
||||||
self.fetch_by_scope_priv(options.scope.as_deref(), limit).await?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// Empty query: scope-based scan only (no FTS5 needed)
|
sqlx::query_as::<_, MemoryRow>(
|
||||||
self.fetch_by_scope_priv(options.scope.as_deref(), limit).await?
|
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed, overview, abstract_summary FROM memories"
|
||||||
|
)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ZclawError::StorageError(format!("Failed to find memories: {}", e)))?
|
||||||
};
|
};
|
||||||
|
|
||||||
// Convert to entries and compute semantic scores
|
// Convert to entries and compute semantic scores
|
||||||
@@ -541,8 +464,16 @@ impl VikingStorage for SqliteStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>> {
|
async fn find_by_prefix(&self, prefix: &str) -> Result<Vec<MemoryEntry>> {
|
||||||
let rows = self.fetch_by_scope_priv(Some(prefix), 100).await?;
|
let rows = sqlx::query_as::<_, MemoryRow>(
|
||||||
|
"SELECT uri, memory_type, content, keywords, importance, access_count, created_at, last_accessed, overview, abstract_summary FROM memories WHERE uri LIKE ?"
|
||||||
|
)
|
||||||
|
.bind(format!("{}%", prefix))
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await
|
||||||
|
.map_err(|e| ZclawError::StorageError(format!("Failed to find by prefix: {}", e)))?;
|
||||||
|
|
||||||
let entries = rows.iter().map(|row| self.row_to_entry(row)).collect();
|
let entries = rows.iter().map(|row| self.row_to_entry(row)).collect();
|
||||||
|
|
||||||
Ok(entries)
|
Ok(entries)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -553,13 +484,13 @@ impl VikingStorage for SqliteStorage {
|
|||||||
.await
|
.await
|
||||||
.map_err(|e| ZclawError::StorageError(format!("Failed to delete memory: {}", e)))?;
|
.map_err(|e| ZclawError::StorageError(format!("Failed to delete memory: {}", e)))?;
|
||||||
|
|
||||||
// Remove from FTS index
|
// Remove from FTS
|
||||||
let _ = sqlx::query("DELETE FROM memories_fts WHERE uri = ?")
|
let _ = sqlx::query("DELETE FROM memories_fts WHERE uri = ?")
|
||||||
.bind(uri)
|
.bind(uri)
|
||||||
.execute(&self.pool)
|
.execute(&self.pool)
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
// Remove from in-memory scorer
|
// Remove from scorer
|
||||||
let mut scorer = self.scorer.write().await;
|
let mut scorer = self.scorer.write().await;
|
||||||
scorer.remove_entry(uri);
|
scorer.remove_entry(uri);
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,3 @@ thiserror = { workspace = true }
|
|||||||
tracing = { workspace = true }
|
tracing = { workspace = true }
|
||||||
async-trait = { workspace = true }
|
async-trait = { workspace = true }
|
||||||
reqwest = { workspace = true }
|
reqwest = { workspace = true }
|
||||||
hmac = "0.12"
|
|
||||||
sha1 = "0.10"
|
|
||||||
base64 = { workspace = true }
|
|
||||||
|
|||||||
@@ -233,32 +233,17 @@ impl SpeechHand {
|
|||||||
state.playback = PlaybackState::Playing;
|
state.playback = PlaybackState::Playing;
|
||||||
state.current_text = Some(text.clone());
|
state.current_text = Some(text.clone());
|
||||||
|
|
||||||
// Determine TTS method based on provider:
|
// In real implementation, would call TTS API
|
||||||
// - Browser: frontend uses Web Speech API (zero deps, works offline)
|
|
||||||
// - OpenAI: frontend calls speech_tts command (high-quality, needs API key)
|
|
||||||
// - Others: future support
|
|
||||||
let tts_method = match state.config.provider {
|
|
||||||
TtsProvider::Browser => "browser",
|
|
||||||
TtsProvider::OpenAI => "openai_api",
|
|
||||||
TtsProvider::Azure => "azure_api",
|
|
||||||
TtsProvider::ElevenLabs => "elevenlabs_api",
|
|
||||||
TtsProvider::Local => "local_engine",
|
|
||||||
};
|
|
||||||
|
|
||||||
let estimated_duration_ms = (text.chars().count() as f64 / 5.0 * 1000.0) as u64;
|
|
||||||
|
|
||||||
Ok(HandResult::success(serde_json::json!({
|
Ok(HandResult::success(serde_json::json!({
|
||||||
"status": "speaking",
|
"status": "speaking",
|
||||||
"tts_method": tts_method,
|
|
||||||
"text": text,
|
"text": text,
|
||||||
"voice": voice_id,
|
"voice": voice_id,
|
||||||
"language": lang,
|
"language": lang,
|
||||||
"rate": actual_rate,
|
"rate": actual_rate,
|
||||||
"pitch": actual_pitch,
|
"pitch": actual_pitch,
|
||||||
"volume": actual_volume,
|
"volume": actual_volume,
|
||||||
"provider": format!("{:?}", state.config.provider).to_lowercase(),
|
"provider": state.config.provider,
|
||||||
"duration_ms": estimated_duration_ms,
|
"duration_ms": text.len() as u64 * 80, // Rough estimate
|
||||||
"instruction": "Frontend should play this via TTS engine"
|
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
SpeechAction::SpeakSsml { ssml, voice } => {
|
SpeechAction::SpeakSsml { ssml, voice } => {
|
||||||
|
|||||||
@@ -289,435 +289,117 @@ impl TwitterHand {
|
|||||||
c.clone()
|
c.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute tweet action — POST /2/tweets
|
/// Execute tweet action
|
||||||
async fn execute_tweet(&self, config: &TweetConfig) -> Result<Value> {
|
async fn execute_tweet(&self, config: &TweetConfig) -> Result<Value> {
|
||||||
let creds = self.get_credentials().await
|
let _creds = self.get_credentials().await
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
// Simulated tweet response (actual implementation would use Twitter API)
|
||||||
let body = json!({ "text": config.text });
|
// In production, this would call Twitter API v2: POST /2/tweets
|
||||||
|
|
||||||
let response = client.post("https://api.twitter.com/2/tweets")
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.json(&body)
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Twitter API request failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
|
||||||
|
|
||||||
if !status.is_success() {
|
|
||||||
tracing::warn!("[TwitterHand] Tweet failed: {} - {}", status, response_text);
|
|
||||||
return Ok(json!({
|
|
||||||
"success": false,
|
|
||||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
|
||||||
"status_code": status.as_u16()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the response to extract tweet_id
|
|
||||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
|
||||||
|
|
||||||
Ok(json!({
|
Ok(json!({
|
||||||
"success": true,
|
"success": true,
|
||||||
"tweet_id": parsed["data"]["id"].as_str().unwrap_or("unknown"),
|
"tweet_id": format!("simulated_{}", chrono::Utc::now().timestamp()),
|
||||||
"text": config.text,
|
"text": config.text,
|
||||||
"raw_response": parsed,
|
"created_at": chrono::Utc::now().to_rfc3339(),
|
||||||
"message": "Tweet posted successfully"
|
"message": "Tweet posted successfully (simulated)",
|
||||||
|
"note": "Connect Twitter API credentials for actual posting"
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute search action — GET /2/tweets/search/recent
|
/// Execute search action
|
||||||
async fn execute_search(&self, config: &SearchConfig) -> Result<Value> {
|
async fn execute_search(&self, config: &SearchConfig) -> Result<Value> {
|
||||||
let creds = self.get_credentials().await
|
let _creds = self.get_credentials().await
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
// Simulated search response
|
||||||
let max = config.max_results.max(10).min(100);
|
// In production, this would call Twitter API v2: GET /2/tweets/search/recent
|
||||||
|
|
||||||
let response = client.get("https://api.twitter.com/2/tweets/search/recent")
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.query(&[
|
|
||||||
("query", config.query.as_str()),
|
|
||||||
("max_results", max.to_string().as_str()),
|
|
||||||
("tweet.fields", "created_at,author_id,public_metrics,lang"),
|
|
||||||
])
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Twitter search failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
|
||||||
|
|
||||||
if !status.is_success() {
|
|
||||||
return Ok(json!({
|
|
||||||
"success": false,
|
|
||||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
|
||||||
"status_code": status.as_u16()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
|
||||||
|
|
||||||
Ok(json!({
|
Ok(json!({
|
||||||
"success": true,
|
"success": true,
|
||||||
"query": config.query,
|
"query": config.query,
|
||||||
"tweets": parsed["data"].as_array().cloned().unwrap_or_default(),
|
"tweets": [],
|
||||||
"meta": parsed["meta"].clone(),
|
"meta": {
|
||||||
"message": "Search completed"
|
"result_count": 0,
|
||||||
|
"newest_id": null,
|
||||||
|
"oldest_id": null,
|
||||||
|
"next_token": null
|
||||||
|
},
|
||||||
|
"message": "Search completed (simulated - no actual results without API)",
|
||||||
|
"note": "Connect Twitter API credentials for actual search results"
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute timeline action — GET /2/users/:id/timelines/reverse_chronological
|
/// Execute timeline action
|
||||||
async fn execute_timeline(&self, config: &TimelineConfig) -> Result<Value> {
|
async fn execute_timeline(&self, config: &TimelineConfig) -> Result<Value> {
|
||||||
let creds = self.get_credentials().await
|
let _creds = self.get_credentials().await
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
// Simulated timeline response
|
||||||
let user_id = config.user_id.as_deref().unwrap_or("me");
|
|
||||||
let url = format!("https://api.twitter.com/2/users/{}/timelines/reverse_chronological", user_id);
|
|
||||||
let max = config.max_results.max(5).min(100);
|
|
||||||
|
|
||||||
let response = client.get(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.query(&[
|
|
||||||
("max_results", max.to_string().as_str()),
|
|
||||||
("tweet.fields", "created_at,author_id,public_metrics"),
|
|
||||||
])
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Timeline fetch failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
|
||||||
|
|
||||||
if !status.is_success() {
|
|
||||||
return Ok(json!({
|
|
||||||
"success": false,
|
|
||||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
|
||||||
"status_code": status.as_u16()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
|
||||||
|
|
||||||
Ok(json!({
|
Ok(json!({
|
||||||
"success": true,
|
"success": true,
|
||||||
"user_id": user_id,
|
"user_id": config.user_id,
|
||||||
"tweets": parsed["data"].as_array().cloned().unwrap_or_default(),
|
"tweets": [],
|
||||||
"meta": parsed["meta"].clone(),
|
"meta": {
|
||||||
"message": "Timeline fetched"
|
"result_count": 0,
|
||||||
|
"newest_id": null,
|
||||||
|
"oldest_id": null,
|
||||||
|
"next_token": null
|
||||||
|
},
|
||||||
|
"message": "Timeline fetched (simulated)",
|
||||||
|
"note": "Connect Twitter API credentials for actual timeline"
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get tweet by ID — GET /2/tweets/:id
|
/// Get tweet by ID
|
||||||
async fn execute_get_tweet(&self, tweet_id: &str) -> Result<Value> {
|
async fn execute_get_tweet(&self, tweet_id: &str) -> Result<Value> {
|
||||||
let creds = self.get_credentials().await
|
let _creds = self.get_credentials().await
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let url = format!("https://api.twitter.com/2/tweets/{}", tweet_id);
|
|
||||||
|
|
||||||
let response = client.get(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.query(&[("tweet.fields", "created_at,author_id,public_metrics,lang")])
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Tweet lookup failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
|
||||||
|
|
||||||
if !status.is_success() {
|
|
||||||
return Ok(json!({
|
|
||||||
"success": false,
|
|
||||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
|
||||||
"status_code": status.as_u16()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
|
||||||
|
|
||||||
Ok(json!({
|
Ok(json!({
|
||||||
"success": true,
|
"success": true,
|
||||||
"tweet_id": tweet_id,
|
"tweet_id": tweet_id,
|
||||||
"tweet": parsed["data"].clone(),
|
"tweet": null,
|
||||||
"message": "Tweet fetched"
|
"message": "Tweet lookup (simulated)",
|
||||||
|
"note": "Connect Twitter API credentials for actual tweet data"
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get user by username — GET /2/users/by/username/:username
|
/// Get user by username
|
||||||
async fn execute_get_user(&self, username: &str) -> Result<Value> {
|
async fn execute_get_user(&self, username: &str) -> Result<Value> {
|
||||||
let creds = self.get_credentials().await
|
let _creds = self.get_credentials().await
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let url = format!("https://api.twitter.com/2/users/by/username/{}", username);
|
|
||||||
|
|
||||||
let response = client.get(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.query(&[("user.fields", "created_at,description,public_metrics,verified")])
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("User lookup failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
|
||||||
|
|
||||||
if !status.is_success() {
|
|
||||||
return Ok(json!({
|
|
||||||
"success": false,
|
|
||||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
|
||||||
"status_code": status.as_u16()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
|
||||||
|
|
||||||
Ok(json!({
|
Ok(json!({
|
||||||
"success": true,
|
"success": true,
|
||||||
"username": username,
|
"username": username,
|
||||||
"user": parsed["data"].clone(),
|
"user": null,
|
||||||
"message": "User fetched"
|
"message": "User lookup (simulated)",
|
||||||
|
"note": "Connect Twitter API credentials for actual user data"
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute like action — PUT /2/users/:id/likes
|
/// Execute like action
|
||||||
async fn execute_like(&self, tweet_id: &str) -> Result<Value> {
|
async fn execute_like(&self, tweet_id: &str) -> Result<Value> {
|
||||||
let creds = self.get_credentials().await
|
let _creds = self.get_credentials().await
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
// Note: For like/retweet, we need OAuth 1.0a user context
|
|
||||||
// Using Bearer token as fallback (may not work for all endpoints)
|
|
||||||
let url = "https://api.twitter.com/2/users/me/likes";
|
|
||||||
|
|
||||||
let response = client.post(url)
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.json(&json!({"tweet_id": tweet_id}))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Like failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await.unwrap_or_default();
|
|
||||||
|
|
||||||
Ok(json!({
|
Ok(json!({
|
||||||
"success": status.is_success(),
|
"success": true,
|
||||||
"tweet_id": tweet_id,
|
"tweet_id": tweet_id,
|
||||||
"action": "liked",
|
"action": "liked",
|
||||||
"status_code": status.as_u16(),
|
"message": "Tweet liked (simulated)"
|
||||||
"message": if status.is_success() { "Tweet liked" } else { &response_text }
|
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Execute retweet action — POST /2/users/:id/retweets
|
/// Execute retweet action
|
||||||
async fn execute_retweet(&self, tweet_id: &str) -> Result<Value> {
|
async fn execute_retweet(&self, tweet_id: &str) -> Result<Value> {
|
||||||
let creds = self.get_credentials().await
|
let _creds = self.get_credentials().await
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let url = "https://api.twitter.com/2/users/me/retweets";
|
|
||||||
|
|
||||||
let response = client.post(url)
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.json(&json!({"tweet_id": tweet_id}))
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Retweet failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await.unwrap_or_default();
|
|
||||||
|
|
||||||
Ok(json!({
|
Ok(json!({
|
||||||
"success": status.is_success(),
|
"success": true,
|
||||||
"tweet_id": tweet_id,
|
"tweet_id": tweet_id,
|
||||||
"action": "retweeted",
|
"action": "retweeted",
|
||||||
"status_code": status.as_u16(),
|
"message": "Tweet retweeted (simulated)"
|
||||||
"message": if status.is_success() { "Tweet retweeted" } else { &response_text }
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute delete tweet — DELETE /2/tweets/:id
|
|
||||||
async fn execute_delete_tweet(&self, tweet_id: &str) -> Result<Value> {
|
|
||||||
let creds = self.get_credentials().await
|
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let url = format!("https://api.twitter.com/2/tweets/{}", tweet_id);
|
|
||||||
|
|
||||||
let response = client.delete(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Delete tweet failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await.unwrap_or_default();
|
|
||||||
|
|
||||||
Ok(json!({
|
|
||||||
"success": status.is_success(),
|
|
||||||
"tweet_id": tweet_id,
|
|
||||||
"action": "deleted",
|
|
||||||
"status_code": status.as_u16(),
|
|
||||||
"message": if status.is_success() { "Tweet deleted" } else { &response_text }
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute unretweet — DELETE /2/users/:id/retweets/:tweet_id
|
|
||||||
async fn execute_unretweet(&self, tweet_id: &str) -> Result<Value> {
|
|
||||||
let creds = self.get_credentials().await
|
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let url = format!("https://api.twitter.com/2/users/me/retweets/{}", tweet_id);
|
|
||||||
|
|
||||||
let response = client.delete(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Unretweet failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await.unwrap_or_default();
|
|
||||||
|
|
||||||
Ok(json!({
|
|
||||||
"success": status.is_success(),
|
|
||||||
"tweet_id": tweet_id,
|
|
||||||
"action": "unretweeted",
|
|
||||||
"status_code": status.as_u16(),
|
|
||||||
"message": if status.is_success() { "Tweet unretweeted" } else { &response_text }
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute unlike — DELETE /2/users/:id/likes/:tweet_id
|
|
||||||
async fn execute_unlike(&self, tweet_id: &str) -> Result<Value> {
|
|
||||||
let creds = self.get_credentials().await
|
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let url = format!("https://api.twitter.com/2/users/me/likes/{}", tweet_id);
|
|
||||||
|
|
||||||
let response = client.delete(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Unlike failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await.unwrap_or_default();
|
|
||||||
|
|
||||||
Ok(json!({
|
|
||||||
"success": status.is_success(),
|
|
||||||
"tweet_id": tweet_id,
|
|
||||||
"action": "unliked",
|
|
||||||
"status_code": status.as_u16(),
|
|
||||||
"message": if status.is_success() { "Tweet unliked" } else { &response_text }
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute followers fetch — GET /2/users/:id/followers
|
|
||||||
async fn execute_followers(&self, user_id: &str, max_results: Option<u32>) -> Result<Value> {
|
|
||||||
let creds = self.get_credentials().await
|
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let url = format!("https://api.twitter.com/2/users/{}/followers", user_id);
|
|
||||||
let max = max_results.unwrap_or(100).max(1).min(1000);
|
|
||||||
|
|
||||||
let response = client.get(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.query(&[
|
|
||||||
("max_results", max.to_string()),
|
|
||||||
("user.fields", "created_at,description,public_metrics,verified,profile_image_url".to_string()),
|
|
||||||
])
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Followers fetch failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
|
||||||
|
|
||||||
if !status.is_success() {
|
|
||||||
return Ok(json!({
|
|
||||||
"success": false,
|
|
||||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
|
||||||
"status_code": status.as_u16()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
|
||||||
|
|
||||||
Ok(json!({
|
|
||||||
"success": true,
|
|
||||||
"user_id": user_id,
|
|
||||||
"followers": parsed["data"].as_array().cloned().unwrap_or_default(),
|
|
||||||
"meta": parsed["meta"].clone(),
|
|
||||||
"message": "Followers fetched"
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Execute following fetch — GET /2/users/:id/following
|
|
||||||
async fn execute_following(&self, user_id: &str, max_results: Option<u32>) -> Result<Value> {
|
|
||||||
let creds = self.get_credentials().await
|
|
||||||
.ok_or_else(|| zclaw_types::ZclawError::HandError("Twitter credentials not configured".to_string()))?;
|
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let url = format!("https://api.twitter.com/2/users/{}/following", user_id);
|
|
||||||
let max = max_results.unwrap_or(100).max(1).min(1000);
|
|
||||||
|
|
||||||
let response = client.get(&url)
|
|
||||||
.header("Authorization", format!("Bearer {}", creds.bearer_token.as_deref().unwrap_or("")))
|
|
||||||
.header("User-Agent", "ZCLAW/1.0")
|
|
||||||
.query(&[
|
|
||||||
("max_results", max.to_string()),
|
|
||||||
("user.fields", "created_at,description,public_metrics,verified,profile_image_url".to_string()),
|
|
||||||
])
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Following fetch failed: {}", e)))?;
|
|
||||||
|
|
||||||
let status = response.status();
|
|
||||||
let response_text = response.text().await
|
|
||||||
.map_err(|e| zclaw_types::ZclawError::HandError(format!("Failed to read response: {}", e)))?;
|
|
||||||
|
|
||||||
if !status.is_success() {
|
|
||||||
return Ok(json!({
|
|
||||||
"success": false,
|
|
||||||
"error": format!("Twitter API returned {}: {}", status, response_text),
|
|
||||||
"status_code": status.as_u16()
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
|
|
||||||
let parsed: Value = serde_json::from_str(&response_text).unwrap_or(json!({"raw": response_text}));
|
|
||||||
|
|
||||||
Ok(json!({
|
|
||||||
"success": true,
|
|
||||||
"user_id": user_id,
|
|
||||||
"following": parsed["data"].as_array().cloned().unwrap_or_default(),
|
|
||||||
"meta": parsed["meta"].clone(),
|
|
||||||
"message": "Following fetched"
|
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -779,17 +461,54 @@ impl Hand for TwitterHand {
|
|||||||
|
|
||||||
let result = match action {
|
let result = match action {
|
||||||
TwitterAction::Tweet { config } => self.execute_tweet(&config).await?,
|
TwitterAction::Tweet { config } => self.execute_tweet(&config).await?,
|
||||||
TwitterAction::DeleteTweet { tweet_id } => self.execute_delete_tweet(&tweet_id).await?,
|
TwitterAction::DeleteTweet { tweet_id } => {
|
||||||
|
json!({
|
||||||
|
"success": true,
|
||||||
|
"tweet_id": tweet_id,
|
||||||
|
"action": "deleted",
|
||||||
|
"message": "Tweet deleted (simulated)"
|
||||||
|
})
|
||||||
|
}
|
||||||
TwitterAction::Retweet { tweet_id } => self.execute_retweet(&tweet_id).await?,
|
TwitterAction::Retweet { tweet_id } => self.execute_retweet(&tweet_id).await?,
|
||||||
TwitterAction::Unretweet { tweet_id } => self.execute_unretweet(&tweet_id).await?,
|
TwitterAction::Unretweet { tweet_id } => {
|
||||||
|
json!({
|
||||||
|
"success": true,
|
||||||
|
"tweet_id": tweet_id,
|
||||||
|
"action": "unretweeted",
|
||||||
|
"message": "Tweet unretweeted (simulated)"
|
||||||
|
})
|
||||||
|
}
|
||||||
TwitterAction::Like { tweet_id } => self.execute_like(&tweet_id).await?,
|
TwitterAction::Like { tweet_id } => self.execute_like(&tweet_id).await?,
|
||||||
TwitterAction::Unlike { tweet_id } => self.execute_unlike(&tweet_id).await?,
|
TwitterAction::Unlike { tweet_id } => {
|
||||||
|
json!({
|
||||||
|
"success": true,
|
||||||
|
"tweet_id": tweet_id,
|
||||||
|
"action": "unliked",
|
||||||
|
"message": "Tweet unliked (simulated)"
|
||||||
|
})
|
||||||
|
}
|
||||||
TwitterAction::Search { config } => self.execute_search(&config).await?,
|
TwitterAction::Search { config } => self.execute_search(&config).await?,
|
||||||
TwitterAction::Timeline { config } => self.execute_timeline(&config).await?,
|
TwitterAction::Timeline { config } => self.execute_timeline(&config).await?,
|
||||||
TwitterAction::GetTweet { tweet_id } => self.execute_get_tweet(&tweet_id).await?,
|
TwitterAction::GetTweet { tweet_id } => self.execute_get_tweet(&tweet_id).await?,
|
||||||
TwitterAction::GetUser { username } => self.execute_get_user(&username).await?,
|
TwitterAction::GetUser { username } => self.execute_get_user(&username).await?,
|
||||||
TwitterAction::Followers { user_id, max_results } => self.execute_followers(&user_id, max_results).await?,
|
TwitterAction::Followers { user_id, max_results } => {
|
||||||
TwitterAction::Following { user_id, max_results } => self.execute_following(&user_id, max_results).await?,
|
json!({
|
||||||
|
"success": true,
|
||||||
|
"user_id": user_id,
|
||||||
|
"followers": [],
|
||||||
|
"max_results": max_results.unwrap_or(100),
|
||||||
|
"message": "Followers fetched (simulated)"
|
||||||
|
})
|
||||||
|
}
|
||||||
|
TwitterAction::Following { user_id, max_results } => {
|
||||||
|
json!({
|
||||||
|
"success": true,
|
||||||
|
"user_id": user_id,
|
||||||
|
"following": [],
|
||||||
|
"max_results": max_results.unwrap_or(100),
|
||||||
|
"message": "Following fetched (simulated)"
|
||||||
|
})
|
||||||
|
}
|
||||||
TwitterAction::CheckCredentials => self.execute_check_credentials().await?,
|
TwitterAction::CheckCredentials => self.execute_check_credentials().await?,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -54,11 +54,6 @@ pub struct LlmConfig {
|
|||||||
/// Temperature
|
/// Temperature
|
||||||
#[serde(default = "default_temperature")]
|
#[serde(default = "default_temperature")]
|
||||||
pub temperature: f32,
|
pub temperature: f32,
|
||||||
|
|
||||||
/// Context window size in tokens (default: 128000)
|
|
||||||
/// Used to calculate dynamic compaction threshold.
|
|
||||||
#[serde(default = "default_context_window")]
|
|
||||||
pub context_window: u32,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LlmConfig {
|
impl LlmConfig {
|
||||||
@@ -71,7 +66,6 @@ impl LlmConfig {
|
|||||||
api_protocol: ApiProtocol::OpenAI,
|
api_protocol: ApiProtocol::OpenAI,
|
||||||
max_tokens: default_max_tokens(),
|
max_tokens: default_max_tokens(),
|
||||||
temperature: default_temperature(),
|
temperature: default_temperature(),
|
||||||
context_window: default_context_window(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,10 +140,6 @@ fn default_temperature() -> f32 {
|
|||||||
0.7
|
0.7
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_context_window() -> u32 {
|
|
||||||
128000
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for KernelConfig {
|
impl Default for KernelConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -161,7 +151,6 @@ impl Default for KernelConfig {
|
|||||||
api_protocol: ApiProtocol::OpenAI,
|
api_protocol: ApiProtocol::OpenAI,
|
||||||
max_tokens: default_max_tokens(),
|
max_tokens: default_max_tokens(),
|
||||||
temperature: default_temperature(),
|
temperature: default_temperature(),
|
||||||
context_window: default_context_window(),
|
|
||||||
},
|
},
|
||||||
skills_dir: default_skills_dir(),
|
skills_dir: default_skills_dir(),
|
||||||
}
|
}
|
||||||
@@ -356,17 +345,6 @@ impl KernelConfig {
|
|||||||
pub fn temperature(&self) -> f32 {
|
pub fn temperature(&self) -> f32 {
|
||||||
self.llm.temperature
|
self.llm.temperature
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get context window size in tokens
|
|
||||||
pub fn context_window(&self) -> u32 {
|
|
||||||
self.llm.context_window
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Dynamic compaction threshold = context_window * 0.6
|
|
||||||
/// Leaves 40% headroom for system prompt + response tokens
|
|
||||||
pub fn compaction_threshold(&self) -> usize {
|
|
||||||
(self.llm.context_window as f64 * 0.6) as usize
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// === Preset configurations for common providers ===
|
// === Preset configurations for common providers ===
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -8,8 +8,6 @@ mod capabilities;
|
|||||||
mod events;
|
mod events;
|
||||||
pub mod trigger_manager;
|
pub mod trigger_manager;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod scheduler;
|
|
||||||
pub mod skill_router;
|
|
||||||
#[cfg(feature = "multi-agent")]
|
#[cfg(feature = "multi-agent")]
|
||||||
pub mod director;
|
pub mod director;
|
||||||
pub mod generation;
|
pub mod generation;
|
||||||
@@ -23,16 +21,8 @@ pub use config::*;
|
|||||||
pub use trigger_manager::{TriggerManager, TriggerEntry, TriggerUpdateRequest, TriggerManagerConfig};
|
pub use trigger_manager::{TriggerManager, TriggerEntry, TriggerUpdateRequest, TriggerManagerConfig};
|
||||||
#[cfg(feature = "multi-agent")]
|
#[cfg(feature = "multi-agent")]
|
||||||
pub use director::*;
|
pub use director::*;
|
||||||
#[cfg(feature = "multi-agent")]
|
|
||||||
pub use zclaw_protocols::{
|
|
||||||
A2aRouter, A2aAgentProfile, A2aCapability, A2aEnvelope, A2aMessageType, A2aRecipient,
|
|
||||||
A2aReceiver,
|
|
||||||
BasicA2aClient,
|
|
||||||
A2aClient,
|
|
||||||
};
|
|
||||||
pub use generation::*;
|
pub use generation::*;
|
||||||
pub use export::{ExportFormat, ExportOptions, ExportResult, Exporter, export_classroom};
|
pub use export::{ExportFormat, ExportOptions, ExportResult, Exporter, export_classroom};
|
||||||
|
|
||||||
// Re-export hands types for convenience
|
// Re-export hands types for convenience
|
||||||
pub use zclaw_hands::{HandRegistry, HandContext, HandResult, HandConfig, Hand, HandStatus};
|
pub use zclaw_hands::{HandRegistry, HandContext, HandResult, HandConfig, Hand, HandStatus};
|
||||||
pub use scheduler::SchedulerService;
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ pub struct AgentRegistry {
|
|||||||
agents: DashMap<AgentId, AgentConfig>,
|
agents: DashMap<AgentId, AgentConfig>,
|
||||||
states: DashMap<AgentId, AgentState>,
|
states: DashMap<AgentId, AgentState>,
|
||||||
created_at: DashMap<AgentId, chrono::DateTime<Utc>>,
|
created_at: DashMap<AgentId, chrono::DateTime<Utc>>,
|
||||||
message_counts: DashMap<AgentId, u64>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentRegistry {
|
impl AgentRegistry {
|
||||||
@@ -18,7 +17,6 @@ impl AgentRegistry {
|
|||||||
agents: DashMap::new(),
|
agents: DashMap::new(),
|
||||||
states: DashMap::new(),
|
states: DashMap::new(),
|
||||||
created_at: DashMap::new(),
|
created_at: DashMap::new(),
|
||||||
message_counts: DashMap::new(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -35,7 +33,6 @@ impl AgentRegistry {
|
|||||||
self.agents.remove(id);
|
self.agents.remove(id);
|
||||||
self.states.remove(id);
|
self.states.remove(id);
|
||||||
self.created_at.remove(id);
|
self.created_at.remove(id);
|
||||||
self.message_counts.remove(id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get an agent by ID
|
/// Get an agent by ID
|
||||||
@@ -56,7 +53,7 @@ impl AgentRegistry {
|
|||||||
model: config.model.model.clone(),
|
model: config.model.model.clone(),
|
||||||
provider: config.model.provider.clone(),
|
provider: config.model.provider.clone(),
|
||||||
state,
|
state,
|
||||||
message_count: self.message_counts.get(id).map(|c| *c as usize).unwrap_or(0),
|
message_count: 0, // TODO: Track this
|
||||||
created_at,
|
created_at,
|
||||||
updated_at: Utc::now(),
|
updated_at: Utc::now(),
|
||||||
})
|
})
|
||||||
@@ -86,11 +83,6 @@ impl AgentRegistry {
|
|||||||
pub fn count(&self) -> usize {
|
pub fn count(&self) -> usize {
|
||||||
self.agents.len()
|
self.agents.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Increment message count for an agent
|
|
||||||
pub fn increment_message_count(&self, id: &AgentId) {
|
|
||||||
self.message_counts.entry(*id).and_modify(|c| *c += 1).or_insert(1);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for AgentRegistry {
|
impl Default for AgentRegistry {
|
||||||
|
|||||||
@@ -1,341 +0,0 @@
|
|||||||
//! Scheduler service for automatic trigger execution
|
|
||||||
//!
|
|
||||||
//! Periodically scans scheduled triggers and fires them at the appropriate time.
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use chrono::{Datelike, Timelike};
|
|
||||||
use tokio::sync::RwLock;
|
|
||||||
use tokio::time::{self, Duration};
|
|
||||||
use zclaw_types::Result;
|
|
||||||
use crate::Kernel;
|
|
||||||
|
|
||||||
/// Scheduler service that runs in the background and executes scheduled triggers
|
|
||||||
pub struct SchedulerService {
|
|
||||||
kernel: Arc<RwLock<Option<Kernel>>>,
|
|
||||||
running: Arc<AtomicBool>,
|
|
||||||
check_interval: Duration,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SchedulerService {
|
|
||||||
/// Create a new scheduler service
|
|
||||||
pub fn new(kernel: Arc<RwLock<Option<Kernel>>>, check_interval_secs: u64) -> Self {
|
|
||||||
Self {
|
|
||||||
kernel,
|
|
||||||
running: Arc::new(AtomicBool::new(false)),
|
|
||||||
check_interval: Duration::from_secs(check_interval_secs),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Start the scheduler loop in the background
|
|
||||||
pub fn start(&self) {
|
|
||||||
if self.running.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst).is_err() {
|
|
||||||
tracing::warn!("[Scheduler] Already running, ignoring start request");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let kernel = self.kernel.clone();
|
|
||||||
let running = self.running.clone();
|
|
||||||
let interval = self.check_interval;
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
|
||||||
tracing::info!("[Scheduler] Starting scheduler loop with {}s interval", interval.as_secs());
|
|
||||||
|
|
||||||
let mut ticker = time::interval(interval);
|
|
||||||
// First tick fires immediately — skip it
|
|
||||||
ticker.tick().await;
|
|
||||||
|
|
||||||
while running.load(Ordering::Relaxed) {
|
|
||||||
ticker.tick().await;
|
|
||||||
|
|
||||||
if !running.load(Ordering::Relaxed) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Err(e) = Self::check_and_fire_scheduled_triggers(&kernel).await {
|
|
||||||
tracing::error!("[Scheduler] Error checking triggers: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::info!("[Scheduler] Scheduler loop stopped");
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Stop the scheduler loop
|
|
||||||
pub fn stop(&self) {
|
|
||||||
self.running.store(false, Ordering::Relaxed);
|
|
||||||
tracing::info!("[Scheduler] Stop requested");
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if the scheduler is running
|
|
||||||
pub fn is_running(&self) -> bool {
|
|
||||||
self.running.load(Ordering::Relaxed)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check all scheduled triggers and fire those that are due
|
|
||||||
async fn check_and_fire_scheduled_triggers(
|
|
||||||
kernel_lock: &Arc<RwLock<Option<Kernel>>>,
|
|
||||||
) -> Result<()> {
|
|
||||||
let kernel_read = kernel_lock.read().await;
|
|
||||||
let kernel = match kernel_read.as_ref() {
|
|
||||||
Some(k) => k,
|
|
||||||
None => return Ok(()),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Get all triggers
|
|
||||||
let triggers = kernel.list_triggers().await;
|
|
||||||
let now = chrono::Utc::now();
|
|
||||||
|
|
||||||
// Filter to enabled Schedule triggers
|
|
||||||
let scheduled: Vec<_> = triggers.iter()
|
|
||||||
.filter(|t| {
|
|
||||||
t.config.enabled && matches!(t.config.trigger_type, zclaw_hands::TriggerType::Schedule { .. })
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if scheduled.is_empty() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
tracing::debug!("[Scheduler] Checking {} scheduled triggers", scheduled.len());
|
|
||||||
|
|
||||||
// Drop the read lock before executing
|
|
||||||
let to_execute: Vec<(String, String, String)> = scheduled.iter()
|
|
||||||
.filter_map(|t| {
|
|
||||||
if let zclaw_hands::TriggerType::Schedule { ref cron } = t.config.trigger_type {
|
|
||||||
// Simple cron matching: check if we should fire now
|
|
||||||
if Self::should_fire_cron(cron, &now) {
|
|
||||||
Some((t.config.id.clone(), t.config.hand_id.clone(), cron.clone()))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
drop(kernel_read);
|
|
||||||
|
|
||||||
// Execute due triggers (with write lock since execute_hand may need it)
|
|
||||||
for (trigger_id, hand_id, cron_expr) in to_execute {
|
|
||||||
tracing::info!(
|
|
||||||
"[Scheduler] Firing scheduled trigger '{}' → hand '{}' (cron: {})",
|
|
||||||
trigger_id, hand_id, cron_expr
|
|
||||||
);
|
|
||||||
|
|
||||||
let kernel_read = kernel_lock.read().await;
|
|
||||||
if let Some(kernel) = kernel_read.as_ref() {
|
|
||||||
let trigger_source = zclaw_types::TriggerSource::Scheduled {
|
|
||||||
trigger_id: trigger_id.clone(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let input = serde_json::json!({
|
|
||||||
"trigger_id": trigger_id,
|
|
||||||
"trigger_type": "schedule",
|
|
||||||
"cron": cron_expr,
|
|
||||||
"fired_at": now.to_rfc3339(),
|
|
||||||
});
|
|
||||||
|
|
||||||
match kernel.execute_hand_with_source(&hand_id, input, trigger_source).await {
|
|
||||||
Ok((_result, run_id)) => {
|
|
||||||
tracing::info!(
|
|
||||||
"[Scheduler] Successfully fired trigger '{}' → run {}",
|
|
||||||
trigger_id, run_id
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(
|
|
||||||
"[Scheduler] Failed to execute trigger '{}': {}",
|
|
||||||
trigger_id, e
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Simple cron expression matcher
|
|
||||||
///
|
|
||||||
/// Supports basic cron format: `minute hour day month weekday`
|
|
||||||
/// Also supports interval shorthand: `every:Ns`, `every:Nm`, `every:Nh`
|
|
||||||
fn should_fire_cron(cron: &str, now: &chrono::DateTime<chrono::Utc>) -> bool {
|
|
||||||
let cron = cron.trim();
|
|
||||||
|
|
||||||
// Handle interval shorthand: "every:30s", "every:5m", "every:1h"
|
|
||||||
if let Some(interval_str) = cron.strip_prefix("every:") {
|
|
||||||
return Self::check_interval_shorthand(interval_str, now);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle ISO timestamp for one-shot: "2026-03-29T10:00:00Z"
|
|
||||||
if cron.contains('T') && cron.contains('-') {
|
|
||||||
if let Ok(target) = chrono::DateTime::parse_from_rfc3339(cron) {
|
|
||||||
let target_utc = target.with_timezone(&chrono::Utc);
|
|
||||||
// Fire if within the check window (± check_interval/2, approx 30s)
|
|
||||||
let diff = (*now - target_utc).num_seconds().abs();
|
|
||||||
return diff <= 30;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Standard 5-field cron: minute hour day_of_month month day_of_week
|
|
||||||
let parts: Vec<&str> = cron.split_whitespace().collect();
|
|
||||||
if parts.len() != 5 {
|
|
||||||
tracing::warn!("[Scheduler] Invalid cron expression (expected 5 fields): '{}'", cron);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let minute = now.minute() as i32;
|
|
||||||
let hour = now.hour() as i32;
|
|
||||||
let day = now.day() as i32;
|
|
||||||
let month = now.month() as i32;
|
|
||||||
let weekday = now.weekday().num_days_from_monday() as i32; // Mon=0..Sun=6
|
|
||||||
|
|
||||||
Self::cron_field_matches(parts[0], minute)
|
|
||||||
&& Self::cron_field_matches(parts[1], hour)
|
|
||||||
&& Self::cron_field_matches(parts[2], day)
|
|
||||||
&& Self::cron_field_matches(parts[3], month)
|
|
||||||
&& Self::cron_field_matches(parts[4], weekday)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if a single cron field matches the current value
|
|
||||||
fn cron_field_matches(field: &str, value: i32) -> bool {
|
|
||||||
if field == "*" || field == "?" {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle step: */N
|
|
||||||
if let Some(step_str) = field.strip_prefix("*/") {
|
|
||||||
if let Ok(step) = step_str.parse::<i32>() {
|
|
||||||
if step > 0 {
|
|
||||||
return value % step == 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle range: N-M
|
|
||||||
if field.contains('-') {
|
|
||||||
let range_parts: Vec<&str> = field.split('-').collect();
|
|
||||||
if range_parts.len() == 2 {
|
|
||||||
if let (Ok(start), Ok(end)) = (range_parts[0].parse::<i32>(), range_parts[1].parse::<i32>()) {
|
|
||||||
return value >= start && value <= end;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle list: N,M,O
|
|
||||||
if field.contains(',') {
|
|
||||||
return field.split(',').any(|part| {
|
|
||||||
part.trim().parse::<i32>().map(|p| p == value).unwrap_or(false)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simple value
|
|
||||||
field.parse::<i32>().map(|p| p == value).unwrap_or(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check interval shorthand expressions
|
|
||||||
fn check_interval_shorthand(interval: &str, now: &chrono::DateTime<chrono::Utc>) -> bool {
|
|
||||||
let (num_str, unit) = if interval.ends_with('s') {
|
|
||||||
(&interval[..interval.len()-1], 's')
|
|
||||||
} else if interval.ends_with('m') {
|
|
||||||
(&interval[..interval.len()-1], 'm')
|
|
||||||
} else if interval.ends_with('h') {
|
|
||||||
(&interval[..interval.len()-1], 'h')
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
|
|
||||||
let num: i64 = match num_str.parse() {
|
|
||||||
Ok(n) => n,
|
|
||||||
Err(_) => return false,
|
|
||||||
};
|
|
||||||
|
|
||||||
if num <= 0 {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
let interval_secs = match unit {
|
|
||||||
's' => num,
|
|
||||||
'm' => num * 60,
|
|
||||||
'h' => num * 3600,
|
|
||||||
_ => return false,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Check if current timestamp aligns with the interval
|
|
||||||
let timestamp = now.timestamp();
|
|
||||||
timestamp % interval_secs == 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use chrono::Timelike;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_cron_field_wildcard() {
|
|
||||||
assert!(SchedulerService::cron_field_matches("*", 5));
|
|
||||||
assert!(SchedulerService::cron_field_matches("?", 5));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_cron_field_exact() {
|
|
||||||
assert!(SchedulerService::cron_field_matches("5", 5));
|
|
||||||
assert!(!SchedulerService::cron_field_matches("5", 6));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_cron_field_step() {
|
|
||||||
assert!(SchedulerService::cron_field_matches("*/5", 0));
|
|
||||||
assert!(SchedulerService::cron_field_matches("*/5", 5));
|
|
||||||
assert!(SchedulerService::cron_field_matches("*/5", 10));
|
|
||||||
assert!(!SchedulerService::cron_field_matches("*/5", 3));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_cron_field_range() {
|
|
||||||
assert!(SchedulerService::cron_field_matches("1-5", 1));
|
|
||||||
assert!(SchedulerService::cron_field_matches("1-5", 3));
|
|
||||||
assert!(SchedulerService::cron_field_matches("1-5", 5));
|
|
||||||
assert!(!SchedulerService::cron_field_matches("1-5", 0));
|
|
||||||
assert!(!SchedulerService::cron_field_matches("1-5", 6));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_cron_field_list() {
|
|
||||||
assert!(SchedulerService::cron_field_matches("1,3,5", 1));
|
|
||||||
assert!(SchedulerService::cron_field_matches("1,3,5", 3));
|
|
||||||
assert!(SchedulerService::cron_field_matches("1,3,5", 5));
|
|
||||||
assert!(!SchedulerService::cron_field_matches("1,3,5", 2));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_should_fire_every_minute() {
|
|
||||||
let now = chrono::Utc::now();
|
|
||||||
assert!(SchedulerService::should_fire_cron("every:1m", &now));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_should_fire_cron_wildcard() {
|
|
||||||
let now = chrono::Utc::now();
|
|
||||||
// Every minute match
|
|
||||||
assert!(SchedulerService::should_fire_cron(
|
|
||||||
&format!("{} * * * *", now.minute()),
|
|
||||||
&now,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_should_not_fire_cron() {
|
|
||||||
let now = chrono::Utc::now();
|
|
||||||
let wrong_minute = if now.minute() < 59 { now.minute() + 1 } else { 0 };
|
|
||||||
assert!(!SchedulerService::should_fire_cron(
|
|
||||||
&format!("{} * * * *", wrong_minute),
|
|
||||||
&now,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
//! Skill router integration for the Kernel
|
|
||||||
//!
|
|
||||||
//! Bridges zclaw-growth's `EmbeddingClient` to zclaw-skills' `Embedder` trait,
|
|
||||||
//! enabling the `SemanticSkillRouter` to use real embedding APIs.
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
|
|
||||||
/// Adapter: zclaw-growth EmbeddingClient → zclaw-skills Embedder
|
|
||||||
pub struct EmbeddingAdapter {
|
|
||||||
client: Arc<dyn zclaw_runtime::EmbeddingClient>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl EmbeddingAdapter {
|
|
||||||
pub fn new(client: Arc<dyn zclaw_runtime::EmbeddingClient>) -> Self {
|
|
||||||
Self { client }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl zclaw_skills::semantic_router::Embedder for EmbeddingAdapter {
|
|
||||||
async fn embed(&self, text: &str) -> Option<Vec<f32>> {
|
|
||||||
self.client.embed(text).await.ok()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -49,26 +49,8 @@ CREATE TABLE IF NOT EXISTS schema_version (
|
|||||||
version INTEGER PRIMARY KEY
|
version INTEGER PRIMARY KEY
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Hand execution runs table
|
|
||||||
CREATE TABLE IF NOT EXISTS hand_runs (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
hand_name TEXT NOT NULL,
|
|
||||||
trigger_source TEXT NOT NULL,
|
|
||||||
params TEXT NOT NULL,
|
|
||||||
status TEXT NOT NULL DEFAULT 'pending',
|
|
||||||
result TEXT,
|
|
||||||
error TEXT,
|
|
||||||
duration_ms INTEGER,
|
|
||||||
created_at TEXT NOT NULL,
|
|
||||||
started_at TEXT,
|
|
||||||
completed_at TEXT
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Indexes
|
-- Indexes
|
||||||
CREATE INDEX IF NOT EXISTS idx_sessions_agent ON sessions(agent_id);
|
CREATE INDEX IF NOT EXISTS idx_sessions_agent ON sessions(agent_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id);
|
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_kv_agent ON kv_store(agent_id);
|
CREATE INDEX IF NOT EXISTS idx_kv_agent ON kv_store(agent_id);
|
||||||
CREATE INDEX IF NOT EXISTS idx_hand_runs_hand ON hand_runs(hand_name);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_hand_runs_status ON hand_runs(status);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_hand_runs_created ON hand_runs(created_at);
|
|
||||||
"#;
|
"#;
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
//! Memory store implementation
|
//! Memory store implementation
|
||||||
|
|
||||||
use sqlx::SqlitePool;
|
use sqlx::SqlitePool;
|
||||||
use zclaw_types::{AgentConfig, AgentId, SessionId, Message, Result, ZclawError, HandRun, HandRunId, HandRunStatus, HandRunFilter};
|
use zclaw_types::{AgentConfig, AgentId, SessionId, Message, Result, ZclawError};
|
||||||
|
|
||||||
/// Memory store for persisting ZCLAW data
|
/// Memory store for persisting ZCLAW data
|
||||||
pub struct MemoryStore {
|
pub struct MemoryStore {
|
||||||
@@ -283,193 +283,6 @@ impl MemoryStore {
|
|||||||
|
|
||||||
Ok(rows.into_iter().map(|(key,)| key).collect())
|
Ok(rows.into_iter().map(|(key,)| key).collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
// === Hand Run Tracking ===
|
|
||||||
|
|
||||||
/// Save a new hand run record
|
|
||||||
pub async fn save_hand_run(&self, run: &HandRun) -> Result<()> {
|
|
||||||
let id = run.id.to_string();
|
|
||||||
let trigger_source = serde_json::to_string(&run.trigger_source)?;
|
|
||||||
let params = serde_json::to_string(&run.params)?;
|
|
||||||
let result = run.result.as_ref().map(|v| serde_json::to_string(v)).transpose()?;
|
|
||||||
let error = run.error.as_ref().map(|e| serde_json::to_string(e)).transpose()?;
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
INSERT INTO hand_runs (id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&id)
|
|
||||||
.bind(&run.hand_name)
|
|
||||||
.bind(&trigger_source)
|
|
||||||
.bind(¶ms)
|
|
||||||
.bind(run.status.to_string())
|
|
||||||
.bind(result.as_deref())
|
|
||||||
.bind(error.as_deref())
|
|
||||||
.bind(run.duration_ms.map(|d| d as i64))
|
|
||||||
.bind(&run.created_at)
|
|
||||||
.bind(run.started_at.as_deref())
|
|
||||||
.bind(run.completed_at.as_deref())
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Update an existing hand run record
|
|
||||||
pub async fn update_hand_run(&self, run: &HandRun) -> Result<()> {
|
|
||||||
let id = run.id.to_string();
|
|
||||||
let trigger_source = serde_json::to_string(&run.trigger_source)?;
|
|
||||||
let params = serde_json::to_string(&run.params)?;
|
|
||||||
let result = run.result.as_ref().map(|v| serde_json::to_string(v)).transpose()?;
|
|
||||||
let error = run.error.as_ref().map(|e| serde_json::to_string(e)).transpose()?;
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
r#"
|
|
||||||
UPDATE hand_runs SET
|
|
||||||
hand_name = ?, trigger_source = ?, params = ?, status = ?,
|
|
||||||
result = ?, error = ?, duration_ms = ?,
|
|
||||||
started_at = ?, completed_at = ?
|
|
||||||
WHERE id = ?
|
|
||||||
"#,
|
|
||||||
)
|
|
||||||
.bind(&run.hand_name)
|
|
||||||
.bind(&trigger_source)
|
|
||||||
.bind(¶ms)
|
|
||||||
.bind(run.status.to_string())
|
|
||||||
.bind(result.as_deref())
|
|
||||||
.bind(error.as_deref())
|
|
||||||
.bind(run.duration_ms.map(|d| d as i64))
|
|
||||||
.bind(run.started_at.as_deref())
|
|
||||||
.bind(run.completed_at.as_deref())
|
|
||||||
.bind(&id)
|
|
||||||
.execute(&self.pool)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get a hand run by ID
|
|
||||||
pub async fn get_hand_run(&self, id: &HandRunId) -> Result<Option<HandRun>> {
|
|
||||||
let id_str = id.to_string();
|
|
||||||
|
|
||||||
let row = sqlx::query_as::<_, (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>)>(
|
|
||||||
"SELECT id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at FROM hand_runs WHERE id = ?"
|
|
||||||
)
|
|
||||||
.bind(&id_str)
|
|
||||||
.fetch_optional(&self.pool)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
||||||
|
|
||||||
match row {
|
|
||||||
Some(r) => Ok(Some(Self::row_to_hand_run(r)?)),
|
|
||||||
None => Ok(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// List hand runs with optional filter
|
|
||||||
pub async fn list_hand_runs(&self, filter: &HandRunFilter) -> Result<Vec<HandRun>> {
|
|
||||||
let mut query = String::from(
|
|
||||||
"SELECT id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at FROM hand_runs WHERE 1=1"
|
|
||||||
);
|
|
||||||
let mut bind_values: Vec<String> = Vec::new();
|
|
||||||
|
|
||||||
if let Some(ref hand_name) = filter.hand_name {
|
|
||||||
query.push_str(" AND hand_name = ?");
|
|
||||||
bind_values.push(hand_name.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(ref status) = filter.status {
|
|
||||||
query.push_str(" AND status = ?");
|
|
||||||
bind_values.push(status.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
query.push_str(" ORDER BY created_at DESC");
|
|
||||||
|
|
||||||
if let Some(limit) = filter.limit {
|
|
||||||
query.push_str(&format!(" LIMIT {}", limit));
|
|
||||||
}
|
|
||||||
if let Some(offset) = filter.offset {
|
|
||||||
query.push_str(&format!(" OFFSET {}", offset));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut sql_query = sqlx::query_as::<_, (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>)>(&query);
|
|
||||||
|
|
||||||
for val in &bind_values {
|
|
||||||
sql_query = sql_query.bind(val);
|
|
||||||
}
|
|
||||||
|
|
||||||
let rows = sql_query
|
|
||||||
.fetch_all(&self.pool)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
||||||
|
|
||||||
rows.into_iter()
|
|
||||||
.map(|r| Self::row_to_hand_run(r))
|
|
||||||
.collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Count hand runs matching filter
|
|
||||||
pub async fn count_hand_runs(&self, filter: &HandRunFilter) -> Result<u32> {
|
|
||||||
let mut query = String::from("SELECT COUNT(*) FROM hand_runs WHERE 1=1");
|
|
||||||
let mut bind_values: Vec<String> = Vec::new();
|
|
||||||
|
|
||||||
if let Some(ref hand_name) = filter.hand_name {
|
|
||||||
query.push_str(" AND hand_name = ?");
|
|
||||||
bind_values.push(hand_name.clone());
|
|
||||||
}
|
|
||||||
if let Some(ref status) = filter.status {
|
|
||||||
query.push_str(" AND status = ?");
|
|
||||||
bind_values.push(status.to_string());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut sql_query = sqlx::query_scalar::<_, i64>(&query);
|
|
||||||
for val in &bind_values {
|
|
||||||
sql_query = sql_query.bind(val);
|
|
||||||
}
|
|
||||||
|
|
||||||
let count = sql_query
|
|
||||||
.fetch_one(&self.pool)
|
|
||||||
.await
|
|
||||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
|
||||||
|
|
||||||
Ok(count as u32)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn row_to_hand_run(
|
|
||||||
row: (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>),
|
|
||||||
) -> Result<HandRun> {
|
|
||||||
let (id, hand_name, trigger_source, params, status, result, error, duration_ms, created_at, started_at, completed_at) = row;
|
|
||||||
|
|
||||||
let run_id: HandRunId = id.parse()
|
|
||||||
.map_err(|e| ZclawError::StorageError(format!("Invalid HandRunId: {}", e)))?;
|
|
||||||
let trigger: zclaw_types::TriggerSource = serde_json::from_str(&trigger_source)?;
|
|
||||||
let params_val: serde_json::Value = serde_json::from_str(¶ms)?;
|
|
||||||
let run_status: HandRunStatus = status.parse()
|
|
||||||
.map_err(|e| ZclawError::StorageError(e))?;
|
|
||||||
let result_val: Option<serde_json::Value> = result.map(|r| serde_json::from_str(&r)).transpose()?;
|
|
||||||
let error_val: Option<String> = error.as_ref()
|
|
||||||
.map(|e| serde_json::from_str::<String>(e))
|
|
||||||
.transpose()
|
|
||||||
.unwrap_or_else(|_| error.clone());
|
|
||||||
|
|
||||||
Ok(HandRun {
|
|
||||||
id: run_id,
|
|
||||||
hand_name,
|
|
||||||
trigger_source: trigger,
|
|
||||||
params: params_val,
|
|
||||||
status: run_status,
|
|
||||||
result: result_val,
|
|
||||||
error: error_val,
|
|
||||||
duration_ms: duration_ms.map(|d| d as u64),
|
|
||||||
created_at,
|
|
||||||
started_at,
|
|
||||||
completed_at,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|||||||
@@ -13,22 +13,12 @@ use super::OrchestrationActionDriver;
|
|||||||
pub struct SkillOrchestrationDriver {
|
pub struct SkillOrchestrationDriver {
|
||||||
/// Skill registry for executing skills
|
/// Skill registry for executing skills
|
||||||
skill_registry: Arc<zclaw_skills::SkillRegistry>,
|
skill_registry: Arc<zclaw_skills::SkillRegistry>,
|
||||||
/// Graph store for persisting/loading graphs by ID
|
|
||||||
graph_store: Option<Arc<dyn zclaw_skills::orchestration::GraphStore>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SkillOrchestrationDriver {
|
impl SkillOrchestrationDriver {
|
||||||
/// Create a new orchestration driver
|
/// Create a new orchestration driver
|
||||||
pub fn new(skill_registry: Arc<zclaw_skills::SkillRegistry>) -> Self {
|
pub fn new(skill_registry: Arc<zclaw_skills::SkillRegistry>) -> Self {
|
||||||
Self { skill_registry, graph_store: None }
|
Self { skill_registry }
|
||||||
}
|
|
||||||
|
|
||||||
/// Create with graph persistence
|
|
||||||
pub fn with_graph_store(
|
|
||||||
skill_registry: Arc<zclaw_skills::SkillRegistry>,
|
|
||||||
graph_store: Arc<dyn zclaw_skills::orchestration::GraphStore>,
|
|
||||||
) -> Self {
|
|
||||||
Self { skill_registry, graph_store: Some(graph_store) }
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,11 +38,8 @@ impl OrchestrationActionDriver for SkillOrchestrationDriver {
|
|||||||
serde_json::from_value::<SkillGraph>(graph_value.clone())
|
serde_json::from_value::<SkillGraph>(graph_value.clone())
|
||||||
.map_err(|e| format!("Failed to parse graph: {}", e))?
|
.map_err(|e| format!("Failed to parse graph: {}", e))?
|
||||||
} else if let Some(id) = graph_id {
|
} else if let Some(id) = graph_id {
|
||||||
// Load graph from store
|
// Load graph from registry (TODO: implement graph storage)
|
||||||
self.graph_store.as_ref()
|
return Err(format!("Graph loading by ID not yet implemented: {}", id));
|
||||||
.ok_or_else(|| "Graph store not configured. Cannot resolve graph_id.".to_string())?
|
|
||||||
.load(id).await
|
|
||||||
.ok_or_else(|| format!("Graph not found: {}", id))?
|
|
||||||
} else {
|
} else {
|
||||||
return Err("Either graph_id or graph must be provided".to_string());
|
return Err("Either graph_id or graph must be provided".to_string());
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -61,10 +61,6 @@ pub struct PipelineMetadata {
|
|||||||
/// Version string
|
/// Version string
|
||||||
#[serde(default = "default_version")]
|
#[serde(default = "default_version")]
|
||||||
pub version: String,
|
pub version: String,
|
||||||
|
|
||||||
/// Arbitrary key-value annotations (e.g., is_template: true)
|
|
||||||
#[serde(default)]
|
|
||||||
pub annotations: Option<std::collections::HashMap<String, serde_json::Value>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_version() -> String {
|
fn default_version() -> String {
|
||||||
|
|||||||
@@ -427,28 +427,6 @@ impl A2aRouter {
|
|||||||
pub fn agent_id(&self) -> &AgentId {
|
pub fn agent_id(&self) -> &AgentId {
|
||||||
&self.agent_id
|
&self.agent_id
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Discover agents that have a specific capability
|
|
||||||
pub async fn discover(&self, capability: &str) -> Result<Vec<A2aAgentProfile>> {
|
|
||||||
let cap_index = self.capability_index.read().await;
|
|
||||||
let profiles = self.profiles.read().await;
|
|
||||||
|
|
||||||
match cap_index.get(capability) {
|
|
||||||
Some(agent_ids) => {
|
|
||||||
let result: Vec<A2aAgentProfile> = agent_ids.iter()
|
|
||||||
.filter_map(|id| profiles.get(id).cloned())
|
|
||||||
.collect();
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
None => Ok(Vec::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get all registered agent profiles
|
|
||||||
pub async fn list_profiles(&self) -> Vec<A2aAgentProfile> {
|
|
||||||
let profiles = self.profiles.read().await;
|
|
||||||
profiles.values().cloned().collect()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Basic A2A client implementation
|
/// Basic A2A client implementation
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
//! Optionally flushes old messages to the growth/memory system before discarding.
|
//! Optionally flushes old messages to the growth/memory system before discarding.
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::sync::atomic::{AtomicU64, Ordering};
|
|
||||||
use zclaw_types::{AgentId, Message, SessionId};
|
use zclaw_types::{AgentId, Message, SessionId};
|
||||||
|
|
||||||
use crate::driver::{CompletionRequest, ContentBlock, LlmDriver};
|
use crate::driver::{CompletionRequest, ContentBlock, LlmDriver};
|
||||||
@@ -41,18 +40,9 @@ pub fn estimate_tokens(text: &str) -> usize {
|
|||||||
{
|
{
|
||||||
// CJK ideographs — ~1.5 tokens
|
// CJK ideographs — ~1.5 tokens
|
||||||
tokens += 1.5;
|
tokens += 1.5;
|
||||||
} else if (0xAC00..=0xD7AF).contains(&code) || (0x1100..=0x11FF).contains(&code) {
|
|
||||||
// Korean Hangul syllables + Jamo — ~1.5 tokens
|
|
||||||
tokens += 1.5;
|
|
||||||
} else if (0x3040..=0x309F).contains(&code) || (0x30A0..=0x30FF).contains(&code) {
|
|
||||||
// Japanese Hiragana + Katakana — ~1.5 tokens
|
|
||||||
tokens += 1.5;
|
|
||||||
} else if (0x3000..=0x303F).contains(&code) || (0xFF00..=0xFFEF).contains(&code) {
|
} else if (0x3000..=0x303F).contains(&code) || (0xFF00..=0xFFEF).contains(&code) {
|
||||||
// CJK / fullwidth punctuation — ~1.0 token
|
// CJK / fullwidth punctuation — ~1.0 token
|
||||||
tokens += 1.0;
|
tokens += 1.0;
|
||||||
} else if (0x1F000..=0x1FAFF).contains(&code) || (0x2600..=0x27BF).contains(&code) {
|
|
||||||
// Emoji & Symbols — ~2.0 tokens
|
|
||||||
tokens += 2.0;
|
|
||||||
} else if char == ' ' || char == '\n' || char == '\t' {
|
} else if char == ' ' || char == '\n' || char == '\t' {
|
||||||
// whitespace
|
// whitespace
|
||||||
tokens += 0.25;
|
tokens += 0.25;
|
||||||
@@ -98,54 +88,6 @@ pub fn estimate_messages_tokens(messages: &[Message]) -> usize {
|
|||||||
total
|
total
|
||||||
}
|
}
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Calibration: adjust heuristic estimates using API feedback
|
|
||||||
// ============================================================
|
|
||||||
|
|
||||||
const F64_1_0_BITS: u64 = 4607182418800017408u64; // 1.0f64.to_bits()
|
|
||||||
|
|
||||||
/// Global calibration factor for token estimation (stored as f64 bits).
|
|
||||||
///
|
|
||||||
/// Updated via exponential moving average when API returns actual token counts.
|
|
||||||
/// Initial value is 1.0 (no adjustment).
|
|
||||||
static CALIBRATION_FACTOR_BITS: AtomicU64 = AtomicU64::new(F64_1_0_BITS);
|
|
||||||
|
|
||||||
/// Get the current calibration factor.
|
|
||||||
pub fn get_calibration_factor() -> f64 {
|
|
||||||
f64::from_bits(CALIBRATION_FACTOR_BITS.load(Ordering::Relaxed))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Update calibration factor using exponential moving average.
|
|
||||||
///
|
|
||||||
/// Compares estimated tokens with actual tokens from API response:
|
|
||||||
/// - `ratio = actual / estimated` so underestimates push factor UP
|
|
||||||
/// - EMA: `new = current * 0.7 + ratio * 0.3`
|
|
||||||
/// - Clamped to [0.5, 2.0] to prevent runaway values
|
|
||||||
pub fn update_calibration(estimated: usize, actual: u32) {
|
|
||||||
if actual == 0 || estimated == 0 {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
let ratio = actual as f64 / estimated as f64;
|
|
||||||
let current = get_calibration_factor();
|
|
||||||
let new_factor = (current * 0.7 + ratio * 0.3).clamp(0.5, 2.0);
|
|
||||||
CALIBRATION_FACTOR_BITS.store(new_factor.to_bits(), Ordering::Relaxed);
|
|
||||||
tracing::debug!(
|
|
||||||
"[Compaction] Calibration: estimated={}, actual={}, ratio={:.2}, factor {:.2} → {:.2}",
|
|
||||||
estimated, actual, ratio, current, new_factor
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Estimate total tokens for messages with calibration applied.
|
|
||||||
fn estimate_messages_tokens_calibrated(messages: &[Message]) -> usize {
|
|
||||||
let raw = estimate_messages_tokens(messages);
|
|
||||||
let factor = get_calibration_factor();
|
|
||||||
if (factor - 1.0).abs() < f64::EPSILON {
|
|
||||||
raw
|
|
||||||
} else {
|
|
||||||
((raw as f64 * factor).ceil()) as usize
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Compact a message list by summarizing old messages and keeping recent ones.
|
/// Compact a message list by summarizing old messages and keeping recent ones.
|
||||||
///
|
///
|
||||||
/// When `messages.len() > keep_recent`, the oldest messages are summarized
|
/// When `messages.len() > keep_recent`, the oldest messages are summarized
|
||||||
@@ -192,7 +134,7 @@ pub fn compact_messages(messages: Vec<Message>, keep_recent: usize) -> (Vec<Mess
|
|||||||
///
|
///
|
||||||
/// Returns the (possibly compacted) message list.
|
/// Returns the (possibly compacted) message list.
|
||||||
pub fn maybe_compact(messages: Vec<Message>, threshold: usize) -> Vec<Message> {
|
pub fn maybe_compact(messages: Vec<Message>, threshold: usize) -> Vec<Message> {
|
||||||
let tokens = estimate_messages_tokens_calibrated(&messages);
|
let tokens = estimate_messages_tokens(&messages);
|
||||||
if tokens < threshold {
|
if tokens < threshold {
|
||||||
return messages;
|
return messages;
|
||||||
}
|
}
|
||||||
@@ -266,7 +208,7 @@ pub async fn maybe_compact_with_config(
|
|||||||
driver: Option<&Arc<dyn LlmDriver>>,
|
driver: Option<&Arc<dyn LlmDriver>>,
|
||||||
growth: Option<&GrowthIntegration>,
|
growth: Option<&GrowthIntegration>,
|
||||||
) -> CompactionOutcome {
|
) -> CompactionOutcome {
|
||||||
let tokens = estimate_messages_tokens_calibrated(&messages);
|
let tokens = estimate_messages_tokens(&messages);
|
||||||
if tokens < threshold {
|
if tokens < threshold {
|
||||||
return CompactionOutcome {
|
return CompactionOutcome {
|
||||||
messages,
|
messages,
|
||||||
@@ -533,11 +475,10 @@ fn generate_summary(messages: &[Message]) -> String {
|
|||||||
|
|
||||||
let summary = sections.join("\n");
|
let summary = sections.join("\n");
|
||||||
|
|
||||||
// Enforce max length (char-safe for CJK)
|
// Enforce max length
|
||||||
let max_chars = 800;
|
let max_chars = 800;
|
||||||
if summary.chars().count() > max_chars {
|
if summary.len() > max_chars {
|
||||||
let truncated: String = summary.chars().take(max_chars).collect();
|
format!("{}...\n(摘要已截断)", &summary[..max_chars])
|
||||||
format!("{}...\n(摘要已截断)", truncated)
|
|
||||||
} else {
|
} else {
|
||||||
summary
|
summary
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -130,8 +130,7 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
let api_key = self.api_key.expose_secret().to_string();
|
let api_key = self.api_key.expose_secret().to_string();
|
||||||
|
|
||||||
Box::pin(stream! {
|
Box::pin(stream! {
|
||||||
println!("[OpenAI:stream] POST to {}/chat/completions", base_url);
|
tracing::debug!("[OpenAiDriver:stream] Starting HTTP request...");
|
||||||
println!("[OpenAI:stream] Request model={}, stream={}", stream_request.model, stream_request.stream);
|
|
||||||
let response = match self.client
|
let response = match self.client
|
||||||
.post(format!("{}/chat/completions", base_url))
|
.post(format!("{}/chat/completions", base_url))
|
||||||
.header("Authorization", format!("Bearer {}", api_key))
|
.header("Authorization", format!("Bearer {}", api_key))
|
||||||
@@ -142,11 +141,11 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(r) => {
|
Ok(r) => {
|
||||||
println!("[OpenAI:stream] Response status: {}, content-type: {:?}", r.status(), r.headers().get("content-type"));
|
tracing::debug!("[OpenAiDriver:stream] Got response, status: {}", r.status());
|
||||||
r
|
r
|
||||||
},
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!("[OpenAI:stream] HTTP request FAILED: {:?}", e);
|
tracing::error!("[OpenAiDriver:stream] HTTP request failed: {:?}", e);
|
||||||
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
|
yield Err(ZclawError::LlmError(format!("HTTP request failed: {}", e)));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -155,7 +154,6 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let body = response.text().await.unwrap_or_default();
|
let body = response.text().await.unwrap_or_default();
|
||||||
println!("[OpenAI:stream] API error {}: {}", status, &body[..body.len().min(500)]);
|
|
||||||
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
yield Err(ZclawError::LlmError(format!("API error {}: {}", status, body)));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -163,45 +161,21 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
let mut byte_stream = response.bytes_stream();
|
let mut byte_stream = response.bytes_stream();
|
||||||
let mut accumulated_tool_calls: std::collections::HashMap<String, (String, String)> = std::collections::HashMap::new();
|
let mut accumulated_tool_calls: std::collections::HashMap<String, (String, String)> = std::collections::HashMap::new();
|
||||||
let mut current_tool_id: Option<String> = None;
|
let mut current_tool_id: Option<String> = None;
|
||||||
let mut sse_event_count: usize = 0;
|
|
||||||
let mut raw_bytes_total: usize = 0;
|
|
||||||
|
|
||||||
while let Some(chunk_result) = byte_stream.next().await {
|
while let Some(chunk_result) = byte_stream.next().await {
|
||||||
let chunk = match chunk_result {
|
let chunk = match chunk_result {
|
||||||
Ok(c) => c,
|
Ok(c) => c,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!("[OpenAI:stream] Byte stream error: {:?}", e);
|
|
||||||
yield Err(ZclawError::LlmError(format!("Stream error: {}", e)));
|
yield Err(ZclawError::LlmError(format!("Stream error: {}", e)));
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
raw_bytes_total += chunk.len();
|
|
||||||
let text = String::from_utf8_lossy(&chunk);
|
let text = String::from_utf8_lossy(&chunk);
|
||||||
// Log first 500 bytes of raw data for debugging SSE format
|
|
||||||
if raw_bytes_total <= 600 {
|
|
||||||
println!("[OpenAI:stream] RAW chunk ({} bytes): {:?}", text.len(), &text[..text.len().min(500)]);
|
|
||||||
}
|
|
||||||
for line in text.lines() {
|
for line in text.lines() {
|
||||||
let trimmed = line.trim();
|
if let Some(data) = line.strip_prefix("data: ") {
|
||||||
if trimmed.is_empty() || trimmed.starts_with(':') {
|
|
||||||
continue; // Skip empty lines and SSE comments
|
|
||||||
}
|
|
||||||
// Handle both "data: " (standard) and "data:" (no space)
|
|
||||||
let data = if let Some(d) = trimmed.strip_prefix("data: ") {
|
|
||||||
Some(d)
|
|
||||||
} else if let Some(d) = trimmed.strip_prefix("data:") {
|
|
||||||
Some(d.trim_start())
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
if let Some(data) = data {
|
|
||||||
sse_event_count += 1;
|
|
||||||
if sse_event_count <= 3 || data == "[DONE]" {
|
|
||||||
println!("[OpenAI:stream] SSE #{}: {}", sse_event_count, &data[..data.len().min(300)]);
|
|
||||||
}
|
|
||||||
if data == "[DONE]" {
|
if data == "[DONE]" {
|
||||||
println!("[OpenAI:stream] Received [DONE], total SSE events: {}, raw bytes: {}", sse_event_count, raw_bytes_total);
|
tracing::debug!("[OpenAI] Stream done, accumulated_tool_calls: {:?}", accumulated_tool_calls.len());
|
||||||
|
|
||||||
// Emit ToolUseEnd for all accumulated tool calls (skip invalid ones with empty name)
|
// Emit ToolUseEnd for all accumulated tool calls (skip invalid ones with empty name)
|
||||||
for (id, (name, args)) in &accumulated_tool_calls {
|
for (id, (name, args)) in &accumulated_tool_calls {
|
||||||
@@ -242,19 +216,10 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
// Handle text content
|
// Handle text content
|
||||||
if let Some(content) = &delta.content {
|
if let Some(content) = &delta.content {
|
||||||
if !content.is_empty() {
|
if !content.is_empty() {
|
||||||
tracing::debug!("[OpenAI:stream] TextDelta: {} chars", content.len());
|
|
||||||
yield Ok(StreamChunk::TextDelta { delta: content.clone() });
|
yield Ok(StreamChunk::TextDelta { delta: content.clone() });
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle reasoning_content (Kimi, Qwen, DeepSeek, GLM thinking)
|
|
||||||
if let Some(reasoning) = &delta.reasoning_content {
|
|
||||||
if !reasoning.is_empty() {
|
|
||||||
tracing::debug!("[OpenAI:stream] ThinkingDelta (reasoning_content): {} chars", reasoning.len());
|
|
||||||
yield Ok(StreamChunk::ThinkingDelta { delta: reasoning.clone() });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle tool calls
|
// Handle tool calls
|
||||||
if let Some(tool_calls) = &delta.tool_calls {
|
if let Some(tool_calls) = &delta.tool_calls {
|
||||||
tracing::trace!("[OpenAI] Received tool_calls delta: {:?}", tool_calls);
|
tracing::trace!("[OpenAI] Received tool_calls delta: {:?}", tool_calls);
|
||||||
@@ -319,7 +284,6 @@ impl LlmDriver for OpenAiDriver {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
println!("[OpenAI:stream] Byte stream ended. Total: {} SSE events, {} raw bytes", sse_event_count, raw_bytes_total);
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -340,122 +304,55 @@ impl OpenAiDriver {
|
|||||||
request.system.clone()
|
request.system.clone()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Build messages with tool result truncation to prevent payload overflow.
|
let messages: Vec<OpenAiMessage> = request.messages
|
||||||
// Most LLM APIs have a 2-4MB HTTP payload limit.
|
.iter()
|
||||||
const MAX_TOOL_RESULT_BYTES: usize = 32_768; // 32KB per tool result
|
.filter_map(|msg| match msg {
|
||||||
const MAX_PAYLOAD_BYTES: usize = 1_800_000; // 1.8MB (under 2MB API limit)
|
zclaw_types::Message::User { content } => Some(OpenAiMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
let mut messages: Vec<OpenAiMessage> = Vec::new();
|
content: Some(content.clone()),
|
||||||
let mut pending_tool_calls: Option<Vec<OpenAiToolCall>> = None;
|
|
||||||
let mut pending_content: Option<String> = None;
|
|
||||||
let mut pending_reasoning: Option<String> = None;
|
|
||||||
|
|
||||||
let flush_pending = |tc: &mut Option<Vec<OpenAiToolCall>>,
|
|
||||||
c: &mut Option<String>,
|
|
||||||
r: &mut Option<String>,
|
|
||||||
out: &mut Vec<OpenAiMessage>| {
|
|
||||||
let calls = tc.take();
|
|
||||||
let content = c.take();
|
|
||||||
let reasoning = r.take();
|
|
||||||
|
|
||||||
if let Some(calls) = calls {
|
|
||||||
if !calls.is_empty() {
|
|
||||||
// Merge assistant content + reasoning into the tool call message
|
|
||||||
out.push(OpenAiMessage {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: content.filter(|s| !s.is_empty()),
|
|
||||||
reasoning_content: reasoning.filter(|s| !s.is_empty()),
|
|
||||||
tool_calls: Some(calls),
|
|
||||||
tool_call_id: None,
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// No tool calls — emit a plain assistant message
|
|
||||||
if content.is_some() || reasoning.is_some() {
|
|
||||||
out.push(OpenAiMessage {
|
|
||||||
role: "assistant".to_string(),
|
|
||||||
content: content.filter(|s| !s.is_empty()),
|
|
||||||
reasoning_content: reasoning.filter(|s| !s.is_empty()),
|
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
tool_call_id: None,
|
}),
|
||||||
});
|
zclaw_types::Message::Assistant { content, thinking: _ } => Some(OpenAiMessage {
|
||||||
}
|
role: "assistant".to_string(),
|
||||||
};
|
content: Some(content.clone()),
|
||||||
|
tool_calls: None,
|
||||||
for msg in &request.messages {
|
}),
|
||||||
match msg {
|
zclaw_types::Message::System { content } => Some(OpenAiMessage {
|
||||||
zclaw_types::Message::User { content } => {
|
role: "system".to_string(),
|
||||||
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
content: Some(content.clone()),
|
||||||
messages.push(OpenAiMessage {
|
tool_calls: None,
|
||||||
role: "user".to_string(),
|
}),
|
||||||
content: Some(content.clone()),
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
reasoning_content: None,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
zclaw_types::Message::Assistant { content, thinking } => {
|
|
||||||
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
|
||||||
// Don't push immediately — wait to see if next messages are ToolUse
|
|
||||||
pending_content = Some(content.clone());
|
|
||||||
pending_reasoning = thinking.clone();
|
|
||||||
}
|
|
||||||
zclaw_types::Message::System { content } => {
|
|
||||||
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
|
||||||
messages.push(OpenAiMessage {
|
|
||||||
role: "system".to_string(),
|
|
||||||
content: Some(content.clone()),
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
reasoning_content: None,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
zclaw_types::Message::ToolUse { id, tool, input } => {
|
zclaw_types::Message::ToolUse { id, tool, input } => {
|
||||||
// Accumulate tool calls — they'll be merged with the pending assistant message
|
// Ensure arguments is always a valid JSON object, never null or invalid
|
||||||
let args = if input.is_null() {
|
let args = if input.is_null() {
|
||||||
"{}".to_string()
|
"{}".to_string()
|
||||||
} else {
|
} else {
|
||||||
serde_json::to_string(input).unwrap_or_else(|_| "{}".to_string())
|
serde_json::to_string(input).unwrap_or_else(|_| "{}".to_string())
|
||||||
};
|
};
|
||||||
pending_tool_calls
|
Some(OpenAiMessage {
|
||||||
.get_or_insert_with(Vec::new)
|
role: "assistant".to_string(),
|
||||||
.push(OpenAiToolCall {
|
content: None,
|
||||||
|
tool_calls: Some(vec![OpenAiToolCall {
|
||||||
id: id.clone(),
|
id: id.clone(),
|
||||||
r#type: "function".to_string(),
|
r#type: "function".to_string(),
|
||||||
function: FunctionCall {
|
function: FunctionCall {
|
||||||
name: tool.to_string(),
|
name: tool.to_string(),
|
||||||
arguments: args,
|
arguments: args,
|
||||||
},
|
},
|
||||||
});
|
}]),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
zclaw_types::Message::ToolResult { tool_call_id, output, is_error, .. } => {
|
zclaw_types::Message::ToolResult { tool_call_id: _, output, is_error, .. } => Some(OpenAiMessage {
|
||||||
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
role: "tool".to_string(),
|
||||||
let content_str = if *is_error {
|
content: Some(if *is_error {
|
||||||
format!("Error: {}", output)
|
format!("Error: {}", output)
|
||||||
} else {
|
} else {
|
||||||
output.to_string()
|
output.to_string()
|
||||||
};
|
}),
|
||||||
// Truncate oversized tool results to prevent payload overflow
|
tool_calls: None,
|
||||||
let truncated = if content_str.len() > MAX_TOOL_RESULT_BYTES {
|
}),
|
||||||
let mut s = String::from(&content_str[..MAX_TOOL_RESULT_BYTES]);
|
})
|
||||||
s.push_str("\n\n... [内容已截断,原文过大]");
|
.collect();
|
||||||
s
|
|
||||||
} else {
|
|
||||||
content_str
|
|
||||||
};
|
|
||||||
messages.push(OpenAiMessage {
|
|
||||||
role: "tool".to_string(),
|
|
||||||
content: Some(truncated),
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: Some(tool_call_id.clone()),
|
|
||||||
reasoning_content: None,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Flush any remaining accumulated assistant content and/or tool calls
|
|
||||||
flush_pending(&mut pending_tool_calls, &mut pending_content, &mut pending_reasoning, &mut messages);
|
|
||||||
|
|
||||||
// Add system prompt if provided
|
// Add system prompt if provided
|
||||||
let mut messages = messages;
|
let mut messages = messages;
|
||||||
@@ -464,8 +361,6 @@ impl OpenAiDriver {
|
|||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: Some(system.clone()),
|
content: Some(system.clone()),
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
tool_call_id: None,
|
|
||||||
reasoning_content: None,
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -481,7 +376,7 @@ impl OpenAiDriver {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let api_request = OpenAiRequest {
|
OpenAiRequest {
|
||||||
model: request.model.clone(), // Use model ID directly without any transformation
|
model: request.model.clone(), // Use model ID directly without any transformation
|
||||||
messages,
|
messages,
|
||||||
max_tokens: request.max_tokens,
|
max_tokens: request.max_tokens,
|
||||||
@@ -489,75 +384,7 @@ impl OpenAiDriver {
|
|||||||
stop: if request.stop.is_empty() { None } else { Some(request.stop.clone()) },
|
stop: if request.stop.is_empty() { None } else { Some(request.stop.clone()) },
|
||||||
stream: request.stream,
|
stream: request.stream,
|
||||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||||
};
|
|
||||||
|
|
||||||
// Pre-send payload size validation
|
|
||||||
if let Ok(serialized) = serde_json::to_string(&api_request) {
|
|
||||||
if serialized.len() > MAX_PAYLOAD_BYTES {
|
|
||||||
tracing::warn!(
|
|
||||||
target: "openai_driver",
|
|
||||||
"Request payload too large: {} bytes (limit: {}), truncating messages",
|
|
||||||
serialized.len(),
|
|
||||||
MAX_PAYLOAD_BYTES
|
|
||||||
);
|
|
||||||
return Self::truncate_messages_to_fit(api_request, MAX_PAYLOAD_BYTES);
|
|
||||||
}
|
|
||||||
tracing::debug!(
|
|
||||||
target: "openai_driver",
|
|
||||||
"Request payload size: {} bytes (limit: {})",
|
|
||||||
serialized.len(),
|
|
||||||
MAX_PAYLOAD_BYTES
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
api_request
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Emergency truncation: drop oldest non-system messages until payload fits
|
|
||||||
fn truncate_messages_to_fit(mut request: OpenAiRequest, _max_bytes: usize) -> OpenAiRequest {
|
|
||||||
// Keep system message (if any) and last 4 non-system messages
|
|
||||||
let has_system = request.messages.first()
|
|
||||||
.map(|m| m.role == "system")
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
let non_system: Vec<OpenAiMessage> = request.messages.into_iter()
|
|
||||||
.filter(|m| m.role != "system")
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Keep last N messages and truncate any remaining large tool results
|
|
||||||
let keep_count = 4.min(non_system.len());
|
|
||||||
let start = non_system.len() - keep_count;
|
|
||||||
let kept: Vec<OpenAiMessage> = non_system.into_iter()
|
|
||||||
.skip(start)
|
|
||||||
.map(|mut msg| {
|
|
||||||
// Additional per-message truncation for tool results
|
|
||||||
if msg.role == "tool" {
|
|
||||||
if let Some(ref content) = msg.content {
|
|
||||||
if content.len() > 16_384 {
|
|
||||||
let mut s = String::from(&content[..16_384]);
|
|
||||||
s.push_str("\n\n... [上下文压缩截断]");
|
|
||||||
msg.content = Some(s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
msg
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut messages = Vec::new();
|
|
||||||
if has_system {
|
|
||||||
messages.push(OpenAiMessage {
|
|
||||||
role: "system".to_string(),
|
|
||||||
content: Some("You are a helpful AI assistant. (注意:对话历史已被压缩以适应上下文大小限制)".to_string()),
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
reasoning_content: None,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
messages.extend(kept);
|
|
||||||
|
|
||||||
request.messages = messages;
|
|
||||||
request
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn convert_response(&self, api_response: OpenAiResponse, model: String) -> CompletionResponse {
|
fn convert_response(&self, api_response: OpenAiResponse, model: String) -> CompletionResponse {
|
||||||
@@ -571,7 +398,6 @@ impl OpenAiDriver {
|
|||||||
// This is important because some providers return empty content with tool_calls
|
// This is important because some providers return empty content with tool_calls
|
||||||
let has_tool_calls = c.message.tool_calls.as_ref().map(|tc| !tc.is_empty()).unwrap_or(false);
|
let has_tool_calls = c.message.tool_calls.as_ref().map(|tc| !tc.is_empty()).unwrap_or(false);
|
||||||
let has_content = c.message.content.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
|
let has_content = c.message.content.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
|
||||||
let has_reasoning = c.message.reasoning_content.as_ref().map(|t| !t.is_empty()).unwrap_or(false);
|
|
||||||
|
|
||||||
let blocks = if has_tool_calls {
|
let blocks = if has_tool_calls {
|
||||||
// Tool calls take priority
|
// Tool calls take priority
|
||||||
@@ -587,11 +413,6 @@ impl OpenAiDriver {
|
|||||||
let text = c.message.content.as_ref().unwrap();
|
let text = c.message.content.as_ref().unwrap();
|
||||||
tracing::debug!("[OpenAiDriver:convert_response] Using text content: {} chars", text.len());
|
tracing::debug!("[OpenAiDriver:convert_response] Using text content: {} chars", text.len());
|
||||||
vec![ContentBlock::Text { text: text.clone() }]
|
vec![ContentBlock::Text { text: text.clone() }]
|
||||||
} else if has_reasoning {
|
|
||||||
// Content empty but reasoning_content present (Kimi, Qwen, DeepSeek)
|
|
||||||
let reasoning = c.message.reasoning_content.as_ref().unwrap();
|
|
||||||
tracing::debug!("[OpenAiDriver:convert_response] Using reasoning_content: {} chars", reasoning.len());
|
|
||||||
vec![ContentBlock::Text { text: reasoning.clone() }]
|
|
||||||
} else {
|
} else {
|
||||||
// No content or tool_calls
|
// No content or tool_calls
|
||||||
tracing::debug!("[OpenAiDriver:convert_response] No content or tool_calls, using empty text");
|
tracing::debug!("[OpenAiDriver:convert_response] No content or tool_calls, using empty text");
|
||||||
@@ -773,10 +594,6 @@ struct OpenAiMessage {
|
|||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
tool_calls: Option<Vec<OpenAiToolCall>>,
|
tool_calls: Option<Vec<OpenAiToolCall>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
tool_call_id: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
reasoning_content: Option<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize)]
|
#[derive(Serialize)]
|
||||||
@@ -839,8 +656,6 @@ struct OpenAiResponseMessage {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
reasoning_content: Option<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
tool_calls: Option<Vec<OpenAiToolCallResponse>>,
|
tool_calls: Option<Vec<OpenAiToolCallResponse>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -890,8 +705,6 @@ struct OpenAiDelta {
|
|||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content: Option<String>,
|
content: Option<String>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
reasoning_content: Option<String>,
|
|
||||||
#[serde(default)]
|
|
||||||
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
|
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,22 @@
|
|||||||
//! enabling automatic memory retrieval before conversations and memory extraction
|
//! enabling automatic memory retrieval before conversations and memory extraction
|
||||||
//! after conversations.
|
//! after conversations.
|
||||||
//!
|
//!
|
||||||
//! **Note (2026-03-30)**: GrowthIntegration IS wired into the Kernel's middleware
|
//! # Usage
|
||||||
//! chain (MemoryMiddleware + CompactionMiddleware). In the Tauri desktop deployment,
|
//!
|
||||||
//! `kernel_commands::kernel_init()` bridges the persistent SqliteStorage to the Kernel
|
//! ```rust,ignore
|
||||||
//! via `set_viking()` + `set_extraction_driver()`, so the middleware chain and the
|
//! use zclaw_runtime::growth::GrowthIntegration;
|
||||||
//! Tauri intelligence_hooks share the same persistent storage backend.
|
//! use zclaw_growth::{VikingAdapter, MemoryExtractor, MemoryRetriever, PromptInjector};
|
||||||
|
//!
|
||||||
|
//! // Create growth integration
|
||||||
|
//! let viking = Arc::new(VikingAdapter::in_memory());
|
||||||
|
//! let growth = GrowthIntegration::new(viking);
|
||||||
|
//!
|
||||||
|
//! // Before conversation: enhance system prompt
|
||||||
|
//! let enhanced_prompt = growth.enhance_prompt(&agent_id, &base_prompt, &user_input).await?;
|
||||||
|
//!
|
||||||
|
//! // After conversation: extract and store memories
|
||||||
|
//! growth.process_conversation(&agent_id, &messages, session_id).await?;
|
||||||
|
//! ```
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use zclaw_growth::{
|
use zclaw_growth::{
|
||||||
|
|||||||
@@ -3,10 +3,8 @@
|
|||||||
//! LLM drivers, tool system, and agent loop implementation.
|
//! LLM drivers, tool system, and agent loop implementation.
|
||||||
|
|
||||||
/// Default User-Agent header sent with all outgoing HTTP requests.
|
/// Default User-Agent header sent with all outgoing HTTP requests.
|
||||||
/// Coding Plan providers (Kimi, Bailian/DashScope, Zhipu) validate the User-Agent against a
|
/// Some LLM providers (e.g. Moonshot, Qwen, DashScope Coding Plan) reject requests without one.
|
||||||
/// whitelist of known Coding Agents (e.g. claude-code, kimi-cli, roo-code, kilo-code).
|
pub const USER_AGENT: &str = "ZCLAW/0.1.0";
|
||||||
/// Must use the exact lowercase format to pass validation.
|
|
||||||
pub const USER_AGENT: &str = "claude-code/0.1.0";
|
|
||||||
|
|
||||||
pub mod driver;
|
pub mod driver;
|
||||||
pub mod tool;
|
pub mod tool;
|
||||||
@@ -15,7 +13,6 @@ pub mod loop_guard;
|
|||||||
pub mod stream;
|
pub mod stream;
|
||||||
pub mod growth;
|
pub mod growth;
|
||||||
pub mod compaction;
|
pub mod compaction;
|
||||||
pub mod middleware;
|
|
||||||
|
|
||||||
// Re-export main types
|
// Re-export main types
|
||||||
pub use driver::{
|
pub use driver::{
|
||||||
@@ -27,7 +24,4 @@ pub use loop_runner::{AgentLoop, AgentLoopResult, LoopEvent};
|
|||||||
pub use loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult};
|
pub use loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult};
|
||||||
pub use stream::{StreamEvent, StreamSender};
|
pub use stream::{StreamEvent, StreamSender};
|
||||||
pub use growth::GrowthIntegration;
|
pub use growth::GrowthIntegration;
|
||||||
pub use zclaw_growth::VikingAdapter;
|
|
||||||
pub use zclaw_growth::EmbeddingClient;
|
|
||||||
pub use zclaw_growth::LlmDriverForExtraction;
|
|
||||||
pub use compaction::{CompactionConfig, CompactionOutcome};
|
pub use compaction::{CompactionConfig, CompactionOutcome};
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ use crate::tool::builtin::PathValidator;
|
|||||||
use crate::loop_guard::{LoopGuard, LoopGuardResult};
|
use crate::loop_guard::{LoopGuard, LoopGuardResult};
|
||||||
use crate::growth::GrowthIntegration;
|
use crate::growth::GrowthIntegration;
|
||||||
use crate::compaction::{self, CompactionConfig};
|
use crate::compaction::{self, CompactionConfig};
|
||||||
use crate::middleware::{self, MiddlewareChain};
|
|
||||||
use zclaw_memory::MemoryStore;
|
use zclaw_memory::MemoryStore;
|
||||||
|
|
||||||
/// Agent loop runner
|
/// Agent loop runner
|
||||||
@@ -35,10 +34,6 @@ pub struct AgentLoop {
|
|||||||
compaction_threshold: usize,
|
compaction_threshold: usize,
|
||||||
/// Compaction behavior configuration
|
/// Compaction behavior configuration
|
||||||
compaction_config: CompactionConfig,
|
compaction_config: CompactionConfig,
|
||||||
/// Optional middleware chain — when `Some`, cross-cutting logic is
|
|
||||||
/// delegated to the chain instead of the inline code below.
|
|
||||||
/// When `None`, the legacy inline path is used (100% backward compatible).
|
|
||||||
middleware_chain: Option<MiddlewareChain>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AgentLoop {
|
impl AgentLoop {
|
||||||
@@ -63,7 +58,6 @@ impl AgentLoop {
|
|||||||
growth: None,
|
growth: None,
|
||||||
compaction_threshold: 0,
|
compaction_threshold: 0,
|
||||||
compaction_config: CompactionConfig::default(),
|
compaction_config: CompactionConfig::default(),
|
||||||
middleware_chain: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,14 +124,6 @@ impl AgentLoop {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inject a middleware chain. When set, cross-cutting concerns (compaction,
|
|
||||||
/// loop guard, token calibration, etc.) are delegated to the chain instead
|
|
||||||
/// of the inline logic.
|
|
||||||
pub fn with_middleware_chain(mut self, chain: MiddlewareChain) -> Self {
|
|
||||||
self.middleware_chain = Some(chain);
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get growth integration reference
|
/// Get growth integration reference
|
||||||
pub fn growth(&self) -> Option<&GrowthIntegration> {
|
pub fn growth(&self) -> Option<&GrowthIntegration> {
|
||||||
self.growth.as_ref()
|
self.growth.as_ref()
|
||||||
@@ -145,30 +131,12 @@ impl AgentLoop {
|
|||||||
|
|
||||||
/// Create tool context for tool execution
|
/// Create tool context for tool execution
|
||||||
fn create_tool_context(&self, session_id: SessionId) -> ToolContext {
|
fn create_tool_context(&self, session_id: SessionId) -> ToolContext {
|
||||||
// If no path_validator is configured, create a default one with user home as workspace.
|
|
||||||
// This allows file_read/file_write tools to work without explicit workspace config,
|
|
||||||
// while still restricting access to the user's home directory for security.
|
|
||||||
let path_validator = self.path_validator.clone().unwrap_or_else(|| {
|
|
||||||
let home = std::env::var("USERPROFILE")
|
|
||||||
.or_else(|_| std::env::var("HOME"))
|
|
||||||
.unwrap_or_else(|_| ".".to_string());
|
|
||||||
let home_path = std::path::PathBuf::from(&home);
|
|
||||||
tracing::info!(
|
|
||||||
"[AgentLoop] No path_validator configured, using user home as workspace: {}",
|
|
||||||
home_path.display()
|
|
||||||
);
|
|
||||||
PathValidator::new().with_workspace(home_path)
|
|
||||||
});
|
|
||||||
|
|
||||||
let working_dir = path_validator.workspace_root()
|
|
||||||
.map(|p| p.to_string_lossy().to_string());
|
|
||||||
|
|
||||||
ToolContext {
|
ToolContext {
|
||||||
agent_id: self.agent_id.clone(),
|
agent_id: self.agent_id.clone(),
|
||||||
working_directory: working_dir,
|
working_directory: None,
|
||||||
session_id: Some(session_id.to_string()),
|
session_id: Some(session_id.to_string()),
|
||||||
skill_executor: self.skill_executor.clone(),
|
skill_executor: self.skill_executor.clone(),
|
||||||
path_validator: Some(path_validator),
|
path_validator: self.path_validator.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -189,10 +157,8 @@ impl AgentLoop {
|
|||||||
// Get all messages for context
|
// Get all messages for context
|
||||||
let mut messages = self.memory.get_messages(&session_id).await?;
|
let mut messages = self.memory.get_messages(&session_id).await?;
|
||||||
|
|
||||||
let use_middleware = self.middleware_chain.is_some();
|
// Apply compaction if threshold is configured
|
||||||
|
if self.compaction_threshold > 0 {
|
||||||
// Apply compaction — skip inline path when middleware chain handles it
|
|
||||||
if !use_middleware && self.compaction_threshold > 0 {
|
|
||||||
let needs_async =
|
let needs_async =
|
||||||
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
||||||
if needs_async {
|
if needs_async {
|
||||||
@@ -212,44 +178,14 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enhance system prompt — skip when middleware chain handles it
|
// Enhance system prompt with growth memories
|
||||||
let mut enhanced_prompt = if use_middleware {
|
let enhanced_prompt = if let Some(ref growth) = self.growth {
|
||||||
self.system_prompt.clone().unwrap_or_default()
|
|
||||||
} else if let Some(ref growth) = self.growth {
|
|
||||||
let base = self.system_prompt.as_deref().unwrap_or("");
|
let base = self.system_prompt.as_deref().unwrap_or("");
|
||||||
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
||||||
} else {
|
} else {
|
||||||
self.system_prompt.clone().unwrap_or_default()
|
self.system_prompt.clone().unwrap_or_default()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Run middleware before_completion hooks (compaction, memory inject, etc.)
|
|
||||||
if let Some(ref chain) = self.middleware_chain {
|
|
||||||
let mut mw_ctx = middleware::MiddlewareContext {
|
|
||||||
agent_id: self.agent_id.clone(),
|
|
||||||
session_id: session_id.clone(),
|
|
||||||
user_input: input.clone(),
|
|
||||||
system_prompt: enhanced_prompt.clone(),
|
|
||||||
messages,
|
|
||||||
response_content: Vec::new(),
|
|
||||||
input_tokens: 0,
|
|
||||||
output_tokens: 0,
|
|
||||||
};
|
|
||||||
match chain.run_before_completion(&mut mw_ctx).await? {
|
|
||||||
middleware::MiddlewareDecision::Continue => {
|
|
||||||
messages = mw_ctx.messages;
|
|
||||||
enhanced_prompt = mw_ctx.system_prompt;
|
|
||||||
}
|
|
||||||
middleware::MiddlewareDecision::Stop(reason) => {
|
|
||||||
return Ok(AgentLoopResult {
|
|
||||||
response: reason,
|
|
||||||
input_tokens: 0,
|
|
||||||
output_tokens: 0,
|
|
||||||
iterations: 1,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let max_iterations = 10;
|
let max_iterations = 10;
|
||||||
let mut iterations = 0;
|
let mut iterations = 0;
|
||||||
let mut total_input_tokens = 0u32;
|
let mut total_input_tokens = 0u32;
|
||||||
@@ -286,14 +222,6 @@ impl AgentLoop {
|
|||||||
total_input_tokens += response.input_tokens;
|
total_input_tokens += response.input_tokens;
|
||||||
total_output_tokens += response.output_tokens;
|
total_output_tokens += response.output_tokens;
|
||||||
|
|
||||||
// Calibrate token estimation on first iteration
|
|
||||||
if iterations == 1 {
|
|
||||||
compaction::update_calibration(
|
|
||||||
compaction::estimate_messages_tokens(&messages),
|
|
||||||
response.input_tokens,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract tool calls from response
|
// Extract tool calls from response
|
||||||
let tool_calls: Vec<(String, String, serde_json::Value)> = response.content.iter()
|
let tool_calls: Vec<(String, String, serde_json::Value)> = response.content.iter()
|
||||||
.filter_map(|block| match block {
|
.filter_map(|block| match block {
|
||||||
@@ -302,49 +230,30 @@ impl AgentLoop {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
// Extract text and thinking separately
|
|
||||||
let text_parts: Vec<String> = response.content.iter()
|
|
||||||
.filter_map(|block| match block {
|
|
||||||
ContentBlock::Text { text } => Some(text.clone()),
|
|
||||||
_ => None,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let thinking_parts: Vec<String> = response.content.iter()
|
|
||||||
.filter_map(|block| match block {
|
|
||||||
ContentBlock::Thinking { thinking } => Some(thinking.clone()),
|
|
||||||
_ => None,
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let text_content = text_parts.join("\n");
|
|
||||||
let thinking_content = if thinking_parts.is_empty() { None } else { Some(thinking_parts.join("")) };
|
|
||||||
|
|
||||||
// If no tool calls, we have the final response
|
// If no tool calls, we have the final response
|
||||||
if tool_calls.is_empty() {
|
if tool_calls.is_empty() {
|
||||||
// Save final assistant message with thinking
|
// Extract text content
|
||||||
let msg = if let Some(thinking) = &thinking_content {
|
let text = response.content.iter()
|
||||||
Message::assistant_with_thinking(&text_content, thinking)
|
.filter_map(|block| match block {
|
||||||
} else {
|
ContentBlock::Text { text } => Some(text.clone()),
|
||||||
Message::assistant(&text_content)
|
ContentBlock::Thinking { thinking } => Some(format!("[思考] {}", thinking)),
|
||||||
};
|
_ => None,
|
||||||
self.memory.append_message(&session_id, &msg).await?;
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
// Save final assistant message
|
||||||
|
self.memory.append_message(&session_id, &Message::assistant(&text)).await?;
|
||||||
|
|
||||||
break AgentLoopResult {
|
break AgentLoopResult {
|
||||||
response: text_content,
|
response: text,
|
||||||
input_tokens: total_input_tokens,
|
input_tokens: total_input_tokens,
|
||||||
output_tokens: total_output_tokens,
|
output_tokens: total_output_tokens,
|
||||||
iterations,
|
iterations,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// There are tool calls - push assistant message with thinking before tool calls
|
// There are tool calls - add assistant message with tool calls to history
|
||||||
// (required by Kimi and other thinking-enabled APIs)
|
|
||||||
let assistant_msg = if let Some(thinking) = &thinking_content {
|
|
||||||
Message::assistant_with_thinking(&text_content, thinking)
|
|
||||||
} else {
|
|
||||||
Message::assistant(&text_content)
|
|
||||||
};
|
|
||||||
messages.push(assistant_msg);
|
|
||||||
|
|
||||||
for (id, name, input) in &tool_calls {
|
for (id, name, input) in &tool_calls {
|
||||||
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
|
messages.push(Message::tool_use(id, zclaw_types::ToolId::new(name), input.clone()));
|
||||||
}
|
}
|
||||||
@@ -353,56 +262,24 @@ impl AgentLoop {
|
|||||||
let tool_context = self.create_tool_context(session_id.clone());
|
let tool_context = self.create_tool_context(session_id.clone());
|
||||||
let mut circuit_breaker_triggered = false;
|
let mut circuit_breaker_triggered = false;
|
||||||
for (id, name, input) in tool_calls {
|
for (id, name, input) in tool_calls {
|
||||||
// Check tool call safety — via middleware chain or inline loop guard
|
// Check loop guard before executing tool
|
||||||
if let Some(ref chain) = self.middleware_chain {
|
let guard_result = self.loop_guard.lock().unwrap().check(&name, &input);
|
||||||
let mw_ctx_ref = middleware::MiddlewareContext {
|
match guard_result {
|
||||||
agent_id: self.agent_id.clone(),
|
LoopGuardResult::CircuitBreaker => {
|
||||||
session_id: session_id.clone(),
|
tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name);
|
||||||
user_input: input.to_string(),
|
circuit_breaker_triggered = true;
|
||||||
system_prompt: enhanced_prompt.clone(),
|
break;
|
||||||
messages: messages.clone(),
|
|
||||||
response_content: Vec::new(),
|
|
||||||
input_tokens: total_input_tokens,
|
|
||||||
output_tokens: total_output_tokens,
|
|
||||||
};
|
|
||||||
match chain.run_before_tool_call(&mw_ctx_ref, &name, &input).await? {
|
|
||||||
middleware::ToolCallDecision::Allow => {}
|
|
||||||
middleware::ToolCallDecision::Block(msg) => {
|
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
|
||||||
let error_output = serde_json::json!({ "error": msg });
|
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
middleware::ToolCallDecision::ReplaceInput(new_input) => {
|
|
||||||
// Execute with replaced input
|
|
||||||
let tool_result = match self.execute_tool(&name, new_input, &tool_context).await {
|
|
||||||
Ok(result) => result,
|
|
||||||
Err(e) => serde_json::json!({ "error": e.to_string() }),
|
|
||||||
};
|
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), tool_result, false));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
LoopGuardResult::Blocked => {
|
||||||
// Legacy inline path
|
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||||
let guard_result = self.loop_guard.lock().unwrap().check(&name, &input);
|
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||||
match guard_result {
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||||
LoopGuardResult::CircuitBreaker => {
|
continue;
|
||||||
tracing::warn!("[AgentLoop] Circuit breaker triggered by tool '{}'", name);
|
|
||||||
circuit_breaker_triggered = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
LoopGuardResult::Blocked => {
|
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
|
||||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
LoopGuardResult::Warn => {
|
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
|
||||||
}
|
|
||||||
LoopGuardResult::Allowed => {}
|
|
||||||
}
|
}
|
||||||
|
LoopGuardResult::Warn => {
|
||||||
|
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||||
|
}
|
||||||
|
LoopGuardResult::Allowed => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
let tool_result = match self.execute_tool(&name, input, &tool_context).await {
|
let tool_result = match self.execute_tool(&name, input, &tool_context).await {
|
||||||
@@ -434,23 +311,8 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Post-completion processing — middleware chain or inline growth
|
// Process conversation for memory extraction (post-conversation)
|
||||||
if let Some(ref chain) = self.middleware_chain {
|
if let Some(ref growth) = self.growth {
|
||||||
let mw_ctx = middleware::MiddlewareContext {
|
|
||||||
agent_id: self.agent_id.clone(),
|
|
||||||
session_id: session_id.clone(),
|
|
||||||
user_input: input.clone(),
|
|
||||||
system_prompt: enhanced_prompt.clone(),
|
|
||||||
messages: self.memory.get_messages(&session_id).await.unwrap_or_default(),
|
|
||||||
response_content: Vec::new(),
|
|
||||||
input_tokens: total_input_tokens,
|
|
||||||
output_tokens: total_output_tokens,
|
|
||||||
};
|
|
||||||
if let Err(e) = chain.run_after_completion(&mw_ctx).await {
|
|
||||||
tracing::warn!("[AgentLoop] Middleware after_completion failed: {}", e);
|
|
||||||
}
|
|
||||||
} else if let Some(ref growth) = self.growth {
|
|
||||||
// Legacy inline path
|
|
||||||
if let Ok(all_messages) = self.memory.get_messages(&session_id).await {
|
if let Ok(all_messages) = self.memory.get_messages(&session_id).await {
|
||||||
if let Err(e) = growth.process_conversation(&self.agent_id, &all_messages, session_id.clone()).await {
|
if let Err(e) = growth.process_conversation(&self.agent_id, &all_messages, session_id.clone()).await {
|
||||||
tracing::warn!("[AgentLoop] Growth processing failed: {}", e);
|
tracing::warn!("[AgentLoop] Growth processing failed: {}", e);
|
||||||
@@ -477,10 +339,8 @@ impl AgentLoop {
|
|||||||
// Get all messages for context
|
// Get all messages for context
|
||||||
let mut messages = self.memory.get_messages(&session_id).await?;
|
let mut messages = self.memory.get_messages(&session_id).await?;
|
||||||
|
|
||||||
let use_middleware = self.middleware_chain.is_some();
|
// Apply compaction if threshold is configured
|
||||||
|
if self.compaction_threshold > 0 {
|
||||||
// Apply compaction — skip inline path when middleware chain handles it
|
|
||||||
if !use_middleware && self.compaction_threshold > 0 {
|
|
||||||
let needs_async =
|
let needs_async =
|
||||||
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
self.compaction_config.use_llm || self.compaction_config.memory_flush_enabled;
|
||||||
if needs_async {
|
if needs_async {
|
||||||
@@ -500,52 +360,20 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enhance system prompt — skip when middleware chain handles it
|
// Enhance system prompt with growth memories
|
||||||
let mut enhanced_prompt = if use_middleware {
|
let enhanced_prompt = if let Some(ref growth) = self.growth {
|
||||||
self.system_prompt.clone().unwrap_or_default()
|
|
||||||
} else if let Some(ref growth) = self.growth {
|
|
||||||
let base = self.system_prompt.as_deref().unwrap_or("");
|
let base = self.system_prompt.as_deref().unwrap_or("");
|
||||||
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
growth.enhance_prompt(&self.agent_id, base, &input).await?
|
||||||
} else {
|
} else {
|
||||||
self.system_prompt.clone().unwrap_or_default()
|
self.system_prompt.clone().unwrap_or_default()
|
||||||
};
|
};
|
||||||
|
|
||||||
// Run middleware before_completion hooks (compaction, memory inject, etc.)
|
|
||||||
if let Some(ref chain) = self.middleware_chain {
|
|
||||||
let mut mw_ctx = middleware::MiddlewareContext {
|
|
||||||
agent_id: self.agent_id.clone(),
|
|
||||||
session_id: session_id.clone(),
|
|
||||||
user_input: input.clone(),
|
|
||||||
system_prompt: enhanced_prompt.clone(),
|
|
||||||
messages,
|
|
||||||
response_content: Vec::new(),
|
|
||||||
input_tokens: 0,
|
|
||||||
output_tokens: 0,
|
|
||||||
};
|
|
||||||
match chain.run_before_completion(&mut mw_ctx).await? {
|
|
||||||
middleware::MiddlewareDecision::Continue => {
|
|
||||||
messages = mw_ctx.messages;
|
|
||||||
enhanced_prompt = mw_ctx.system_prompt;
|
|
||||||
}
|
|
||||||
middleware::MiddlewareDecision::Stop(reason) => {
|
|
||||||
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
|
|
||||||
response: reason,
|
|
||||||
input_tokens: 0,
|
|
||||||
output_tokens: 0,
|
|
||||||
iterations: 1,
|
|
||||||
})).await;
|
|
||||||
return Ok(rx);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clone necessary data for the async task
|
// Clone necessary data for the async task
|
||||||
let session_id_clone = session_id.clone();
|
let session_id_clone = session_id.clone();
|
||||||
let memory = self.memory.clone();
|
let memory = self.memory.clone();
|
||||||
let driver = self.driver.clone();
|
let driver = self.driver.clone();
|
||||||
let tools = self.tools.clone();
|
let tools = self.tools.clone();
|
||||||
let loop_guard_clone = self.loop_guard.lock().unwrap().clone();
|
let loop_guard_clone = self.loop_guard.lock().unwrap().clone();
|
||||||
let middleware_chain = self.middleware_chain.clone();
|
|
||||||
let skill_executor = self.skill_executor.clone();
|
let skill_executor = self.skill_executor.clone();
|
||||||
let path_validator = self.path_validator.clone();
|
let path_validator = self.path_validator.clone();
|
||||||
let agent_id = self.agent_id.clone();
|
let agent_id = self.agent_id.clone();
|
||||||
@@ -589,29 +417,19 @@ impl AgentLoop {
|
|||||||
let mut stream = driver.stream(request);
|
let mut stream = driver.stream(request);
|
||||||
let mut pending_tool_calls: Vec<(String, String, serde_json::Value)> = Vec::new();
|
let mut pending_tool_calls: Vec<(String, String, serde_json::Value)> = Vec::new();
|
||||||
let mut iteration_text = String::new();
|
let mut iteration_text = String::new();
|
||||||
let mut reasoning_text = String::new(); // Track reasoning separately for API requirement
|
|
||||||
|
|
||||||
// Process stream chunks
|
// Process stream chunks
|
||||||
tracing::debug!("[AgentLoop] Starting to process stream chunks");
|
tracing::debug!("[AgentLoop] Starting to process stream chunks");
|
||||||
let mut chunk_count: usize = 0;
|
|
||||||
let mut text_delta_count: usize = 0;
|
|
||||||
let mut thinking_delta_count: usize = 0;
|
|
||||||
while let Some(chunk_result) = stream.next().await {
|
while let Some(chunk_result) = stream.next().await {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
Ok(chunk) => {
|
Ok(chunk) => {
|
||||||
chunk_count += 1;
|
|
||||||
match &chunk {
|
match &chunk {
|
||||||
StreamChunk::TextDelta { delta } => {
|
StreamChunk::TextDelta { delta } => {
|
||||||
text_delta_count += 1;
|
|
||||||
tracing::debug!("[AgentLoop] TextDelta #{}: {} chars", text_delta_count, delta.len());
|
|
||||||
iteration_text.push_str(delta);
|
iteration_text.push_str(delta);
|
||||||
let _ = tx.send(LoopEvent::Delta(delta.clone())).await;
|
let _ = tx.send(LoopEvent::Delta(delta.clone())).await;
|
||||||
}
|
}
|
||||||
StreamChunk::ThinkingDelta { delta } => {
|
StreamChunk::ThinkingDelta { delta } => {
|
||||||
thinking_delta_count += 1;
|
let _ = tx.send(LoopEvent::Delta(format!("[思考] {}", delta))).await;
|
||||||
tracing::debug!("[AgentLoop] ThinkingDelta #{}: {} chars", thinking_delta_count, delta.len());
|
|
||||||
// Accumulate reasoning separately — not mixed into iteration_text
|
|
||||||
reasoning_text.push_str(delta);
|
|
||||||
}
|
}
|
||||||
StreamChunk::ToolUseStart { id, name } => {
|
StreamChunk::ToolUseStart { id, name } => {
|
||||||
tracing::debug!("[AgentLoop] ToolUseStart: id={}, name={}", id, name);
|
tracing::debug!("[AgentLoop] ToolUseStart: id={}, name={}", id, name);
|
||||||
@@ -640,13 +458,6 @@ impl AgentLoop {
|
|||||||
tracing::debug!("[AgentLoop] Stream complete: input_tokens={}, output_tokens={}", it, ot);
|
tracing::debug!("[AgentLoop] Stream complete: input_tokens={}, output_tokens={}", it, ot);
|
||||||
total_input_tokens += *it;
|
total_input_tokens += *it;
|
||||||
total_output_tokens += *ot;
|
total_output_tokens += *ot;
|
||||||
// Calibrate token estimation on first iteration
|
|
||||||
if iteration == 1 {
|
|
||||||
compaction::update_calibration(
|
|
||||||
compaction::estimate_messages_tokens(&messages),
|
|
||||||
*it,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
StreamChunk::Error { message } => {
|
StreamChunk::Error { message } => {
|
||||||
tracing::error!("[AgentLoop] Stream error: {}", message);
|
tracing::error!("[AgentLoop] Stream error: {}", message);
|
||||||
@@ -660,59 +471,24 @@ impl AgentLoop {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
tracing::info!("[AgentLoop] Stream ended: {} total chunks (text={}, thinking={}, tools={}), iteration_text={} chars",
|
tracing::debug!("[AgentLoop] Stream ended, pending_tool_calls count: {}", pending_tool_calls.len());
|
||||||
chunk_count, text_delta_count, thinking_delta_count, pending_tool_calls.len(),
|
|
||||||
iteration_text.len());
|
|
||||||
if iteration_text.is_empty() {
|
|
||||||
tracing::warn!("[AgentLoop] WARNING: iteration_text is EMPTY after {} chunks! text_delta={}, thinking_delta={}",
|
|
||||||
chunk_count, text_delta_count, thinking_delta_count);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no tool calls, we have the final response
|
// If no tool calls, we have the final response
|
||||||
if pending_tool_calls.is_empty() {
|
if pending_tool_calls.is_empty() {
|
||||||
tracing::info!("[AgentLoop] No tool calls, returning final response: {} chars (reasoning: {} chars)", iteration_text.len(), reasoning_text.len());
|
tracing::debug!("[AgentLoop] No tool calls, returning final response");
|
||||||
// Save final assistant message with reasoning
|
// Save final assistant message
|
||||||
if let Err(e) = memory.append_message(&session_id_clone, &Message::assistant_with_thinking(
|
let _ = memory.append_message(&session_id_clone, &Message::assistant(&iteration_text)).await;
|
||||||
&iteration_text,
|
|
||||||
&reasoning_text,
|
|
||||||
)).await {
|
|
||||||
tracing::warn!("[AgentLoop] Failed to save final assistant message: {}", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
|
let _ = tx.send(LoopEvent::Complete(AgentLoopResult {
|
||||||
response: iteration_text.clone(),
|
response: iteration_text,
|
||||||
input_tokens: total_input_tokens,
|
input_tokens: total_input_tokens,
|
||||||
output_tokens: total_output_tokens,
|
output_tokens: total_output_tokens,
|
||||||
iterations: iteration,
|
iterations: iteration,
|
||||||
})).await;
|
})).await;
|
||||||
|
|
||||||
// Post-completion: middleware after_completion (memory extraction, etc.)
|
|
||||||
if let Some(ref chain) = middleware_chain {
|
|
||||||
let mw_ctx = middleware::MiddlewareContext {
|
|
||||||
agent_id: agent_id.clone(),
|
|
||||||
session_id: session_id_clone.clone(),
|
|
||||||
user_input: String::new(),
|
|
||||||
system_prompt: enhanced_prompt.clone(),
|
|
||||||
messages: memory.get_messages(&session_id_clone).await.unwrap_or_default(),
|
|
||||||
response_content: Vec::new(),
|
|
||||||
input_tokens: total_input_tokens,
|
|
||||||
output_tokens: total_output_tokens,
|
|
||||||
};
|
|
||||||
if let Err(e) = chain.run_after_completion(&mw_ctx).await {
|
|
||||||
tracing::warn!("[AgentLoop] Streaming middleware after_completion failed: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
break 'outer;
|
break 'outer;
|
||||||
}
|
}
|
||||||
|
|
||||||
tracing::debug!("[AgentLoop] Processing {} tool calls (reasoning: {} chars)", pending_tool_calls.len(), reasoning_text.len());
|
tracing::debug!("[AgentLoop] Processing {} tool calls", pending_tool_calls.len());
|
||||||
|
|
||||||
// Push assistant message with reasoning before tool calls (required by Kimi and other thinking-enabled APIs)
|
|
||||||
messages.push(Message::assistant_with_thinking(
|
|
||||||
&iteration_text,
|
|
||||||
&reasoning_text,
|
|
||||||
));
|
|
||||||
|
|
||||||
// There are tool calls - add to message history
|
// There are tool calls - add to message history
|
||||||
for (id, name, input) in &pending_tool_calls {
|
for (id, name, input) in &pending_tool_calls {
|
||||||
@@ -724,108 +500,31 @@ impl AgentLoop {
|
|||||||
for (id, name, input) in pending_tool_calls {
|
for (id, name, input) in pending_tool_calls {
|
||||||
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
|
tracing::debug!("[AgentLoop] Executing tool: name={}, input={:?}", name, input);
|
||||||
|
|
||||||
// Check tool call safety — via middleware chain or inline loop guard
|
// Check loop guard before executing tool
|
||||||
if let Some(ref chain) = middleware_chain {
|
let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input);
|
||||||
let mw_ctx = middleware::MiddlewareContext {
|
match guard_result {
|
||||||
agent_id: agent_id.clone(),
|
LoopGuardResult::CircuitBreaker => {
|
||||||
session_id: session_id_clone.clone(),
|
let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await;
|
||||||
user_input: input.to_string(),
|
break 'outer;
|
||||||
system_prompt: enhanced_prompt.clone(),
|
|
||||||
messages: messages.clone(),
|
|
||||||
response_content: Vec::new(),
|
|
||||||
input_tokens: total_input_tokens,
|
|
||||||
output_tokens: total_output_tokens,
|
|
||||||
};
|
|
||||||
match chain.run_before_tool_call(&mw_ctx, &name, &input).await {
|
|
||||||
Ok(middleware::ToolCallDecision::Allow) => {}
|
|
||||||
Ok(middleware::ToolCallDecision::Block(msg)) => {
|
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
|
||||||
let error_output = serde_json::json!({ "error": msg });
|
|
||||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
Ok(middleware::ToolCallDecision::ReplaceInput(new_input)) => {
|
|
||||||
// Execute with replaced input (same path_validator logic below)
|
|
||||||
let pv = path_validator.clone().unwrap_or_else(|| {
|
|
||||||
let home = std::env::var("USERPROFILE")
|
|
||||||
.or_else(|_| std::env::var("HOME"))
|
|
||||||
.unwrap_or_else(|_| ".".to_string());
|
|
||||||
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
|
|
||||||
});
|
|
||||||
let working_dir = pv.workspace_root()
|
|
||||||
.map(|p| p.to_string_lossy().to_string());
|
|
||||||
let tool_context = ToolContext {
|
|
||||||
agent_id: agent_id.clone(),
|
|
||||||
working_directory: working_dir,
|
|
||||||
session_id: Some(session_id_clone.to_string()),
|
|
||||||
skill_executor: skill_executor.clone(),
|
|
||||||
path_validator: Some(pv),
|
|
||||||
};
|
|
||||||
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
|
||||||
match tool.execute(new_input, &tool_context).await {
|
|
||||||
Ok(output) => {
|
|
||||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: output.clone() }).await;
|
|
||||||
(output, false)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
let error_output = serde_json::json!({ "error": e.to_string() });
|
|
||||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
|
||||||
(error_output, true)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
let error_output = serde_json::json!({ "error": format!("Unknown tool: {}", name) });
|
|
||||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
|
||||||
(error_output, true)
|
|
||||||
};
|
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), result, is_error));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("[AgentLoop] Middleware error for tool '{}': {}", name, e);
|
|
||||||
let error_output = serde_json::json!({ "error": e.to_string() });
|
|
||||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
LoopGuardResult::Blocked => {
|
||||||
// Legacy inline loop guard path
|
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
||||||
let guard_result = loop_guard_clone.lock().unwrap().check(&name, &input);
|
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
||||||
match guard_result {
|
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
||||||
LoopGuardResult::CircuitBreaker => {
|
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||||
let _ = tx.send(LoopEvent::Error("检测到工具调用循环,已自动终止".to_string())).await;
|
continue;
|
||||||
break 'outer;
|
|
||||||
}
|
|
||||||
LoopGuardResult::Blocked => {
|
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by loop guard", name);
|
|
||||||
let error_output = serde_json::json!({ "error": "工具调用被循环防护拦截" });
|
|
||||||
let _ = tx.send(LoopEvent::ToolEnd { name: name.clone(), output: error_output.clone() }).await;
|
|
||||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
LoopGuardResult::Warn => {
|
|
||||||
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
|
||||||
}
|
|
||||||
LoopGuardResult::Allowed => {}
|
|
||||||
}
|
}
|
||||||
|
LoopGuardResult::Warn => {
|
||||||
|
tracing::warn!("[AgentLoop] Tool '{}' triggered loop guard warning", name);
|
||||||
|
}
|
||||||
|
LoopGuardResult::Allowed => {}
|
||||||
}
|
}
|
||||||
// Use pre-resolved path_validator (already has default fallback from create_tool_context logic)
|
|
||||||
let pv = path_validator.clone().unwrap_or_else(|| {
|
|
||||||
let home = std::env::var("USERPROFILE")
|
|
||||||
.or_else(|_| std::env::var("HOME"))
|
|
||||||
.unwrap_or_else(|_| ".".to_string());
|
|
||||||
PathValidator::new().with_workspace(std::path::PathBuf::from(&home))
|
|
||||||
});
|
|
||||||
let working_dir = pv.workspace_root()
|
|
||||||
.map(|p| p.to_string_lossy().to_string());
|
|
||||||
let tool_context = ToolContext {
|
let tool_context = ToolContext {
|
||||||
agent_id: agent_id.clone(),
|
agent_id: agent_id.clone(),
|
||||||
working_directory: working_dir,
|
working_directory: None,
|
||||||
session_id: Some(session_id_clone.to_string()),
|
session_id: Some(session_id_clone.to_string()),
|
||||||
skill_executor: skill_executor.clone(),
|
skill_executor: skill_executor.clone(),
|
||||||
path_validator: Some(pv),
|
path_validator: path_validator.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
let (result, is_error) = if let Some(tool) = tools.get(&name) {
|
||||||
|
|||||||
@@ -1,252 +0,0 @@
|
|||||||
//! Agent middleware system — composable hooks for cross-cutting concerns.
|
|
||||||
//!
|
|
||||||
//! Inspired by [DeerFlow 2.0](https://github.com/bytedance/deer-flow)'s 9-layer middleware chain,
|
|
||||||
//! this module provides a standardised way to inject behaviour before/after LLM completions
|
|
||||||
//! and tool calls without modifying the core `AgentLoop` logic.
|
|
||||||
//!
|
|
||||||
//! # Priority convention
|
|
||||||
//!
|
|
||||||
//! | Range | Category | Example |
|
|
||||||
//! |---------|----------------|-----------------------------|
|
|
||||||
//! | 100-199 | Context shaping| Compaction, MemoryInject |
|
|
||||||
//! | 200-399 | Capability | SkillIndex, Guardrail |
|
|
||||||
//! | 400-599 | Safety | LoopGuard, Guardrail |
|
|
||||||
//! | 600-799 | Telemetry | TokenCalibration, Tracking |
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::Value;
|
|
||||||
use zclaw_types::{AgentId, Result, SessionId};
|
|
||||||
use crate::driver::ContentBlock;
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// Decisions returned by middleware hooks
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/// Decision returned by `before_completion`.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum MiddlewareDecision {
|
|
||||||
/// Continue to the next middleware / proceed with the LLM call.
|
|
||||||
Continue,
|
|
||||||
/// Abort the agent loop and return *reason* to the caller.
|
|
||||||
Stop(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decision returned by `before_tool_call`.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum ToolCallDecision {
|
|
||||||
/// Allow the tool call to proceed unchanged.
|
|
||||||
Allow,
|
|
||||||
/// Block the call and return *message* as a tool-error to the LLM.
|
|
||||||
Block(String),
|
|
||||||
/// Allow the call but replace the tool input with *new_input*.
|
|
||||||
ReplaceInput(Value),
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// Middleware context — shared mutable state passed through the chain
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/// Carries the mutable state that middleware may inspect or modify.
|
|
||||||
pub struct MiddlewareContext {
|
|
||||||
/// The agent that owns this loop.
|
|
||||||
pub agent_id: AgentId,
|
|
||||||
/// Current session.
|
|
||||||
pub session_id: SessionId,
|
|
||||||
/// The raw user input that started this turn.
|
|
||||||
pub user_input: String,
|
|
||||||
|
|
||||||
// -- mutable state -------------------------------------------------------
|
|
||||||
/// System prompt — middleware may prepend/append context.
|
|
||||||
pub system_prompt: String,
|
|
||||||
/// Conversation messages sent to the LLM.
|
|
||||||
pub messages: Vec<zclaw_types::Message>,
|
|
||||||
/// Accumulated LLM content blocks from the current response.
|
|
||||||
pub response_content: Vec<ContentBlock>,
|
|
||||||
/// Token usage reported by the LLM driver (updated after each call).
|
|
||||||
pub input_tokens: u32,
|
|
||||||
pub output_tokens: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::fmt::Debug for MiddlewareContext {
|
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
||||||
f.debug_struct("MiddlewareContext")
|
|
||||||
.field("agent_id", &self.agent_id)
|
|
||||||
.field("session_id", &self.session_id)
|
|
||||||
.field("messages", &self.messages.len())
|
|
||||||
.field("input_tokens", &self.input_tokens)
|
|
||||||
.field("output_tokens", &self.output_tokens)
|
|
||||||
.finish()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// Core trait
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/// A composable middleware hook for the agent loop.
|
|
||||||
///
|
|
||||||
/// Each middleware focuses on one cross-cutting concern and is executed
|
|
||||||
/// in `priority` order (ascending). All hook methods have default no-op
|
|
||||||
/// implementations so implementors only override what they need.
|
|
||||||
#[async_trait]
|
|
||||||
pub trait AgentMiddleware: Send + Sync {
|
|
||||||
/// Human-readable name for logging / debugging.
|
|
||||||
fn name(&self) -> &str;
|
|
||||||
|
|
||||||
/// Execution priority — lower values run first.
|
|
||||||
fn priority(&self) -> i32 {
|
|
||||||
500
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Hook executed **before** the LLM completion request is sent.
|
|
||||||
///
|
|
||||||
/// Use this to inject context (memory, skill index, etc.) or to
|
|
||||||
/// trigger pre-processing (compaction, summarisation).
|
|
||||||
async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
|
||||||
Ok(MiddlewareDecision::Continue)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Hook executed **before** each tool call.
|
|
||||||
///
|
|
||||||
/// Return `Block` to prevent execution and feed an error back to
|
|
||||||
/// the LLM, or `ReplaceInput` to sanitise / modify the arguments.
|
|
||||||
async fn before_tool_call(
|
|
||||||
&self,
|
|
||||||
_ctx: &MiddlewareContext,
|
|
||||||
_tool_name: &str,
|
|
||||||
_tool_input: &Value,
|
|
||||||
) -> Result<ToolCallDecision> {
|
|
||||||
Ok(ToolCallDecision::Allow)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Hook executed **after** each tool call.
|
|
||||||
async fn after_tool_call(
|
|
||||||
&self,
|
|
||||||
_ctx: &mut MiddlewareContext,
|
|
||||||
_tool_name: &str,
|
|
||||||
_result: &Value,
|
|
||||||
) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Hook executed **after** the entire agent loop turn completes.
|
|
||||||
///
|
|
||||||
/// Use this for post-processing (memory extraction, telemetry, etc.).
|
|
||||||
async fn after_completion(&self, _ctx: &MiddlewareContext) -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// Middleware chain — ordered collection with run methods
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/// An ordered chain of `AgentMiddleware` instances.
|
|
||||||
pub struct MiddlewareChain {
|
|
||||||
middlewares: Vec<Arc<dyn AgentMiddleware>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MiddlewareChain {
|
|
||||||
/// Create an empty chain.
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self { middlewares: Vec::new() }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Register a middleware. The chain is kept sorted by `priority`
|
|
||||||
/// (ascending) and by registration order within the same priority.
|
|
||||||
pub fn register(&mut self, mw: Arc<dyn AgentMiddleware>) {
|
|
||||||
let p = mw.priority();
|
|
||||||
let pos = self.middlewares.iter().position(|m| m.priority() > p).unwrap_or(self.middlewares.len());
|
|
||||||
self.middlewares.insert(pos, mw);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run all `before_completion` hooks in order.
|
|
||||||
pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
|
||||||
for mw in &self.middlewares {
|
|
||||||
match mw.before_completion(ctx).await? {
|
|
||||||
MiddlewareDecision::Continue => {}
|
|
||||||
MiddlewareDecision::Stop(reason) => {
|
|
||||||
tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason);
|
|
||||||
return Ok(MiddlewareDecision::Stop(reason));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(MiddlewareDecision::Continue)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run all `before_tool_call` hooks in order.
|
|
||||||
pub async fn run_before_tool_call(
|
|
||||||
&self,
|
|
||||||
ctx: &MiddlewareContext,
|
|
||||||
tool_name: &str,
|
|
||||||
tool_input: &Value,
|
|
||||||
) -> Result<ToolCallDecision> {
|
|
||||||
for mw in &self.middlewares {
|
|
||||||
match mw.before_tool_call(ctx, tool_name, tool_input).await? {
|
|
||||||
ToolCallDecision::Allow => {}
|
|
||||||
other => {
|
|
||||||
tracing::info!("[MiddlewareChain] '{}' decided {:?} for tool '{}'", mw.name(), other, tool_name);
|
|
||||||
return Ok(other);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(ToolCallDecision::Allow)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run all `after_tool_call` hooks in order.
|
|
||||||
pub async fn run_after_tool_call(
|
|
||||||
&self,
|
|
||||||
ctx: &mut MiddlewareContext,
|
|
||||||
tool_name: &str,
|
|
||||||
result: &Value,
|
|
||||||
) -> Result<()> {
|
|
||||||
for mw in &self.middlewares {
|
|
||||||
mw.after_tool_call(ctx, tool_name, result).await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run all `after_completion` hooks in order.
|
|
||||||
pub async fn run_after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
|
||||||
for mw in &self.middlewares {
|
|
||||||
mw.after_completion(ctx).await?;
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Number of registered middlewares.
|
|
||||||
pub fn len(&self) -> usize {
|
|
||||||
self.middlewares.len()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Whether the chain is empty.
|
|
||||||
pub fn is_empty(&self) -> bool {
|
|
||||||
self.middlewares.is_empty()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Clone for MiddlewareChain {
|
|
||||||
fn clone(&self) -> Self {
|
|
||||||
Self {
|
|
||||||
middlewares: self.middlewares.clone(), // Arc clone — cheap ref-count bump
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for MiddlewareChain {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// Sub-modules — concrete middleware implementations
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
pub mod compaction;
|
|
||||||
pub mod guardrail;
|
|
||||||
pub mod loop_guard;
|
|
||||||
pub mod memory;
|
|
||||||
pub mod skill_index;
|
|
||||||
pub mod token_calibration;
|
|
||||||
@@ -1,61 +0,0 @@
|
|||||||
//! Compaction middleware — wraps the existing compaction module.
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use zclaw_types::Result;
|
|
||||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
|
||||||
use crate::compaction::{self, CompactionConfig};
|
|
||||||
use crate::growth::GrowthIntegration;
|
|
||||||
use crate::driver::LlmDriver;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
/// Middleware that compresses conversation history when it exceeds a token threshold.
|
|
||||||
pub struct CompactionMiddleware {
|
|
||||||
threshold: usize,
|
|
||||||
config: CompactionConfig,
|
|
||||||
/// Optional LLM driver for async compaction (LLM summarisation, memory flush).
|
|
||||||
driver: Option<Arc<dyn LlmDriver>>,
|
|
||||||
/// Optional growth integration for memory flushing during compaction.
|
|
||||||
growth: Option<GrowthIntegration>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl CompactionMiddleware {
|
|
||||||
pub fn new(
|
|
||||||
threshold: usize,
|
|
||||||
config: CompactionConfig,
|
|
||||||
driver: Option<Arc<dyn LlmDriver>>,
|
|
||||||
growth: Option<GrowthIntegration>,
|
|
||||||
) -> Self {
|
|
||||||
Self { threshold, config, driver, growth }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl AgentMiddleware for CompactionMiddleware {
|
|
||||||
fn name(&self) -> &str { "compaction" }
|
|
||||||
fn priority(&self) -> i32 { 100 }
|
|
||||||
|
|
||||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
|
||||||
if self.threshold == 0 {
|
|
||||||
return Ok(MiddlewareDecision::Continue);
|
|
||||||
}
|
|
||||||
|
|
||||||
let needs_async = self.config.use_llm || self.config.memory_flush_enabled;
|
|
||||||
if needs_async {
|
|
||||||
let outcome = compaction::maybe_compact_with_config(
|
|
||||||
ctx.messages.clone(),
|
|
||||||
self.threshold,
|
|
||||||
&self.config,
|
|
||||||
&ctx.agent_id,
|
|
||||||
&ctx.session_id,
|
|
||||||
self.driver.as_ref(),
|
|
||||||
self.growth.as_ref(),
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
ctx.messages = outcome.messages;
|
|
||||||
} else {
|
|
||||||
ctx.messages = compaction::maybe_compact(ctx.messages.clone(), self.threshold);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(MiddlewareDecision::Continue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,223 +0,0 @@
|
|||||||
//! Guardrail middleware — configurable safety rules for tool call evaluation.
|
|
||||||
//!
|
|
||||||
//! This middleware inspects tool calls before execution and can block or
|
|
||||||
//! modify them based on configurable rules. Inspired by DeerFlow's safety
|
|
||||||
//! evaluation hooks.
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::Value;
|
|
||||||
use std::collections::HashMap;
|
|
||||||
use zclaw_types::Result;
|
|
||||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
|
||||||
|
|
||||||
/// A single guardrail rule that can inspect and decide on tool calls.
|
|
||||||
pub trait GuardrailRule: Send + Sync {
|
|
||||||
/// Human-readable name for logging.
|
|
||||||
fn name(&self) -> &str;
|
|
||||||
|
|
||||||
/// Evaluate a tool call.
|
|
||||||
fn evaluate(&self, tool_name: &str, tool_input: &Value) -> GuardrailVerdict;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Decision returned by a guardrail rule.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub enum GuardrailVerdict {
|
|
||||||
/// Allow the tool call to proceed.
|
|
||||||
Allow,
|
|
||||||
/// Block the call and return *message* as an error to the LLM.
|
|
||||||
Block(String),
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Middleware that evaluates tool calls against a set of configurable safety rules.
|
|
||||||
///
|
|
||||||
/// Rules are grouped by tool name. When a tool call is made, all rules for
|
|
||||||
/// that tool are evaluated in order. If any rule returns `Block`, the call
|
|
||||||
/// is blocked. This is a "deny-by-exception" model — calls are allowed unless
|
|
||||||
/// a rule explicitly blocks them.
|
|
||||||
pub struct GuardrailMiddleware {
|
|
||||||
/// Rules keyed by tool name.
|
|
||||||
rules: HashMap<String, Vec<Box<dyn GuardrailRule>>>,
|
|
||||||
/// Default policy for tools with no specific rules: true = allow, false = block.
|
|
||||||
fail_open: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GuardrailMiddleware {
|
|
||||||
pub fn new(fail_open: bool) -> Self {
|
|
||||||
Self {
|
|
||||||
rules: HashMap::new(),
|
|
||||||
fail_open,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Register a guardrail rule for a specific tool.
|
|
||||||
pub fn add_rule(&mut self, tool_name: impl Into<String>, rule: Box<dyn GuardrailRule>) {
|
|
||||||
self.rules.entry(tool_name.into()).or_default().push(rule);
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Register built-in safety rules (shell_exec, file_write, web_fetch).
|
|
||||||
pub fn with_builtin_rules(mut self) -> Self {
|
|
||||||
self.add_rule("shell_exec", Box::new(ShellExecRule));
|
|
||||||
self.add_rule("file_write", Box::new(FileWriteRule));
|
|
||||||
self.add_rule("web_fetch", Box::new(WebFetchRule));
|
|
||||||
self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl AgentMiddleware for GuardrailMiddleware {
|
|
||||||
fn name(&self) -> &str { "guardrail" }
|
|
||||||
fn priority(&self) -> i32 { 400 }
|
|
||||||
|
|
||||||
async fn before_tool_call(
|
|
||||||
&self,
|
|
||||||
_ctx: &MiddlewareContext,
|
|
||||||
tool_name: &str,
|
|
||||||
tool_input: &Value,
|
|
||||||
) -> Result<ToolCallDecision> {
|
|
||||||
if let Some(rules) = self.rules.get(tool_name) {
|
|
||||||
for rule in rules {
|
|
||||||
match rule.evaluate(tool_name, tool_input) {
|
|
||||||
GuardrailVerdict::Allow => {}
|
|
||||||
GuardrailVerdict::Block(msg) => {
|
|
||||||
tracing::warn!(
|
|
||||||
"[GuardrailMiddleware] Rule '{}' blocked tool '{}': {}",
|
|
||||||
rule.name(),
|
|
||||||
tool_name,
|
|
||||||
msg
|
|
||||||
);
|
|
||||||
return Ok(ToolCallDecision::Block(msg));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if !self.fail_open {
|
|
||||||
// fail-closed: unknown tools are blocked
|
|
||||||
tracing::warn!(
|
|
||||||
"[GuardrailMiddleware] No rules for tool '{}', fail-closed policy blocks it",
|
|
||||||
tool_name
|
|
||||||
);
|
|
||||||
return Ok(ToolCallDecision::Block(format!(
|
|
||||||
"工具 '{}' 未注册安全规则,fail-closed 策略阻止执行",
|
|
||||||
tool_name
|
|
||||||
)));
|
|
||||||
}
|
|
||||||
Ok(ToolCallDecision::Allow)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// Built-in rules
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
/// Rule that blocks dangerous shell commands.
|
|
||||||
pub struct ShellExecRule;
|
|
||||||
|
|
||||||
impl GuardrailRule for ShellExecRule {
|
|
||||||
fn name(&self) -> &str { "shell_exec_dangerous_commands" }
|
|
||||||
|
|
||||||
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
|
||||||
let cmd = tool_input["command"].as_str().unwrap_or("");
|
|
||||||
let dangerous = [
|
|
||||||
"rm -rf /",
|
|
||||||
"rm -rf ~",
|
|
||||||
"del /s /q C:\\",
|
|
||||||
"format ",
|
|
||||||
"mkfs.",
|
|
||||||
"dd if=",
|
|
||||||
":(){ :|:& };:", // fork bomb
|
|
||||||
"> /dev/sda",
|
|
||||||
"shutdown",
|
|
||||||
"reboot",
|
|
||||||
];
|
|
||||||
let cmd_lower = cmd.to_lowercase();
|
|
||||||
for pattern in &dangerous {
|
|
||||||
if cmd_lower.contains(pattern) {
|
|
||||||
return GuardrailVerdict::Block(format!(
|
|
||||||
"危险命令被安全护栏拦截: 包含 '{}'",
|
|
||||||
pattern
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GuardrailVerdict::Allow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Rule that blocks writes to critical system directories.
|
|
||||||
pub struct FileWriteRule;
|
|
||||||
|
|
||||||
impl GuardrailRule for FileWriteRule {
|
|
||||||
fn name(&self) -> &str { "file_write_critical_dirs" }
|
|
||||||
|
|
||||||
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
|
||||||
let path = tool_input["path"].as_str().unwrap_or("");
|
|
||||||
let critical_prefixes = [
|
|
||||||
"/etc/",
|
|
||||||
"/usr/",
|
|
||||||
"/bin/",
|
|
||||||
"/sbin/",
|
|
||||||
"/boot/",
|
|
||||||
"/System/",
|
|
||||||
"/Library/",
|
|
||||||
"C:\\Windows\\",
|
|
||||||
"C:\\Program Files\\",
|
|
||||||
"C:\\ProgramData\\",
|
|
||||||
];
|
|
||||||
let path_lower = path.to_lowercase();
|
|
||||||
for prefix in &critical_prefixes {
|
|
||||||
if path_lower.starts_with(&prefix.to_lowercase()) {
|
|
||||||
return GuardrailVerdict::Block(format!(
|
|
||||||
"写入系统关键目录被拦截: {}",
|
|
||||||
path
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GuardrailVerdict::Allow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Rule that blocks web requests to internal/private network addresses.
|
|
||||||
pub struct WebFetchRule;
|
|
||||||
|
|
||||||
impl GuardrailRule for WebFetchRule {
|
|
||||||
fn name(&self) -> &str { "web_fetch_private_network" }
|
|
||||||
|
|
||||||
fn evaluate(&self, _tool_name: &str, tool_input: &Value) -> GuardrailVerdict {
|
|
||||||
let url = tool_input["url"].as_str().unwrap_or("");
|
|
||||||
let blocked = [
|
|
||||||
"localhost",
|
|
||||||
"127.0.0.1",
|
|
||||||
"0.0.0.0",
|
|
||||||
"10.",
|
|
||||||
"172.16.",
|
|
||||||
"172.17.",
|
|
||||||
"172.18.",
|
|
||||||
"172.19.",
|
|
||||||
"172.20.",
|
|
||||||
"172.21.",
|
|
||||||
"172.22.",
|
|
||||||
"172.23.",
|
|
||||||
"172.24.",
|
|
||||||
"172.25.",
|
|
||||||
"172.26.",
|
|
||||||
"172.27.",
|
|
||||||
"172.28.",
|
|
||||||
"172.29.",
|
|
||||||
"172.30.",
|
|
||||||
"172.31.",
|
|
||||||
"192.168.",
|
|
||||||
"::1",
|
|
||||||
"169.254.",
|
|
||||||
"metadata.google",
|
|
||||||
"metadata.azure",
|
|
||||||
];
|
|
||||||
let url_lower = url.to_lowercase();
|
|
||||||
for prefix in &blocked {
|
|
||||||
if url_lower.contains(prefix) {
|
|
||||||
return GuardrailVerdict::Block(format!(
|
|
||||||
"请求内网/私有地址被拦截: {}",
|
|
||||||
url
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GuardrailVerdict::Allow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
//! Loop guard middleware — extracts loop detection into a middleware hook.
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::Value;
|
|
||||||
use zclaw_types::Result;
|
|
||||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
|
||||||
use crate::loop_guard::{LoopGuard, LoopGuardConfig, LoopGuardResult};
|
|
||||||
use std::sync::Mutex;
|
|
||||||
|
|
||||||
/// Middleware that detects and blocks repetitive tool-call loops.
|
|
||||||
pub struct LoopGuardMiddleware {
|
|
||||||
guard: Mutex<LoopGuard>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LoopGuardMiddleware {
|
|
||||||
pub fn new(config: LoopGuardConfig) -> Self {
|
|
||||||
Self {
|
|
||||||
guard: Mutex::new(LoopGuard::new(config)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn with_defaults() -> Self {
|
|
||||||
Self {
|
|
||||||
guard: Mutex::new(LoopGuard::default()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl AgentMiddleware for LoopGuardMiddleware {
|
|
||||||
fn name(&self) -> &str { "loop_guard" }
|
|
||||||
fn priority(&self) -> i32 { 500 }
|
|
||||||
|
|
||||||
async fn before_tool_call(
|
|
||||||
&self,
|
|
||||||
_ctx: &MiddlewareContext,
|
|
||||||
tool_name: &str,
|
|
||||||
tool_input: &Value,
|
|
||||||
) -> Result<ToolCallDecision> {
|
|
||||||
let result = self.guard.lock().unwrap().check(tool_name, tool_input);
|
|
||||||
match result {
|
|
||||||
LoopGuardResult::CircuitBreaker => {
|
|
||||||
tracing::warn!("[LoopGuardMiddleware] Circuit breaker triggered by tool '{}'", tool_name);
|
|
||||||
Ok(ToolCallDecision::Block("检测到工具调用循环,已自动终止".to_string()))
|
|
||||||
}
|
|
||||||
LoopGuardResult::Blocked => {
|
|
||||||
tracing::warn!("[LoopGuardMiddleware] Tool '{}' blocked", tool_name);
|
|
||||||
Ok(ToolCallDecision::Block("工具调用被循环防护拦截".to_string()))
|
|
||||||
}
|
|
||||||
LoopGuardResult::Warn => {
|
|
||||||
tracing::warn!("[LoopGuardMiddleware] Tool '{}' triggered warning", tool_name);
|
|
||||||
Ok(ToolCallDecision::Allow)
|
|
||||||
}
|
|
||||||
LoopGuardResult::Allowed => Ok(ToolCallDecision::Allow),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
//! Memory middleware — unified pre/post hooks for memory retrieval and extraction.
|
|
||||||
//!
|
|
||||||
//! This middleware unifies the memory lifecycle:
|
|
||||||
//! - `before_completion`: retrieves relevant memories and injects them into the system prompt
|
|
||||||
//! - `after_completion`: extracts learnings from the conversation and stores them
|
|
||||||
//!
|
|
||||||
//! It replaces both the inline `GrowthIntegration` calls in `AgentLoop` and the
|
|
||||||
//! `intelligence_hooks` calls in the Tauri desktop layer.
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use zclaw_types::Result;
|
|
||||||
use crate::growth::GrowthIntegration;
|
|
||||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
|
||||||
|
|
||||||
/// Middleware that handles memory retrieval (pre-completion) and extraction (post-completion).
|
|
||||||
///
|
|
||||||
/// Wraps `GrowthIntegration` and delegates:
|
|
||||||
/// - `before_completion` → `enhance_prompt()` for memory injection
|
|
||||||
/// - `after_completion` → `process_conversation()` for memory extraction
|
|
||||||
pub struct MemoryMiddleware {
|
|
||||||
growth: GrowthIntegration,
|
|
||||||
/// Minimum seconds between extractions for the same agent (debounce).
|
|
||||||
debounce_secs: u64,
|
|
||||||
/// Timestamp of last extraction per agent (for debouncing).
|
|
||||||
last_extraction: std::sync::Mutex<std::collections::HashMap<String, std::time::Instant>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MemoryMiddleware {
|
|
||||||
pub fn new(growth: GrowthIntegration) -> Self {
|
|
||||||
Self {
|
|
||||||
growth,
|
|
||||||
debounce_secs: 30,
|
|
||||||
last_extraction: std::sync::Mutex::new(std::collections::HashMap::new()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Set the debounce interval in seconds.
|
|
||||||
pub fn with_debounce_secs(mut self, secs: u64) -> Self {
|
|
||||||
self.debounce_secs = secs;
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if enough time has passed since the last extraction for this agent.
|
|
||||||
fn should_extract(&self, agent_id: &str) -> bool {
|
|
||||||
let now = std::time::Instant::now();
|
|
||||||
let mut map = self.last_extraction.lock().unwrap();
|
|
||||||
if let Some(last) = map.get(agent_id) {
|
|
||||||
if now.duration_since(*last).as_secs() < self.debounce_secs {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
map.insert(agent_id.to_string(), now);
|
|
||||||
true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl AgentMiddleware for MemoryMiddleware {
|
|
||||||
fn name(&self) -> &str { "memory" }
|
|
||||||
fn priority(&self) -> i32 { 150 }
|
|
||||||
|
|
||||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
|
||||||
match self.growth.enhance_prompt(
|
|
||||||
&ctx.agent_id,
|
|
||||||
&ctx.system_prompt,
|
|
||||||
&ctx.user_input,
|
|
||||||
).await {
|
|
||||||
Ok(enhanced) => {
|
|
||||||
ctx.system_prompt = enhanced;
|
|
||||||
Ok(MiddlewareDecision::Continue)
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
// Non-fatal: memory retrieval failure should not block the loop
|
|
||||||
tracing::warn!("[MemoryMiddleware] Prompt enhancement failed: {}", e);
|
|
||||||
Ok(MiddlewareDecision::Continue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
|
||||||
// Debounce: skip extraction if called too recently for this agent
|
|
||||||
let agent_key = ctx.agent_id.to_string();
|
|
||||||
if !self.should_extract(&agent_key) {
|
|
||||||
tracing::debug!(
|
|
||||||
"[MemoryMiddleware] Skipping extraction for agent {} (debounced)",
|
|
||||||
agent_key
|
|
||||||
);
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
if ctx.messages.is_empty() {
|
|
||||||
return Ok(());
|
|
||||||
}
|
|
||||||
|
|
||||||
match self.growth.process_conversation(
|
|
||||||
&ctx.agent_id,
|
|
||||||
&ctx.messages,
|
|
||||||
ctx.session_id.clone(),
|
|
||||||
).await {
|
|
||||||
Ok(count) => {
|
|
||||||
tracing::info!(
|
|
||||||
"[MemoryMiddleware] Extracted {} memories for agent {}",
|
|
||||||
count,
|
|
||||||
agent_key
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
// Non-fatal: extraction failure should not affect the response
|
|
||||||
tracing::warn!("[MemoryMiddleware] Memory extraction failed: {}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
//! Skill index middleware — injects a lightweight skill index into the system prompt.
|
|
||||||
//!
|
|
||||||
//! Instead of embedding full skill descriptions (which can consume ~2000 tokens for 70+ skills),
|
|
||||||
//! this middleware injects only skill IDs and one-line triggers (~600 tokens). The LLM can then
|
|
||||||
//! call the `skill_load` tool on demand to retrieve full skill details when needed.
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use zclaw_types::Result;
|
|
||||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
|
||||||
use crate::tool::{SkillIndexEntry, SkillExecutor};
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
/// Middleware that injects a lightweight skill index into the system prompt.
|
|
||||||
///
|
|
||||||
/// The index format is compact:
|
|
||||||
/// ```text
|
|
||||||
/// ## Skills (index — use skill_load for details)
|
|
||||||
/// - finance-tracker: 财务分析、财报解读 [数据分析]
|
|
||||||
/// - senior-developer: 代码开发、架构设计 [开发工程]
|
|
||||||
/// ```
|
|
||||||
pub struct SkillIndexMiddleware {
|
|
||||||
/// Pre-built skill index entries, constructed at chain creation time.
|
|
||||||
entries: Vec<SkillIndexEntry>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SkillIndexMiddleware {
|
|
||||||
pub fn new(entries: Vec<SkillIndexEntry>) -> Self {
|
|
||||||
Self { entries }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Build index entries from a skill executor that supports listing.
|
|
||||||
pub fn from_executor(executor: &Arc<dyn SkillExecutor>) -> Self {
|
|
||||||
Self {
|
|
||||||
entries: executor.list_skill_index(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl AgentMiddleware for SkillIndexMiddleware {
|
|
||||||
fn name(&self) -> &str { "skill_index" }
|
|
||||||
fn priority(&self) -> i32 { 200 }
|
|
||||||
|
|
||||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
|
||||||
if self.entries.is_empty() {
|
|
||||||
return Ok(MiddlewareDecision::Continue);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut index = String::from("\n\n## Skills (index — call skill_load for details)\n\n");
|
|
||||||
for entry in &self.entries {
|
|
||||||
let triggers = if entry.triggers.is_empty() {
|
|
||||||
String::new()
|
|
||||||
} else {
|
|
||||||
format!(" — {}", entry.triggers.join(", "))
|
|
||||||
};
|
|
||||||
index.push_str(&format!("- **{}**: {}{}\n", entry.id, entry.description, triggers));
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.system_prompt.push_str(&index);
|
|
||||||
Ok(MiddlewareDecision::Continue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,52 +0,0 @@
|
|||||||
//! Token calibration middleware — calibrates token estimation after first LLM response.
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use zclaw_types::Result;
|
|
||||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
|
||||||
use crate::compaction;
|
|
||||||
|
|
||||||
/// Middleware that calibrates the global token estimation factor based on
|
|
||||||
/// actual API-returned token counts from the first LLM response.
|
|
||||||
pub struct TokenCalibrationMiddleware {
|
|
||||||
/// Whether calibration has already been applied in this session.
|
|
||||||
calibrated: std::sync::atomic::AtomicBool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TokenCalibrationMiddleware {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
calibrated: std::sync::atomic::AtomicBool::new(false),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for TokenCalibrationMiddleware {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl AgentMiddleware for TokenCalibrationMiddleware {
|
|
||||||
fn name(&self) -> &str { "token_calibration" }
|
|
||||||
fn priority(&self) -> i32 { 700 }
|
|
||||||
|
|
||||||
async fn before_completion(&self, _ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
|
||||||
// Calibration happens in after_completion when we have actual token counts.
|
|
||||||
// Before-completion is a no-op.
|
|
||||||
Ok(MiddlewareDecision::Continue)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn after_completion(&self, ctx: &MiddlewareContext) -> Result<()> {
|
|
||||||
if ctx.input_tokens > 0 && !self.calibrated.load(std::sync::atomic::Ordering::Relaxed) {
|
|
||||||
let estimated = compaction::estimate_messages_tokens(&ctx.messages);
|
|
||||||
compaction::update_calibration(estimated, ctx.input_tokens);
|
|
||||||
self.calibrated.store(true, std::sync::atomic::Ordering::Relaxed);
|
|
||||||
tracing::debug!(
|
|
||||||
"[TokenCalibrationMiddleware] Calibrated: estimated={}, actual={}",
|
|
||||||
estimated, ctx.input_tokens
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -37,39 +37,6 @@ pub trait SkillExecutor: Send + Sync {
|
|||||||
session_id: &str,
|
session_id: &str,
|
||||||
input: Value,
|
input: Value,
|
||||||
) -> Result<Value>;
|
) -> Result<Value>;
|
||||||
|
|
||||||
/// Return metadata for on-demand skill loading.
|
|
||||||
/// Default returns `None` (skill detail not available).
|
|
||||||
fn get_skill_detail(&self, skill_id: &str) -> Option<SkillDetail> {
|
|
||||||
let _ = skill_id;
|
|
||||||
None
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return lightweight index of all available skills.
|
|
||||||
/// Default returns empty (no index available).
|
|
||||||
fn list_skill_index(&self) -> Vec<SkillIndexEntry> {
|
|
||||||
Vec::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Lightweight skill index entry for system prompt injection.
|
|
||||||
#[derive(Debug, Clone, serde::Serialize)]
|
|
||||||
pub struct SkillIndexEntry {
|
|
||||||
pub id: String,
|
|
||||||
pub description: String,
|
|
||||||
pub triggers: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Full skill detail returned by `skill_load` tool.
|
|
||||||
#[derive(Debug, Clone, serde::Serialize)]
|
|
||||||
pub struct SkillDetail {
|
|
||||||
pub id: String,
|
|
||||||
pub name: String,
|
|
||||||
pub description: String,
|
|
||||||
pub category: Option<String>,
|
|
||||||
pub input_schema: Option<Value>,
|
|
||||||
pub triggers: Vec<String>,
|
|
||||||
pub capabilities: Vec<String>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Context provided to tool execution
|
/// Context provided to tool execution
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ mod file_write;
|
|||||||
mod shell_exec;
|
mod shell_exec;
|
||||||
mod web_fetch;
|
mod web_fetch;
|
||||||
mod execute_skill;
|
mod execute_skill;
|
||||||
mod skill_load;
|
|
||||||
mod path_validator;
|
mod path_validator;
|
||||||
|
|
||||||
pub use file_read::FileReadTool;
|
pub use file_read::FileReadTool;
|
||||||
@@ -13,7 +12,6 @@ pub use file_write::FileWriteTool;
|
|||||||
pub use shell_exec::ShellExecTool;
|
pub use shell_exec::ShellExecTool;
|
||||||
pub use web_fetch::WebFetchTool;
|
pub use web_fetch::WebFetchTool;
|
||||||
pub use execute_skill::ExecuteSkillTool;
|
pub use execute_skill::ExecuteSkillTool;
|
||||||
pub use skill_load::SkillLoadTool;
|
|
||||||
pub use path_validator::{PathValidator, PathValidatorConfig};
|
pub use path_validator::{PathValidator, PathValidatorConfig};
|
||||||
|
|
||||||
use crate::tool::ToolRegistry;
|
use crate::tool::ToolRegistry;
|
||||||
@@ -25,5 +23,4 @@ pub fn register_builtin_tools(registry: &mut ToolRegistry) {
|
|||||||
registry.register(Box::new(ShellExecTool::new()));
|
registry.register(Box::new(ShellExecTool::new()));
|
||||||
registry.register(Box::new(WebFetchTool::new()));
|
registry.register(Box::new(WebFetchTool::new()));
|
||||||
registry.register(Box::new(ExecuteSkillTool::new()));
|
registry.register(Box::new(ExecuteSkillTool::new()));
|
||||||
registry.register(Box::new(SkillLoadTool::new()));
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,11 +160,6 @@ impl PathValidator {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the workspace root directory
|
|
||||||
pub fn workspace_root(&self) -> Option<&PathBuf> {
|
|
||||||
self.workspace_root.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Validate a path for read access
|
/// Validate a path for read access
|
||||||
pub fn validate_read(&self, path: &str) -> Result<PathBuf> {
|
pub fn validate_read(&self, path: &str) -> Result<PathBuf> {
|
||||||
let canonical = self.resolve_and_validate(path)?;
|
let canonical = self.resolve_and_validate(path)?;
|
||||||
|
|||||||
@@ -1,81 +0,0 @@
|
|||||||
//! Skill load tool — on-demand retrieval of full skill details.
|
|
||||||
//!
|
|
||||||
//! When the `SkillIndexMiddleware` is active, the system prompt contains only a lightweight
|
|
||||||
//! skill index. This tool allows the LLM to load full skill details (description, input schema,
|
|
||||||
//! capabilities) on demand, exactly when the LLM decides a particular skill is relevant.
|
|
||||||
|
|
||||||
use async_trait::async_trait;
|
|
||||||
use serde_json::{json, Value};
|
|
||||||
use zclaw_types::{Result, ZclawError};
|
|
||||||
|
|
||||||
use crate::tool::{Tool, ToolContext};
|
|
||||||
|
|
||||||
pub struct SkillLoadTool;
|
|
||||||
|
|
||||||
impl SkillLoadTool {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl Tool for SkillLoadTool {
|
|
||||||
fn name(&self) -> &str {
|
|
||||||
"skill_load"
|
|
||||||
}
|
|
||||||
|
|
||||||
fn description(&self) -> &str {
|
|
||||||
"Load full details for a skill by its ID. Use this when you need to understand a skill's \
|
|
||||||
input parameters, capabilities, or usage instructions before calling execute_skill. \
|
|
||||||
Returns the skill description, input schema, and trigger conditions."
|
|
||||||
}
|
|
||||||
|
|
||||||
fn input_schema(&self) -> Value {
|
|
||||||
json!({
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"skill_id": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The ID of the skill to load details for"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["skill_id"]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
|
||||||
let skill_id = input["skill_id"].as_str()
|
|
||||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'skill_id' parameter".into()))?;
|
|
||||||
|
|
||||||
let executor = context.skill_executor.as_ref()
|
|
||||||
.ok_or_else(|| ZclawError::ToolError("Skill executor not available".into()))?;
|
|
||||||
|
|
||||||
match executor.get_skill_detail(skill_id) {
|
|
||||||
Some(detail) => {
|
|
||||||
let mut result = json!({
|
|
||||||
"id": detail.id,
|
|
||||||
"name": detail.name,
|
|
||||||
"description": detail.description,
|
|
||||||
"triggers": detail.triggers,
|
|
||||||
});
|
|
||||||
if let Some(schema) = &detail.input_schema {
|
|
||||||
result["input_schema"] = schema.clone();
|
|
||||||
}
|
|
||||||
if let Some(cat) = &detail.category {
|
|
||||||
result["category"] = json!(cat);
|
|
||||||
}
|
|
||||||
if !detail.capabilities.is_empty() {
|
|
||||||
result["capabilities"] = json!(detail.capabilities);
|
|
||||||
}
|
|
||||||
Ok(result)
|
|
||||||
}
|
|
||||||
None => Err(ZclawError::ToolError(format!("Skill not found: {}", skill_id))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for SkillLoadTool {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -9,12 +9,8 @@ name = "zclaw-saas"
|
|||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
zclaw-types = { workspace = true }
|
|
||||||
|
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
tokio-stream = { workspace = true }
|
|
||||||
futures = { workspace = true }
|
futures = { workspace = true }
|
||||||
async-trait = { workspace = true }
|
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
toml = { workspace = true }
|
toml = { workspace = true }
|
||||||
@@ -35,6 +31,8 @@ url = "2"
|
|||||||
|
|
||||||
axum = { workspace = true }
|
axum = { workspace = true }
|
||||||
axum-extra = { workspace = true }
|
axum-extra = { workspace = true }
|
||||||
|
bytes = { workspace = true }
|
||||||
|
async-stream = { workspace = true }
|
||||||
tower = { workspace = true }
|
tower = { workspace = true }
|
||||||
tower-http = { workspace = true }
|
tower-http = { workspace = true }
|
||||||
jsonwebtoken = { workspace = true }
|
jsonwebtoken = { workspace = true }
|
||||||
@@ -42,9 +40,9 @@ argon2 = { workspace = true }
|
|||||||
totp-rs = { workspace = true }
|
totp-rs = { workspace = true }
|
||||||
urlencoding = "2"
|
urlencoding = "2"
|
||||||
data-encoding = "2"
|
data-encoding = "2"
|
||||||
regex = "1"
|
aes-gcm = { workspace = true }
|
||||||
aes-gcm = "0.10"
|
utoipa = { version = "5", features = ["axum_extras"] }
|
||||||
bytes = "1"
|
utoipa-swagger-ui = { version = "5", features = ["axum"] }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
tempfile = { workspace = true }
|
tempfile = { workspace = true }
|
||||||
|
|||||||
@@ -1,339 +0,0 @@
|
|||||||
-- Migration: Initial schema with TIMESTAMPTZ
|
|
||||||
-- Extracted from inline SCHEMA_SQL in db.rs, with TEXT timestamps converted to TIMESTAMPTZ.
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS saas_schema_version (
|
|
||||||
version INTEGER PRIMARY KEY
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS accounts (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
username TEXT NOT NULL UNIQUE,
|
|
||||||
email TEXT NOT NULL UNIQUE,
|
|
||||||
password_hash TEXT NOT NULL,
|
|
||||||
display_name TEXT NOT NULL DEFAULT '',
|
|
||||||
avatar_url TEXT,
|
|
||||||
role TEXT NOT NULL DEFAULT 'user',
|
|
||||||
status TEXT NOT NULL DEFAULT 'active',
|
|
||||||
totp_secret TEXT,
|
|
||||||
totp_enabled BOOLEAN NOT NULL DEFAULT FALSE,
|
|
||||||
last_login_at TIMESTAMPTZ,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_accounts_email ON accounts(email);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_accounts_role ON accounts(role);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS api_tokens (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
account_id TEXT NOT NULL,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
token_hash TEXT NOT NULL,
|
|
||||||
token_prefix TEXT NOT NULL,
|
|
||||||
permissions TEXT NOT NULL DEFAULT '[]',
|
|
||||||
last_used_at TIMESTAMPTZ,
|
|
||||||
expires_at TIMESTAMPTZ,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
revoked_at TIMESTAMPTZ,
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_api_tokens_account ON api_tokens(account_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_api_tokens_hash ON api_tokens(token_hash);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS roles (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
description TEXT,
|
|
||||||
permissions TEXT NOT NULL DEFAULT '[]',
|
|
||||||
is_system BOOLEAN NOT NULL DEFAULT FALSE,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS permission_templates (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
description TEXT,
|
|
||||||
permissions TEXT NOT NULL DEFAULT '[]',
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS operation_logs (
|
|
||||||
id BIGSERIAL PRIMARY KEY,
|
|
||||||
account_id TEXT,
|
|
||||||
action TEXT NOT NULL,
|
|
||||||
target_type TEXT,
|
|
||||||
target_id TEXT,
|
|
||||||
details TEXT,
|
|
||||||
ip_address TEXT,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_op_logs_account ON operation_logs(account_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_op_logs_action ON operation_logs(action);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_op_logs_time ON operation_logs(created_at);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS providers (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
name TEXT NOT NULL UNIQUE,
|
|
||||||
display_name TEXT NOT NULL,
|
|
||||||
api_key TEXT,
|
|
||||||
base_url TEXT NOT NULL,
|
|
||||||
api_protocol TEXT NOT NULL DEFAULT 'openai',
|
|
||||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
|
||||||
rate_limit_rpm BIGINT,
|
|
||||||
rate_limit_tpm BIGINT,
|
|
||||||
config_json TEXT DEFAULT '{}',
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS models (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
provider_id TEXT NOT NULL,
|
|
||||||
model_id TEXT NOT NULL,
|
|
||||||
alias TEXT NOT NULL,
|
|
||||||
context_window BIGINT NOT NULL DEFAULT 8192,
|
|
||||||
max_output_tokens BIGINT NOT NULL DEFAULT 4096,
|
|
||||||
supports_streaming BOOLEAN NOT NULL DEFAULT TRUE,
|
|
||||||
supports_vision BOOLEAN NOT NULL DEFAULT FALSE,
|
|
||||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
|
||||||
pricing_input DOUBLE PRECISION DEFAULT 0,
|
|
||||||
pricing_output DOUBLE PRECISION DEFAULT 0,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
UNIQUE(provider_id, model_id),
|
|
||||||
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_models_provider ON models(provider_id);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS account_api_keys (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
account_id TEXT NOT NULL,
|
|
||||||
provider_id TEXT NOT NULL,
|
|
||||||
key_value TEXT NOT NULL,
|
|
||||||
key_label TEXT,
|
|
||||||
permissions TEXT NOT NULL DEFAULT '[]',
|
|
||||||
enabled BOOLEAN NOT NULL DEFAULT TRUE,
|
|
||||||
last_used_at TIMESTAMPTZ,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
revoked_at TIMESTAMPTZ,
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
|
||||||
FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_account_api_keys_account ON account_api_keys(account_id);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS usage_records (
|
|
||||||
id BIGSERIAL PRIMARY KEY,
|
|
||||||
account_id TEXT NOT NULL,
|
|
||||||
provider_id TEXT NOT NULL,
|
|
||||||
model_id TEXT NOT NULL,
|
|
||||||
input_tokens INTEGER NOT NULL DEFAULT 0,
|
|
||||||
output_tokens INTEGER NOT NULL DEFAULT 0,
|
|
||||||
latency_ms INTEGER,
|
|
||||||
status TEXT NOT NULL DEFAULT 'success',
|
|
||||||
error_message TEXT,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_usage_account ON usage_records(account_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_usage_time ON usage_records(created_at);
|
|
||||||
-- idx_usage_day: Skipping because ::date on TIMESTAMPTZ is not IMMUTABLE
|
|
||||||
-- CREATE INDEX IF NOT EXISTS idx_usage_day ON usage_records((created_at::date));
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS relay_tasks (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
account_id TEXT NOT NULL,
|
|
||||||
provider_id TEXT NOT NULL,
|
|
||||||
model_id TEXT NOT NULL,
|
|
||||||
request_hash TEXT NOT NULL,
|
|
||||||
status TEXT NOT NULL DEFAULT 'queued',
|
|
||||||
priority INTEGER NOT NULL DEFAULT 0,
|
|
||||||
attempt_count INTEGER NOT NULL DEFAULT 0,
|
|
||||||
max_attempts INTEGER NOT NULL DEFAULT 3,
|
|
||||||
request_body TEXT NOT NULL,
|
|
||||||
response_body TEXT,
|
|
||||||
input_tokens INTEGER DEFAULT 0,
|
|
||||||
output_tokens INTEGER DEFAULT 0,
|
|
||||||
error_message TEXT,
|
|
||||||
queued_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
started_at TIMESTAMPTZ,
|
|
||||||
completed_at TIMESTAMPTZ,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_relay_status ON relay_tasks(status);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_relay_account ON relay_tasks(account_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_relay_provider ON relay_tasks(provider_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_relay_time ON relay_tasks(created_at);
|
|
||||||
-- idx_relay_day: Skipping because ::date on TIMESTAMPTZ is not IMMUTABLE
|
|
||||||
-- CREATE INDEX IF NOT EXISTS idx_relay_day ON relay_tasks((created_at::date));
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS config_items (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
category TEXT NOT NULL,
|
|
||||||
key_path TEXT NOT NULL,
|
|
||||||
value_type TEXT NOT NULL,
|
|
||||||
current_value TEXT,
|
|
||||||
default_value TEXT,
|
|
||||||
source TEXT NOT NULL DEFAULT 'local',
|
|
||||||
description TEXT,
|
|
||||||
requires_restart BOOLEAN NOT NULL DEFAULT FALSE,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
UNIQUE(category, key_path)
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_config_category ON config_items(category);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS config_sync_log (
|
|
||||||
id BIGSERIAL PRIMARY KEY,
|
|
||||||
account_id TEXT NOT NULL,
|
|
||||||
client_fingerprint TEXT NOT NULL,
|
|
||||||
action TEXT NOT NULL,
|
|
||||||
config_keys TEXT NOT NULL,
|
|
||||||
client_values TEXT,
|
|
||||||
saas_values TEXT,
|
|
||||||
resolution TEXT,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_sync_account ON config_sync_log(account_id);
|
|
||||||
|
|
||||||
CREATE TABLE IF NOT EXISTS devices (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
account_id TEXT NOT NULL,
|
|
||||||
device_id TEXT NOT NULL,
|
|
||||||
device_name TEXT,
|
|
||||||
platform TEXT,
|
|
||||||
app_version TEXT,
|
|
||||||
last_seen_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_devices_account ON devices(account_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_devices_device_id ON devices(device_id);
|
|
||||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_devices_unique ON devices(account_id, device_id);
|
|
||||||
|
|
||||||
-- Prompt template master table
|
|
||||||
CREATE TABLE IF NOT EXISTS prompt_templates (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
name TEXT NOT NULL UNIQUE,
|
|
||||||
category TEXT NOT NULL,
|
|
||||||
description TEXT,
|
|
||||||
source TEXT NOT NULL DEFAULT 'builtin',
|
|
||||||
current_version INTEGER NOT NULL DEFAULT 1,
|
|
||||||
status TEXT NOT NULL DEFAULT 'active',
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_prompt_status ON prompt_templates(status);
|
|
||||||
|
|
||||||
-- Prompt versions table (immutable)
|
|
||||||
CREATE TABLE IF NOT EXISTS prompt_versions (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
template_id TEXT NOT NULL,
|
|
||||||
version INTEGER NOT NULL,
|
|
||||||
system_prompt TEXT,
|
|
||||||
user_prompt_template TEXT,
|
|
||||||
variables TEXT NOT NULL DEFAULT '[]',
|
|
||||||
changelog TEXT,
|
|
||||||
min_app_version TEXT,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
UNIQUE(template_id, version)
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_prompt_ver_template ON prompt_versions(template_id);
|
|
||||||
|
|
||||||
-- Client prompt sync status
|
|
||||||
CREATE TABLE IF NOT EXISTS prompt_sync_status (
|
|
||||||
device_id TEXT NOT NULL,
|
|
||||||
template_id TEXT NOT NULL,
|
|
||||||
synced_version INTEGER NOT NULL,
|
|
||||||
synced_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
PRIMARY KEY(device_id, template_id)
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Provider Key Pool table
|
|
||||||
CREATE TABLE IF NOT EXISTS provider_keys (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
provider_id TEXT NOT NULL,
|
|
||||||
key_label TEXT NOT NULL,
|
|
||||||
key_value TEXT NOT NULL,
|
|
||||||
priority INTEGER NOT NULL DEFAULT 0,
|
|
||||||
max_rpm BIGINT,
|
|
||||||
max_tpm BIGINT,
|
|
||||||
quota_reset_interval TEXT,
|
|
||||||
is_active BOOLEAN NOT NULL DEFAULT TRUE,
|
|
||||||
last_429_at TIMESTAMPTZ,
|
|
||||||
cooldown_until TIMESTAMPTZ,
|
|
||||||
total_requests BIGINT NOT NULL DEFAULT 0,
|
|
||||||
total_tokens BIGINT NOT NULL DEFAULT 0,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_pkeys_provider ON provider_keys(provider_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_pkeys_active ON provider_keys(provider_id, is_active);
|
|
||||||
|
|
||||||
-- Key usage sliding window
|
|
||||||
CREATE TABLE IF NOT EXISTS key_usage_window (
|
|
||||||
key_id TEXT NOT NULL,
|
|
||||||
window_minute TEXT NOT NULL,
|
|
||||||
request_count INTEGER NOT NULL DEFAULT 0,
|
|
||||||
token_count BIGINT NOT NULL DEFAULT 0,
|
|
||||||
PRIMARY KEY(key_id, window_minute)
|
|
||||||
);
|
|
||||||
|
|
||||||
-- Agent config template table
|
|
||||||
CREATE TABLE IF NOT EXISTS agent_templates (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
description TEXT,
|
|
||||||
category TEXT NOT NULL DEFAULT 'general',
|
|
||||||
source TEXT NOT NULL DEFAULT 'builtin',
|
|
||||||
model TEXT,
|
|
||||||
system_prompt TEXT,
|
|
||||||
tools TEXT NOT NULL DEFAULT '[]'::text,
|
|
||||||
capabilities TEXT NOT NULL DEFAULT '[]'::text,
|
|
||||||
temperature DOUBLE PRECISION,
|
|
||||||
max_tokens INTEGER,
|
|
||||||
visibility TEXT NOT NULL DEFAULT 'public',
|
|
||||||
status TEXT NOT NULL DEFAULT 'active',
|
|
||||||
current_version INTEGER NOT NULL DEFAULT 1,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_agent_tmpl_status ON agent_templates(status);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_agent_tmpl_visibility ON agent_templates(visibility);
|
|
||||||
|
|
||||||
-- Desktop telemetry report table (token usage statistics, no content)
|
|
||||||
CREATE TABLE IF NOT EXISTS telemetry_reports (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
account_id TEXT NOT NULL,
|
|
||||||
device_id TEXT NOT NULL,
|
|
||||||
app_version TEXT,
|
|
||||||
model_id TEXT NOT NULL,
|
|
||||||
input_tokens BIGINT NOT NULL DEFAULT 0,
|
|
||||||
output_tokens BIGINT NOT NULL DEFAULT 0,
|
|
||||||
latency_ms INTEGER,
|
|
||||||
success BOOLEAN NOT NULL DEFAULT TRUE,
|
|
||||||
error_type TEXT,
|
|
||||||
connection_mode TEXT NOT NULL DEFAULT 'tauri',
|
|
||||||
reported_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_telemetry_account ON telemetry_reports(account_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_telemetry_time ON telemetry_reports(reported_at);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_telemetry_model ON telemetry_reports(model_id);
|
|
||||||
-- idx_telemetry_day: Skipping because ::date on TIMESTAMPTZ is not IMMUTABLE
|
|
||||||
-- CREATE INDEX IF NOT EXISTS idx_telemetry_day ON telemetry_reports((reported_at::date));
|
|
||||||
|
|
||||||
-- Refresh Token storage (single-use, JWT jti tracking)
|
|
||||||
CREATE TABLE IF NOT EXISTS refresh_tokens (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
account_id TEXT NOT NULL,
|
|
||||||
jti TEXT NOT NULL UNIQUE,
|
|
||||||
token_hash TEXT NOT NULL,
|
|
||||||
expires_at TIMESTAMPTZ NOT NULL,
|
|
||||||
used_at TIMESTAMPTZ,
|
|
||||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
||||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_refresh_account ON refresh_tokens(account_id);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_refresh_jti ON refresh_tokens(jti);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_refresh_expires ON refresh_tokens(expires_at);
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
-- Migration: Seed roles (super_admin, admin, user)
|
|
||||||
-- Timestamps use NOW() to match TIMESTAMPTZ columns from initial schema.
|
|
||||||
|
|
||||||
INSERT INTO roles (id, name, description, permissions, is_system, created_at, updated_at)
|
|
||||||
VALUES
|
|
||||||
('super_admin', '超级管理员', '拥有所有权限', '["admin:full","account:admin","provider:manage","model:manage","relay:admin","config:write","prompt:read","prompt:write","prompt:publish","prompt:admin"]', TRUE, NOW(), NOW()),
|
|
||||||
('admin', '管理员', '管理账号和配置', '["account:read","account:admin","provider:manage","model:read","model:manage","relay:use","relay:admin","config:read","config:write","prompt:read","prompt:write","prompt:publish"]', TRUE, NOW(), NOW()),
|
|
||||||
('user', '普通用户', '基础使用权限', '["model:read","relay:use","config:read","prompt:read"]', TRUE, NOW(), NOW())
|
|
||||||
ON CONFLICT (id) DO NOTHING;
|
|
||||||
@@ -8,7 +8,6 @@ use crate::state::AppState;
|
|||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
use crate::auth::types::AuthContext;
|
use crate::auth::types::AuthContext;
|
||||||
use crate::auth::handlers::{log_operation, check_permission};
|
use crate::auth::handlers::{log_operation, check_permission};
|
||||||
use crate::models::{OperationLogRow, DashboardStatsRow, DashboardTodayRow};
|
|
||||||
use super::{types::*, service};
|
use super::{types::*, service};
|
||||||
|
|
||||||
fn require_admin(ctx: &AuthContext) -> SaasResult<()> {
|
fn require_admin(ctx: &AuthContext) -> SaasResult<()> {
|
||||||
@@ -38,7 +37,7 @@ pub async fn get_account(
|
|||||||
service::get_account(&state.db, &id).await.map(Json)
|
service::get_account(&state.db, &id).await.map(Json)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// PATCH /api/v1/accounts/:id (admin or self for limited fields)
|
/// PUT /api/v1/accounts/:id (admin or self for limited fields)
|
||||||
pub async fn update_account(
|
pub async fn update_account(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Path(id): Path<String>,
|
Path(id): Path<String>,
|
||||||
@@ -81,15 +80,12 @@ pub async fn update_status(
|
|||||||
Ok(Json(serde_json::json!({"ok": true})))
|
Ok(Json(serde_json::json!({"ok": true})))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/tokens?page=1&page_size=20
|
/// GET /api/v1/tokens
|
||||||
pub async fn list_tokens(
|
pub async fn list_tokens(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
) -> SaasResult<Json<Vec<TokenInfo>>> {
|
||||||
) -> SaasResult<Json<PaginatedResponse<TokenInfo>>> {
|
service::list_api_tokens(&state.db, &ctx.account_id).await.map(Json)
|
||||||
let page = params.get("page").and_then(|v| v.parse().ok());
|
|
||||||
let page_size = params.get("page_size").and_then(|v| v.parse().ok());
|
|
||||||
service::list_api_tokens(&state.db, &ctx.account_id, page, page_size).await.map(Json)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// POST /api/v1/tokens
|
/// POST /api/v1/tokens
|
||||||
@@ -98,24 +94,9 @@ pub async fn create_token(
|
|||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Json(req): Json<CreateTokenRequest>,
|
Json(req): Json<CreateTokenRequest>,
|
||||||
) -> SaasResult<Json<TokenInfo>> {
|
) -> SaasResult<Json<TokenInfo>> {
|
||||||
// 权限校验: 创建的 token 不能超出创建者已有的权限
|
let token = service::create_api_token(&state.db, &ctx.account_id, &req).await?;
|
||||||
let allowed_permissions: Vec<String> = req.permissions
|
|
||||||
.into_iter()
|
|
||||||
.filter(|p| ctx.permissions.contains(p))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
if allowed_permissions.is_empty() {
|
|
||||||
return Err(SaasError::InvalidInput("请求的权限均不被允许".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let filtered_req = CreateTokenRequest {
|
|
||||||
name: req.name,
|
|
||||||
permissions: allowed_permissions,
|
|
||||||
expires_days: req.expires_days,
|
|
||||||
};
|
|
||||||
let token = service::create_api_token(&state.db, &ctx.account_id, &filtered_req).await?;
|
|
||||||
log_operation(&state.db, &ctx.account_id, "token.create", "api_token", &token.id,
|
log_operation(&state.db, &ctx.account_id, "token.create", "api_token", &token.id,
|
||||||
Some(serde_json::json!({"name": &filtered_req.name})), ctx.client_ip.as_deref()).await?;
|
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
|
||||||
Ok(Json(token))
|
Ok(Json(token))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -135,35 +116,52 @@ pub async fn list_operation_logs(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
Query(params): Query<std::collections::HashMap<String, String>>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<PaginatedResponse<serde_json::Value>>> {
|
) -> SaasResult<Json<Vec<serde_json::Value>>> {
|
||||||
require_admin(&ctx)?;
|
require_admin(&ctx)?;
|
||||||
let page: u32 = params.get("page").and_then(|v| v.parse().ok()).unwrap_or(1).max(1);
|
let page: i64 = params.get("page").and_then(|v| v.parse().ok()).unwrap_or(1);
|
||||||
let page_size: u32 = params.get("page_size").and_then(|v| v.parse().ok()).unwrap_or(50).min(100);
|
let page_size: i64 = params.get("page_size").and_then(|v| v.parse().ok()).unwrap_or(50);
|
||||||
let offset = ((page - 1) * page_size) as i64;
|
let offset = (page - 1) * page_size;
|
||||||
|
let action_filter = params.get("action").map(|s| s.as_str());
|
||||||
|
let target_type_filter = params.get("target_type").map(|s| s.as_str());
|
||||||
|
|
||||||
let total: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM operation_logs")
|
let mut sql = String::from(
|
||||||
.fetch_one(&state.db).await?;
|
"SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at
|
||||||
|
FROM operation_logs"
|
||||||
|
);
|
||||||
|
let mut param_idx: usize = 1;
|
||||||
|
if action_filter.is_some() || target_type_filter.is_some() {
|
||||||
|
sql.push_str(" WHERE 1=1");
|
||||||
|
if action_filter.is_some() {
|
||||||
|
sql.push_str(&format!(" AND action = ${}", param_idx));
|
||||||
|
param_idx += 1;
|
||||||
|
}
|
||||||
|
if target_type_filter.is_some() {
|
||||||
|
sql.push_str(&format!(" AND target_type = ${}", param_idx));
|
||||||
|
param_idx += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sql.push_str(&format!(" ORDER BY created_at DESC LIMIT ${} OFFSET ${}", param_idx, param_idx + 1));
|
||||||
|
|
||||||
let rows: Vec<OperationLogRow> =
|
let mut query = sqlx::query_as::<_, (i64, Option<String>, String, Option<String>, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>)>(&sql);
|
||||||
sqlx::query_as(
|
if let Some(action) = action_filter {
|
||||||
"SELECT id, account_id, action, target_type, target_id, details, ip_address, created_at
|
query = query.bind(action);
|
||||||
FROM operation_logs ORDER BY created_at DESC LIMIT $1 OFFSET $2"
|
}
|
||||||
)
|
if let Some(target_type) = target_type_filter {
|
||||||
.bind(page_size as i64)
|
query = query.bind(target_type);
|
||||||
.bind(offset)
|
}
|
||||||
.fetch_all(&state.db)
|
query = query.bind(page_size).bind(offset);
|
||||||
.await?;
|
let rows = query.fetch_all(&state.db).await?;
|
||||||
|
|
||||||
let items: Vec<serde_json::Value> = rows.into_iter().map(|r| {
|
let items: Vec<serde_json::Value> = rows.into_iter().map(|(id, account_id, action, target_type, target_id, details, ip_address, created_at)| {
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
"id": r.id, "account_id": r.account_id, "action": r.action,
|
"id": id, "account_id": account_id, "action": action,
|
||||||
"target_type": r.target_type, "target_id": r.target_id,
|
"target_type": target_type, "target_id": target_id,
|
||||||
"details": r.details.and_then(|d| serde_json::from_str::<serde_json::Value>(&d).ok()),
|
"details": details.and_then(|d| serde_json::from_str::<serde_json::Value>(&d).ok()),
|
||||||
"ip_address": r.ip_address, "created_at": r.created_at,
|
"ip_address": ip_address, "created_at": created_at.to_rfc3339(),
|
||||||
})
|
})
|
||||||
}).collect();
|
}).collect();
|
||||||
|
|
||||||
Ok(Json(PaginatedResponse { items, total, page, page_size }))
|
Ok(Json(items))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/stats/dashboard — 仪表盘聚合统计 (需要 admin 权限)
|
/// GET /api/v1/stats/dashboard — 仪表盘聚合统计 (需要 admin 权限)
|
||||||
@@ -173,41 +171,27 @@ pub async fn dashboard_stats(
|
|||||||
) -> SaasResult<Json<serde_json::Value>> {
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
require_admin(&ctx)?;
|
require_admin(&ctx)?;
|
||||||
|
|
||||||
// 查询 1: 账号 + Provider + Model 聚合 (一次查询)
|
let row: (i64, i64, i64, i64, i64, i64, i64) = sqlx::query_as(
|
||||||
let stats_row: DashboardStatsRow = sqlx::query_as(
|
|
||||||
"SELECT
|
"SELECT
|
||||||
(SELECT COUNT(*) FROM accounts) as total_accounts,
|
(SELECT COUNT(*) FROM accounts),
|
||||||
(SELECT COUNT(*) FROM accounts WHERE status = 'active') as active_accounts,
|
(SELECT COUNT(*) FROM accounts WHERE status = 'active'),
|
||||||
(SELECT COUNT(*) FROM providers WHERE enabled = true) as active_providers,
|
(SELECT COUNT(*) FROM relay_tasks WHERE DATE(created_at) = CURRENT_DATE),
|
||||||
(SELECT COUNT(*) FROM models WHERE enabled = true) as active_models"
|
(SELECT COUNT(*) FROM providers WHERE enabled = true),
|
||||||
).fetch_one(&state.db).await?;
|
(SELECT COUNT(*) FROM models WHERE enabled = true),
|
||||||
|
(SELECT COALESCE(SUM(input_tokens), 0) FROM usage_records WHERE DATE(created_at) = CURRENT_DATE),
|
||||||
// 查询 2: 今日中转统计 — 使用范围查询走 B-tree 索引
|
(SELECT COALESCE(SUM(output_tokens), 0) FROM usage_records WHERE DATE(created_at) = CURRENT_DATE)"
|
||||||
let today_start = chrono::Utc::now()
|
)
|
||||||
.date_naive()
|
.fetch_one(&state.db)
|
||||||
.and_hms_opt(0, 0, 0).unwrap()
|
.await?;
|
||||||
.and_utc()
|
|
||||||
.to_rfc3339();
|
|
||||||
let tomorrow_start = (chrono::Utc::now() + chrono::Duration::days(1))
|
|
||||||
.date_naive()
|
|
||||||
.and_hms_opt(0, 0, 0).unwrap()
|
|
||||||
.and_utc()
|
|
||||||
.to_rfc3339();
|
|
||||||
let today_row: DashboardTodayRow = sqlx::query_as(
|
|
||||||
"SELECT
|
|
||||||
(SELECT COUNT(*) FROM relay_tasks WHERE created_at >= $1 AND created_at < $2) as tasks_today,
|
|
||||||
COALESCE((SELECT SUM(input_tokens) FROM usage_records WHERE created_at >= $1 AND created_at < $2), 0) as tokens_input,
|
|
||||||
COALESCE((SELECT SUM(output_tokens) FROM usage_records WHERE created_at >= $1 AND created_at < $2), 0) as tokens_output"
|
|
||||||
).bind(&today_start).bind(&tomorrow_start).fetch_one(&state.db).await?;
|
|
||||||
|
|
||||||
Ok(Json(serde_json::json!({
|
Ok(Json(serde_json::json!({
|
||||||
"total_accounts": stats_row.total_accounts,
|
"total_accounts": row.0,
|
||||||
"active_accounts": stats_row.active_accounts,
|
"active_accounts": row.1,
|
||||||
"tasks_today": today_row.tasks_today,
|
"tasks_today": row.2,
|
||||||
"active_providers": stats_row.active_providers,
|
"active_providers": row.3,
|
||||||
"active_models": stats_row.active_models,
|
"active_models": row.4,
|
||||||
"tokens_today_input": today_row.tokens_input,
|
"tokens_today_input": row.5,
|
||||||
"tokens_today_output": today_row.tokens_output,
|
"tokens_today_output": row.6,
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,16 +201,9 @@ pub async fn dashboard_stats(
|
|||||||
pub async fn register_device(
|
pub async fn register_device(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Json(req): Json<serde_json::Value>,
|
Json(req): Json<super::types::RegisterDeviceRequest>,
|
||||||
) -> SaasResult<Json<serde_json::Value>> {
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
let device_id = req.get("device_id")
|
let now = chrono::Utc::now();
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| SaasError::InvalidInput("缺少 device_id".into()))?;
|
|
||||||
let device_name = req.get("device_name").and_then(|v| v.as_str()).unwrap_or("Unknown");
|
|
||||||
let platform = req.get("platform").and_then(|v| v.as_str()).unwrap_or("unknown");
|
|
||||||
let app_version = req.get("app_version").and_then(|v| v.as_str()).unwrap_or("");
|
|
||||||
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
|
||||||
let device_uuid = uuid::Uuid::new_v4().to_string();
|
let device_uuid = uuid::Uuid::new_v4().to_string();
|
||||||
|
|
||||||
// UPSERT: 已存在则更新 last_seen_at,不存在则插入
|
// UPSERT: 已存在则更新 last_seen_at,不存在则插入
|
||||||
@@ -234,62 +211,40 @@ pub async fn register_device(
|
|||||||
"INSERT INTO devices (id, account_id, device_id, device_name, platform, app_version, last_seen_at, created_at)
|
"INSERT INTO devices (id, account_id, device_id, device_name, platform, app_version, last_seen_at, created_at)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $7)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $7)
|
||||||
ON CONFLICT(account_id, device_id) DO UPDATE SET
|
ON CONFLICT(account_id, device_id) DO UPDATE SET
|
||||||
device_name = $4, platform = $5, app_version = $6, last_seen_at = $7"
|
device_name = EXCLUDED.device_name, platform = EXCLUDED.platform, app_version = EXCLUDED.app_version, last_seen_at = EXCLUDED.last_seen_at"
|
||||||
)
|
)
|
||||||
.bind(&device_uuid)
|
.bind(&device_uuid)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.bind(device_id)
|
.bind(&req.device_id)
|
||||||
.bind(device_name)
|
.bind(&req.device_name)
|
||||||
.bind(platform)
|
.bind(&req.platform)
|
||||||
.bind(app_version)
|
.bind(&req.app_version)
|
||||||
.bind(&now)
|
.bind(&now)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
log_operation(&state.db, &ctx.account_id, "device.register", "device", device_id,
|
log_operation(&state.db, &ctx.account_id, "device.register", "device", &req.device_id,
|
||||||
Some(serde_json::json!({"device_name": device_name, "platform": platform})),
|
Some(serde_json::json!({"device_name": req.device_name, "platform": req.platform})),
|
||||||
ctx.client_ip.as_deref()).await?;
|
ctx.client_ip.as_deref()).await?;
|
||||||
|
|
||||||
Ok(Json(serde_json::json!({"ok": true, "device_id": device_id})))
|
Ok(Json(serde_json::json!({"ok": true, "device_id": req.device_id})))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// POST /api/v1/devices/heartbeat — 设备心跳
|
/// POST /api/v1/devices/heartbeat — 设备心跳
|
||||||
pub async fn device_heartbeat(
|
pub async fn device_heartbeat(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Json(req): Json<serde_json::Value>,
|
Json(req): Json<super::types::DeviceHeartbeatRequest>,
|
||||||
) -> SaasResult<Json<serde_json::Value>> {
|
) -> SaasResult<Json<serde_json::Value>> {
|
||||||
let device_id = req.get("device_id")
|
let now = chrono::Utc::now();
|
||||||
.and_then(|v| v.as_str())
|
let result = sqlx::query(
|
||||||
.ok_or_else(|| SaasError::InvalidInput("缺少 device_id".into()))?;
|
"UPDATE devices SET last_seen_at = $1 WHERE account_id = $2 AND device_id = $3"
|
||||||
|
)
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
.bind(&now)
|
||||||
|
.bind(&ctx.account_id)
|
||||||
// Also update platform/app_version if provided (supports client upgrades)
|
.bind(&req.device_id)
|
||||||
let platform = req.get("platform").and_then(|v| v.as_str());
|
.execute(&state.db)
|
||||||
let app_version = req.get("app_version").and_then(|v| v.as_str());
|
.await?;
|
||||||
|
|
||||||
let result = if platform.is_some() || app_version.is_some() {
|
|
||||||
sqlx::query(
|
|
||||||
"UPDATE devices SET last_seen_at = $1, platform = COALESCE($4, platform), app_version = COALESCE($5, app_version) WHERE account_id = $2 AND device_id = $3"
|
|
||||||
)
|
|
||||||
.bind(&now)
|
|
||||||
.bind(&ctx.account_id)
|
|
||||||
.bind(device_id)
|
|
||||||
.bind(platform)
|
|
||||||
.bind(app_version)
|
|
||||||
.execute(&state.db)
|
|
||||||
.await?
|
|
||||||
} else {
|
|
||||||
sqlx::query(
|
|
||||||
"UPDATE devices SET last_seen_at = $1 WHERE account_id = $2 AND device_id = $3"
|
|
||||||
)
|
|
||||||
.bind(&now)
|
|
||||||
.bind(&ctx.account_id)
|
|
||||||
.bind(device_id)
|
|
||||||
.execute(&state.db)
|
|
||||||
.await?
|
|
||||||
};
|
|
||||||
|
|
||||||
if result.rows_affected() == 0 {
|
if result.rows_affected() == 0 {
|
||||||
return Err(SaasError::NotFound("设备未注册".into()));
|
return Err(SaasError::NotFound("设备未注册".into()));
|
||||||
@@ -298,13 +253,27 @@ pub async fn device_heartbeat(
|
|||||||
Ok(Json(serde_json::json!({"ok": true})))
|
Ok(Json(serde_json::json!({"ok": true})))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/devices?page=1&page_size=20 — 列出当前用户的设备
|
/// GET /api/v1/devices — 列出当前用户的设备
|
||||||
pub async fn list_devices(
|
pub async fn list_devices(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Extension(ctx): Extension<AuthContext>,
|
Extension(ctx): Extension<AuthContext>,
|
||||||
Query(params): Query<std::collections::HashMap<String, String>>,
|
) -> SaasResult<Json<Vec<super::types::DeviceInfo>>> {
|
||||||
) -> SaasResult<Json<PaginatedResponse<serde_json::Value>>> {
|
let rows: Vec<(String, String, Option<String>, Option<String>, Option<String>, chrono::DateTime<chrono::Utc>, chrono::DateTime<chrono::Utc>)> =
|
||||||
let page = params.get("page").and_then(|v| v.parse().ok());
|
sqlx::query_as(
|
||||||
let page_size = params.get("page_size").and_then(|v| v.parse().ok());
|
"SELECT id, device_id, device_name, platform, app_version, last_seen_at, created_at
|
||||||
service::list_devices(&state.db, &ctx.account_id, page, page_size).await.map(Json)
|
FROM devices WHERE account_id = $1 ORDER BY last_seen_at DESC"
|
||||||
|
)
|
||||||
|
.bind(&ctx.account_id)
|
||||||
|
.fetch_all(&state.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let items: Vec<super::types::DeviceInfo> = rows.into_iter().map(|r| {
|
||||||
|
super::types::DeviceInfo {
|
||||||
|
id: r.0, device_id: r.1,
|
||||||
|
device_name: r.2, platform: r.3, app_version: r.4,
|
||||||
|
last_seen_at: r.5.to_rfc3339(), created_at: r.6.to_rfc3339(),
|
||||||
|
}
|
||||||
|
}).collect();
|
||||||
|
|
||||||
|
Ok(Json(items))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ pub mod types;
|
|||||||
pub mod service;
|
pub mod service;
|
||||||
pub mod handlers;
|
pub mod handlers;
|
||||||
|
|
||||||
use axum::routing::{delete, get, patch, post};
|
use axum::routing::{delete, get, patch, post, put};
|
||||||
|
|
||||||
pub fn routes() -> axum::Router<crate::state::AppState> {
|
pub fn routes() -> axum::Router<crate::state::AppState> {
|
||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
.route("/api/v1/accounts", get(handlers::list_accounts))
|
.route("/api/v1/accounts", get(handlers::list_accounts))
|
||||||
.route("/api/v1/accounts/:id", get(handlers::get_account))
|
.route("/api/v1/accounts/{id}", get(handlers::get_account))
|
||||||
.route("/api/v1/accounts/:id", patch(handlers::update_account))
|
.route("/api/v1/accounts/{id}", put(handlers::update_account))
|
||||||
.route("/api/v1/accounts/:id/status", patch(handlers::update_status))
|
.route("/api/v1/accounts/{id}/status", patch(handlers::update_status))
|
||||||
.route("/api/v1/tokens", get(handlers::list_tokens))
|
.route("/api/v1/tokens", get(handlers::list_tokens))
|
||||||
.route("/api/v1/tokens", post(handlers::create_token))
|
.route("/api/v1/tokens", post(handlers::create_token))
|
||||||
.route("/api/v1/tokens/:id", delete(handlers::revoke_token))
|
.route("/api/v1/tokens/{id}", delete(handlers::revoke_token))
|
||||||
.route("/api/v1/logs/operations", get(handlers::list_operation_logs))
|
.route("/api/v1/logs/operations", get(handlers::list_operation_logs))
|
||||||
.route("/api/v1/stats/dashboard", get(handlers::dashboard_stats))
|
.route("/api/v1/stats/dashboard", get(handlers::dashboard_stats))
|
||||||
.route("/api/v1/devices", get(handlers::list_devices))
|
.route("/api/v1/devices", get(handlers::list_devices))
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
use sqlx::PgPool;
|
use sqlx::PgPool;
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
use crate::common::{PaginatedResponse, normalize_pagination};
|
|
||||||
use crate::models::{AccountRow, ApiTokenRow, DeviceRow};
|
|
||||||
use super::types::*;
|
use super::types::*;
|
||||||
|
|
||||||
pub async fn list_accounts(
|
pub async fn list_accounts(
|
||||||
@@ -14,116 +12,60 @@ pub async fn list_accounts(
|
|||||||
let page_size = query.page_size.unwrap_or(20).min(100);
|
let page_size = query.page_size.unwrap_or(20).min(100);
|
||||||
let offset = (page - 1) * page_size;
|
let offset = (page - 1) * page_size;
|
||||||
|
|
||||||
// Static SQL per combination -- no format!() string interpolation
|
let mut where_clauses = Vec::new();
|
||||||
let (total, rows) = match (&query.role, &query.status, &query.search) {
|
let mut params: Vec<String> = Vec::new();
|
||||||
// role + status + search
|
let mut param_idx: usize = 1;
|
||||||
(Some(role), Some(status), Some(search)) => {
|
|
||||||
let pattern = format!("%{}%", search);
|
if let Some(role) = &query.role {
|
||||||
let total: i64 = sqlx::query_scalar(
|
where_clauses.push(format!("role = ${}", param_idx));
|
||||||
"SELECT COUNT(*) FROM accounts WHERE role = $1 AND status = $2 AND (username LIKE $3 OR email LIKE $3 OR display_name LIKE $3)"
|
params.push(role.clone());
|
||||||
).bind(role).bind(status).bind(&pattern).fetch_one(db).await?;
|
param_idx += 1;
|
||||||
let rows = sqlx::query_as::<_, AccountRow>(
|
}
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
if let Some(status) = &query.status {
|
||||||
FROM accounts WHERE role = $1 AND status = $2 AND (username LIKE $3 OR email LIKE $3 OR display_name LIKE $3)
|
where_clauses.push(format!("status = ${}", param_idx));
|
||||||
ORDER BY created_at DESC LIMIT $4 OFFSET $5"
|
params.push(status.clone());
|
||||||
).bind(role).bind(status).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
param_idx += 1;
|
||||||
(total, rows)
|
}
|
||||||
}
|
if let Some(search) = &query.search {
|
||||||
// role + status
|
where_clauses.push(format!("(username LIKE ${} OR email LIKE ${} OR display_name LIKE ${})", param_idx, param_idx + 1, param_idx + 2));
|
||||||
(Some(role), Some(status), None) => {
|
let pattern = format!("%{}%", search);
|
||||||
let total: i64 = sqlx::query_scalar(
|
params.push(pattern.clone());
|
||||||
"SELECT COUNT(*) FROM accounts WHERE role = $1 AND status = $2"
|
params.push(pattern.clone());
|
||||||
).bind(role).bind(status).fetch_one(db).await?;
|
params.push(pattern);
|
||||||
let rows = sqlx::query_as::<_, AccountRow>(
|
param_idx += 3;
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
}
|
||||||
FROM accounts WHERE role = $1 AND status = $2
|
|
||||||
ORDER BY created_at DESC LIMIT $3 OFFSET $4"
|
let where_sql = if where_clauses.is_empty() {
|
||||||
).bind(role).bind(status).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
String::new()
|
||||||
(total, rows)
|
} else {
|
||||||
}
|
format!("WHERE {}", where_clauses.join(" AND "))
|
||||||
// role + search
|
|
||||||
(Some(role), None, Some(search)) => {
|
|
||||||
let pattern = format!("%{}%", search);
|
|
||||||
let total: i64 = sqlx::query_scalar(
|
|
||||||
"SELECT COUNT(*) FROM accounts WHERE role = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)"
|
|
||||||
).bind(role).bind(&pattern).fetch_one(db).await?;
|
|
||||||
let rows = sqlx::query_as::<_, AccountRow>(
|
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
|
||||||
FROM accounts WHERE role = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)
|
|
||||||
ORDER BY created_at DESC LIMIT $3 OFFSET $4"
|
|
||||||
).bind(role).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
|
||||||
(total, rows)
|
|
||||||
}
|
|
||||||
// status + search
|
|
||||||
(None, Some(status), Some(search)) => {
|
|
||||||
let pattern = format!("%{}%", search);
|
|
||||||
let total: i64 = sqlx::query_scalar(
|
|
||||||
"SELECT COUNT(*) FROM accounts WHERE status = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)"
|
|
||||||
).bind(status).bind(&pattern).fetch_one(db).await?;
|
|
||||||
let rows = sqlx::query_as::<_, AccountRow>(
|
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
|
||||||
FROM accounts WHERE status = $1 AND (username LIKE $2 OR email LIKE $2 OR display_name LIKE $2)
|
|
||||||
ORDER BY created_at DESC LIMIT $3 OFFSET $4"
|
|
||||||
).bind(status).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
|
||||||
(total, rows)
|
|
||||||
}
|
|
||||||
// role only
|
|
||||||
(Some(role), None, None) => {
|
|
||||||
let total: i64 = sqlx::query_scalar(
|
|
||||||
"SELECT COUNT(*) FROM accounts WHERE role = $1"
|
|
||||||
).bind(role).fetch_one(db).await?;
|
|
||||||
let rows = sqlx::query_as::<_, AccountRow>(
|
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
|
||||||
FROM accounts WHERE role = $1
|
|
||||||
ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
|
||||||
).bind(role).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
|
||||||
(total, rows)
|
|
||||||
}
|
|
||||||
// status only
|
|
||||||
(None, Some(status), None) => {
|
|
||||||
let total: i64 = sqlx::query_scalar(
|
|
||||||
"SELECT COUNT(*) FROM accounts WHERE status = $1"
|
|
||||||
).bind(status).fetch_one(db).await?;
|
|
||||||
let rows = sqlx::query_as::<_, AccountRow>(
|
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
|
||||||
FROM accounts WHERE status = $1
|
|
||||||
ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
|
||||||
).bind(status).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
|
||||||
(total, rows)
|
|
||||||
}
|
|
||||||
// search only
|
|
||||||
(None, None, Some(search)) => {
|
|
||||||
let pattern = format!("%{}%", search);
|
|
||||||
let total: i64 = sqlx::query_scalar(
|
|
||||||
"SELECT COUNT(*) FROM accounts WHERE (username LIKE $1 OR email LIKE $1 OR display_name LIKE $1)"
|
|
||||||
).bind(&pattern).fetch_one(db).await?;
|
|
||||||
let rows = sqlx::query_as::<_, AccountRow>(
|
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
|
||||||
FROM accounts WHERE (username LIKE $1 OR email LIKE $1 OR display_name LIKE $1)
|
|
||||||
ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
|
||||||
).bind(&pattern).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
|
||||||
(total, rows)
|
|
||||||
}
|
|
||||||
// no filter
|
|
||||||
(None, None, None) => {
|
|
||||||
let total: i64 = sqlx::query_scalar(
|
|
||||||
"SELECT COUNT(*) FROM accounts"
|
|
||||||
).fetch_one(db).await?;
|
|
||||||
let rows = sqlx::query_as::<_, AccountRow>(
|
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
|
||||||
FROM accounts ORDER BY created_at DESC LIMIT $1 OFFSET $2"
|
|
||||||
).bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
|
||||||
(total, rows)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let count_sql = format!("SELECT COUNT(*) as count FROM accounts {}", where_sql);
|
||||||
|
let mut count_query = sqlx::query_scalar::<_, i64>(&count_sql);
|
||||||
|
for p in ¶ms {
|
||||||
|
count_query = count_query.bind(p);
|
||||||
|
}
|
||||||
|
let total: i64 = count_query.fetch_one(db).await?;
|
||||||
|
|
||||||
|
let data_sql = format!(
|
||||||
|
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||||
|
FROM accounts {} ORDER BY created_at DESC LIMIT ${} OFFSET ${}",
|
||||||
|
where_sql, param_idx, param_idx + 1
|
||||||
|
);
|
||||||
|
let mut data_query = sqlx::query_as::<_, (String, String, String, String, String, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)>(&data_sql);
|
||||||
|
for p in ¶ms {
|
||||||
|
data_query = data_query.bind(p);
|
||||||
|
}
|
||||||
|
let rows = data_query.bind(page_size as i64).bind(offset as i64).fetch_all(db).await?;
|
||||||
|
|
||||||
let items: Vec<serde_json::Value> = rows
|
let items: Vec<serde_json::Value> = rows
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|r| {
|
.map(|(id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at)| {
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
"id": r.id, "username": r.username, "email": r.email, "display_name": r.display_name,
|
"id": id, "username": username, "email": email, "display_name": display_name,
|
||||||
"role": r.role, "status": r.status, "totp_enabled": r.totp_enabled,
|
"role": role, "status": status, "totp_enabled": totp_enabled,
|
||||||
"last_login_at": r.last_login_at, "created_at": r.created_at,
|
"last_login_at": last_login_at.map(|t| t.to_rfc3339()), "created_at": created_at.to_rfc3339(),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
@@ -132,7 +74,7 @@ pub async fn list_accounts(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json::Value> {
|
pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json::Value> {
|
||||||
let row: Option<AccountRow> =
|
let row: Option<(String, String, String, String, String, String, bool, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
"SELECT id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at
|
||||||
FROM accounts WHERE id = $1"
|
FROM accounts WHERE id = $1"
|
||||||
@@ -141,12 +83,13 @@ pub async fn get_account(db: &PgPool, account_id: &str) -> SaasResult<serde_json
|
|||||||
.fetch_optional(db)
|
.fetch_optional(db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("账号 {} 不存在", account_id)))?;
|
let (id, username, email, display_name, role, status, totp_enabled, last_login_at, created_at) =
|
||||||
|
row.ok_or_else(|| SaasError::NotFound(format!("账号 {} 不存在", account_id)))?;
|
||||||
|
|
||||||
Ok(serde_json::json!({
|
Ok(serde_json::json!({
|
||||||
"id": r.id, "username": r.username, "email": r.email, "display_name": r.display_name,
|
"id": id, "username": username, "email": email, "display_name": display_name,
|
||||||
"role": r.role, "status": r.status, "totp_enabled": r.totp_enabled,
|
"role": role, "status": status, "totp_enabled": totp_enabled,
|
||||||
"last_login_at": r.last_login_at, "created_at": r.created_at,
|
"last_login_at": last_login_at.map(|t| t.to_rfc3339()), "created_at": created_at.to_rfc3339(),
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -155,27 +98,31 @@ pub async fn update_account(
|
|||||||
account_id: &str,
|
account_id: &str,
|
||||||
req: &UpdateAccountRequest,
|
req: &UpdateAccountRequest,
|
||||||
) -> SaasResult<serde_json::Value> {
|
) -> SaasResult<serde_json::Value> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now();
|
||||||
|
let mut updates = Vec::new();
|
||||||
|
let mut params: Vec<String> = Vec::new();
|
||||||
|
let mut param_idx: usize = 1;
|
||||||
|
|
||||||
// COALESCE pattern: all updatable fields in a single static SQL.
|
if let Some(ref v) = req.display_name { updates.push(format!("display_name = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||||
// NULL parameters leave the column unchanged.
|
if let Some(ref v) = req.email { updates.push(format!("email = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||||
sqlx::query(
|
if let Some(ref v) = req.role { updates.push(format!("role = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||||
"UPDATE accounts SET
|
if let Some(ref v) = req.avatar_url { updates.push(format!("avatar_url = ${}", param_idx)); params.push(v.clone()); param_idx += 1; }
|
||||||
display_name = COALESCE($1, display_name),
|
|
||||||
email = COALESCE($2, email),
|
|
||||||
role = COALESCE($3, role),
|
|
||||||
avatar_url = COALESCE($4, avatar_url),
|
|
||||||
updated_at = $5
|
|
||||||
WHERE id = $6"
|
|
||||||
)
|
|
||||||
.bind(req.display_name.as_deref())
|
|
||||||
.bind(req.email.as_deref())
|
|
||||||
.bind(req.role.as_deref())
|
|
||||||
.bind(req.avatar_url.as_deref())
|
|
||||||
.bind(&now)
|
|
||||||
.bind(account_id)
|
|
||||||
.execute(db).await?;
|
|
||||||
|
|
||||||
|
if updates.is_empty() {
|
||||||
|
return get_account(db, account_id).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
updates.push(format!("updated_at = ${}", param_idx));
|
||||||
|
param_idx += 1;
|
||||||
|
params.push(account_id.to_string());
|
||||||
|
|
||||||
|
let sql = format!("UPDATE accounts SET {} WHERE id = ${}", updates.join(", "), param_idx);
|
||||||
|
let mut query = sqlx::query(&sql);
|
||||||
|
for p in ¶ms {
|
||||||
|
query = query.bind(p);
|
||||||
|
}
|
||||||
|
query = query.bind(now);
|
||||||
|
query.execute(db).await?;
|
||||||
get_account(db, account_id).await
|
get_account(db, account_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,7 +135,7 @@ pub async fn update_account_status(
|
|||||||
if !valid.contains(&status) {
|
if !valid.contains(&status) {
|
||||||
return Err(SaasError::InvalidInput(format!("无效状态: {},有效值: {:?}", status, valid)));
|
return Err(SaasError::InvalidInput(format!("无效状态: {},有效值: {:?}", status, valid)));
|
||||||
}
|
}
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now();
|
||||||
let result = sqlx::query("UPDATE accounts SET status = $1, updated_at = $2 WHERE id = $3")
|
let result = sqlx::query("UPDATE accounts SET status = $1, updated_at = $2 WHERE id = $3")
|
||||||
.bind(status).bind(&now).bind(account_id)
|
.bind(status).bind(&now).bind(account_id)
|
||||||
.execute(db).await?;
|
.execute(db).await?;
|
||||||
@@ -213,10 +160,12 @@ pub async fn create_api_token(
|
|||||||
let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
|
let token_hash = hex::encode(Sha256::digest(raw_token.as_bytes()));
|
||||||
let token_prefix = raw_token[..8].to_string();
|
let token_prefix = raw_token[..8].to_string();
|
||||||
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now();
|
||||||
|
let now_str = now.to_rfc3339();
|
||||||
let expires_at = req.expires_days.map(|d| {
|
let expires_at = req.expires_days.map(|d| {
|
||||||
(chrono::Utc::now() + chrono::Duration::days(d)).to_rfc3339()
|
chrono::Utc::now() + chrono::Duration::days(d)
|
||||||
});
|
});
|
||||||
|
let expires_at_str = expires_at.as_ref().map(|t| t.to_rfc3339());
|
||||||
let permissions = serde_json::to_string(&req.permissions)?;
|
let permissions = serde_json::to_string(&req.permissions)?;
|
||||||
let token_id = uuid::Uuid::new_v4().to_string();
|
let token_id = uuid::Uuid::new_v4().to_string();
|
||||||
|
|
||||||
@@ -241,8 +190,8 @@ pub async fn create_api_token(
|
|||||||
token_prefix,
|
token_prefix,
|
||||||
permissions: req.permissions.clone(),
|
permissions: req.permissions.clone(),
|
||||||
last_used_at: None,
|
last_used_at: None,
|
||||||
expires_at,
|
expires_at: expires_at_str,
|
||||||
created_at: now,
|
created_at: now_str,
|
||||||
token: Some(raw_token),
|
token: Some(raw_token),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -250,76 +199,24 @@ pub async fn create_api_token(
|
|||||||
pub async fn list_api_tokens(
|
pub async fn list_api_tokens(
|
||||||
db: &PgPool,
|
db: &PgPool,
|
||||||
account_id: &str,
|
account_id: &str,
|
||||||
page: Option<u32>,
|
) -> SaasResult<Vec<TokenInfo>> {
|
||||||
page_size: Option<u32>,
|
let rows: Vec<(String, String, String, String, Option<chrono::DateTime<chrono::Utc>>, Option<chrono::DateTime<chrono::Utc>>, chrono::DateTime<chrono::Utc>)> =
|
||||||
) -> SaasResult<PaginatedResponse<TokenInfo>> {
|
|
||||||
let (p, ps, offset) = normalize_pagination(page, page_size);
|
|
||||||
|
|
||||||
let total: (i64,) = sqlx::query_as(
|
|
||||||
"SELECT COUNT(*) FROM api_tokens WHERE account_id = $1 AND revoked_at IS NULL"
|
|
||||||
)
|
|
||||||
.bind(account_id)
|
|
||||||
.fetch_one(db)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let rows: Vec<ApiTokenRow> =
|
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, name, token_prefix, permissions, last_used_at, expires_at, created_at
|
"SELECT id, name, token_prefix, permissions, last_used_at, expires_at, created_at
|
||||||
FROM api_tokens WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC LIMIT $2 OFFSET $3"
|
FROM api_tokens WHERE account_id = $1 AND revoked_at IS NULL ORDER BY created_at DESC"
|
||||||
)
|
)
|
||||||
.bind(account_id)
|
.bind(account_id)
|
||||||
.bind(ps as i64)
|
|
||||||
.bind(offset)
|
|
||||||
.fetch_all(db)
|
.fetch_all(db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let items = rows.into_iter().map(|r| {
|
Ok(rows.into_iter().map(|(id, name, token_prefix, perms, last_used, expires, created)| {
|
||||||
let permissions: Vec<String> = serde_json::from_str(&r.permissions).unwrap_or_default();
|
let permissions: Vec<String> = serde_json::from_str(&perms).unwrap_or_default();
|
||||||
TokenInfo { id: r.id, name: r.name, token_prefix: r.token_prefix, permissions, last_used_at: r.last_used_at, expires_at: r.expires_at, created_at: r.created_at, token: None, }
|
TokenInfo { id, name, token_prefix, permissions, last_used_at: last_used.map(|t| t.to_rfc3339()), expires_at: expires.map(|t| t.to_rfc3339()), created_at: created.to_rfc3339(), token: None, }
|
||||||
}).collect();
|
}).collect())
|
||||||
|
|
||||||
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn list_devices(
|
|
||||||
db: &PgPool,
|
|
||||||
account_id: &str,
|
|
||||||
page: Option<u32>,
|
|
||||||
page_size: Option<u32>,
|
|
||||||
) -> SaasResult<PaginatedResponse<serde_json::Value>> {
|
|
||||||
let (p, ps, offset) = normalize_pagination(page, page_size);
|
|
||||||
|
|
||||||
let total: (i64,) = sqlx::query_as(
|
|
||||||
"SELECT COUNT(*) FROM devices WHERE account_id = $1"
|
|
||||||
)
|
|
||||||
.bind(account_id)
|
|
||||||
.fetch_one(db)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let rows: Vec<DeviceRow> =
|
|
||||||
sqlx::query_as(
|
|
||||||
"SELECT id, device_id, device_name, platform, app_version, last_seen_at, created_at
|
|
||||||
FROM devices WHERE account_id = $1 ORDER BY last_seen_at DESC LIMIT $2 OFFSET $3"
|
|
||||||
)
|
|
||||||
.bind(account_id)
|
|
||||||
.bind(ps as i64)
|
|
||||||
.bind(offset)
|
|
||||||
.fetch_all(db)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let items: Vec<serde_json::Value> = rows.into_iter().map(|r| {
|
|
||||||
serde_json::json!({
|
|
||||||
"id": r.id, "device_id": r.device_id,
|
|
||||||
"device_name": r.device_name, "platform": r.platform, "app_version": r.app_version,
|
|
||||||
"last_seen_at": r.last_seen_at, "created_at": r.created_at,
|
|
||||||
})
|
|
||||||
}).collect();
|
|
||||||
|
|
||||||
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn revoke_api_token(db: &PgPool, token_id: &str, account_id: &str) -> SaasResult<()> {
|
pub async fn revoke_api_token(db: &PgPool, token_id: &str, account_id: &str) -> SaasResult<()> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now();
|
||||||
let result = sqlx::query(
|
let result = sqlx::query(
|
||||||
"UPDATE api_tokens SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
|
"UPDATE api_tokens SET revoked_at = $1 WHERE id = $2 AND account_id = $3 AND revoked_at IS NULL"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,10 +2,7 @@
|
|||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
// Re-export from common module
|
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||||
pub use crate::common::PaginatedResponse;
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct UpdateAccountRequest {
|
pub struct UpdateAccountRequest {
|
||||||
pub display_name: Option<String>,
|
pub display_name: Option<String>,
|
||||||
pub email: Option<String>,
|
pub email: Option<String>,
|
||||||
@@ -13,12 +10,12 @@ pub struct UpdateAccountRequest {
|
|||||||
pub avatar_url: Option<String>,
|
pub avatar_url: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||||
pub struct UpdateStatusRequest {
|
pub struct UpdateStatusRequest {
|
||||||
pub status: String,
|
pub status: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, utoipa::ToSchema, utoipa::IntoParams)]
|
||||||
pub struct ListAccountsQuery {
|
pub struct ListAccountsQuery {
|
||||||
pub page: Option<u32>,
|
pub page: Option<u32>,
|
||||||
pub page_size: Option<u32>,
|
pub page_size: Option<u32>,
|
||||||
@@ -27,14 +24,36 @@ pub struct ListAccountsQuery {
|
|||||||
pub search: Option<String>,
|
pub search: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize)]
|
||||||
|
pub struct PaginatedResponse<T: Serialize> {
|
||||||
|
pub items: Vec<T>,
|
||||||
|
pub total: i64,
|
||||||
|
pub page: u32,
|
||||||
|
pub page_size: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Concrete type alias for OpenAPI schema generation.
|
||||||
|
///
|
||||||
|
/// NOTE: This is intentionally a concrete (non-generic) type because utoipa
|
||||||
|
/// requires concrete types for schema generation. It is functionally
|
||||||
|
/// identical to `Paginated<AccountPublic>`.
|
||||||
|
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||||
|
#[allow(clippy::manual_non_exhaustive)] // kept for OpenAPI schema
|
||||||
|
pub struct AccountPublicPaginatedResponse {
|
||||||
|
pub items: Vec<crate::auth::types::AccountPublic>,
|
||||||
|
pub total: i64,
|
||||||
|
pub page: u32,
|
||||||
|
pub page_size: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||||
pub struct CreateTokenRequest {
|
pub struct CreateTokenRequest {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub permissions: Vec<String>,
|
pub permissions: Vec<String>,
|
||||||
pub expires_days: Option<i64>,
|
pub expires_days: Option<i64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||||
pub struct TokenInfo {
|
pub struct TokenInfo {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
@@ -46,3 +65,35 @@ pub struct TokenInfo {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub token: Option<String>,
|
pub token: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============ Device Types ============
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||||
|
pub struct RegisterDeviceRequest {
|
||||||
|
pub device_id: String,
|
||||||
|
#[serde(default = "default_device_name")]
|
||||||
|
pub device_name: String,
|
||||||
|
#[serde(default = "default_platform")]
|
||||||
|
pub platform: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub app_version: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_device_name() -> String { "Unknown".into() }
|
||||||
|
fn default_platform() -> String { "unknown".into() }
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||||
|
pub struct DeviceHeartbeatRequest {
|
||||||
|
pub device_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||||
|
pub struct DeviceInfo {
|
||||||
|
pub id: String,
|
||||||
|
pub device_id: String,
|
||||||
|
pub device_name: Option<String>,
|
||||||
|
pub platform: Option<String>,
|
||||||
|
pub app_version: Option<String>,
|
||||||
|
pub last_seen_at: String,
|
||||||
|
pub created_at: String,
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,104 +0,0 @@
|
|||||||
//! Agent 配置模板 HTTP 处理器
|
|
||||||
|
|
||||||
use axum::{
|
|
||||||
extract::{Extension, Path, Query, State},
|
|
||||||
Json,
|
|
||||||
};
|
|
||||||
use crate::state::AppState;
|
|
||||||
use crate::error::SaasResult;
|
|
||||||
use crate::auth::types::AuthContext;
|
|
||||||
use crate::auth::handlers::{log_operation, check_permission};
|
|
||||||
use super::types::*;
|
|
||||||
use super::service;
|
|
||||||
|
|
||||||
/// GET /api/v1/agent-templates — 列出 Agent 模板
|
|
||||||
pub async fn list_templates(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Extension(ctx): Extension<AuthContext>,
|
|
||||||
Query(query): Query<AgentTemplateListQuery>,
|
|
||||||
) -> SaasResult<Json<crate::common::PaginatedResponse<AgentTemplateInfo>>> {
|
|
||||||
check_permission(&ctx, "model:read")?;
|
|
||||||
Ok(Json(service::list_templates(&state.db, &query).await?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// POST /api/v1/agent-templates — 创建 Agent 模板
|
|
||||||
pub async fn create_template(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Extension(ctx): Extension<AuthContext>,
|
|
||||||
Json(req): Json<CreateAgentTemplateRequest>,
|
|
||||||
) -> SaasResult<Json<AgentTemplateInfo>> {
|
|
||||||
check_permission(&ctx, "model:manage")?;
|
|
||||||
|
|
||||||
let category = req.category.as_deref().unwrap_or("general");
|
|
||||||
let source = req.source.as_deref().unwrap_or("custom");
|
|
||||||
let visibility = req.visibility.as_deref().unwrap_or("public");
|
|
||||||
let tools = req.tools.as_deref().unwrap_or(&[]);
|
|
||||||
let capabilities = req.capabilities.as_deref().unwrap_or(&[]);
|
|
||||||
|
|
||||||
let result = service::create_template(
|
|
||||||
&state.db, &req.name, req.description.as_deref(),
|
|
||||||
category, source, req.model.as_deref(),
|
|
||||||
req.system_prompt.as_deref(),
|
|
||||||
tools, capabilities,
|
|
||||||
req.temperature, req.max_tokens, visibility,
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
log_operation(&state.db, &ctx.account_id, "agent_template.create", "agent_template", &result.id,
|
|
||||||
Some(serde_json::json!({"name": req.name})), ctx.client_ip.as_deref()).await?;
|
|
||||||
|
|
||||||
Ok(Json(result))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// GET /api/v1/agent-templates/:id — 获取单个 Agent 模板
|
|
||||||
pub async fn get_template(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Extension(ctx): Extension<AuthContext>,
|
|
||||||
Path(id): Path<String>,
|
|
||||||
) -> SaasResult<Json<AgentTemplateInfo>> {
|
|
||||||
check_permission(&ctx, "model:read")?;
|
|
||||||
Ok(Json(service::get_template(&state.db, &id).await?))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// POST /api/v1/agent-templates/:id — 更新 Agent 模板
|
|
||||||
pub async fn update_template(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Extension(ctx): Extension<AuthContext>,
|
|
||||||
Path(id): Path<String>,
|
|
||||||
Json(req): Json<UpdateAgentTemplateRequest>,
|
|
||||||
) -> SaasResult<Json<AgentTemplateInfo>> {
|
|
||||||
check_permission(&ctx, "model:manage")?;
|
|
||||||
|
|
||||||
let result = service::update_template(
|
|
||||||
&state.db, &id,
|
|
||||||
req.description.as_deref(),
|
|
||||||
req.model.as_deref(),
|
|
||||||
req.system_prompt.as_deref(),
|
|
||||||
req.tools.as_deref(),
|
|
||||||
req.capabilities.as_deref(),
|
|
||||||
req.temperature,
|
|
||||||
req.max_tokens,
|
|
||||||
req.visibility.as_deref(),
|
|
||||||
req.status.as_deref(),
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
log_operation(&state.db, &ctx.account_id, "agent_template.update", "agent_template", &id,
|
|
||||||
None, ctx.client_ip.as_deref()).await?;
|
|
||||||
|
|
||||||
Ok(Json(result))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// DELETE /api/v1/agent-templates/:id — 归档 Agent 模板
|
|
||||||
pub async fn archive_template(
|
|
||||||
State(state): State<AppState>,
|
|
||||||
Extension(ctx): Extension<AuthContext>,
|
|
||||||
Path(id): Path<String>,
|
|
||||||
) -> SaasResult<Json<AgentTemplateInfo>> {
|
|
||||||
check_permission(&ctx, "model:manage")?;
|
|
||||||
|
|
||||||
let result = service::archive_template(&state.db, &id).await?;
|
|
||||||
|
|
||||||
log_operation(&state.db, &ctx.account_id, "agent_template.archive", "agent_template", &id,
|
|
||||||
None, ctx.client_ip.as_deref()).await?;
|
|
||||||
|
|
||||||
Ok(Json(result))
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
//! Agent 配置模板管理模块
|
|
||||||
|
|
||||||
pub mod types;
|
|
||||||
pub mod service;
|
|
||||||
pub mod handlers;
|
|
||||||
|
|
||||||
use axum::routing::{delete, get, post};
|
|
||||||
use crate::state::AppState;
|
|
||||||
|
|
||||||
/// Agent 模板管理路由 (需要认证)
|
|
||||||
pub fn routes() -> axum::Router<AppState> {
|
|
||||||
axum::Router::new()
|
|
||||||
.route("/api/v1/agent-templates", get(handlers::list_templates).post(handlers::create_template))
|
|
||||||
.route("/api/v1/agent-templates/:id", get(handlers::get_template))
|
|
||||||
.route("/api/v1/agent-templates/:id", post(handlers::update_template))
|
|
||||||
.route("/api/v1/agent-templates/:id", delete(handlers::archive_template))
|
|
||||||
}
|
|
||||||
@@ -1,170 +0,0 @@
|
|||||||
//! Agent 配置模板业务逻辑
|
|
||||||
|
|
||||||
use sqlx::PgPool;
|
|
||||||
use crate::error::{SaasError, SaasResult};
|
|
||||||
use super::types::*;
|
|
||||||
|
|
||||||
fn row_to_template(
|
|
||||||
row: (String, String, Option<String>, String, String, Option<String>, Option<String>,
|
|
||||||
String, String, Option<f64>, Option<i32>, String, String, i32, String, String),
|
|
||||||
) -> AgentTemplateInfo {
|
|
||||||
AgentTemplateInfo {
|
|
||||||
id: row.0, name: row.1, description: row.2, category: row.3, source: row.4,
|
|
||||||
model: row.5, system_prompt: row.6, tools: serde_json::from_str(&row.7).unwrap_or_default(),
|
|
||||||
capabilities: serde_json::from_str(&row.8).unwrap_or_default(),
|
|
||||||
temperature: row.9, max_tokens: row.10, visibility: row.11, status: row.12,
|
|
||||||
current_version: row.13, created_at: row.14, updated_at: row.15,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Row type for agent_template queries (avoids multi-line turbofish parsing issues)
|
|
||||||
type AgentTemplateRow = (String, String, Option<String>, String, String, Option<String>, Option<String>, String, String, Option<f64>, Option<i32>, String, String, i32, String, String);
|
|
||||||
|
|
||||||
/// 创建 Agent 模板
|
|
||||||
pub async fn create_template(
|
|
||||||
db: &PgPool,
|
|
||||||
name: &str,
|
|
||||||
description: Option<&str>,
|
|
||||||
category: &str,
|
|
||||||
source: &str,
|
|
||||||
model: Option<&str>,
|
|
||||||
system_prompt: Option<&str>,
|
|
||||||
tools: &[String],
|
|
||||||
capabilities: &[String],
|
|
||||||
temperature: Option<f64>,
|
|
||||||
max_tokens: Option<i32>,
|
|
||||||
visibility: &str,
|
|
||||||
) -> SaasResult<AgentTemplateInfo> {
|
|
||||||
let id = uuid::Uuid::new_v4().to_string();
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
|
||||||
let tools_json = serde_json::to_string(tools).unwrap_or_else(|_| "[]".to_string());
|
|
||||||
let caps_json = serde_json::to_string(capabilities).unwrap_or_else(|_| "[]".to_string());
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
"INSERT INTO agent_templates (id, name, description, category, source, model, system_prompt,
|
|
||||||
tools, capabilities, temperature, max_tokens, visibility, status, current_version, created_at, updated_at)
|
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, 'active', 1, $13, $13)"
|
|
||||||
)
|
|
||||||
.bind(&id).bind(name).bind(description).bind(category).bind(source)
|
|
||||||
.bind(model).bind(system_prompt).bind(&tools_json).bind(&caps_json)
|
|
||||||
.bind(temperature).bind(max_tokens).bind(visibility).bind(&now)
|
|
||||||
.execute(db).await.map_err(|e| {
|
|
||||||
if e.to_string().contains("unique") {
|
|
||||||
SaasError::AlreadyExists(format!("Agent 模板 '{}' 已存在", name))
|
|
||||||
} else {
|
|
||||||
SaasError::Database(e)
|
|
||||||
}
|
|
||||||
})?;
|
|
||||||
|
|
||||||
get_template(db, &id).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 获取单个模板
|
|
||||||
pub async fn get_template(db: &PgPool, id: &str) -> SaasResult<AgentTemplateInfo> {
|
|
||||||
let row: Option<AgentTemplateRow> = sqlx::query_as(
|
|
||||||
"SELECT id, name, description, category, source, model, system_prompt,
|
|
||||||
tools, capabilities, temperature, max_tokens, visibility, status,
|
|
||||||
current_version, created_at, updated_at
|
|
||||||
FROM agent_templates WHERE id = $1"
|
|
||||||
).bind(id).fetch_optional(db).await?;
|
|
||||||
|
|
||||||
row.map(row_to_template)
|
|
||||||
.ok_or_else(|| SaasError::NotFound(format!("Agent 模板 {} 不存在", id)))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 列出模板(分页 + 过滤)
|
|
||||||
/// Static SQL + conditional filter pattern: ($N IS NULL OR col = $N).
|
|
||||||
/// When the parameter is NULL the whole OR evaluates to TRUE (no filter).
|
|
||||||
pub async fn list_templates(
|
|
||||||
db: &PgPool,
|
|
||||||
query: &AgentTemplateListQuery,
|
|
||||||
) -> SaasResult<crate::common::PaginatedResponse<AgentTemplateInfo>> {
|
|
||||||
let page = query.page.unwrap_or(1).max(1);
|
|
||||||
let page_size = query.page_size.unwrap_or(20).min(100);
|
|
||||||
let offset = ((page - 1) * page_size) as i64;
|
|
||||||
|
|
||||||
let count_sql = "SELECT COUNT(*) FROM agent_templates WHERE ($1 IS NULL OR category = $1) AND ($2 IS NULL OR source = $2) AND ($3 IS NULL OR visibility = $3) AND ($4 IS NULL OR status = $4)";
|
|
||||||
let data_sql = "SELECT id, name, description, category, source, model, system_prompt,
|
|
||||||
tools, capabilities, temperature, max_tokens, visibility, status,
|
|
||||||
current_version, created_at, updated_at
|
|
||||||
FROM agent_templates WHERE ($1 IS NULL OR category = $1) AND ($2 IS NULL OR source = $2) AND ($3 IS NULL OR visibility = $3) AND ($4 IS NULL OR status = $4) ORDER BY created_at DESC LIMIT $5 OFFSET $6";
|
|
||||||
|
|
||||||
let total: i64 = sqlx::query_scalar(count_sql)
|
|
||||||
.bind(&query.category)
|
|
||||||
.bind(&query.source)
|
|
||||||
.bind(&query.visibility)
|
|
||||||
.bind(&query.status)
|
|
||||||
.fetch_one(db).await?;
|
|
||||||
|
|
||||||
let rows: Vec<AgentTemplateRow> = sqlx::query_as(data_sql)
|
|
||||||
.bind(&query.category)
|
|
||||||
.bind(&query.source)
|
|
||||||
.bind(&query.visibility)
|
|
||||||
.bind(&query.status)
|
|
||||||
.bind(page_size as i64)
|
|
||||||
.bind(offset)
|
|
||||||
.fetch_all(db).await?;
|
|
||||||
let items = rows.into_iter().map(row_to_template).collect();
|
|
||||||
|
|
||||||
Ok(crate::common::PaginatedResponse { items, total, page, page_size })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 更新模板
|
|
||||||
/// COALESCE pattern: all updatable fields in a single static SQL.
|
|
||||||
/// NULL parameters leave the column unchanged.
|
|
||||||
pub async fn update_template(
|
|
||||||
db: &PgPool,
|
|
||||||
id: &str,
|
|
||||||
description: Option<&str>,
|
|
||||||
model: Option<&str>,
|
|
||||||
system_prompt: Option<&str>,
|
|
||||||
tools: Option<&[String]>,
|
|
||||||
capabilities: Option<&[String]>,
|
|
||||||
temperature: Option<f64>,
|
|
||||||
max_tokens: Option<i32>,
|
|
||||||
visibility: Option<&str>,
|
|
||||||
status: Option<&str>,
|
|
||||||
) -> SaasResult<AgentTemplateInfo> {
|
|
||||||
// Confirm existence
|
|
||||||
get_template(db, id).await?;
|
|
||||||
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
|
||||||
|
|
||||||
// Serialize JSON fields upfront so we can bind Option<&str> consistently
|
|
||||||
let tools_json = tools.map(|t| serde_json::to_string(t).unwrap_or_else(|_| "[]".to_string()));
|
|
||||||
let caps_json = capabilities.map(|c| serde_json::to_string(c).unwrap_or_else(|_| "[]".to_string()));
|
|
||||||
|
|
||||||
sqlx::query(
|
|
||||||
"UPDATE agent_templates SET
|
|
||||||
description = COALESCE($1, description),
|
|
||||||
model = COALESCE($2, model),
|
|
||||||
system_prompt = COALESCE($3, system_prompt),
|
|
||||||
tools = COALESCE($4, tools),
|
|
||||||
capabilities = COALESCE($5, capabilities),
|
|
||||||
temperature = COALESCE($6, temperature),
|
|
||||||
max_tokens = COALESCE($7, max_tokens),
|
|
||||||
visibility = COALESCE($8, visibility),
|
|
||||||
status = COALESCE($9, status),
|
|
||||||
updated_at = $10
|
|
||||||
WHERE id = $11"
|
|
||||||
)
|
|
||||||
.bind(description)
|
|
||||||
.bind(model)
|
|
||||||
.bind(system_prompt)
|
|
||||||
.bind(tools_json.as_deref())
|
|
||||||
.bind(caps_json.as_deref())
|
|
||||||
.bind(temperature)
|
|
||||||
.bind(max_tokens)
|
|
||||||
.bind(visibility)
|
|
||||||
.bind(status)
|
|
||||||
.bind(&now)
|
|
||||||
.bind(id)
|
|
||||||
.execute(db).await?;
|
|
||||||
|
|
||||||
get_template(db, id).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 归档模板
|
|
||||||
pub async fn archive_template(db: &PgPool, id: &str) -> SaasResult<AgentTemplateInfo> {
|
|
||||||
update_template(db, id, None, None, None, None, None, None, None, None, Some("archived")).await
|
|
||||||
}
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
//! Agent 配置模板类型定义
|
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
|
||||||
|
|
||||||
// --- Agent Template ---
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
||||||
pub struct AgentTemplateInfo {
|
|
||||||
pub id: String,
|
|
||||||
pub name: String,
|
|
||||||
pub description: Option<String>,
|
|
||||||
pub category: String,
|
|
||||||
pub source: String,
|
|
||||||
pub model: Option<String>,
|
|
||||||
pub system_prompt: Option<String>,
|
|
||||||
pub tools: Vec<String>,
|
|
||||||
pub capabilities: Vec<String>,
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
pub max_tokens: Option<i32>,
|
|
||||||
pub visibility: String,
|
|
||||||
pub status: String,
|
|
||||||
pub current_version: i32,
|
|
||||||
pub created_at: String,
|
|
||||||
pub updated_at: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct CreateAgentTemplateRequest {
|
|
||||||
pub name: String,
|
|
||||||
pub description: Option<String>,
|
|
||||||
pub category: Option<String>,
|
|
||||||
pub source: Option<String>,
|
|
||||||
pub model: Option<String>,
|
|
||||||
pub system_prompt: Option<String>,
|
|
||||||
pub tools: Option<Vec<String>>,
|
|
||||||
pub capabilities: Option<Vec<String>>,
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
pub max_tokens: Option<i32>,
|
|
||||||
pub visibility: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct UpdateAgentTemplateRequest {
|
|
||||||
pub description: Option<String>,
|
|
||||||
pub model: Option<String>,
|
|
||||||
pub system_prompt: Option<String>,
|
|
||||||
pub tools: Option<Vec<String>>,
|
|
||||||
pub capabilities: Option<Vec<String>>,
|
|
||||||
pub temperature: Option<f64>,
|
|
||||||
pub max_tokens: Option<i32>,
|
|
||||||
pub visibility: Option<String>,
|
|
||||||
pub status: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- List ---
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct AgentTemplateListQuery {
|
|
||||||
pub category: Option<String>,
|
|
||||||
pub source: Option<String>,
|
|
||||||
pub visibility: Option<String>,
|
|
||||||
pub status: Option<String>,
|
|
||||||
pub page: Option<u32>,
|
|
||||||
pub page_size: Option<u32>,
|
|
||||||
}
|
|
||||||
@@ -5,45 +5,32 @@ use std::net::SocketAddr;
|
|||||||
use secrecy::ExposeSecret;
|
use secrecy::ExposeSecret;
|
||||||
use crate::state::AppState;
|
use crate::state::AppState;
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
use crate::models::{AccountAuthRow, AccountLoginRow};
|
|
||||||
use super::{
|
use super::{
|
||||||
jwt::{create_token, create_refresh_token, verify_token, verify_token_skip_expiry},
|
jwt::create_token,
|
||||||
password::{hash_password_async, verify_password_async},
|
password::{hash_password, verify_password},
|
||||||
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic, RefreshRequest},
|
types::{AuthContext, LoginRequest, LoginResponse, RegisterRequest, ChangePasswordRequest, AccountPublic},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// POST /api/v1/auth/register
|
/// POST /api/v1/auth/register
|
||||||
/// 注册成功后自动签发 JWT,返回与 login 一致的 LoginResponse
|
|
||||||
pub async fn register(
|
pub async fn register(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
Json(req): Json<RegisterRequest>,
|
Json(req): Json<RegisterRequest>,
|
||||||
) -> SaasResult<(StatusCode, Json<LoginResponse>)> {
|
) -> SaasResult<(StatusCode, Json<LoginResponse>)> {
|
||||||
if req.username.len() < 3 {
|
// 4.6: 用户名格式验证 — 3-32 字符,仅允许字母数字下划线
|
||||||
return Err(SaasError::InvalidInput("用户名至少 3 个字符".into()));
|
if req.username.len() < 3 || req.username.len() > 32 {
|
||||||
|
return Err(SaasError::InvalidInput("用户名长度需在 3-32 个字符之间".into()));
|
||||||
}
|
}
|
||||||
if req.username.len() > 32 {
|
if !req.username.chars().all(|c| c.is_alphanumeric() || c == '_') {
|
||||||
return Err(SaasError::InvalidInput("用户名最多 32 个字符".into()));
|
return Err(SaasError::InvalidInput("用户名仅允许字母、数字和下划线".into()));
|
||||||
}
|
}
|
||||||
static USERNAME_RE: std::sync::OnceLock<regex::Regex> = std::sync::OnceLock::new();
|
// 4.7: 邮箱格式验证
|
||||||
let username_re = USERNAME_RE.get_or_init(|| regex::Regex::new(r"^[a-zA-Z0-9_-]+$").unwrap());
|
if !req.email.contains('@') || !req.email.split('@').nth(1).map_or(false, |d| d.contains('.')) {
|
||||||
if !username_re.is_match(&req.username) {
|
|
||||||
return Err(SaasError::InvalidInput("用户名只能包含字母、数字、下划线和连字符".into()));
|
|
||||||
}
|
|
||||||
if !req.email.contains('@') || !req.email.contains('.') {
|
|
||||||
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
|
return Err(SaasError::InvalidInput("邮箱格式不正确".into()));
|
||||||
}
|
}
|
||||||
if req.password.len() < 8 {
|
if req.password.len() < 8 {
|
||||||
return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
|
return Err(SaasError::InvalidInput("密码至少 8 个字符".into()));
|
||||||
}
|
}
|
||||||
if req.password.len() > 128 {
|
|
||||||
return Err(SaasError::InvalidInput("密码最多 128 个字符".into()));
|
|
||||||
}
|
|
||||||
if let Some(ref name) = req.display_name {
|
|
||||||
if name.len() > 64 {
|
|
||||||
return Err(SaasError::InvalidInput("显示名称最多 64 个字符".into()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let existing: Vec<(String,)> = sqlx::query_as(
|
let existing: Vec<(String,)> = sqlx::query_as(
|
||||||
"SELECT id FROM accounts WHERE username = $1 OR email = $2"
|
"SELECT id FROM accounts WHERE username = $1 OR email = $2"
|
||||||
@@ -57,11 +44,11 @@ pub async fn register(
|
|||||||
return Err(SaasError::AlreadyExists("用户名或邮箱已存在".into()));
|
return Err(SaasError::AlreadyExists("用户名或邮箱已存在".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
let password_hash = hash_password_async(req.password.clone()).await?;
|
let password_hash = hash_password(&req.password)?;
|
||||||
let account_id = uuid::Uuid::new_v4().to_string();
|
let account_id = uuid::Uuid::new_v4().to_string();
|
||||||
let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配
|
let role = "user".to_string(); // 注册固定为普通用户,角色由管理员分配
|
||||||
let display_name = req.display_name.unwrap_or_default();
|
let display_name = req.display_name.unwrap_or_default();
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now();
|
||||||
|
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
|
"INSERT INTO accounts (id, username, email, password_hash, display_name, role, status, created_at, updated_at)
|
||||||
@@ -73,45 +60,32 @@ pub async fn register(
|
|||||||
.bind(&password_hash)
|
.bind(&password_hash)
|
||||||
.bind(&display_name)
|
.bind(&display_name)
|
||||||
.bind(&role)
|
.bind(&role)
|
||||||
.bind(&now)
|
.bind(now)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let client_ip = addr.ip().to_string();
|
let client_ip = addr.ip().to_string();
|
||||||
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
|
log_operation(&state.db, &account_id, "account.create", "account", &account_id, None, Some(&client_ip)).await?;
|
||||||
|
|
||||||
// 注册成功后自动签发 JWT + Refresh Token
|
// Generate JWT token for auto-login after registration
|
||||||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
|
|
||||||
let config = state.config.read().await;
|
let config = state.config.read().await;
|
||||||
let token = create_token(
|
let token = create_token(
|
||||||
&account_id, &role, permissions.clone(),
|
&account_id, &role, vec![],
|
||||||
state.jwt_secret.expose_secret(),
|
state.jwt_secret.expose_secret(), config.auth.jwt_expiration_hours,
|
||||||
config.auth.jwt_expiration_hours,
|
|
||||||
)?;
|
)?;
|
||||||
let refresh_token = create_refresh_token(
|
|
||||||
&account_id, &role, permissions,
|
|
||||||
state.jwt_secret.expose_secret(),
|
|
||||||
config.auth.refresh_token_hours,
|
|
||||||
)?;
|
|
||||||
drop(config);
|
|
||||||
|
|
||||||
store_refresh_token(
|
|
||||||
&state.db, &account_id, &refresh_token,
|
|
||||||
state.jwt_secret.expose_secret(), 168,
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
Ok((StatusCode::CREATED, Json(LoginResponse {
|
Ok((StatusCode::CREATED, Json(LoginResponse {
|
||||||
token,
|
token,
|
||||||
refresh_token,
|
|
||||||
account: AccountPublic {
|
account: AccountPublic {
|
||||||
id: account_id,
|
id: account_id,
|
||||||
username: req.username,
|
username: req.username,
|
||||||
email: req.email,
|
email: req.email,
|
||||||
display_name,
|
display_name,
|
||||||
role,
|
role,
|
||||||
|
permissions: vec![],
|
||||||
status: "active".into(),
|
status: "active".into(),
|
||||||
totp_enabled: false,
|
totp_enabled: false,
|
||||||
created_at: now,
|
created_at: now.to_rfc3339(),
|
||||||
},
|
},
|
||||||
})))
|
})))
|
||||||
}
|
}
|
||||||
@@ -122,170 +96,109 @@ pub async fn login(
|
|||||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||||
Json(req): Json<LoginRequest>,
|
Json(req): Json<LoginRequest>,
|
||||||
) -> SaasResult<Json<LoginResponse>> {
|
) -> SaasResult<Json<LoginResponse>> {
|
||||||
// 一次查询获取用户信息 + password_hash + totp_secret(合并原来的 3 次查询)
|
let row: Option<(String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)> =
|
||||||
let row: Option<AccountLoginRow> =
|
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled,
|
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||||
password_hash, totp_secret, created_at
|
|
||||||
FROM accounts WHERE username = $1 OR email = $1"
|
FROM accounts WHERE username = $1 OR email = $1"
|
||||||
)
|
)
|
||||||
.bind(&req.username)
|
.bind(&req.username)
|
||||||
.fetch_optional(&state.db)
|
.fetch_optional(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let r = row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
|
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
|
||||||
|
row.ok_or_else(|| SaasError::AuthError("用户名或密码错误".into()))?;
|
||||||
|
let created_at = created_at.to_rfc3339();
|
||||||
|
|
||||||
if r.status != "active" {
|
if status != "active" {
|
||||||
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", r.status)));
|
return Err(SaasError::Forbidden(format!("账号已{},请联系管理员", status)));
|
||||||
}
|
}
|
||||||
|
|
||||||
if !verify_password_async(req.password.clone(), r.password_hash.clone()).await? {
|
let (password_hash,): (String,) = sqlx::query_as(
|
||||||
|
"SELECT password_hash FROM accounts WHERE id = $1"
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.fetch_one(&state.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !verify_password(&req.password, &password_hash)? {
|
||||||
return Err(SaasError::AuthError("用户名或密码错误".into()));
|
return Err(SaasError::AuthError("用户名或密码错误".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TOTP 验证: 如果用户已启用 2FA,必须提供有效 TOTP 码
|
// TOTP 验证: 如果用户已启用 2FA,必须提供有效 TOTP 码
|
||||||
if r.totp_enabled {
|
if totp_enabled {
|
||||||
let code = req.totp_code.as_deref()
|
let code = req.totp_code.as_deref()
|
||||||
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
|
.ok_or_else(|| SaasError::Totp("此账号已启用双因素认证,请提供 TOTP 码".into()))?;
|
||||||
|
|
||||||
let secret = r.totp_secret.clone().ok_or_else(|| {
|
let (totp_secret,): (Option<String>,) = sqlx::query_as(
|
||||||
|
"SELECT totp_secret FROM accounts WHERE id = $1"
|
||||||
|
)
|
||||||
|
.bind(&id)
|
||||||
|
.fetch_one(&state.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let secret = totp_secret.ok_or_else(|| {
|
||||||
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
|
SaasError::Internal("TOTP 已启用但密钥丢失,请联系管理员".into())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// 解密 TOTP secret (兼容旧的明文格式)
|
// 解密 TOTP 密钥(兼容迁移期间的明文数据)
|
||||||
let config = state.config.read().await;
|
let decrypted_secret = state.field_encryption.decrypt_or_plaintext(&secret);
|
||||||
let enc_key = config.totp_encryption_key()
|
|
||||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
|
||||||
let secret = super::totp::decrypt_totp_for_login(&secret, &enc_key)?;
|
|
||||||
|
|
||||||
if !super::totp::verify_totp_code(&secret, code) {
|
if !super::totp::verify_totp_code(&decrypted_secret, code) {
|
||||||
return Err(SaasError::Totp("TOTP 码错误或已过期".into()));
|
return Err(SaasError::Totp("TOTP 码错误或已过期".into()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &r.role).await?;
|
let permissions = get_role_permissions(&state.db, &role).await?;
|
||||||
let config = state.config.read().await;
|
let config = state.config.read().await;
|
||||||
let token = create_token(
|
let token = create_token(
|
||||||
&r.id, &r.role, permissions.clone(),
|
&id, &role, permissions.clone(),
|
||||||
state.jwt_secret.expose_secret(),
|
state.jwt_secret.expose_secret(),
|
||||||
config.auth.jwt_expiration_hours,
|
config.auth.jwt_expiration_hours,
|
||||||
)?;
|
)?;
|
||||||
let refresh_token = create_refresh_token(
|
|
||||||
&r.id, &r.role, permissions,
|
|
||||||
state.jwt_secret.expose_secret(),
|
|
||||||
config.auth.refresh_token_hours,
|
|
||||||
)?;
|
|
||||||
drop(config);
|
|
||||||
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now();
|
||||||
sqlx::query("UPDATE accounts SET last_login_at = $1 WHERE id = $2")
|
sqlx::query("UPDATE accounts SET last_login_at = $1 WHERE id = $2")
|
||||||
.bind(&now).bind(&r.id)
|
.bind(now).bind(&id)
|
||||||
.execute(&state.db).await?;
|
.execute(&state.db).await?;
|
||||||
let client_ip = addr.ip().to_string();
|
let client_ip = addr.ip().to_string();
|
||||||
log_operation(&state.db, &r.id, "account.login", "account", &r.id, None, Some(&client_ip)).await?;
|
log_operation(&state.db, &id, "account.login", "account", &id, None, Some(&client_ip)).await?;
|
||||||
|
|
||||||
store_refresh_token(
|
|
||||||
&state.db, &r.id, &refresh_token,
|
|
||||||
state.jwt_secret.expose_secret(), 168,
|
|
||||||
).await?;
|
|
||||||
|
|
||||||
Ok(Json(LoginResponse {
|
Ok(Json(LoginResponse {
|
||||||
token,
|
token,
|
||||||
refresh_token,
|
|
||||||
account: AccountPublic {
|
account: AccountPublic {
|
||||||
id: r.id, username: r.username, email: r.email, display_name: r.display_name,
|
id, username, email, display_name, role, permissions, status, totp_enabled, created_at,
|
||||||
role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at,
|
|
||||||
},
|
},
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// POST /api/v1/auth/refresh
|
/// POST /api/v1/auth/refresh
|
||||||
/// 使用 refresh_token 换取新的 access + refresh token 对
|
|
||||||
/// refresh_token 一次性使用,使用后立即失效
|
|
||||||
pub async fn refresh(
|
pub async fn refresh(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Json(req): Json<RefreshRequest>,
|
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<serde_json::Value>> {
|
) -> SaasResult<Json<LoginResponse>> {
|
||||||
// 1. 验证 refresh token 签名 (跳过过期检查,但有 7 天窗口限制)
|
|
||||||
let claims = verify_token_skip_expiry(&req.refresh_token, state.jwt_secret.expose_secret())?;
|
|
||||||
|
|
||||||
// 2. 确认是 refresh 类型 token
|
|
||||||
if claims.token_type != "refresh" {
|
|
||||||
return Err(SaasError::AuthError("无效的 refresh token".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
let jti = claims.jti.as_deref()
|
|
||||||
.ok_or_else(|| SaasError::AuthError("refresh token 缺少 jti".into()))?;
|
|
||||||
|
|
||||||
// 3. 从 DB 查找 refresh token,确保未被使用
|
|
||||||
let row: Option<(String,)> = sqlx::query_as(
|
|
||||||
"SELECT account_id FROM refresh_tokens WHERE jti = $1 AND used_at IS NULL AND expires_at > $2"
|
|
||||||
)
|
|
||||||
.bind(jti)
|
|
||||||
.bind(&chrono::Utc::now().to_rfc3339())
|
|
||||||
.fetch_optional(&state.db)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let token_account_id = row
|
|
||||||
.ok_or_else(|| SaasError::AuthError("refresh token 已使用、已过期或不存在".into()))?
|
|
||||||
.0;
|
|
||||||
|
|
||||||
// 4. 验证 token 中的 account_id 与 DB 中的一致
|
|
||||||
if token_account_id != claims.sub {
|
|
||||||
return Err(SaasError::AuthError("refresh token 账号不匹配".into()));
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. 标记旧 refresh token 为已使用 (一次性)
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
|
||||||
sqlx::query("UPDATE refresh_tokens SET used_at = $1 WHERE jti = $2")
|
|
||||||
.bind(&now).bind(jti)
|
|
||||||
.execute(&state.db).await?;
|
|
||||||
|
|
||||||
// 6. 获取最新角色权限
|
|
||||||
let (role,): (String,) = sqlx::query_as(
|
|
||||||
"SELECT role FROM accounts WHERE id = $1 AND status = 'active'"
|
|
||||||
)
|
|
||||||
.bind(&claims.sub)
|
|
||||||
.fetch_optional(&state.db)
|
|
||||||
.await?
|
|
||||||
.ok_or_else(|| SaasError::AuthError("账号不存在或已禁用".into()))?;
|
|
||||||
|
|
||||||
let permissions = get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
|
|
||||||
|
|
||||||
// 7. 创建新的 access token + refresh token
|
|
||||||
let config = state.config.read().await;
|
let config = state.config.read().await;
|
||||||
let new_access = create_token(
|
let token = create_token(
|
||||||
&claims.sub, &role, permissions.clone(),
|
&ctx.account_id, &ctx.role, ctx.permissions.clone(),
|
||||||
state.jwt_secret.expose_secret(),
|
state.jwt_secret.expose_secret(),
|
||||||
config.auth.jwt_expiration_hours,
|
config.auth.jwt_expiration_hours,
|
||||||
)?;
|
)?;
|
||||||
let new_refresh = create_refresh_token(
|
|
||||||
&claims.sub, &role, permissions.clone(),
|
|
||||||
state.jwt_secret.expose_secret(),
|
|
||||||
config.auth.refresh_token_hours,
|
|
||||||
)?;
|
|
||||||
drop(config);
|
|
||||||
|
|
||||||
// 8. 存储新 refresh token 到 DB
|
// 查询账号信息以返回完整 LoginResponse
|
||||||
let new_claims = verify_token(&new_refresh, state.jwt_secret.expose_secret())?;
|
let row = sqlx::query_as::<_, (String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)>(
|
||||||
let new_jti = new_claims.jti.unwrap_or_default();
|
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||||
let new_id = uuid::Uuid::new_v4().to_string();
|
FROM accounts WHERE id = $1"
|
||||||
let refresh_expires = (chrono::Utc::now() + chrono::Duration::hours(168)).to_rfc3339();
|
)
|
||||||
sqlx::query(
|
.bind(&ctx.account_id)
|
||||||
"INSERT INTO refresh_tokens (id, account_id, jti, token_hash, expires_at, created_at)
|
.fetch_optional(&state.db)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)"
|
.await?
|
||||||
)
|
.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
|
||||||
.bind(&new_id).bind(&claims.sub).bind(&new_jti)
|
|
||||||
.bind(sha256_hex(&new_refresh)).bind(&refresh_expires).bind(&now)
|
|
||||||
.execute(&state.db).await?;
|
|
||||||
|
|
||||||
// 9. 清理过期/已使用的 refresh tokens 已迁移到 Scheduler 定期执行
|
let (id, username, email, display_name, role, status, totp_enabled, created_at) = row;
|
||||||
// 不再在每次 refresh 时阻塞请求
|
let created_at = created_at.to_rfc3339();
|
||||||
|
Ok(Json(LoginResponse {
|
||||||
Ok(Json(serde_json::json!({
|
token,
|
||||||
"token": new_access,
|
account: AccountPublic { id, username, email, display_name, role, permissions: ctx.permissions, status, totp_enabled, created_at },
|
||||||
"refresh_token": new_refresh,
|
}))
|
||||||
})))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
|
/// GET /api/v1/auth/me — 返回当前认证用户的公开信息
|
||||||
@@ -293,7 +206,7 @@ pub async fn me(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
axum::extract::Extension(ctx): axum::extract::Extension<AuthContext>,
|
||||||
) -> SaasResult<Json<AccountPublic>> {
|
) -> SaasResult<Json<AccountPublic>> {
|
||||||
let row: Option<AccountAuthRow> =
|
let row: Option<(String, String, String, String, String, String, bool, chrono::DateTime<chrono::Utc>)> =
|
||||||
sqlx::query_as(
|
sqlx::query_as(
|
||||||
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
"SELECT id, username, email, display_name, role, status, totp_enabled, created_at
|
||||||
FROM accounts WHERE id = $1"
|
FROM accounts WHERE id = $1"
|
||||||
@@ -302,11 +215,12 @@ pub async fn me(
|
|||||||
.fetch_optional(&state.db)
|
.fetch_optional(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let r = row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
|
let (id, username, email, display_name, role, status, totp_enabled, created_at) =
|
||||||
|
row.ok_or_else(|| SaasError::NotFound("账号不存在".into()))?;
|
||||||
|
let created_at = created_at.to_rfc3339();
|
||||||
|
|
||||||
Ok(Json(AccountPublic {
|
Ok(Json(AccountPublic {
|
||||||
id: r.id, username: r.username, email: r.email, display_name: r.display_name,
|
id, username, email, display_name, role, permissions: ctx.permissions, status, totp_enabled, created_at,
|
||||||
role: r.role, status: r.status, totp_enabled: r.totp_enabled, created_at: r.created_at,
|
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -329,16 +243,16 @@ pub async fn change_password(
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// 验证旧密码
|
// 验证旧密码
|
||||||
if !verify_password_async(req.old_password.clone(), password_hash.clone()).await? {
|
if !verify_password(&req.old_password, &password_hash)? {
|
||||||
return Err(SaasError::AuthError("旧密码错误".into()));
|
return Err(SaasError::AuthError("旧密码错误".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 更新密码
|
// 更新密码
|
||||||
let new_hash = hash_password_async(req.new_password.clone()).await?;
|
let new_hash = hash_password(&req.new_password)?;
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now();
|
||||||
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2 WHERE id = $3")
|
sqlx::query("UPDATE accounts SET password_hash = $1, updated_at = $2 WHERE id = $3")
|
||||||
.bind(&new_hash)
|
.bind(&new_hash)
|
||||||
.bind(&now)
|
.bind(now)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -349,16 +263,7 @@ pub async fn change_password(
|
|||||||
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
|
Ok(Json(serde_json::json!({"ok": true, "message": "密码修改成功"})))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get_role_permissions(
|
pub(crate) async fn get_role_permissions(db: &sqlx::PgPool, role: &str) -> SaasResult<Vec<String>> {
|
||||||
db: &sqlx::PgPool,
|
|
||||||
cache: &dashmap::DashMap<String, Vec<String>>,
|
|
||||||
role: &str,
|
|
||||||
) -> SaasResult<Vec<String>> {
|
|
||||||
// Check cache first
|
|
||||||
if let Some(cached) = cache.get(role) {
|
|
||||||
return Ok(cached.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
let row: Option<(String,)> = sqlx::query_as(
|
let row: Option<(String,)> = sqlx::query_as(
|
||||||
"SELECT permissions FROM roles WHERE id = $1"
|
"SELECT permissions FROM roles WHERE id = $1"
|
||||||
)
|
)
|
||||||
@@ -367,11 +272,10 @@ pub(crate) async fn get_role_permissions(
|
|||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let permissions_str = row
|
let permissions_str = row
|
||||||
.ok_or_else(|| SaasError::Internal(format!("角色 {} 不存在", role)))?
|
.ok_or_else(|| SaasError::Forbidden(format!("角色 {} 不存在或无权限", role)))?
|
||||||
.0;
|
.0;
|
||||||
|
|
||||||
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
|
let permissions: Vec<String> = serde_json::from_str(&permissions_str)?;
|
||||||
cache.insert(role.to_string(), permissions.clone());
|
|
||||||
Ok(permissions)
|
Ok(permissions)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -396,7 +300,7 @@ pub async fn log_operation(
|
|||||||
details: Option<serde_json::Value>,
|
details: Option<serde_json::Value>,
|
||||||
ip_address: Option<&str>,
|
ip_address: Option<&str>,
|
||||||
) -> SaasResult<()> {
|
) -> SaasResult<()> {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now();
|
||||||
sqlx::query(
|
sqlx::query(
|
||||||
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
|
"INSERT INTO operation_logs (account_id, action, target_type, target_id, details, ip_address, created_at)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||||
@@ -407,52 +311,54 @@ pub async fn log_operation(
|
|||||||
.bind(target_id)
|
.bind(target_id)
|
||||||
.bind(details.map(|d| d.to_string()))
|
.bind(details.map(|d| d.to_string()))
|
||||||
.bind(ip_address)
|
.bind(ip_address)
|
||||||
.bind(&now)
|
.bind(now)
|
||||||
.execute(db)
|
.execute(db)
|
||||||
.await?;
|
.await?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 存储 refresh token 到 DB
|
#[cfg(test)]
|
||||||
async fn store_refresh_token(
|
mod tests {
|
||||||
db: &sqlx::PgPool,
|
use super::*;
|
||||||
account_id: &str,
|
use crate::auth::types::AuthContext;
|
||||||
refresh_token: &str,
|
|
||||||
secret: &str,
|
|
||||||
refresh_hours: i64,
|
|
||||||
) -> SaasResult<()> {
|
|
||||||
let claims = verify_token(refresh_token, secret)?;
|
|
||||||
let jti = claims.jti.unwrap_or_default();
|
|
||||||
let id = uuid::Uuid::new_v4().to_string();
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
|
||||||
let expires_at = (chrono::Utc::now() + chrono::Duration::hours(refresh_hours)).to_rfc3339();
|
|
||||||
|
|
||||||
sqlx::query(
|
fn ctx(permissions: Vec<&str>) -> AuthContext {
|
||||||
"INSERT INTO refresh_tokens (id, account_id, jti, token_hash, expires_at, created_at)
|
AuthContext {
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)"
|
account_id: "test-id".into(),
|
||||||
)
|
role: "user".into(),
|
||||||
.bind(&id).bind(account_id).bind(&jti)
|
permissions: permissions.into_iter().map(String::from).collect(),
|
||||||
.bind(sha256_hex(refresh_token)).bind(&expires_at).bind(&now)
|
client_ip: None,
|
||||||
.execute(db).await?;
|
}
|
||||||
Ok(())
|
}
|
||||||
}
|
|
||||||
|
|
||||||
/// 清理过期和已使用的 refresh tokens
|
#[test]
|
||||||
/// 注意: 现已迁移到 Worker/Scheduler 定期执行,此函数保留作为备用
|
fn test_check_permission_admin_full() {
|
||||||
#[allow(dead_code)]
|
let c = ctx(vec!["admin:full"]);
|
||||||
async fn cleanup_expired_refresh_tokens(db: &sqlx::PgPool) -> SaasResult<()> {
|
assert!(check_permission(&c, "config:write").is_ok());
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
assert!(check_permission(&c, "account:admin").is_ok());
|
||||||
// 删除过期超过 30 天的已使用 token (减少 DB 膨胀)
|
assert!(check_permission(&c, "any:permission").is_ok());
|
||||||
sqlx::query(
|
}
|
||||||
"DELETE FROM refresh_tokens WHERE (used_at IS NOT NULL AND used_at < $1) OR (expires_at < $1)"
|
|
||||||
)
|
|
||||||
.bind(&now)
|
|
||||||
.execute(db).await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// SHA-256 hex digest
|
#[test]
|
||||||
fn sha256_hex(input: &str) -> String {
|
fn test_check_permission_has_permission() {
|
||||||
use sha2::{Sha256, Digest};
|
let c = ctx(vec!["config:write", "model:read"]);
|
||||||
hex::encode(Sha256::digest(input.as_bytes()))
|
assert!(check_permission(&c, "config:write").is_ok());
|
||||||
|
assert!(check_permission(&c, "model:read").is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_check_permission_missing() {
|
||||||
|
let c = ctx(vec!["model:read"]);
|
||||||
|
let result = check_permission(&c, "config:write");
|
||||||
|
assert!(result.is_err());
|
||||||
|
let err = result.unwrap_err().to_string();
|
||||||
|
assert!(err.contains("config:write"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_check_permission_empty_list() {
|
||||||
|
let c = ctx(vec![]);
|
||||||
|
assert!(check_permission(&c, "config:write").is_err());
|
||||||
|
assert!(check_permission(&c, "admin:full").is_err());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,52 +9,34 @@ use crate::error::SaasResult;
|
|||||||
/// JWT Claims
|
/// JWT Claims
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct Claims {
|
pub struct Claims {
|
||||||
/// JWT ID — 唯一标识,用于 token 追踪和吊销
|
|
||||||
pub jti: Option<String>,
|
|
||||||
pub sub: String,
|
pub sub: String,
|
||||||
|
pub aud: String,
|
||||||
|
pub iss: String,
|
||||||
pub role: String,
|
pub role: String,
|
||||||
pub permissions: Vec<String>,
|
pub permissions: Vec<String>,
|
||||||
/// token 类型: "access" 或 "refresh"
|
|
||||||
#[serde(default = "default_token_type")]
|
|
||||||
pub token_type: String,
|
|
||||||
pub iat: i64,
|
pub iat: i64,
|
||||||
pub exp: i64,
|
pub exp: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_token_type() -> String {
|
const JWT_AUDIENCE: &str = "zclaw-saas";
|
||||||
"access".to_string()
|
const JWT_ISSUER: &str = "zclaw-saas";
|
||||||
}
|
|
||||||
|
|
||||||
impl Claims {
|
impl Claims {
|
||||||
pub fn new_access(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self {
|
pub fn new(account_id: &str, role: &str, permissions: Vec<String>, expiration_hours: i64) -> Self {
|
||||||
let now = Utc::now();
|
let now = Utc::now();
|
||||||
Self {
|
Self {
|
||||||
jti: Some(uuid::Uuid::new_v4().to_string()),
|
|
||||||
sub: account_id.to_string(),
|
sub: account_id.to_string(),
|
||||||
|
aud: JWT_AUDIENCE.to_string(),
|
||||||
|
iss: JWT_ISSUER.to_string(),
|
||||||
role: role.to_string(),
|
role: role.to_string(),
|
||||||
permissions,
|
permissions,
|
||||||
token_type: "access".to_string(),
|
|
||||||
iat: now.timestamp(),
|
iat: now.timestamp(),
|
||||||
exp: (now + Duration::hours(expiration_hours)).timestamp(),
|
exp: (now + Duration::hours(expiration_hours)).timestamp(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 创建 refresh token claims (有效期更长,用于一次性刷新)
|
|
||||||
pub fn new_refresh(account_id: &str, role: &str, permissions: Vec<String>, refresh_hours: i64) -> Self {
|
|
||||||
let now = Utc::now();
|
|
||||||
Self {
|
|
||||||
jti: Some(uuid::Uuid::new_v4().to_string()),
|
|
||||||
sub: account_id.to_string(),
|
|
||||||
role: role.to_string(),
|
|
||||||
permissions,
|
|
||||||
token_type: "refresh".to_string(),
|
|
||||||
iat: now.timestamp(),
|
|
||||||
exp: (now + Duration::hours(refresh_hours)).timestamp(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 创建 Access JWT Token
|
/// 创建 JWT Token
|
||||||
pub fn create_token(
|
pub fn create_token(
|
||||||
account_id: &str,
|
account_id: &str,
|
||||||
role: &str,
|
role: &str,
|
||||||
@@ -62,24 +44,7 @@ pub fn create_token(
|
|||||||
secret: &str,
|
secret: &str,
|
||||||
expiration_hours: i64,
|
expiration_hours: i64,
|
||||||
) -> SaasResult<String> {
|
) -> SaasResult<String> {
|
||||||
let claims = Claims::new_access(account_id, role, permissions, expiration_hours);
|
let claims = Claims::new(account_id, role, permissions, expiration_hours);
|
||||||
let token = encode(
|
|
||||||
&Header::default(),
|
|
||||||
&claims,
|
|
||||||
&EncodingKey::from_secret(secret.as_bytes()),
|
|
||||||
)?;
|
|
||||||
Ok(token)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 创建 Refresh JWT Token (独立 jti,有效期更长)
|
|
||||||
pub fn create_refresh_token(
|
|
||||||
account_id: &str,
|
|
||||||
role: &str,
|
|
||||||
permissions: Vec<String>,
|
|
||||||
secret: &str,
|
|
||||||
refresh_hours: i64,
|
|
||||||
) -> SaasResult<String> {
|
|
||||||
let claims = Claims::new_refresh(account_id, role, permissions, refresh_hours);
|
|
||||||
let token = encode(
|
let token = encode(
|
||||||
&Header::default(),
|
&Header::default(),
|
||||||
&claims,
|
&claims,
|
||||||
@@ -90,60 +55,18 @@ pub fn create_refresh_token(
|
|||||||
|
|
||||||
/// 验证 JWT Token
|
/// 验证 JWT Token
|
||||||
pub fn verify_token(token: &str, secret: &str) -> SaasResult<Claims> {
|
pub fn verify_token(token: &str, secret: &str) -> SaasResult<Claims> {
|
||||||
let token_data = decode::<Claims>(
|
|
||||||
token,
|
|
||||||
&DecodingKey::from_secret(secret.as_bytes()),
|
|
||||||
&Validation::default(),
|
|
||||||
)?;
|
|
||||||
Ok(token_data.claims)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 验证 JWT Token 但跳过过期检查(仅用于 refresh token 刷新)
|
|
||||||
/// 限制: 原始 token 的 iat 必须在 7 天内
|
|
||||||
pub fn verify_token_skip_expiry(token: &str, secret: &str) -> SaasResult<Claims> {
|
|
||||||
let mut validation = Validation::default();
|
let mut validation = Validation::default();
|
||||||
validation.validate_exp = false;
|
validation.set_audience(&[JWT_AUDIENCE]);
|
||||||
|
validation.set_issuer(&[JWT_ISSUER]);
|
||||||
|
|
||||||
let token_data = decode::<Claims>(
|
let token_data = decode::<Claims>(
|
||||||
token,
|
token,
|
||||||
&DecodingKey::from_secret(secret.as_bytes()),
|
&DecodingKey::from_secret(secret.as_bytes()),
|
||||||
&validation,
|
&validation,
|
||||||
)?;
|
)?;
|
||||||
let claims = &token_data.claims;
|
|
||||||
|
|
||||||
// 限制刷新窗口: token 签发时间必须在 7 天内
|
|
||||||
let now = Utc::now().timestamp();
|
|
||||||
let max_refresh_window = 7 * 24 * 3600; // 7 天
|
|
||||||
if now - claims.iat > max_refresh_window {
|
|
||||||
return Err(jsonwebtoken::errors::Error::from(
|
|
||||||
jsonwebtoken::errors::ErrorKind::ExpiredSignature
|
|
||||||
).into());
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(token_data.claims)
|
Ok(token_data.claims)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Token 对: access token + refresh token
|
|
||||||
#[derive(Debug, serde::Serialize)]
|
|
||||||
pub struct TokenPair {
|
|
||||||
pub access_token: String,
|
|
||||||
pub refresh_token: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 创建 access + refresh token 对
|
|
||||||
pub fn create_token_pair(
|
|
||||||
account_id: &str,
|
|
||||||
role: &str,
|
|
||||||
permissions: Vec<String>,
|
|
||||||
secret: &str,
|
|
||||||
access_hours: i64,
|
|
||||||
refresh_hours: i64,
|
|
||||||
) -> SaasResult<TokenPair> {
|
|
||||||
Ok(TokenPair {
|
|
||||||
access_token: create_token(account_id, role, permissions.clone(), secret, access_hours)?,
|
|
||||||
refresh_token: create_refresh_token(account_id, role, permissions, secret, refresh_hours)?,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@@ -162,8 +85,6 @@ mod tests {
|
|||||||
assert_eq!(claims.sub, "account-123");
|
assert_eq!(claims.sub, "account-123");
|
||||||
assert_eq!(claims.role, "admin");
|
assert_eq!(claims.role, "admin");
|
||||||
assert_eq!(claims.permissions, vec!["model:read"]);
|
assert_eq!(claims.permissions, vec!["model:read"]);
|
||||||
assert!(claims.jti.is_some());
|
|
||||||
assert_eq!(claims.token_type, "access");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -178,17 +99,4 @@ mod tests {
|
|||||||
let result = verify_token(&token, "wrong-secret");
|
let result = verify_token(&token, "wrong-secret");
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_refresh_token_has_different_jti() {
|
|
||||||
let access = create_token("acct-1", "user", vec![], TEST_SECRET, 1).unwrap();
|
|
||||||
let refresh = create_refresh_token("acct-1", "user", vec![], TEST_SECRET, 168).unwrap();
|
|
||||||
|
|
||||||
let access_claims = verify_token(&access, TEST_SECRET).unwrap();
|
|
||||||
let refresh_claims = verify_token(&refresh, TEST_SECRET).unwrap();
|
|
||||||
|
|
||||||
assert_ne!(access_claims.jti, refresh_claims.jti);
|
|
||||||
assert_eq!(access_claims.token_type, "access");
|
|
||||||
assert_eq!(refresh_claims.token_type, "refresh");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
|||||||
.ok_or(SaasError::Unauthorized)?;
|
.ok_or(SaasError::Unauthorized)?;
|
||||||
|
|
||||||
// 合并 token 权限与角色权限(去重)
|
// 合并 token 权限与角色权限(去重)
|
||||||
let role_permissions = handlers::get_role_permissions(&state.db, &state.role_permissions_cache, &role).await?;
|
let role_permissions = handlers::get_role_permissions(&state.db, &role).await?;
|
||||||
let token_permissions: Vec<String> = serde_json::from_str(&permissions_json).unwrap_or_default();
|
let token_permissions: Vec<String> = serde_json::from_str(&permissions_json).unwrap_or_default();
|
||||||
let mut permissions = role_permissions;
|
let mut permissions = role_permissions;
|
||||||
for p in token_permissions {
|
for p in token_permissions {
|
||||||
@@ -70,9 +70,9 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
|||||||
// 异步更新 last_used_at(不阻塞请求)
|
// 异步更新 last_used_at(不阻塞请求)
|
||||||
let db = state.db.clone();
|
let db = state.db.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now();
|
||||||
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
|
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
|
||||||
.bind(&now).bind(&token_hash)
|
.bind(now).bind(&token_hash)
|
||||||
.execute(&db).await;
|
.execute(&db).await;
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -84,23 +84,11 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 从请求中提取客户端 IP
|
/// 从请求中提取客户端 IP(仅信任直连 IP,不信任可伪造的 proxy header)
|
||||||
fn extract_client_ip(req: &Request) -> Option<String> {
|
fn extract_client_ip(req: &Request) -> Option<String> {
|
||||||
// 优先从 ConnectInfo 获取
|
req.extensions()
|
||||||
if let Some(ConnectInfo(addr)) = req.extensions().get::<ConnectInfo<SocketAddr>>() {
|
.get::<ConnectInfo<SocketAddr>>()
|
||||||
return Some(addr.ip().to_string());
|
.map(|addr| addr.ip().to_string())
|
||||||
}
|
|
||||||
// 回退到 X-Forwarded-For / X-Real-IP
|
|
||||||
if let Some(forwarded) = req.headers()
|
|
||||||
.get("x-forwarded-for")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
{
|
|
||||||
return Some(forwarded.split(',').next()?.trim().to_string());
|
|
||||||
}
|
|
||||||
req.headers()
|
|
||||||
.get("x-real-ip")
|
|
||||||
.and_then(|v| v.to_str().ok())
|
|
||||||
.map(|s| s.to_string())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 认证中间件: 从 JWT 或 API Token 提取身份
|
/// 认证中间件: 从 JWT 或 API Token 提取身份
|
||||||
@@ -121,8 +109,7 @@ pub async fn auth_middleware(
|
|||||||
verify_api_token(&state, token, client_ip.clone()).await
|
verify_api_token(&state, token, client_ip.clone()).await
|
||||||
} else {
|
} else {
|
||||||
// JWT 路径
|
// JWT 路径
|
||||||
let verify_result = jwt::verify_token(token, state.jwt_secret.expose_secret());
|
jwt::verify_token(token, state.jwt_secret.expose_secret())
|
||||||
verify_result
|
|
||||||
.map(|claims| AuthContext {
|
.map(|claims| AuthContext {
|
||||||
account_id: claims.sub,
|
account_id: claims.sub,
|
||||||
role: claims.role,
|
role: claims.role,
|
||||||
@@ -154,7 +141,6 @@ pub fn routes() -> axum::Router<AppState> {
|
|||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
.route("/api/v1/auth/register", post(handlers::register))
|
.route("/api/v1/auth/register", post(handlers::register))
|
||||||
.route("/api/v1/auth/login", post(handlers::login))
|
.route("/api/v1/auth/login", post(handlers::login))
|
||||||
.route("/api/v1/auth/refresh", post(handlers::refresh))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 需要认证的路由
|
/// 需要认证的路由
|
||||||
@@ -162,6 +148,7 @@ pub fn protected_routes() -> axum::Router<AppState> {
|
|||||||
use axum::routing::{get, post, put};
|
use axum::routing::{get, post, put};
|
||||||
|
|
||||||
axum::Router::new()
|
axum::Router::new()
|
||||||
|
.route("/api/v1/auth/refresh", post(handlers::refresh))
|
||||||
.route("/api/v1/auth/me", get(handlers::me))
|
.route("/api/v1/auth/me", get(handlers::me))
|
||||||
.route("/api/v1/auth/password", put(handlers::change_password))
|
.route("/api/v1/auth/password", put(handlers::change_password))
|
||||||
.route("/api/v1/auth/totp/setup", post(totp::setup_totp))
|
.route("/api/v1/auth/totp/setup", post(totp::setup_totp))
|
||||||
|
|||||||
@@ -1,8 +1,4 @@
|
|||||||
//! 密码哈希 (Argon2id)
|
//! 密码哈希 (Argon2id)
|
||||||
//!
|
|
||||||
//! Argon2 是 CPU 密集型操作(~100-500ms),不能在 tokio worker 线程上直接执行,
|
|
||||||
//! 否则会阻塞整个异步运行时。所有 async 上下文必须使用 `hash_password_async`
|
|
||||||
//! 和 `verify_password_async`。
|
|
||||||
|
|
||||||
use argon2::{
|
use argon2::{
|
||||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||||
@@ -11,7 +7,7 @@ use argon2::{
|
|||||||
|
|
||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
|
|
||||||
/// 哈希密码(同步版本,仅用于测试和启动时 seed)
|
/// 哈希密码
|
||||||
pub fn hash_password(password: &str) -> SaasResult<String> {
|
pub fn hash_password(password: &str) -> SaasResult<String> {
|
||||||
let salt = SaltString::generate(&mut OsRng);
|
let salt = SaltString::generate(&mut OsRng);
|
||||||
let argon2 = Argon2::default();
|
let argon2 = Argon2::default();
|
||||||
@@ -21,7 +17,7 @@ pub fn hash_password(password: &str) -> SaasResult<String> {
|
|||||||
Ok(hash.to_string())
|
Ok(hash.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 验证密码(同步版本,仅用于测试)
|
/// 验证密码
|
||||||
pub fn verify_password(password: &str, hash: &str) -> SaasResult<bool> {
|
pub fn verify_password(password: &str, hash: &str) -> SaasResult<bool> {
|
||||||
let parsed_hash = PasswordHash::new(hash)
|
let parsed_hash = PasswordHash::new(hash)
|
||||||
.map_err(|e| SaasError::PasswordHash(e.to_string()))?;
|
.map_err(|e| SaasError::PasswordHash(e.to_string()))?;
|
||||||
@@ -30,20 +26,6 @@ pub fn verify_password(password: &str, hash: &str) -> SaasResult<bool> {
|
|||||||
.is_ok())
|
.is_ok())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 异步哈希密码 — 在 spawn_blocking 线程池中执行 Argon2
|
|
||||||
pub async fn hash_password_async(password: String) -> SaasResult<String> {
|
|
||||||
tokio::task::spawn_blocking(move || hash_password(&password))
|
|
||||||
.await
|
|
||||||
.map_err(|e| SaasError::Internal(format!("spawn_blocking error: {e}")))?
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 异步验证密码 — 在 spawn_blocking 线程池中执行 Argon2
|
|
||||||
pub async fn verify_password_async(password: String, hash: String) -> SaasResult<bool> {
|
|
||||||
tokio::task::spawn_blocking(move || verify_password(&password, &hash))
|
|
||||||
.await
|
|
||||||
.map_err(|e| SaasError::Internal(format!("spawn_blocking error: {e}")))?
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ use crate::state::AppState;
|
|||||||
use crate::error::{SaasError, SaasResult};
|
use crate::error::{SaasError, SaasResult};
|
||||||
use crate::auth::types::AuthContext;
|
use crate::auth::types::AuthContext;
|
||||||
use crate::auth::handlers::log_operation;
|
use crate::auth::handlers::log_operation;
|
||||||
use crate::crypto;
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// TOTP 设置响应
|
/// TOTP 设置响应
|
||||||
@@ -47,21 +46,6 @@ fn base32_decode(data: &str) -> Option<Vec<u8>> {
|
|||||||
data_encoding::BASE32.decode(data.as_bytes()).ok()
|
data_encoding::BASE32.decode(data.as_bytes()).ok()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 加密 TOTP secret (AES-256-GCM,随机 nonce)
|
|
||||||
/// 存储格式: enc:<base64(nonce||ciphertext)>
|
|
||||||
/// 委托给 crypto::encrypt_value 统一加密
|
|
||||||
fn encrypt_totp_secret(plaintext: &str, key: &[u8; 32]) -> Result<String, SaasError> {
|
|
||||||
crate::crypto::encrypt_value(plaintext, key)
|
|
||||||
.map_err(|e| SaasError::Internal(e.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 解密 TOTP secret (仅支持新格式: 随机 nonce)
|
|
||||||
/// 旧的固定 nonce 格式应通过启动时迁移转换。
|
|
||||||
fn decrypt_totp_secret(encrypted: &str, key: &[u8; 32]) -> Result<String, SaasError> {
|
|
||||||
crate::crypto::decrypt_value(encrypted, key)
|
|
||||||
.map_err(|e| SaasError::Internal(e.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 生成 TOTP 密钥并返回 otpauth URI
|
/// 生成 TOTP 密钥并返回 otpauth URI
|
||||||
pub fn generate_totp_secret(issuer: &str, account_name: &str) -> TotpSetupResponse {
|
pub fn generate_totp_secret(issuer: &str, account_name: &str) -> TotpSetupResponse {
|
||||||
let secret = generate_random_secret();
|
let secret = generate_random_secret();
|
||||||
@@ -119,11 +103,8 @@ pub async fn setup_totp(
|
|||||||
let config = state.config.read().await;
|
let config = state.config.read().await;
|
||||||
let setup = generate_totp_secret(&config.auth.totp_issuer, &username);
|
let setup = generate_totp_secret(&config.auth.totp_issuer, &username);
|
||||||
|
|
||||||
// 加密后存储密钥 (但不启用,需要 /verify 确认)
|
// 加密 TOTP 密钥后存储 (但不启用,需要 /verify 确认)
|
||||||
let enc_key = config.totp_encryption_key()
|
let encrypted_secret = state.field_encryption.encrypt(&setup.secret)?;
|
||||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
|
||||||
let encrypted_secret = encrypt_totp_secret(&setup.secret, &enc_key)?;
|
|
||||||
|
|
||||||
sqlx::query("UPDATE accounts SET totp_secret = $1 WHERE id = $2")
|
sqlx::query("UPDATE accounts SET totp_secret = $1 WHERE id = $2")
|
||||||
.bind(&encrypted_secret)
|
.bind(&encrypted_secret)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
@@ -156,37 +137,21 @@ pub async fn verify_totp(
|
|||||||
.fetch_one(&state.db)
|
.fetch_one(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let encrypted_secret = totp_secret.ok_or_else(|| {
|
let secret = totp_secret.ok_or_else(|| {
|
||||||
SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into())
|
SaasError::InvalidInput("请先调用 /totp/setup 获取密钥".into())
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// 解密 secret (兼容旧的明文格式)
|
// 解密 TOTP 密钥(兼容迁移期间的明文数据)
|
||||||
let config = state.config.read().await;
|
let decrypted_secret = state.field_encryption.decrypt_or_plaintext(&secret);
|
||||||
let enc_key = config.totp_encryption_key()
|
|
||||||
.map_err(|e| SaasError::Internal(e.to_string()))?;
|
|
||||||
let secret = if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
|
|
||||||
decrypt_totp_secret(&encrypted_secret, &enc_key)?
|
|
||||||
} else {
|
|
||||||
// 旧格式: 明文存储,需要迁移
|
|
||||||
encrypted_secret.clone()
|
|
||||||
};
|
|
||||||
|
|
||||||
if !verify_totp_code(&secret, code) {
|
if !verify_totp_code(&decrypted_secret, code) {
|
||||||
return Err(SaasError::Totp("TOTP 码验证失败".into()));
|
return Err(SaasError::Totp("TOTP 码验证失败".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证成功 → 启用 TOTP,同时确保密钥已加密
|
// 验证成功 → 启用 TOTP
|
||||||
let final_secret = if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
|
let now = chrono::Utc::now();
|
||||||
encrypted_secret
|
sqlx::query("UPDATE accounts SET totp_enabled = true, updated_at = $1 WHERE id = $2")
|
||||||
} else {
|
.bind(now)
|
||||||
// 迁移: 加密旧明文密钥
|
|
||||||
encrypt_totp_secret(&secret, &enc_key)?
|
|
||||||
};
|
|
||||||
|
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
|
||||||
sqlx::query("UPDATE accounts SET totp_enabled = true, totp_secret = $1, updated_at = $2 WHERE id = $3")
|
|
||||||
.bind(&final_secret)
|
|
||||||
.bind(&now)
|
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -212,14 +177,14 @@ pub async fn disable_totp(
|
|||||||
.fetch_one(&state.db)
|
.fetch_one(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
if !crate::auth::password::verify_password_async(req.password.clone(), password_hash.clone()).await? {
|
if !crate::auth::password::verify_password(&req.password, &password_hash)? {
|
||||||
return Err(SaasError::AuthError("密码错误".into()));
|
return Err(SaasError::AuthError("密码错误".into()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// 清除 TOTP
|
// 清除 TOTP
|
||||||
let now = chrono::Utc::now().to_rfc3339();
|
let now = chrono::Utc::now();
|
||||||
sqlx::query("UPDATE accounts SET totp_enabled = false, totp_secret = NULL, updated_at = $1 WHERE id = $2")
|
sqlx::query("UPDATE accounts SET totp_enabled = false, totp_secret = NULL, updated_at = $1 WHERE id = $2")
|
||||||
.bind(&now)
|
.bind(now)
|
||||||
.bind(&ctx.account_id)
|
.bind(&ctx.account_id)
|
||||||
.execute(&state.db)
|
.execute(&state.db)
|
||||||
.await?;
|
.await?;
|
||||||
@@ -230,13 +195,64 @@ pub async fn disable_totp(
|
|||||||
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"})))
|
Ok(Json(serde_json::json!({"ok": true, "totp_enabled": false, "message": "TOTP 已禁用"})))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 解密 TOTP secret (供 login handler 使用)
|
#[cfg(test)]
|
||||||
/// 返回解密后的明文 secret
|
mod tests {
|
||||||
pub fn decrypt_totp_for_login(encrypted_secret: &str, enc_key: &[u8; 32]) -> SaasResult<String> {
|
use super::*;
|
||||||
if encrypted_secret.starts_with(crypto::ENCRYPTED_PREFIX) {
|
|
||||||
decrypt_totp_secret(encrypted_secret, enc_key)
|
#[test]
|
||||||
} else {
|
fn test_generate_totp_secret_format() {
|
||||||
// 兼容旧的明文格式
|
let result = generate_totp_secret("TestIssuer", "user@example.com");
|
||||||
Ok(encrypted_secret.to_string())
|
assert!(result.otpauth_uri.starts_with("otpauth://totp/"));
|
||||||
|
assert!(result.otpauth_uri.contains("secret="));
|
||||||
|
assert!(result.otpauth_uri.contains("issuer=TestIssuer"));
|
||||||
|
assert!(result.otpauth_uri.contains("algorithm=SHA1"));
|
||||||
|
assert!(result.otpauth_uri.contains("digits=6"));
|
||||||
|
assert!(result.otpauth_uri.contains("period=30"));
|
||||||
|
// Base32 编码的 20 字节 = 32 字符
|
||||||
|
assert_eq!(result.secret.len(), 32);
|
||||||
|
assert_eq!(result.issuer, "TestIssuer");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_totp_secret_special_chars() {
|
||||||
|
let result = generate_totp_secret("My App", "user@domain:8080");
|
||||||
|
// 特殊字符应被 URL 编码
|
||||||
|
assert!(!result.otpauth_uri.contains("user@domain:8080"));
|
||||||
|
assert!(result.otpauth_uri.contains("user%40domain"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_verify_totp_code_valid() {
|
||||||
|
// 使用 generate_random_secret 创建合法 secret,然后生成并验证码
|
||||||
|
let secret = generate_random_secret();
|
||||||
|
let secret_bytes = data_encoding::BASE32.decode(secret.as_bytes()).unwrap();
|
||||||
|
let totp = totp_rs::TOTP::new(
|
||||||
|
totp_rs::Algorithm::SHA1, 6, 1, 30, secret_bytes,
|
||||||
|
).unwrap();
|
||||||
|
let valid_code = totp.generate(chrono::Utc::now().timestamp() as u64);
|
||||||
|
assert!(verify_totp_code(&secret, &valid_code));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_verify_totp_code_invalid() {
|
||||||
|
let secret = generate_random_secret();
|
||||||
|
assert!(!verify_totp_code(&secret, "000000"));
|
||||||
|
assert!(!verify_totp_code(&secret, "999999"));
|
||||||
|
assert!(!verify_totp_code(&secret, "abcdef"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_verify_totp_code_invalid_secret() {
|
||||||
|
assert!(!verify_totp_code("not-valid-base32!!!", "123456"));
|
||||||
|
assert!(!verify_totp_code("", "123456"));
|
||||||
|
assert!(!verify_totp_code("短", "123456"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_verify_totp_code_empty() {
|
||||||
|
let secret = "JBSWY3DPEHPK3PXP";
|
||||||
|
assert!(!verify_totp_code(secret, ""));
|
||||||
|
assert!(!verify_totp_code(secret, "12345"));
|
||||||
|
assert!(!verify_totp_code(secret, "1234567"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// 登录请求
|
/// 登录请求
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||||
pub struct LoginRequest {
|
pub struct LoginRequest {
|
||||||
pub username: String,
|
pub username: String,
|
||||||
pub password: String,
|
pub password: String,
|
||||||
@@ -11,15 +11,14 @@ pub struct LoginRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 登录响应
|
/// 登录响应
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, utoipa::ToSchema)]
|
||||||
pub struct LoginResponse {
|
pub struct LoginResponse {
|
||||||
pub token: String,
|
pub token: String,
|
||||||
pub refresh_token: String,
|
|
||||||
pub account: AccountPublic,
|
pub account: AccountPublic,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 注册请求
|
/// 注册请求
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||||
pub struct RegisterRequest {
|
pub struct RegisterRequest {
|
||||||
pub username: String,
|
pub username: String,
|
||||||
pub email: String,
|
pub email: String,
|
||||||
@@ -28,20 +27,21 @@ pub struct RegisterRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// 修改密码请求
|
/// 修改密码请求
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize, utoipa::ToSchema)]
|
||||||
pub struct ChangePasswordRequest {
|
pub struct ChangePasswordRequest {
|
||||||
pub old_password: String,
|
pub old_password: String,
|
||||||
pub new_password: String,
|
pub new_password: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// 公开账号信息 (无敏感数据)
|
/// 公开账号信息 (无敏感数据)
|
||||||
#[derive(Debug, Clone, Serialize)]
|
#[derive(Debug, Clone, Serialize, utoipa::ToSchema)]
|
||||||
pub struct AccountPublic {
|
pub struct AccountPublic {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub username: String,
|
pub username: String,
|
||||||
pub email: String,
|
pub email: String,
|
||||||
pub display_name: String,
|
pub display_name: String,
|
||||||
pub role: String,
|
pub role: String,
|
||||||
|
pub permissions: Vec<String>,
|
||||||
pub status: String,
|
pub status: String,
|
||||||
pub totp_enabled: bool,
|
pub totp_enabled: bool,
|
||||||
pub created_at: String,
|
pub created_at: String,
|
||||||
@@ -55,9 +55,3 @@ pub struct AuthContext {
|
|||||||
pub permissions: Vec<String>,
|
pub permissions: Vec<String>,
|
||||||
pub client_ip: Option<String>,
|
pub client_ip: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Token 刷新请求
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
pub struct RefreshRequest {
|
|
||||||
pub refresh_token: String,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
//! 公共类型和工具函数
|
|
||||||
|
|
||||||
use serde::Serialize;
|
|
||||||
|
|
||||||
/// 分页响应通用包装
|
|
||||||
#[derive(Debug, Serialize)]
|
|
||||||
pub struct PaginatedResponse<T: Serialize> {
|
|
||||||
pub items: Vec<T>,
|
|
||||||
pub total: i64,
|
|
||||||
pub page: u32,
|
|
||||||
pub page_size: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// 分页上限
|
|
||||||
pub const MAX_PAGE_SIZE: u32 = 100;
|
|
||||||
|
|
||||||
/// 默认分页大小
|
|
||||||
pub const DEFAULT_PAGE_SIZE: u32 = 20;
|
|
||||||
|
|
||||||
/// 规范化分页参数,返回 (page, page_size, offset)
|
|
||||||
pub fn normalize_pagination(page: Option<u32>, page_size: Option<u32>) -> (u32, u32, i64) {
|
|
||||||
let p = page.unwrap_or(1).max(1);
|
|
||||||
let ps = page_size.unwrap_or(DEFAULT_PAGE_SIZE).min(MAX_PAGE_SIZE).max(1);
|
|
||||||
let offset = ((p - 1) * ps) as i64;
|
|
||||||
(p, ps, offset)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_normalize_pagination_defaults() {
|
|
||||||
let (page, size, offset) = normalize_pagination(None, None);
|
|
||||||
assert_eq!(page, 1);
|
|
||||||
assert_eq!(size, DEFAULT_PAGE_SIZE);
|
|
||||||
assert_eq!(offset, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_normalize_pagination_clamp() {
|
|
||||||
let (page, size, offset) = normalize_pagination(None, Some(999));
|
|
||||||
assert_eq!(size, MAX_PAGE_SIZE);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_normalize_pagination_offset() {
|
|
||||||
let (page, size, offset) = normalize_pagination(Some(3), Some(10));
|
|
||||||
assert_eq!(offset, 20);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user