Compare commits
25 Commits
73ff5e8c5e
...
8898bb399e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8898bb399e | ||
|
|
28299807b6 | ||
|
|
d40c4605b2 | ||
|
|
7e4b787d5c | ||
|
|
837abec48a | ||
|
|
11e3d37468 | ||
|
|
8263b236fd | ||
|
|
08268b32b8 | ||
|
|
1bf0d3a73d | ||
|
|
07099e3ef0 | ||
|
|
dce9035584 | ||
|
|
c8dc654fd4 | ||
|
|
b1e3a27043 | ||
|
|
becfda3fbf | ||
|
|
830e9fa301 | ||
|
|
ef60f9a183 | ||
|
|
b66087de0e | ||
|
|
d06ecded34 | ||
|
|
9487cd7f72 | ||
|
|
c6bd4aea27 | ||
|
|
17a2501808 | ||
|
|
cc7ee3189d | ||
|
|
62df7feac1 | ||
|
|
a851a2854f | ||
|
|
59fc7debd6 |
@@ -9,7 +9,7 @@
|
||||
ZCLAW 是面向中文用户的 AI Agent 桌面端,核心能力包括:
|
||||
|
||||
- **智能对话** - 多模型支持、流式响应、上下文管理
|
||||
- **自主能力** - 8 个 Hands(浏览器、数据采集、研究、预测等)
|
||||
- **自主能力** - 11 个 Hands(9 启用 + 2 禁用: Predictor, Lead)
|
||||
- **技能系统** - 可扩展的 SKILL.md 技能定义
|
||||
- **工作流编排** - 多步骤自动化任务
|
||||
- **安全审计** - 完整的操作日志和权限控制
|
||||
@@ -69,7 +69,7 @@ ZCLAW/
|
||||
| 桌面框架 | Tauri 2.x |
|
||||
| 样式方案 | Tailwind CSS |
|
||||
| 配置格式 | TOML |
|
||||
| 后端核心 | Rust Workspace (9 crates) |
|
||||
| 后端核心 | Rust Workspace (10 crates) |
|
||||
| SaaS 后端 | Axum + PostgreSQL (zclaw-saas) |
|
||||
| 管理后台 | Next.js (admin/) |
|
||||
|
||||
|
||||
178
Cargo.lock
generated
178
Cargo.lock
generated
@@ -1314,7 +1314,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_yaml",
|
||||
"sha2",
|
||||
"sqlx",
|
||||
"sqlx 0.7.4",
|
||||
"tauri",
|
||||
"tauri-build",
|
||||
"tauri-plugin-opener",
|
||||
@@ -2262,6 +2262,8 @@ version = "0.15.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
|
||||
dependencies = [
|
||||
"allocator-api2",
|
||||
"equivalent",
|
||||
"foldhash 0.1.5",
|
||||
]
|
||||
|
||||
@@ -2285,6 +2287,15 @@ dependencies = [
|
||||
"hashbrown 0.14.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashlink"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
|
||||
dependencies = [
|
||||
"hashbrown 0.15.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "headers"
|
||||
version = "0.4.1"
|
||||
@@ -3716,6 +3727,15 @@ version = "2.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220"
|
||||
|
||||
[[package]]
|
||||
name = "pgvector"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc58e2d255979a31caa7cabfa7aac654af0354220719ab7a68520ae7a91e8c0b"
|
||||
dependencies = [
|
||||
"sqlx 0.8.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "phf"
|
||||
version = "0.8.0"
|
||||
@@ -4571,6 +4591,7 @@ dependencies = [
|
||||
"pkcs1",
|
||||
"pkcs8",
|
||||
"rand_core 0.6.4",
|
||||
"sha2",
|
||||
"signature",
|
||||
"spki",
|
||||
"subtle",
|
||||
@@ -5271,13 +5292,24 @@ version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa"
|
||||
dependencies = [
|
||||
"sqlx-core",
|
||||
"sqlx-macros",
|
||||
"sqlx-core 0.7.4",
|
||||
"sqlx-macros 0.7.4",
|
||||
"sqlx-mysql",
|
||||
"sqlx-postgres",
|
||||
"sqlx-postgres 0.7.4",
|
||||
"sqlx-sqlite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fefb893899429669dcdd979aff487bd78f4064e5e7907e4269081e0ef7d97dc"
|
||||
dependencies = [
|
||||
"sqlx-core 0.8.6",
|
||||
"sqlx-macros 0.8.6",
|
||||
"sqlx-postgres 0.8.6",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-core"
|
||||
version = "0.7.4"
|
||||
@@ -5288,6 +5320,7 @@ dependencies = [
|
||||
"atoi",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"crc",
|
||||
"crossbeam-queue",
|
||||
"either",
|
||||
@@ -5297,7 +5330,7 @@ dependencies = [
|
||||
"futures-intrusive",
|
||||
"futures-io",
|
||||
"futures-util",
|
||||
"hashlink",
|
||||
"hashlink 0.8.4",
|
||||
"hex",
|
||||
"indexmap 2.13.0",
|
||||
"log",
|
||||
@@ -5317,6 +5350,38 @@ dependencies = [
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-core"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ee6798b1838b6a0f69c007c133b8df5866302197e404e8b6ee8ed3e3a5e68dc6"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"crc",
|
||||
"crossbeam-queue",
|
||||
"either",
|
||||
"event-listener 5.4.1",
|
||||
"futures-core",
|
||||
"futures-intrusive",
|
||||
"futures-io",
|
||||
"futures-util",
|
||||
"hashbrown 0.15.5",
|
||||
"hashlink 0.10.0",
|
||||
"indexmap 2.13.0",
|
||||
"log",
|
||||
"memchr",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"thiserror 2.0.18",
|
||||
"tracing",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-macros"
|
||||
version = "0.7.4"
|
||||
@@ -5325,11 +5390,24 @@ checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"sqlx-core",
|
||||
"sqlx-macros-core",
|
||||
"sqlx-core 0.7.4",
|
||||
"sqlx-macros-core 0.7.4",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-macros"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2d452988ccaacfbf5e0bdbc348fb91d7c8af5bee192173ac3636b5fb6e6715d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"sqlx-core 0.8.6",
|
||||
"sqlx-macros-core 0.8.6",
|
||||
"syn 2.0.117",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-macros-core"
|
||||
version = "0.7.4"
|
||||
@@ -5346,9 +5424,9 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"sqlx-core",
|
||||
"sqlx-core 0.7.4",
|
||||
"sqlx-mysql",
|
||||
"sqlx-postgres",
|
||||
"sqlx-postgres 0.7.4",
|
||||
"sqlx-sqlite",
|
||||
"syn 1.0.109",
|
||||
"tempfile",
|
||||
@@ -5356,6 +5434,28 @@ dependencies = [
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-macros-core"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19a9c1841124ac5a61741f96e1d9e2ec77424bf323962dd894bdb93f37d5219b"
|
||||
dependencies = [
|
||||
"dotenvy",
|
||||
"either",
|
||||
"heck 0.5.0",
|
||||
"hex",
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"sqlx-core 0.8.6",
|
||||
"sqlx-postgres 0.8.6",
|
||||
"syn 2.0.117",
|
||||
"url",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-mysql"
|
||||
version = "0.7.4"
|
||||
@@ -5367,6 +5467,7 @@ dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"crc",
|
||||
"digest",
|
||||
"dotenvy",
|
||||
@@ -5391,7 +5492,7 @@ dependencies = [
|
||||
"sha1",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"sqlx-core",
|
||||
"sqlx-core 0.7.4",
|
||||
"stringprep",
|
||||
"thiserror 1.0.69",
|
||||
"tracing",
|
||||
@@ -5408,6 +5509,7 @@ dependencies = [
|
||||
"base64 0.21.7",
|
||||
"bitflags 2.11.0",
|
||||
"byteorder",
|
||||
"chrono",
|
||||
"crc",
|
||||
"dotenvy",
|
||||
"etcetera",
|
||||
@@ -5429,13 +5531,50 @@ dependencies = [
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"sqlx-core",
|
||||
"sqlx-core 0.7.4",
|
||||
"stringprep",
|
||||
"thiserror 1.0.69",
|
||||
"tracing",
|
||||
"whoami",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-postgres"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"base64 0.22.1",
|
||||
"bitflags 2.11.0",
|
||||
"byteorder",
|
||||
"crc",
|
||||
"dotenvy",
|
||||
"etcetera",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-util",
|
||||
"hex",
|
||||
"hkdf",
|
||||
"hmac",
|
||||
"home",
|
||||
"itoa",
|
||||
"log",
|
||||
"md-5",
|
||||
"memchr",
|
||||
"once_cell",
|
||||
"rand 0.8.5",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"smallvec",
|
||||
"sqlx-core 0.8.6",
|
||||
"stringprep",
|
||||
"thiserror 2.0.18",
|
||||
"tracing",
|
||||
"whoami",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sqlx-sqlite"
|
||||
version = "0.7.4"
|
||||
@@ -5443,6 +5582,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
|
||||
dependencies = [
|
||||
"atoi",
|
||||
"chrono",
|
||||
"flume",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
@@ -5453,7 +5593,7 @@ dependencies = [
|
||||
"log",
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
"sqlx-core",
|
||||
"sqlx-core 0.7.4",
|
||||
"tracing",
|
||||
"url",
|
||||
"urlencoding",
|
||||
@@ -8211,7 +8351,7 @@ dependencies = [
|
||||
"libsqlite3-sys",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
"sqlx 0.7.4",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tokio-test",
|
||||
@@ -8227,11 +8367,9 @@ dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"hmac",
|
||||
"reqwest 0.12.28",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha1",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -8272,12 +8410,14 @@ dependencies = [
|
||||
name = "zclaw-memory"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"futures",
|
||||
"libsqlite3-sys",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
"sqlx 0.7.4",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
@@ -8362,9 +8502,11 @@ dependencies = [
|
||||
"aes-gcm",
|
||||
"anyhow",
|
||||
"argon2",
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
"axum",
|
||||
"axum-extra",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"dashmap",
|
||||
@@ -8372,15 +8514,17 @@ dependencies = [
|
||||
"futures",
|
||||
"hex",
|
||||
"jsonwebtoken",
|
||||
"pgvector",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest 0.12.28",
|
||||
"rsa",
|
||||
"secrecy",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sha2",
|
||||
"socket2 0.5.10",
|
||||
"sqlx",
|
||||
"sqlx 0.7.4",
|
||||
"tempfile",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
|
||||
@@ -57,7 +57,7 @@ chrono = { version = "0.4", features = ["serde"] }
|
||||
uuid = { version = "1", features = ["v4", "v5", "serde"] }
|
||||
|
||||
# Database
|
||||
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite", "postgres"] }
|
||||
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite", "postgres", "chrono"] }
|
||||
libsqlite3-sys = { version = "0.27", features = ["bundled"] }
|
||||
|
||||
# HTTP client (for LLM drivers)
|
||||
@@ -84,6 +84,7 @@ rand = "0.8"
|
||||
# Crypto
|
||||
sha2 = "0.10"
|
||||
aes-gcm = "0.10"
|
||||
rsa = { version = "0.9", features = ["pem"] }
|
||||
|
||||
# Home directory
|
||||
dirs = "6"
|
||||
|
||||
@@ -16,6 +16,9 @@ import {
|
||||
SunOutlined,
|
||||
MoonOutlined,
|
||||
ApiOutlined,
|
||||
BookOutlined,
|
||||
CrownOutlined,
|
||||
SafetyOutlined,
|
||||
} from '@ant-design/icons'
|
||||
import { Avatar, Dropdown, Tooltip, Drawer } from 'antd'
|
||||
import { useAuthStore } from '@/stores/authStore'
|
||||
@@ -37,11 +40,14 @@ interface NavItem {
|
||||
const navItems: NavItem[] = [
|
||||
{ path: '/', name: '仪表盘', icon: <DashboardOutlined />, group: '核心' },
|
||||
{ path: '/accounts', name: '账号管理', icon: <TeamOutlined />, permission: 'account:admin', group: '资源管理' },
|
||||
{ path: '/roles', name: '角色与权限', icon: <SafetyOutlined />, permission: 'account:admin', group: '资源管理' },
|
||||
{ path: '/model-services', name: '模型服务', icon: <CloudServerOutlined />, permission: 'provider:manage', group: '资源管理' },
|
||||
{ path: '/agent-templates', name: 'Agent 模板', icon: <RobotOutlined />, permission: 'model:read', group: '资源管理' },
|
||||
{ path: '/api-keys', name: 'API 密钥', icon: <ApiOutlined />, permission: 'provider:manage', group: '资源管理' },
|
||||
{ path: '/usage', name: '用量统计', icon: <BarChartOutlined />, permission: 'admin:full', group: '运维' },
|
||||
{ path: '/relay', name: '中转任务', icon: <SwapOutlined />, permission: 'relay:use', group: '运维' },
|
||||
{ path: '/knowledge', name: '知识库', icon: <BookOutlined />, permission: 'knowledge:read', group: '资源管理' },
|
||||
{ path: '/billing', name: '计费管理', icon: <CrownOutlined />, permission: 'billing:read', group: '核心' },
|
||||
{ path: '/logs', name: '操作日志', icon: <FileTextOutlined />, permission: 'admin:full', group: '运维' },
|
||||
{ path: '/config', name: '系统配置', icon: <SettingOutlined />, permission: 'config:read', group: '系统' },
|
||||
{ path: '/prompts', name: '提示词管理', icon: <MessageOutlined />, permission: 'prompt:read', group: '系统' },
|
||||
@@ -197,6 +203,7 @@ function MobileDrawer({
|
||||
const breadcrumbMap: Record<string, string> = {
|
||||
'/': '仪表盘',
|
||||
'/accounts': '账号管理',
|
||||
'/roles': '角色与权限',
|
||||
'/model-services': '模型服务',
|
||||
'/providers': '模型服务',
|
||||
'/models': '模型服务',
|
||||
@@ -204,6 +211,8 @@ const breadcrumbMap: Record<string, string> = {
|
||||
'/agent-templates': 'Agent 模板',
|
||||
'/usage': '用量统计',
|
||||
'/relay': '中转任务',
|
||||
'/knowledge': '知识库',
|
||||
'/billing': '计费管理',
|
||||
'/config': '系统配置',
|
||||
'/prompts': '提示词管理',
|
||||
'/logs': '操作日志',
|
||||
|
||||
352
admin-v2/src/pages/Billing.tsx
Normal file
352
admin-v2/src/pages/Billing.tsx
Normal file
@@ -0,0 +1,352 @@
|
||||
// ============================================================
|
||||
// 计费管理 — 计划/订阅/用量/支付
|
||||
// ============================================================
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import {
|
||||
Button, message, Tag, Modal, Card, Row, Col, Statistic, Typography,
|
||||
Progress, Space, Radio, Spin, Empty, Divider,
|
||||
} from 'antd'
|
||||
import {
|
||||
CrownOutlined, CheckCircleOutlined, ThunderboltOutlined,
|
||||
RocketOutlined, TeamOutlined, AlipayCircleOutlined,
|
||||
WechatOutlined, LoadingOutlined,
|
||||
} from '@ant-design/icons'
|
||||
import { PageHeader } from '@/components/PageHeader'
|
||||
import { ErrorState } from '@/components/ErrorState'
|
||||
import { billingService } from '@/services/billing'
|
||||
import type { BillingPlan, SubscriptionInfo, PaymentResult } from '@/services/billing'
|
||||
|
||||
const { Text, Title } = Typography
|
||||
|
||||
// === 计划卡片 ===
|
||||
|
||||
const planIcons: Record<string, React.ReactNode> = {
|
||||
free: <RocketOutlined style={{ fontSize: 24 }} />,
|
||||
pro: <ThunderboltOutlined style={{ fontSize: 24 }} />,
|
||||
team: <TeamOutlined style={{ fontSize: 24 }} />,
|
||||
}
|
||||
|
||||
const planColors: Record<string, string> = {
|
||||
free: '#8c8c8c',
|
||||
pro: '#863bff',
|
||||
team: '#47bfff',
|
||||
}
|
||||
|
||||
function PlanCard({
|
||||
plan,
|
||||
isCurrent,
|
||||
onSelect,
|
||||
}: {
|
||||
plan: BillingPlan
|
||||
isCurrent: boolean
|
||||
onSelect: (plan: BillingPlan) => void
|
||||
}) {
|
||||
const color = planColors[plan.name] || '#666'
|
||||
const limits = plan.limits as Record<string, unknown> | undefined
|
||||
const maxRelay = (limits?.max_relay_requests_monthly as number) ?? '∞'
|
||||
const maxHand = (limits?.max_hand_executions_monthly as number) ?? '∞'
|
||||
const maxPipeline = (limits?.max_pipeline_runs_monthly as number) ?? '∞'
|
||||
|
||||
return (
|
||||
<Card
|
||||
className={`relative overflow-hidden transition-all duration-200 hover:shadow-lg ${
|
||||
isCurrent ? 'ring-2 ring-offset-2' : ''
|
||||
}`}
|
||||
style={isCurrent ? { borderColor: color, '--tw-ring-color': color } as React.CSSProperties : {}}
|
||||
>
|
||||
{isCurrent && (
|
||||
<div
|
||||
className="absolute top-0 right-0 px-3 py-1 text-xs font-medium text-white rounded-bl-lg"
|
||||
style={{ background: color }}
|
||||
>
|
||||
当前计划
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="text-center mb-4">
|
||||
<div style={{ color }} className="mb-2">
|
||||
{planIcons[plan.name] || <CrownOutlined style={{ fontSize: 24 }} />}
|
||||
</div>
|
||||
<Title level={4} style={{ margin: 0 }}>{plan.display_name}</Title>
|
||||
{plan.description && (
|
||||
<Text type="secondary" className="text-sm">{plan.description}</Text>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="text-center mb-4">
|
||||
<span className="text-3xl font-bold" style={{ color }}>
|
||||
¥{plan.price_cents === 0 ? '0' : (plan.price_cents / 100).toFixed(0)}
|
||||
</span>
|
||||
<Text type="secondary"> /{plan.interval === 'month' ? '月' : '年'}</Text>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2 text-sm">
|
||||
<div className="flex items-center gap-2">
|
||||
<CheckCircleOutlined style={{ color: '#52c41a' }} />
|
||||
<span>中转请求: {maxRelay === Infinity ? '无限' : `${maxRelay} 次/月`}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<CheckCircleOutlined style={{ color: '#52c41a' }} />
|
||||
<span>Hand 执行: {maxHand === Infinity ? '无限' : `${maxHand} 次/月`}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<CheckCircleOutlined style={{ color: '#52c41a' }} />
|
||||
<span>Pipeline 运行: {maxPipeline === Infinity ? '无限' : `${maxPipeline} 次/月`}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<CheckCircleOutlined style={{ color: '#52c41a' }} />
|
||||
<span>知识库: {plan.name === 'free' ? '基础' : '高级'}</span>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<CheckCircleOutlined style={{ color: '#52c41a' }} />
|
||||
<span>优先级队列: {plan.name === 'team' ? '最高' : plan.name === 'pro' ? '高' : '标准'}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Divider />
|
||||
|
||||
<Button
|
||||
block
|
||||
type={isCurrent ? 'default' : 'primary'}
|
||||
disabled={isCurrent}
|
||||
onClick={() => onSelect(plan)}
|
||||
style={!isCurrent ? { background: color, borderColor: color } : {}}
|
||||
>
|
||||
{isCurrent ? '当前计划' : '升级'}
|
||||
</Button>
|
||||
</Card>
|
||||
)
|
||||
}
|
||||
|
||||
// === 用量进度条 ===
|
||||
|
||||
function UsageBar({ label, current, max }: { label: string; current: number; max: number | null }) {
|
||||
const pct = max ? Math.min((current / max) * 100, 100) : 0
|
||||
const displayMax = max ? max.toLocaleString() : '∞'
|
||||
|
||||
return (
|
||||
<div className="mb-3">
|
||||
<div className="flex justify-between text-xs text-neutral-500 dark:text-neutral-400 mb-1">
|
||||
<span>{label}</span>
|
||||
<span>{current.toLocaleString()} / {displayMax}</span>
|
||||
</div>
|
||||
<Progress
|
||||
percent={pct}
|
||||
showInfo={false}
|
||||
strokeColor={pct >= 90 ? '#ff4d4f' : pct >= 70 ? '#faad14' : '#863bff'}
|
||||
size="small"
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// === 主页面 ===
|
||||
|
||||
export default function Billing() {
|
||||
const queryClient = useQueryClient()
|
||||
const [payModalOpen, setPayModalOpen] = useState(false)
|
||||
const [selectedPlan, setSelectedPlan] = useState<BillingPlan | null>(null)
|
||||
const [payMethod, setPayMethod] = useState<'alipay' | 'wechat'>('alipay')
|
||||
const [payResult, setPayResult] = useState<PaymentResult | null>(null)
|
||||
const [pollingPayment, setPollingPayment] = useState<string | null>(null)
|
||||
|
||||
const { data: plans = [], isLoading: plansLoading, error: plansError, refetch } = useQuery({
|
||||
queryKey: ['billing-plans'],
|
||||
queryFn: ({ signal }) => billingService.listPlans(signal),
|
||||
})
|
||||
|
||||
const { data: subInfo, isLoading: subLoading } = useQuery({
|
||||
queryKey: ['billing-subscription'],
|
||||
queryFn: ({ signal }) => billingService.getSubscription(signal),
|
||||
})
|
||||
|
||||
// 支付状态轮询
|
||||
const { data: paymentStatus } = useQuery({
|
||||
queryKey: ['payment-status', pollingPayment],
|
||||
queryFn: ({ signal }) => billingService.getPaymentStatus(pollingPayment!, signal),
|
||||
enabled: !!pollingPayment,
|
||||
refetchInterval: pollingPayment ? 3000 : false,
|
||||
})
|
||||
|
||||
// 支付成功后刷新
|
||||
if (paymentStatus?.status === 'succeeded' && pollingPayment) {
|
||||
setPollingPayment(null)
|
||||
setPayModalOpen(false)
|
||||
setPayResult(null)
|
||||
message.success('支付成功!计划已更新')
|
||||
queryClient.invalidateQueries({ queryKey: ['billing-subscription'] })
|
||||
}
|
||||
|
||||
const createPaymentMutation = useMutation({
|
||||
mutationFn: (data: { plan_id: string; payment_method: 'alipay' | 'wechat' }) =>
|
||||
billingService.createPayment(data),
|
||||
onSuccess: (result) => {
|
||||
setPayResult(result)
|
||||
setPollingPayment(result.payment_id)
|
||||
// 打开支付链接
|
||||
window.open(result.pay_url, '_blank', 'width=480,height=640')
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '创建支付失败'),
|
||||
})
|
||||
|
||||
const handleSelectPlan = (plan: BillingPlan) => {
|
||||
if (plan.price_cents === 0) return
|
||||
setSelectedPlan(plan)
|
||||
setPayResult(null)
|
||||
setPayModalOpen(true)
|
||||
}
|
||||
|
||||
const handleConfirmPay = () => {
|
||||
if (!selectedPlan) return
|
||||
createPaymentMutation.mutate({
|
||||
plan_id: selectedPlan.id,
|
||||
payment_method: payMethod,
|
||||
})
|
||||
}
|
||||
|
||||
if (plansError) {
|
||||
return (
|
||||
<>
|
||||
<PageHeader title="计费管理" description="管理订阅计划和用量配额" />
|
||||
<ErrorState message={(plansError as Error).message} onRetry={() => refetch()} />
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
const currentPlanName = subInfo?.plan?.name || 'free'
|
||||
const usage = subInfo?.usage
|
||||
|
||||
return (
|
||||
<div>
|
||||
<PageHeader title="计费管理" description="管理订阅计划和用量配额" />
|
||||
|
||||
{/* 当前计划 + 用量 */}
|
||||
{subInfo && usage && (
|
||||
<Card className="mb-6" title={<span className="text-sm font-semibold">当前用量</span>}>
|
||||
<Row gutter={[24, 16]}>
|
||||
<Col xs={24} md={8}>
|
||||
<UsageBar
|
||||
label="中转请求"
|
||||
current={usage.relay_requests}
|
||||
max={usage.max_relay_requests}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} md={8}>
|
||||
<UsageBar
|
||||
label="Hand 执行"
|
||||
current={usage.hand_executions}
|
||||
max={usage.max_hand_executions}
|
||||
/>
|
||||
</Col>
|
||||
<Col xs={24} md={8}>
|
||||
<UsageBar
|
||||
label="Pipeline 运行"
|
||||
current={usage.pipeline_runs}
|
||||
max={usage.max_pipeline_runs}
|
||||
/>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
{subInfo.subscription && (
|
||||
<div className="mt-4 text-xs text-neutral-400">
|
||||
订阅周期: {new Date(subInfo.subscription.current_period_start).toLocaleDateString()} — {new Date(subInfo.subscription.current_period_end).toLocaleDateString()}
|
||||
</div>
|
||||
)}
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{/* 计划选择 */}
|
||||
<Title level={5} className="mb-4">选择计划</Title>
|
||||
|
||||
{plansLoading ? (
|
||||
<div className="flex justify-center py-8"><Spin /></div>
|
||||
) : (
|
||||
<Row gutter={[16, 16]}>
|
||||
{plans.map((plan) => (
|
||||
<Col key={plan.id} xs={24} sm={12} lg={8}>
|
||||
<PlanCard
|
||||
plan={plan}
|
||||
isCurrent={plan.name === currentPlanName}
|
||||
onSelect={handleSelectPlan}
|
||||
/>
|
||||
</Col>
|
||||
))}
|
||||
</Row>
|
||||
)}
|
||||
|
||||
{/* 支付弹窗 */}
|
||||
<Modal
|
||||
title={selectedPlan ? `升级到 ${selectedPlan.display_name}` : '支付'}
|
||||
open={payModalOpen}
|
||||
onCancel={() => {
|
||||
setPayModalOpen(false)
|
||||
setPollingPayment(null)
|
||||
setPayResult(null)
|
||||
}}
|
||||
footer={payResult ? null : undefined}
|
||||
onOk={handleConfirmPay}
|
||||
okText={createPaymentMutation.isPending ? '处理中...' : '确认支付'}
|
||||
confirmLoading={createPaymentMutation.isPending}
|
||||
>
|
||||
{payResult ? (
|
||||
<div className="text-center py-4">
|
||||
<LoadingOutlined style={{ fontSize: 32, color: '#863bff' }} className="mb-4" />
|
||||
<Title level={5}>等待支付确认...</Title>
|
||||
<Text type="secondary">
|
||||
支付窗口已打开,请在新窗口完成支付。
|
||||
<br />
|
||||
支付金额: ¥{(payResult.amount_cents / 100).toFixed(2)}
|
||||
</Text>
|
||||
<div className="mt-4">
|
||||
<Button onClick={() => { setPollingPayment(null); setPayModalOpen(false); setPayResult(null) }}>
|
||||
关闭
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div>
|
||||
{selectedPlan && (
|
||||
<div className="text-center mb-6">
|
||||
<div className="text-2xl font-bold" style={{ color: planColors[selectedPlan.name] || '#666' }}>
|
||||
¥{(selectedPlan.price_cents / 100).toFixed(0)}
|
||||
</div>
|
||||
<Text type="secondary">/{selectedPlan.interval === 'month' ? '月' : '年'}</Text>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Title level={5} className="text-center mb-4">选择支付方式</Title>
|
||||
|
||||
<Radio.Group
|
||||
value={payMethod}
|
||||
onChange={(e) => setPayMethod(e.target.value)}
|
||||
className="w-full"
|
||||
>
|
||||
<Space direction="vertical" className="w-full" size={12}>
|
||||
<Radio value="alipay" className="w-full">
|
||||
<div className="flex items-center gap-3 p-3 border rounded-lg w-full hover:border-blue-400 transition-colors">
|
||||
<AlipayCircleOutlined style={{ fontSize: 28, color: '#1677ff' }} />
|
||||
<div>
|
||||
<div className="font-medium">支付宝</div>
|
||||
<div className="text-xs text-neutral-400">推荐个人用户</div>
|
||||
</div>
|
||||
</div>
|
||||
</Radio>
|
||||
<Radio value="wechat" className="w-full">
|
||||
<div className="flex items-center gap-3 p-3 border rounded-lg w-full hover:border-green-400 transition-colors">
|
||||
<WechatOutlined style={{ fontSize: 28, color: '#07c160' }} />
|
||||
<div>
|
||||
<div className="font-medium">微信支付</div>
|
||||
<div className="text-xs text-neutral-400">扫码支付</div>
|
||||
</div>
|
||||
</div>
|
||||
</Radio>
|
||||
</Space>
|
||||
</Radio.Group>
|
||||
</div>
|
||||
)}
|
||||
</Modal>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
750
admin-v2/src/pages/Knowledge.tsx
Normal file
750
admin-v2/src/pages/Knowledge.tsx
Normal file
@@ -0,0 +1,750 @@
|
||||
// ============================================================
|
||||
// 知识库管理
|
||||
// ============================================================
|
||||
|
||||
import { useState, useMemo, useEffect } from 'react'
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import {
|
||||
Button, message, Tag, Modal, Form, Input, Select, Space, Popconfirm,
|
||||
Card, Statistic, Row, Col, Tabs, Tree, Typography, Empty, Spin, InputNumber,
|
||||
Table, Tooltip,
|
||||
} from 'antd'
|
||||
import {
|
||||
PlusOutlined, SearchOutlined, BookOutlined, FolderOutlined,
|
||||
DeleteOutlined, EditOutlined, EyeOutlined, BarChartOutlined,
|
||||
HistoryOutlined, RollbackOutlined,
|
||||
WarningOutlined,
|
||||
} from '@ant-design/icons'
|
||||
import type { ProColumns } from '@ant-design/pro-components'
|
||||
import { ProTable } from '@ant-design/pro-components'
|
||||
import { knowledgeService } from '@/services/knowledge'
|
||||
import type { CategoryResponse, KnowledgeItem, SearchResult } from '@/services/knowledge'
|
||||
|
||||
const { TextArea } = Input
|
||||
const { Text, Title } = Typography
|
||||
|
||||
// === 分类树 + 条目列表 Tab ===
|
||||
|
||||
function CategoriesPanel() {
|
||||
const queryClient = useQueryClient()
|
||||
const [createOpen, setCreateOpen] = useState(false)
|
||||
const [editItem, setEditItem] = useState<CategoryResponse | null>(null)
|
||||
const [createForm] = Form.useForm()
|
||||
const [editForm] = Form.useForm()
|
||||
|
||||
const { data: categories = [], isLoading } = useQuery({
|
||||
queryKey: ['knowledge-categories'],
|
||||
queryFn: ({ signal }) => knowledgeService.listCategories(signal),
|
||||
})
|
||||
|
||||
const createMutation = useMutation({
|
||||
mutationFn: (data: Parameters<typeof knowledgeService.createCategory>[0]) =>
|
||||
knowledgeService.createCategory(data),
|
||||
onSuccess: () => {
|
||||
message.success('分类已创建')
|
||||
queryClient.invalidateQueries({ queryKey: ['knowledge-categories'] })
|
||||
setCreateOpen(false)
|
||||
createForm.resetFields()
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '创建失败'),
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (id: string) => knowledgeService.deleteCategory(id),
|
||||
onSuccess: () => {
|
||||
message.success('分类已删除')
|
||||
queryClient.invalidateQueries({ queryKey: ['knowledge-categories'] })
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '删除失败'),
|
||||
})
|
||||
|
||||
const updateMutation = useMutation({
|
||||
mutationFn: ({ id, ...data }: { id: string } & Record<string, unknown>) =>
|
||||
knowledgeService.updateCategory(id, data),
|
||||
onSuccess: () => {
|
||||
message.success('分类已更新')
|
||||
queryClient.invalidateQueries({ queryKey: ['knowledge-categories'] })
|
||||
setEditItem(null)
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '更新失败'),
|
||||
})
|
||||
|
||||
// 编辑弹窗打开时同步表单值(Ant Design Form initialValues 仅首次挂载生效)
|
||||
useEffect(() => {
|
||||
if (editItem) {
|
||||
editForm.setFieldsValue({
|
||||
name: editItem.name,
|
||||
description: editItem.description,
|
||||
parent_id: editItem.parent_id,
|
||||
icon: editItem.icon,
|
||||
})
|
||||
}
|
||||
}, [editItem, editForm])
|
||||
|
||||
// 获取当前编辑分类及其所有后代的 ID(防止循环引用)
|
||||
const getDescendantIds = (id: string, cats: CategoryResponse[]): string[] => {
|
||||
const ids: string[] = [id]
|
||||
for (const c of cats) {
|
||||
if (c.parent_id === id) {
|
||||
ids.push(...getDescendantIds(c.id, cats))
|
||||
}
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
const treeData = useMemo(
|
||||
() => buildTreeData(categories, (id) => {
|
||||
Modal.confirm({
|
||||
title: '确认删除',
|
||||
content: '删除后无法恢复,请确保分类下没有子分类和条目。',
|
||||
okType: 'danger',
|
||||
onOk: () => deleteMutation.mutate(id),
|
||||
})
|
||||
}, (id) => {
|
||||
setEditItem(categories.find((c) => c.id === id) || null)
|
||||
}),
|
||||
[categories, deleteMutation],
|
||||
)
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<Title level={5} style={{ margin: 0 }}>分类管理</Title>
|
||||
<Button type="primary" icon={<PlusOutlined />} onClick={() => setCreateOpen(true)}>
|
||||
新建分类
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{isLoading ? (
|
||||
<div className="flex justify-center py-8"><Spin /></div>
|
||||
) : categories.length === 0 ? (
|
||||
<Empty description="暂无分类,请新建一个" />
|
||||
) : (
|
||||
<Tree
|
||||
treeData={treeData}
|
||||
defaultExpandAll
|
||||
showLine={{ showLeafIcon: false }}
|
||||
showIcon
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* 新建分类弹窗 */}
|
||||
<Modal
|
||||
title="新建分类"
|
||||
open={createOpen}
|
||||
onCancel={() => { setCreateOpen(false); createForm.resetFields() }}
|
||||
onOk={() => createForm.submit()}
|
||||
confirmLoading={createMutation.isPending}
|
||||
>
|
||||
<Form form={createForm} layout="vertical" onFinish={(v) => createMutation.mutate(v)}>
|
||||
<Form.Item name="name" label="分类名称" rules={[{ required: true, message: '请输入分类名称' }]}>
|
||||
<Input placeholder="例如:产品知识、技术文档" />
|
||||
</Form.Item>
|
||||
<Form.Item name="description" label="描述">
|
||||
<TextArea rows={2} placeholder="可选描述" />
|
||||
</Form.Item>
|
||||
<Form.Item name="parent_id" label="父分类">
|
||||
<Select placeholder="无(顶级分类)" allowClear>
|
||||
{flattenCategories(categories).map((c) => (
|
||||
<Select.Option key={c.id} value={c.id}>{c.name}</Select.Option>
|
||||
))}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
<Form.Item name="icon" label="图标">
|
||||
<Input placeholder="可选,如 📚" />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
|
||||
{/* 编辑分类弹窗 */}
|
||||
<Modal
|
||||
title="编辑分类"
|
||||
open={!!editItem}
|
||||
onCancel={() => { setEditItem(null); editForm.resetFields() }}
|
||||
onOk={() => editForm.submit()}
|
||||
confirmLoading={updateMutation.isPending}
|
||||
>
|
||||
<Form
|
||||
form={editForm}
|
||||
layout="vertical"
|
||||
initialValues={editItem ? { name: editItem.name, description: editItem.description, parent_id: editItem.parent_id, icon: editItem.icon } : undefined}
|
||||
onFinish={(v) => editItem && updateMutation.mutate({ id: editItem.id, ...v })}
|
||||
>
|
||||
<Form.Item name="name" label="分类名称" rules={[{ required: true }]}>
|
||||
<Input />
|
||||
</Form.Item>
|
||||
<Form.Item name="description" label="描述">
|
||||
<TextArea rows={2} />
|
||||
</Form.Item>
|
||||
<Form.Item name="parent_id" label="父分类">
|
||||
<Select placeholder="无(顶级分类)" allowClear>
|
||||
{editItem && flattenCategories(categories)
|
||||
.filter((c) => !getDescendantIds(editItem.id, categories).includes(c.id))
|
||||
.map((c) => (
|
||||
<Select.Option key={c.id} value={c.id}>{c.name}</Select.Option>
|
||||
))}
|
||||
</Form.Item>
|
||||
<Form.Item name="icon" label="图标">
|
||||
<Input placeholder="如 📚" />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// === 条目列表 ===
|
||||
|
||||
function ItemsPanel() {
|
||||
const queryClient = useQueryClient()
|
||||
const [createOpen, setCreateOpen] = useState(false)
|
||||
const [detailItem, setDetailItem] = useState<string | null>(null)
|
||||
const [versionModalOpen, setVersionModalOpen] = useState(false)
|
||||
const [rollingBackVersion, setRollingBackVersion] = useState<number | null>(null)
|
||||
const [page, setPage] = useState(1)
|
||||
const [pageSize, setPageSize] = useState(20)
|
||||
const [filters, setFilters] = useState<{ category_id?: string; status?: string; keyword?: string }>({})
|
||||
const [form] = Form.useForm()
|
||||
|
||||
const { data: categories = [] } = useQuery({
|
||||
queryKey: ['knowledge-categories'],
|
||||
queryFn: ({ signal }) => knowledgeService.listCategories(signal),
|
||||
})
|
||||
|
||||
const { data: detailData, isLoading: detailLoading } = useQuery({
|
||||
queryKey: ['knowledge-item-detail', detailItem],
|
||||
queryFn: ({ signal }) => knowledgeService.getItem(detailItem!, signal),
|
||||
enabled: !!detailItem,
|
||||
})
|
||||
|
||||
const { data: versions } = useQuery({
|
||||
queryKey: ['knowledge-item-versions', detailItem],
|
||||
queryFn: ({ signal }) => knowledgeService.getVersions(detailItem!, signal),
|
||||
enabled: !!detailItem,
|
||||
})
|
||||
|
||||
const { data, isLoading } = useQuery({
|
||||
queryKey: ['knowledge-items', page, pageSize, filters],
|
||||
queryFn: ({ signal }) =>
|
||||
knowledgeService.listItems({ page, page_size: pageSize, ...filters }, signal),
|
||||
})
|
||||
|
||||
const createMutation = useMutation({
|
||||
mutationFn: (data: Parameters<typeof knowledgeService.createItem>[0]) =>
|
||||
knowledgeService.createItem(data),
|
||||
onSuccess: () => {
|
||||
message.success('条目已创建')
|
||||
queryClient.invalidateQueries({ queryKey: ['knowledge-items'] })
|
||||
setCreateOpen(false)
|
||||
form.resetFields()
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '创建失败'),
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (id: string) => knowledgeService.deleteItem(id),
|
||||
onSuccess: () => {
|
||||
message.success('已删除')
|
||||
queryClient.invalidateQueries({ queryKey: ['knowledge-items'] })
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '删除失败'),
|
||||
})
|
||||
|
||||
const rollbackMutation = useMutation({
|
||||
mutationFn: ({ itemId, version }: { itemId: string; version: number }) =>
|
||||
knowledgeService.rollbackVersion(itemId, version),
|
||||
onSuccess: () => {
|
||||
message.success('已回滚')
|
||||
queryClient.invalidateQueries({ queryKey: ['knowledge-items'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['knowledge-item-detail'] })
|
||||
queryClient.invalidateQueries({ queryKey: ['knowledge-item-versions'] })
|
||||
setVersionModalOpen(false)
|
||||
setRollingBackVersion(null)
|
||||
},
|
||||
onError: (err: Error) => {
|
||||
message.error(err.message || '回滚失败')
|
||||
setRollingBackVersion(null)
|
||||
},
|
||||
})
|
||||
|
||||
const statusColors: Record<string, string> = { active: 'green', draft: 'orange', archived: 'default' }
|
||||
const statusLabels: Record<string, string> = { active: '活跃', draft: '草稿', archived: '已归档' }
|
||||
|
||||
const columns: ProColumns<KnowledgeItem>[] = [
|
||||
{
|
||||
title: '标题',
|
||||
dataIndex: 'keyword',
|
||||
width: 250,
|
||||
render: (_, r) => (
|
||||
<Button type="link" size="small" onClick={() => setDetailItem(r.id)}>
|
||||
{r.title}
|
||||
</Button>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '状态',
|
||||
dataIndex: 'status',
|
||||
width: 80,
|
||||
valueEnum: Object.fromEntries(
|
||||
Object.entries(statusLabels).map(([k, v]) => [k, { text: v, status: statusColors[k] === 'green' ? 'Success' : statusColors[k] === 'orange' ? 'Warning' : 'Default' }]),
|
||||
),
|
||||
},
|
||||
{ title: '版本', dataIndex: 'version', width: 60, search: false },
|
||||
{ title: '优先级', dataIndex: 'priority', width: 70, search: false },
|
||||
{
|
||||
title: '标签',
|
||||
dataIndex: 'tags',
|
||||
width: 200,
|
||||
search: false,
|
||||
render: (_, r) => (
|
||||
<Space size={[4, 4]} wrap>
|
||||
{r.tags?.map((t) => <Tag key={t}>{t}</Tag>)}
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
{ title: '更新时间', dataIndex: 'updated_at', width: 160, valueType: 'dateTime', search: false },
|
||||
{
|
||||
title: '操作',
|
||||
width: 150,
|
||||
search: false,
|
||||
render: (_, r) => (
|
||||
<Space>
|
||||
<Button type="link" size="small" icon={<EyeOutlined />} onClick={() => setDetailItem(r.id)} />
|
||||
<Tooltip title="版本历史">
|
||||
<Button type="link" size="small" icon={<HistoryOutlined />} onClick={() => { setDetailItem(r.id); setVersionModalOpen(true) }} />
|
||||
</Tooltip>
|
||||
<Popconfirm title="确认删除?" onConfirm={() => deleteMutation.mutate(r.id)}>
|
||||
<Button type="link" size="small" danger icon={<DeleteOutlined />} />
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
return (
|
||||
<div>
|
||||
<ProTable<KnowledgeItem>
|
||||
columns={columns}
|
||||
dataSource={data?.items || []}
|
||||
loading={isLoading}
|
||||
rowKey="id"
|
||||
search={{
|
||||
onReset: () => { setFilters({}); setPage(1) },
|
||||
onSearch: (values) => { setFilters(values); setPage(1) },
|
||||
}}
|
||||
toolBarRender={() => [
|
||||
<Button key="create" type="primary" icon={<PlusOutlined />} onClick={() => setCreateOpen(true)}>
|
||||
新建条目
|
||||
</Button>,
|
||||
]}
|
||||
pagination={{
|
||||
current: page,
|
||||
pageSize,
|
||||
total: data?.total || 0,
|
||||
showSizeChanger: true,
|
||||
onChange: (p, ps) => { setPage(p); setPageSize(ps) },
|
||||
}}
|
||||
options={{ density: false, fullScreen: false, reload: () => queryClient.invalidateQueries({ queryKey: ['knowledge-items'] }) }}
|
||||
/>
|
||||
|
||||
{/* 创建弹窗 */}
|
||||
<Modal
|
||||
title="新建知识条目"
|
||||
open={createOpen}
|
||||
onCancel={() => { setCreateOpen(false); form.resetFields() }}
|
||||
onOk={() => form.submit()}
|
||||
confirmLoading={createMutation.isPending}
|
||||
width={640}
|
||||
>
|
||||
<Form form={form} layout="vertical" onFinish={(v) => createMutation.mutate(v)}>
|
||||
<Form.Item name="category_id" label="分类" rules={[{ required: true, message: '请选择分类' }]}>
|
||||
<Select placeholder="选择分类">
|
||||
{flattenCategories(categories).map((c) => (
|
||||
<Select.Option key={c.id} value={c.id}>{c.name}</Select.Option>
|
||||
))}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
<Form.Item name="title" label="标题" rules={[{ required: true, message: '请输入标题' }]}>
|
||||
<Input placeholder="知识条目标题" />
|
||||
</Form.Item>
|
||||
<Form.Item name="content" label="内容" rules={[{ required: true, message: '请输入内容' }]}>
|
||||
<TextArea rows={8} placeholder="支持 Markdown 格式" />
|
||||
</Form.Item>
|
||||
<Row gutter={16}>
|
||||
<Col span={12}>
|
||||
<Form.Item name="keywords" label="关键词">
|
||||
<Select mode="tags" placeholder="输入后回车添加" />
|
||||
</Form.Item>
|
||||
</Col>
|
||||
<Col span={12}>
|
||||
<Form.Item name="tags" label="标签">
|
||||
<Select mode="tags" placeholder="输入后回车添加" />
|
||||
</Form.Item>
|
||||
</Col>
|
||||
</Row>
|
||||
<Form.Item name="priority" label="优先级" initialValue={0}>
|
||||
<InputNumber min={0} max={100} />
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
|
||||
{/* 详情弹窗 */}
|
||||
<Modal
|
||||
title={detailData?.title || '条目详情'}
|
||||
open={!!detailItem && !versionModalOpen}
|
||||
onCancel={() => setDetailItem(null)}
|
||||
footer={null}
|
||||
width={720}
|
||||
>
|
||||
{detailData && (
|
||||
<div>
|
||||
<div className="mb-4 flex gap-2">
|
||||
<Tag color={statusColors[detailData.status]}>{statusLabels[detailData.status] || detailData.status}</Tag>
|
||||
<Tag>版本 {detailData.version}</Tag>
|
||||
<Tag>优先级 {detailData.priority}</Tag>
|
||||
</div>
|
||||
<div className="mb-4 whitespace-pre-wrap bg-neutral-50 dark:bg-neutral-900 p-4 rounded-lg max-h-96 overflow-y-auto text-sm">
|
||||
{detailData.content}
|
||||
</div>
|
||||
<div className="flex gap-2 flex-wrap">
|
||||
{detailData.tags?.map((t) => <Tag key={t} color="blue">{t}</Tag>)}
|
||||
{detailData.keywords?.map((k) => <Tag key={k} color="cyan">{k}</Tag>)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</Modal>
|
||||
|
||||
{/* 版本历史弹窗 */}
|
||||
<Modal
|
||||
title={`版本历史 - ${detailData?.title || ''}`}
|
||||
open={versionModalOpen}
|
||||
onCancel={() => { setVersionModalOpen(false); setDetailItem(null) }}
|
||||
footer={null}
|
||||
width={720}
|
||||
>
|
||||
<Table
|
||||
dataSource={versions?.versions || []}
|
||||
rowKey="id"
|
||||
loading={!versions}
|
||||
size="small"
|
||||
pagination={{ pageSize: 10 }}
|
||||
columns={[
|
||||
{ title: '版本', dataIndex: 'version', width: 70 },
|
||||
{ title: '标题', dataIndex: 'title', ellipsis: true },
|
||||
{ title: '摘要', dataIndex: 'change_summary', width: 200, ellipsis: true },
|
||||
{ title: '创建者', dataIndex: 'created_by', width: 100 },
|
||||
{ title: '创建时间', dataIndex: 'created_at', width: 160 },
|
||||
{
|
||||
title: '操作',
|
||||
width: 80,
|
||||
render: (_, r) => (
|
||||
<Popconfirm
|
||||
title={`确认回滚到版本 ${r.version}?`}
|
||||
description="回滚将创建新版本,当前版本内容会被替换。"
|
||||
onConfirm={() => {
|
||||
setRollingBackVersion(r.version)
|
||||
rollbackMutation.mutate({ itemId: detailItem!, version: r.version })
|
||||
}}
|
||||
>
|
||||
<Button type="link" size="small" icon={<RollbackOutlined />} loading={rollingBackVersion === r.version}>
|
||||
回滚
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
),
|
||||
},
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// === 搜索面板 ===
|
||||
|
||||
function SearchPanel() {
|
||||
const [query, setQuery] = useState('')
|
||||
const [results, setResults] = useState<SearchResult[]>([])
|
||||
const [searching, setSearching] = useState(false)
|
||||
const [hasSearched, setHasSearched] = useState(false)
|
||||
|
||||
const handleSearch = async () => {
|
||||
if (!query.trim()) return
|
||||
setSearching(true)
|
||||
try {
|
||||
const data = await knowledgeService.search({ query: query.trim(), limit: 10 })
|
||||
setResults(data)
|
||||
setHasSearched(true)
|
||||
} catch {
|
||||
message.error('搜索失败')
|
||||
} finally {
|
||||
setSearching(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Title level={5}>语义搜索</Title>
|
||||
<Space.Compact className="w-full mb-4">
|
||||
<Input
|
||||
size="large"
|
||||
placeholder="输入搜索关键词..."
|
||||
value={query}
|
||||
onChange={(e) => setQuery(e.target.value)}
|
||||
onPressEnter={handleSearch}
|
||||
prefix={<SearchOutlined />}
|
||||
/>
|
||||
<Button size="large" type="primary" loading={searching} onClick={handleSearch}>
|
||||
搜索
|
||||
</Button>
|
||||
</Space.Compact>
|
||||
|
||||
{results.length === 0 && !searching && !hasSearched && (
|
||||
<Empty description="输入关键词搜索知识库" />
|
||||
)}
|
||||
|
||||
{results.length === 0 && !searching && hasSearched && (
|
||||
<Empty description="未找到匹配的知识条目" />
|
||||
)}
|
||||
|
||||
<div className="space-y-3">
|
||||
{results.map((r) => (
|
||||
<Card key={r.chunk_id} size="small" hoverable>
|
||||
<div className="flex justify-between items-start mb-2">
|
||||
<Text strong>{r.item_title}</Text>
|
||||
<Tag>{r.category_name}</Tag>
|
||||
</div>
|
||||
<div className="text-sm text-neutral-600 dark:text-neutral-400 line-clamp-3 mb-2">
|
||||
{r.content}
|
||||
</div>
|
||||
<div className="flex gap-1 flex-wrap">
|
||||
{r.keywords?.slice(0, 5).map((k) => (
|
||||
<Tag key={k} color="cyan" style={{ fontSize: 11 }}>{k}</Tag>
|
||||
))}
|
||||
</div>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// === 分析看板 ===
|
||||
|
||||
function AnalyticsPanel() {
|
||||
const { data: overview, isLoading: overviewLoading } = useQuery({
|
||||
queryKey: ['knowledge-analytics'],
|
||||
queryFn: ({ signal }) => knowledgeService.getOverview(signal),
|
||||
})
|
||||
|
||||
const { data: trends } = useQuery({
|
||||
queryKey: ['knowledge-trends'],
|
||||
queryFn: ({ signal }) => knowledgeService.getTrends(signal),
|
||||
})
|
||||
|
||||
const { data: topItems } = useQuery({
|
||||
queryKey: ['knowledge-top-items'],
|
||||
queryFn: ({ signal }) => knowledgeService.getTopItems(signal),
|
||||
})
|
||||
|
||||
const { data: quality } = useQuery({
|
||||
queryKey: ['knowledge-quality'],
|
||||
queryFn: ({ signal }) => knowledgeService.getQuality(signal),
|
||||
})
|
||||
|
||||
const { data: gaps } = useQuery({
|
||||
queryKey: ['knowledge-gaps'],
|
||||
queryFn: ({ signal }) => knowledgeService.getGaps(signal),
|
||||
})
|
||||
|
||||
if (overviewLoading) return <div className="flex justify-center py-8"><Spin /></div>
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Title level={5} className="mb-4">知识库概览</Title>
|
||||
<Row gutter={[16, 16]}>
|
||||
<Col span={6}>
|
||||
<Card><Statistic title="总条目数" value={overview?.total_items || 0} /></Card>
|
||||
</Col>
|
||||
<Col span={6}>
|
||||
<Card><Statistic title="活跃条目" value={overview?.active_items || 0} valueStyle={{ color: '#52c41a' }} /></Card>
|
||||
</Col>
|
||||
<Col span={6}>
|
||||
<Card><Statistic title="分类数" value={overview?.total_categories || 0} /></Card>
|
||||
</Col>
|
||||
<Col span={6}>
|
||||
<Card><Statistic title="本周新增" value={overview?.weekly_new_items || 0} valueStyle={{ color: '#1890ff' }} /></Card>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
<Row gutter={[16, 16]} className="mt-4">
|
||||
<Col span={6}>
|
||||
<Card><Statistic title="总引用次数" value={overview?.total_references || 0} /></Card>
|
||||
</Col>
|
||||
<Col span={6}>
|
||||
<Card>
|
||||
<Statistic title="注入率" value={((overview?.injection_rate || 0) * 100).toFixed(1)} suffix="%" />
|
||||
</Card>
|
||||
</Col>
|
||||
<Col span={6}>
|
||||
<Card>
|
||||
<Statistic title="正面反馈率" value={((overview?.positive_feedback_rate || 0) * 100).toFixed(1)} suffix="%" valueStyle={{ color: '#52c41a' }} />
|
||||
</Card>
|
||||
</Col>
|
||||
<Col span={6}>
|
||||
<Card><Statistic title="过期条目" value={overview?.stale_items_count || 0} valueStyle={{ color: '#faad14' }} /></Card>
|
||||
</Col>
|
||||
</Row>
|
||||
|
||||
{/* 趋势数据表格 */}
|
||||
<Card title="检索趋势(近30天)" className="mt-4" size="small">
|
||||
<Table
|
||||
dataSource={trends?.trends || []}
|
||||
rowKey="date"
|
||||
loading={!trends}
|
||||
size="small"
|
||||
pagination={{ pageSize: 10 }}
|
||||
columns={[
|
||||
{ title: '日期', dataIndex: 'date', width: 120 },
|
||||
{ title: '检索次数', dataIndex: 'count', width: 100 },
|
||||
{ title: '注入次数', dataIndex: 'injected_count', width: 100 },
|
||||
]}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* Top Items 表格 */}
|
||||
<Card title="高频引用 Top 20" className="mt-4" size="small">
|
||||
<Table
|
||||
dataSource={topItems?.items || []}
|
||||
rowKey="id"
|
||||
loading={!topItems}
|
||||
size="small"
|
||||
pagination={{ pageSize: 10 }}
|
||||
columns={[
|
||||
{ title: '标题', dataIndex: 'title', ellipsis: true },
|
||||
{ title: '分类', dataIndex: 'category', width: 120 },
|
||||
{ title: '引用次数', dataIndex: 'ref_count', width: 100 },
|
||||
]}
|
||||
/>
|
||||
</Card>
|
||||
|
||||
{/* 质量指标 */}
|
||||
{quality?.categories?.length > 0 && (
|
||||
<Card title="分类质量指标" className="mt-4" size="small">
|
||||
<Table
|
||||
dataSource={quality.categories}
|
||||
rowKey="category"
|
||||
size="small"
|
||||
pagination={false}
|
||||
columns={[
|
||||
{ title: '分类', dataIndex: 'category', width: 150 },
|
||||
{ title: '总条目', dataIndex: 'total', width: 80 },
|
||||
{ title: '活跃', dataIndex: 'active', width: 80 },
|
||||
{ title: '有关键词', dataIndex: 'with_keywords', width: 100 },
|
||||
{ title: '平均优先级', dataIndex: 'avg_priority', width: 100, render: (v: number) => v?.toFixed(1) },
|
||||
]}
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
|
||||
{/* 知识缺口 */}
|
||||
{gaps?.gaps?.length > 0 && (
|
||||
<Card
|
||||
title={
|
||||
<Space>
|
||||
<WarningOutlined style={{ color: '#faad14' }} />
|
||||
<span>知识缺口检测</span>
|
||||
</Space>
|
||||
}
|
||||
className="mt-4"
|
||||
size="small"
|
||||
>
|
||||
<Table
|
||||
dataSource={gaps.gaps}
|
||||
rowKey="query"
|
||||
size="small"
|
||||
pagination={{ pageSize: 10 }}
|
||||
columns={[
|
||||
{ title: '查询', dataIndex: 'query', ellipsis: true },
|
||||
{ title: '次数', dataIndex: 'count', width: 80 },
|
||||
{ title: '平均分', dataIndex: 'avg_score', width: 100, render: (v: number) => v?.toFixed(2) },
|
||||
]}
|
||||
/>
|
||||
</Card>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// === 主页面 ===
|
||||
|
||||
export default function Knowledge() {
|
||||
return (
|
||||
<div className="p-6">
|
||||
<Tabs
|
||||
defaultActiveKey="items"
|
||||
items={[
|
||||
{
|
||||
key: 'items',
|
||||
label: '知识条目',
|
||||
icon: <BookOutlined />,
|
||||
children: <ItemsPanel />,
|
||||
},
|
||||
{
|
||||
key: 'categories',
|
||||
label: '分类管理',
|
||||
icon: <FolderOutlined />,
|
||||
children: <CategoriesPanel />,
|
||||
},
|
||||
{
|
||||
key: 'search',
|
||||
label: '搜索',
|
||||
icon: <SearchOutlined />,
|
||||
children: <SearchPanel />,
|
||||
},
|
||||
{
|
||||
key: 'analytics',
|
||||
label: '分析看板',
|
||||
icon: <BarChartOutlined />,
|
||||
children: <AnalyticsPanel />,
|
||||
},
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// === 辅助函数 ===
|
||||
|
||||
function flattenCategories(cats: CategoryResponse[]): { id: string; name: string }[] {
|
||||
const result: { id: string; name: string }[] = []
|
||||
for (const c of cats) {
|
||||
result.push({ id: c.id, name: c.name })
|
||||
if (c.children?.length) {
|
||||
result.push(...flattenCategories(c.children))
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
interface TreeNode {
|
||||
key: string
|
||||
title: React.ReactNode
|
||||
icon?: React.ReactNode
|
||||
children?: TreeNode[]
|
||||
}
|
||||
|
||||
function buildTreeData(cats: CategoryResponse[], onDelete: (id: string) => void, onEdit: (id: string) => void): TreeNode[] {
|
||||
return cats.map((c) => ({
|
||||
key: c.id,
|
||||
title: (
|
||||
<div className="flex items-center gap-2">
|
||||
<span>{c.icon || '📁'} {c.name}</span>
|
||||
<Tag>{c.item_count}</Tag>
|
||||
<Button type="link" size="small" icon={<EditOutlined />} onClick={() => onEdit(c.id)} />
|
||||
<Button type="link" size="small" danger onClick={() => onDelete(c.id)}>
|
||||
<DeleteOutlined />
|
||||
</Button>
|
||||
</div>
|
||||
),
|
||||
children: c.children?.length ? buildTreeData(c.children, onDelete, onEdit) : undefined,
|
||||
}))
|
||||
}
|
||||
509
admin-v2/src/pages/Roles.tsx
Normal file
509
admin-v2/src/pages/Roles.tsx
Normal file
@@ -0,0 +1,509 @@
|
||||
// ============================================================
|
||||
// 角色与权限模板管理
|
||||
// ============================================================
|
||||
|
||||
import { useState } from 'react'
|
||||
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
|
||||
import {
|
||||
Button,
|
||||
message,
|
||||
Tag,
|
||||
Modal,
|
||||
Form,
|
||||
Input,
|
||||
Select,
|
||||
Space,
|
||||
Popconfirm,
|
||||
Tabs,
|
||||
Tooltip,
|
||||
} from 'antd'
|
||||
import { PlusOutlined, SafetyOutlined, CheckCircleOutlined } from '@ant-design/icons'
|
||||
import type { ProColumns } from '@ant-design/pro-components'
|
||||
import { ProTable } from '@ant-design/pro-components'
|
||||
import { roleService } from '@/services/roles'
|
||||
import { PageHeader } from '@/components/PageHeader'
|
||||
import type {
|
||||
Role,
|
||||
PermissionTemplate,
|
||||
CreateRoleRequest,
|
||||
UpdateRoleRequest,
|
||||
CreateTemplateRequest,
|
||||
} from '@/types'
|
||||
|
||||
// ============================================================
|
||||
// 常见权限选项
|
||||
// ============================================================
|
||||
const permissionOptions = [
|
||||
{ value: 'account:admin', label: 'account:admin' },
|
||||
{ value: 'provider:manage', label: 'provider:manage' },
|
||||
{ value: 'model:read', label: 'model:read' },
|
||||
{ value: 'model:write', label: 'model:write' },
|
||||
{ value: 'relay:use', label: 'relay:use' },
|
||||
{ value: 'knowledge:read', label: 'knowledge:read' },
|
||||
{ value: 'knowledge:write', label: 'knowledge:write' },
|
||||
{ value: 'billing:read', label: 'billing:read' },
|
||||
{ value: 'billing:write', label: 'billing:write' },
|
||||
{ value: 'config:read', label: 'config:read' },
|
||||
{ value: 'config:write', label: 'config:write' },
|
||||
{ value: 'prompt:read', label: 'prompt:read' },
|
||||
{ value: 'prompt:write', label: 'prompt:write' },
|
||||
{ value: 'admin:full', label: 'admin:full' },
|
||||
]
|
||||
|
||||
// ============================================================
|
||||
// Roles Tab
|
||||
// ============================================================
|
||||
function RolesTab() {
|
||||
const queryClient = useQueryClient()
|
||||
const [form] = Form.useForm()
|
||||
const [modalOpen, setModalOpen] = useState(false)
|
||||
const [editingId, setEditingId] = useState<string | null>(null)
|
||||
|
||||
const { data, isLoading } = useQuery({
|
||||
queryKey: ['roles'],
|
||||
queryFn: ({ signal }) => roleService.list(signal),
|
||||
})
|
||||
|
||||
const createMutation = useMutation({
|
||||
mutationFn: (data: CreateRoleRequest) => roleService.create(data),
|
||||
onSuccess: () => {
|
||||
message.success('角色已创建')
|
||||
queryClient.invalidateQueries({ queryKey: ['roles'] })
|
||||
setModalOpen(false)
|
||||
form.resetFields()
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '创建失败'),
|
||||
})
|
||||
|
||||
const updateMutation = useMutation({
|
||||
mutationFn: ({ id, data }: { id: string; data: UpdateRoleRequest }) =>
|
||||
roleService.update(id, data),
|
||||
onSuccess: () => {
|
||||
message.success('角色已更新')
|
||||
queryClient.invalidateQueries({ queryKey: ['roles'] })
|
||||
setModalOpen(false)
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '更新失败'),
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (id: string) => roleService.delete(id),
|
||||
onSuccess: () => {
|
||||
message.success('角色已删除')
|
||||
queryClient.invalidateQueries({ queryKey: ['roles'] })
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '删除失败'),
|
||||
})
|
||||
|
||||
const handleSave = async () => {
|
||||
const values = await form.validateFields()
|
||||
if (editingId) {
|
||||
updateMutation.mutate({ id: editingId, data: values })
|
||||
} else {
|
||||
createMutation.mutate(values)
|
||||
}
|
||||
}
|
||||
|
||||
const openEdit = async (record: Role) => {
|
||||
setEditingId(record.id)
|
||||
const permissions = await roleService.getPermissions(record.id).catch(() => record.permissions)
|
||||
form.setFieldsValue({ ...record, permissions })
|
||||
setModalOpen(true)
|
||||
}
|
||||
|
||||
const openCreate = () => {
|
||||
setEditingId(null)
|
||||
form.resetFields()
|
||||
setModalOpen(true)
|
||||
}
|
||||
|
||||
const closeModal = () => {
|
||||
setModalOpen(false)
|
||||
setEditingId(null)
|
||||
form.resetFields()
|
||||
}
|
||||
|
||||
const columns: ProColumns<Role>[] = [
|
||||
{
|
||||
title: '角色名称',
|
||||
dataIndex: 'name',
|
||||
width: 160,
|
||||
render: (_, record) => (
|
||||
<span className="font-medium text-neutral-900 dark:text-neutral-100">
|
||||
{record.name}
|
||||
</span>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '描述',
|
||||
dataIndex: 'description',
|
||||
width: 240,
|
||||
ellipsis: true,
|
||||
render: (_, record) => record.description || '-',
|
||||
},
|
||||
{
|
||||
title: '权限数',
|
||||
dataIndex: 'permissions',
|
||||
width: 100,
|
||||
render: (_, record) => (
|
||||
<Tooltip title={record.permissions?.join(', ') || '无权限'}>
|
||||
<Tag>{record.permissions?.length ?? 0} 项</Tag>
|
||||
</Tooltip>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '关联账号',
|
||||
dataIndex: 'account_count',
|
||||
width: 100,
|
||||
render: (_, record) => record.account_count ?? 0,
|
||||
},
|
||||
{
|
||||
title: '创建时间',
|
||||
dataIndex: 'created_at',
|
||||
width: 180,
|
||||
render: (_, record) =>
|
||||
record.created_at ? new Date(record.created_at).toLocaleString('zh-CN') : '-',
|
||||
},
|
||||
{
|
||||
title: '操作',
|
||||
width: 160,
|
||||
render: (_, record) => (
|
||||
<Space>
|
||||
<Button size="small" onClick={() => openEdit(record)}>
|
||||
编辑
|
||||
</Button>
|
||||
<Popconfirm
|
||||
title="确定删除此角色?"
|
||||
description="删除后关联的账号将失去此角色权限"
|
||||
onConfirm={() => deleteMutation.mutate(record.id)}
|
||||
>
|
||||
<Button size="small" danger>
|
||||
删除
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
return (
|
||||
<div>
|
||||
<ProTable<Role>
|
||||
columns={columns}
|
||||
dataSource={data ?? []}
|
||||
loading={isLoading}
|
||||
rowKey="id"
|
||||
search={false}
|
||||
toolBarRender={() => [
|
||||
<Button key="add" type="primary" icon={<PlusOutlined />} onClick={openCreate}>
|
||||
新建角色
|
||||
</Button>,
|
||||
]}
|
||||
pagination={{ showSizeChanger: false }}
|
||||
/>
|
||||
|
||||
<Modal
|
||||
title={editingId ? '编辑角色' : '新建角色'}
|
||||
open={modalOpen}
|
||||
onOk={handleSave}
|
||||
onCancel={closeModal}
|
||||
confirmLoading={createMutation.isPending || updateMutation.isPending}
|
||||
width={560}
|
||||
>
|
||||
<Form form={form} layout="vertical" className="mt-4">
|
||||
<Form.Item
|
||||
name="name"
|
||||
label="角色名称"
|
||||
rules={[{ required: true, message: '请输入角色名称' }]}
|
||||
>
|
||||
<Input placeholder="如 editor, viewer" />
|
||||
</Form.Item>
|
||||
<Form.Item name="description" label="描述">
|
||||
<Input.TextArea rows={2} placeholder="角色用途说明" />
|
||||
</Form.Item>
|
||||
<Form.Item name="permissions" label="权限">
|
||||
<Select
|
||||
mode="multiple"
|
||||
placeholder="选择权限"
|
||||
options={permissionOptions}
|
||||
maxTagCount={5}
|
||||
allowClear
|
||||
filterOption={(input, option) =>
|
||||
(option?.label as string)?.toLowerCase().includes(input.toLowerCase())
|
||||
}
|
||||
/>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Permission Templates Tab
|
||||
// ============================================================
|
||||
function TemplatesTab() {
|
||||
const queryClient = useQueryClient()
|
||||
const [form] = Form.useForm()
|
||||
const [modalOpen, setModalOpen] = useState(false)
|
||||
const [applyOpen, setApplyOpen] = useState(false)
|
||||
const [applyForm] = Form.useForm()
|
||||
const [selectedTemplate, setSelectedTemplate] = useState<PermissionTemplate | null>(null)
|
||||
|
||||
const { data, isLoading } = useQuery({
|
||||
queryKey: ['permission-templates'],
|
||||
queryFn: ({ signal }) => roleService.listTemplates(signal),
|
||||
})
|
||||
|
||||
const createMutation = useMutation({
|
||||
mutationFn: (data: CreateTemplateRequest) => roleService.createTemplate(data),
|
||||
onSuccess: () => {
|
||||
message.success('模板已创建')
|
||||
queryClient.invalidateQueries({ queryKey: ['permission-templates'] })
|
||||
setModalOpen(false)
|
||||
form.resetFields()
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '创建失败'),
|
||||
})
|
||||
|
||||
const deleteMutation = useMutation({
|
||||
mutationFn: (id: string) => roleService.deleteTemplate(id),
|
||||
onSuccess: () => {
|
||||
message.success('模板已删除')
|
||||
queryClient.invalidateQueries({ queryKey: ['permission-templates'] })
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '删除失败'),
|
||||
})
|
||||
|
||||
const applyMutation = useMutation({
|
||||
mutationFn: ({ templateId, accountIds }: { templateId: string; accountIds: string[] }) =>
|
||||
roleService.applyTemplate(templateId, accountIds),
|
||||
onSuccess: () => {
|
||||
message.success('模板已应用到所选账号')
|
||||
queryClient.invalidateQueries({ queryKey: ['permission-templates'] })
|
||||
setApplyOpen(false)
|
||||
applyForm.resetFields()
|
||||
setSelectedTemplate(null)
|
||||
},
|
||||
onError: (err: Error) => message.error(err.message || '应用失败'),
|
||||
})
|
||||
|
||||
const openApply = (record: PermissionTemplate) => {
|
||||
setSelectedTemplate(record)
|
||||
applyForm.resetFields()
|
||||
setApplyOpen(true)
|
||||
}
|
||||
|
||||
const handleApply = async () => {
|
||||
const values = await applyForm.validateFields()
|
||||
if (!selectedTemplate) return
|
||||
const accountIds = values.account_ids
|
||||
?.split(',')
|
||||
.map((s: string) => s.trim())
|
||||
.filter(Boolean)
|
||||
if (!accountIds?.length) {
|
||||
message.warning('请输入至少一个账号 ID')
|
||||
return
|
||||
}
|
||||
applyMutation.mutate({ templateId: selectedTemplate.id, accountIds })
|
||||
}
|
||||
|
||||
const columns: ProColumns<PermissionTemplate>[] = [
|
||||
{
|
||||
title: '模板名称',
|
||||
dataIndex: 'name',
|
||||
width: 180,
|
||||
render: (_, record) => (
|
||||
<span className="font-medium text-neutral-900 dark:text-neutral-100">
|
||||
{record.name}
|
||||
</span>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '描述',
|
||||
dataIndex: 'description',
|
||||
width: 240,
|
||||
ellipsis: true,
|
||||
render: (_, record) => record.description || '-',
|
||||
},
|
||||
{
|
||||
title: '权限数',
|
||||
dataIndex: 'permissions',
|
||||
width: 100,
|
||||
render: (_, record) => (
|
||||
<Tooltip title={record.permissions?.join(', ') || '无权限'}>
|
||||
<Tag>{record.permissions?.length ?? 0} 项</Tag>
|
||||
</Tooltip>
|
||||
),
|
||||
},
|
||||
{
|
||||
title: '创建时间',
|
||||
dataIndex: 'created_at',
|
||||
width: 180,
|
||||
render: (_, record) =>
|
||||
record.created_at ? new Date(record.created_at).toLocaleString('zh-CN') : '-',
|
||||
},
|
||||
{
|
||||
title: '操作',
|
||||
width: 180,
|
||||
render: (_, record) => (
|
||||
<Space>
|
||||
<Button
|
||||
size="small"
|
||||
icon={<CheckCircleOutlined />}
|
||||
onClick={() => openApply(record)}
|
||||
>
|
||||
应用
|
||||
</Button>
|
||||
<Popconfirm
|
||||
title="确定删除此模板?"
|
||||
description="删除后已应用的账号不受影响"
|
||||
onConfirm={() => deleteMutation.mutate(record.id)}
|
||||
>
|
||||
<Button size="small" danger>
|
||||
删除
|
||||
</Button>
|
||||
</Popconfirm>
|
||||
</Space>
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
return (
|
||||
<div>
|
||||
<ProTable<PermissionTemplate>
|
||||
columns={columns}
|
||||
dataSource={data ?? []}
|
||||
loading={isLoading}
|
||||
rowKey="id"
|
||||
search={false}
|
||||
toolBarRender={() => [
|
||||
<Button
|
||||
key="add"
|
||||
type="primary"
|
||||
icon={<PlusOutlined />}
|
||||
onClick={() => {
|
||||
form.resetFields()
|
||||
setModalOpen(true)
|
||||
}}
|
||||
>
|
||||
新建模板
|
||||
</Button>,
|
||||
]}
|
||||
pagination={{ showSizeChanger: false }}
|
||||
/>
|
||||
|
||||
{/* Create Template Modal */}
|
||||
<Modal
|
||||
title="新建权限模板"
|
||||
open={modalOpen}
|
||||
onOk={async () => {
|
||||
const values = await form.validateFields()
|
||||
createMutation.mutate(values)
|
||||
}}
|
||||
onCancel={() => {
|
||||
setModalOpen(false)
|
||||
form.resetFields()
|
||||
}}
|
||||
confirmLoading={createMutation.isPending}
|
||||
width={560}
|
||||
>
|
||||
<Form form={form} layout="vertical" className="mt-4">
|
||||
<Form.Item
|
||||
name="name"
|
||||
label="模板名称"
|
||||
rules={[{ required: true, message: '请输入模板名称' }]}
|
||||
>
|
||||
<Input placeholder="如 basic-user, power-user" />
|
||||
</Form.Item>
|
||||
<Form.Item name="description" label="描述">
|
||||
<Input.TextArea rows={2} placeholder="模板用途说明" />
|
||||
</Form.Item>
|
||||
<Form.Item name="permissions" label="权限">
|
||||
<Select
|
||||
mode="multiple"
|
||||
placeholder="选择权限"
|
||||
options={permissionOptions}
|
||||
maxTagCount={5}
|
||||
allowClear
|
||||
filterOption={(input, option) =>
|
||||
(option?.label as string)?.toLowerCase().includes(input.toLowerCase())
|
||||
}
|
||||
/>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
|
||||
{/* Apply Template Modal */}
|
||||
<Modal
|
||||
title={`应用模板: ${selectedTemplate?.name ?? ''}`}
|
||||
open={applyOpen}
|
||||
onOk={handleApply}
|
||||
onCancel={() => {
|
||||
setApplyOpen(false)
|
||||
setSelectedTemplate(null)
|
||||
applyForm.resetFields()
|
||||
}}
|
||||
confirmLoading={applyMutation.isPending}
|
||||
width={480}
|
||||
>
|
||||
<Form form={applyForm} layout="vertical" className="mt-4">
|
||||
<div className="mb-4 text-sm text-neutral-500 dark:text-neutral-400">
|
||||
将模板的 {selectedTemplate?.permissions?.length ?? 0} 项权限应用到指定账号。
|
||||
请输入账号 ID,多个 ID 用逗号分隔。
|
||||
</div>
|
||||
<Form.Item
|
||||
name="account_ids"
|
||||
label="账号 ID"
|
||||
rules={[{ required: true, message: '请输入账号 ID' }]}
|
||||
>
|
||||
<Input.TextArea
|
||||
rows={3}
|
||||
placeholder="如: acc_abc123, acc_def456"
|
||||
/>
|
||||
</Form.Item>
|
||||
</Form>
|
||||
</Modal>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// Main Page: Roles & Permissions
|
||||
// ============================================================
|
||||
export default function Roles() {
|
||||
return (
|
||||
<div>
|
||||
<PageHeader
|
||||
title="角色与权限"
|
||||
description="管理角色、权限模板,并将权限批量应用到账号"
|
||||
/>
|
||||
|
||||
<Tabs
|
||||
defaultActiveKey="roles"
|
||||
items={[
|
||||
{
|
||||
key: 'roles',
|
||||
label: (
|
||||
<span className="flex items-center gap-1.5">
|
||||
<SafetyOutlined />
|
||||
角色
|
||||
</span>
|
||||
),
|
||||
children: <RolesTab />,
|
||||
},
|
||||
{
|
||||
key: 'templates',
|
||||
label: (
|
||||
<span className="flex items-center gap-1.5">
|
||||
<CheckCircleOutlined />
|
||||
权限模板
|
||||
</span>
|
||||
),
|
||||
children: <TemplatesTab />,
|
||||
},
|
||||
]}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -21,13 +21,16 @@ export const router = createBrowserRouter([
|
||||
children: [
|
||||
{ index: true, lazy: () => import('@/pages/Dashboard').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'accounts', lazy: () => import('@/pages/Accounts').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'roles', lazy: () => import('@/pages/Roles').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'model-services', lazy: () => import('@/pages/ModelServices').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'providers', lazy: () => import('@/pages/ModelServices').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'models', lazy: () => import('@/pages/ModelServices').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'agent-templates', lazy: () => import('@/pages/AgentTemplates').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'api-keys', lazy: () => import('@/pages/ModelServices').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'usage', lazy: () => import('@/pages/Usage').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'billing', lazy: () => import('@/pages/Billing').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'relay', lazy: () => import('@/pages/Relay').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'knowledge', lazy: () => import('@/pages/Knowledge').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'config', lazy: () => import('@/pages/Config').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'prompts', lazy: () => import('@/pages/Prompts').then((m) => ({ Component: m.default })) },
|
||||
{ path: 'logs', lazy: () => import('@/pages/Logs').then((m) => ({ Component: m.default })) },
|
||||
|
||||
101
admin-v2/src/services/billing.ts
Normal file
101
admin-v2/src/services/billing.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
import request, { withSignal } from './request'
|
||||
|
||||
// === Types ===
|
||||
|
||||
export interface BillingPlan {
|
||||
id: string
|
||||
name: string
|
||||
display_name: string
|
||||
description: string | null
|
||||
price_cents: number
|
||||
currency: string
|
||||
interval: string
|
||||
features: Record<string, unknown>
|
||||
limits: Record<string, unknown>
|
||||
is_default: boolean
|
||||
sort_order: number
|
||||
status: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface Subscription {
|
||||
id: string
|
||||
account_id: string
|
||||
plan_id: string
|
||||
status: string
|
||||
current_period_start: string
|
||||
current_period_end: string
|
||||
trial_end: string | null
|
||||
canceled_at: string | null
|
||||
cancel_at_period_end: boolean
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface UsageQuota {
|
||||
id: string
|
||||
account_id: string
|
||||
period_start: string
|
||||
period_end: string
|
||||
input_tokens: number
|
||||
output_tokens: number
|
||||
relay_requests: number
|
||||
hand_executions: number
|
||||
pipeline_runs: number
|
||||
max_input_tokens: number | null
|
||||
max_output_tokens: number | null
|
||||
max_relay_requests: number | null
|
||||
max_hand_executions: number | null
|
||||
max_pipeline_runs: number | null
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface SubscriptionInfo {
|
||||
plan: BillingPlan
|
||||
subscription: Subscription | null
|
||||
usage: UsageQuota
|
||||
}
|
||||
|
||||
export interface PaymentResult {
|
||||
payment_id: string
|
||||
trade_no: string
|
||||
pay_url: string
|
||||
amount_cents: number
|
||||
}
|
||||
|
||||
export interface PaymentStatus {
|
||||
id: string
|
||||
method: string
|
||||
amount_cents: number
|
||||
currency: string
|
||||
status: string
|
||||
}
|
||||
|
||||
// === Service ===
|
||||
|
||||
export const billingService = {
|
||||
listPlans: (signal?: AbortSignal) =>
|
||||
request.get<BillingPlan[]>('/billing/plans', withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
getPlan: (id: string, signal?: AbortSignal) =>
|
||||
request.get<BillingPlan>(`/billing/plans/${id}`, withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
getSubscription: (signal?: AbortSignal) =>
|
||||
request.get<SubscriptionInfo>('/billing/subscription', withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
getUsage: (signal?: AbortSignal) =>
|
||||
request.get<UsageQuota>('/billing/usage', withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
createPayment: (data: { plan_id: string; payment_method: 'alipay' | 'wechat' }) =>
|
||||
request.post<PaymentResult>('/billing/payments', data).then((r) => r.data),
|
||||
|
||||
getPaymentStatus: (id: string, signal?: AbortSignal) =>
|
||||
request.get<PaymentStatus>(`/billing/payments/${id}`, withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
}
|
||||
162
admin-v2/src/services/knowledge.ts
Normal file
162
admin-v2/src/services/knowledge.ts
Normal file
@@ -0,0 +1,162 @@
|
||||
import request, { withSignal } from './request'
|
||||
|
||||
// === Types ===
|
||||
|
||||
export interface CategoryResponse {
|
||||
id: string
|
||||
name: string
|
||||
description: string | null
|
||||
parent_id: string | null
|
||||
icon: string | null
|
||||
sort_order: number
|
||||
item_count: number
|
||||
children: CategoryResponse[]
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface KnowledgeItem {
|
||||
id: string
|
||||
category_id: string
|
||||
title: string
|
||||
content: string
|
||||
keywords: string[]
|
||||
related_questions: string[]
|
||||
priority: number
|
||||
status: string
|
||||
version: number
|
||||
source: string
|
||||
tags: string[]
|
||||
created_by: string
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
export interface SearchResult {
|
||||
chunk_id: string
|
||||
item_id: string
|
||||
item_title: string
|
||||
category_name: string
|
||||
content: string
|
||||
score: number
|
||||
keywords: string[]
|
||||
}
|
||||
|
||||
export interface AnalyticsOverview {
|
||||
total_items: number
|
||||
active_items: number
|
||||
total_categories: number
|
||||
weekly_new_items: number
|
||||
total_references: number
|
||||
avg_reference_per_item: number
|
||||
hit_rate: number
|
||||
injection_rate: number
|
||||
positive_feedback_rate: number
|
||||
stale_items_count: number
|
||||
}
|
||||
|
||||
export interface ListItemsResponse {
|
||||
items: KnowledgeItem[]
|
||||
total: number
|
||||
page: number
|
||||
page_size: number
|
||||
}
|
||||
|
||||
// === Service ===
|
||||
|
||||
export const knowledgeService = {
|
||||
// 分类
|
||||
listCategories: (signal?: AbortSignal) =>
|
||||
request.get<CategoryResponse[]>('/knowledge/categories', withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
createCategory: (data: { name: string; description?: string; parent_id?: string; icon?: string }) =>
|
||||
request.post('/knowledge/categories', data).then((r) => r.data),
|
||||
|
||||
deleteCategory: (id: string) =>
|
||||
request.delete(`/knowledge/categories/${id}`).then((r) => r.data),
|
||||
|
||||
updateCategory: (id: string, data: { name?: string; description?: string; parent_id?: string; icon?: string }) =>
|
||||
request.put(`/knowledge/categories/${id}`, data).then((r) => r.data),
|
||||
|
||||
reorderCategories: (items: Array<{ id: string; sort_order: number }>) =>
|
||||
request.patch('/knowledge/categories/reorder', { items }).then((r) => r.data),
|
||||
|
||||
getCategoryItems: (id: string, params?: { page?: number; page_size?: number; status?: string }, signal?: AbortSignal) =>
|
||||
request.get<ListItemsResponse>(`/knowledge/categories/${id}/items`, withSignal({ params }, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
// 条目
|
||||
listItems: (params: { page?: number; page_size?: number; category_id?: string; status?: string; keyword?: string }, signal?: AbortSignal) =>
|
||||
request.get<ListItemsResponse>('/knowledge/items', withSignal({ params }, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
getItem: (id: string, signal?: AbortSignal) =>
|
||||
request.get<KnowledgeItem>(`/knowledge/items/${id}`, withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
createItem: (data: {
|
||||
category_id: string
|
||||
title: string
|
||||
content: string
|
||||
keywords?: string[]
|
||||
related_questions?: string[]
|
||||
priority?: number
|
||||
tags?: string[]
|
||||
}) => request.post('/knowledge/items', data).then((r) => r.data),
|
||||
|
||||
updateItem: (id: string, data: Record<string, unknown>) =>
|
||||
request.put(`/knowledge/items/${id}`, data).then((r) => r.data),
|
||||
|
||||
deleteItem: (id: string) =>
|
||||
request.delete(`/knowledge/items/${id}`).then((r) => r.data),
|
||||
|
||||
batchCreate: (items: Array<{
|
||||
category_id: string
|
||||
title: string
|
||||
content: string
|
||||
keywords?: string[]
|
||||
tags?: string[]
|
||||
}>) => request.post('/knowledge/items/batch', items).then((r) => r.data),
|
||||
|
||||
// 搜索
|
||||
search: (data: { query: string; category_id?: string; limit?: number }) =>
|
||||
request.post<SearchResult[]>('/knowledge/search', data).then((r) => r.data),
|
||||
|
||||
// 分析
|
||||
getOverview: (signal?: AbortSignal) =>
|
||||
request.get<AnalyticsOverview>('/knowledge/analytics/overview', withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
getTrends: (signal?: AbortSignal) =>
|
||||
request.get('/knowledge/analytics/trends', withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
getTopItems: (signal?: AbortSignal) =>
|
||||
request.get('/knowledge/analytics/top-items', withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
getQuality: (signal?: AbortSignal) =>
|
||||
request.get('/knowledge/analytics/quality', withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
getGaps: (signal?: AbortSignal) =>
|
||||
request.get('/knowledge/analytics/gaps', withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
// 版本
|
||||
getVersions: (itemId: string, signal?: AbortSignal) =>
|
||||
request.get(`/knowledge/items/${itemId}/versions`, withSignal({}, signal))
|
||||
.then((r) => r.data),
|
||||
|
||||
rollbackVersion: (itemId: string, version: number) =>
|
||||
request.post(`/knowledge/items/${itemId}/rollback/${version}`).then((r) => r.data),
|
||||
|
||||
// 推荐搜索
|
||||
recommend: (data: { query: string; category_id?: string; limit?: number }) =>
|
||||
request.post<SearchResult[]>('/knowledge/recommend', data).then((r) => r.data),
|
||||
|
||||
// 导入
|
||||
importItems: (data: { category_id: string; files: Array<{ content: string; title?: string; keywords?: string[]; tags?: string[] }> }) =>
|
||||
request.post('/knowledge/items/import', data).then((r) => r.data),
|
||||
}
|
||||
50
admin-v2/src/services/roles.ts
Normal file
50
admin-v2/src/services/roles.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
// ============================================================
|
||||
// 角色与权限模板 服务层
|
||||
// ============================================================
|
||||
|
||||
import request, { withSignal } from './request'
|
||||
import type {
|
||||
Role,
|
||||
PermissionTemplate,
|
||||
CreateRoleRequest,
|
||||
UpdateRoleRequest,
|
||||
CreateTemplateRequest,
|
||||
} from '@/types'
|
||||
|
||||
export const roleService = {
|
||||
// ── Roles ─────────────────────────────────────────────────
|
||||
list: (signal?: AbortSignal) =>
|
||||
request.get<Role[]>('/roles', withSignal({}, signal)).then((r) => r.data),
|
||||
|
||||
get: (id: string, signal?: AbortSignal) =>
|
||||
request.get<Role>(`/roles/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||
|
||||
create: (data: CreateRoleRequest, signal?: AbortSignal) =>
|
||||
request.post<Role>('/roles', data, withSignal({}, signal)).then((r) => r.data),
|
||||
|
||||
update: (id: string, data: UpdateRoleRequest, signal?: AbortSignal) =>
|
||||
request.put<Role>(`/roles/${id}`, data, withSignal({}, signal)).then((r) => r.data),
|
||||
|
||||
delete: (id: string, signal?: AbortSignal) =>
|
||||
request.delete(`/roles/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||
|
||||
// ── Role Permissions ──────────────────────────────────────
|
||||
getPermissions: (roleId: string, signal?: AbortSignal) =>
|
||||
request.get<string[]>(`/roles/${roleId}/permissions`, withSignal({}, signal)).then((r) => r.data),
|
||||
|
||||
// ── Permission Templates ──────────────────────────────────
|
||||
listTemplates: (signal?: AbortSignal) =>
|
||||
request.get<PermissionTemplate[]>('/permission-templates', withSignal({}, signal)).then((r) => r.data),
|
||||
|
||||
getTemplate: (id: string, signal?: AbortSignal) =>
|
||||
request.get<PermissionTemplate>(`/permission-templates/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||
|
||||
createTemplate: (data: CreateTemplateRequest, signal?: AbortSignal) =>
|
||||
request.post<PermissionTemplate>('/permission-templates', data, withSignal({}, signal)).then((r) => r.data),
|
||||
|
||||
deleteTemplate: (id: string, signal?: AbortSignal) =>
|
||||
request.delete(`/permission-templates/${id}`, withSignal({}, signal)).then((r) => r.data),
|
||||
|
||||
applyTemplate: (templateId: string, accountIds: string[], signal?: AbortSignal) =>
|
||||
request.post(`/permission-templates/${templateId}/apply`, { account_ids: accountIds }, withSignal({}, signal)).then((r) => r.data),
|
||||
}
|
||||
@@ -282,3 +282,45 @@ export interface DailyUsageStat {
|
||||
output_tokens: number
|
||||
unique_devices: number
|
||||
}
|
||||
|
||||
/** 角色 */
|
||||
export interface Role {
|
||||
id: string
|
||||
name: string
|
||||
description: string
|
||||
permissions: string[]
|
||||
account_count?: number
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
/** 权限模板 */
|
||||
export interface PermissionTemplate {
|
||||
id: string
|
||||
name: string
|
||||
description: string
|
||||
permissions: string[]
|
||||
created_at: string
|
||||
updated_at: string
|
||||
}
|
||||
|
||||
/** 创建角色请求 */
|
||||
export interface CreateRoleRequest {
|
||||
name: string
|
||||
description?: string
|
||||
permissions?: string[]
|
||||
}
|
||||
|
||||
/** 更新角色请求 */
|
||||
export interface UpdateRoleRequest {
|
||||
name?: string
|
||||
description?: string
|
||||
permissions?: string[]
|
||||
}
|
||||
|
||||
/** 创建权限模板请求 */
|
||||
export interface CreateTemplateRequest {
|
||||
name: string
|
||||
description?: string
|
||||
permissions?: string[]
|
||||
}
|
||||
|
||||
219
admin-v2/tests/pages/Config.test.tsx
Normal file
219
admin-v2/tests/pages/Config.test.tsx
Normal file
@@ -0,0 +1,219 @@
|
||||
// ============================================================
|
||||
// Config 页面测试
|
||||
// ============================================================
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { http, HttpResponse } from 'msw'
|
||||
import { setupServer } from 'msw/node'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
import Config from '@/pages/Config'
|
||||
|
||||
// ── Mock data ────────────────────────────────────────────────
|
||||
|
||||
const mockConfigItems = [
|
||||
{
|
||||
id: 'cfg-001',
|
||||
category: 'general',
|
||||
key_path: 'general.app_name',
|
||||
value_type: 'string',
|
||||
current_value: 'ZCLAW',
|
||||
default_value: 'ZCLAW',
|
||||
source: 'database',
|
||||
description: '应用程序名称',
|
||||
requires_restart: false,
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
updated_at: '2026-01-01T00:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 'cfg-002',
|
||||
category: 'general',
|
||||
key_path: 'general.debug_mode',
|
||||
value_type: 'boolean',
|
||||
current_value: 'false',
|
||||
default_value: 'false',
|
||||
source: 'default',
|
||||
description: '调试模式开关',
|
||||
requires_restart: true,
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
updated_at: '2026-01-01T00:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 'cfg-003',
|
||||
category: 'general',
|
||||
key_path: 'general.max_connections',
|
||||
value_type: 'integer',
|
||||
current_value: null,
|
||||
default_value: '100',
|
||||
source: 'default',
|
||||
description: '最大连接数',
|
||||
requires_restart: false,
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
updated_at: '2026-01-01T00:00:00Z',
|
||||
},
|
||||
]
|
||||
|
||||
const mockResponse = {
|
||||
items: mockConfigItems,
|
||||
total: 3,
|
||||
page: 1,
|
||||
page_size: 50,
|
||||
}
|
||||
|
||||
// ── MSW server ───────────────────────────────────────────────
|
||||
|
||||
const server = setupServer()
|
||||
|
||||
beforeEach(() => {
|
||||
server.listen({ onUnhandledRequest: 'bypass' })
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
server.close()
|
||||
})
|
||||
|
||||
// ── Helper: render with QueryClient ──────────────────────────
|
||||
|
||||
function renderWithProviders(ui: React.ReactElement) {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
})
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────
|
||||
|
||||
describe('Config page', () => {
|
||||
it('renders page header', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/config/items', () => {
|
||||
return HttpResponse.json(mockResponse)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Config />)
|
||||
|
||||
expect(screen.getByText('系统配置')).toBeInTheDocument()
|
||||
expect(screen.getByText('管理系统运行参数和功能开关')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('fetches and displays config items', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/config/items', () => {
|
||||
return HttpResponse.json(mockResponse)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Config />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('general.app_name')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('general.debug_mode')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows loading spinner while fetching', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/config/items', async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
return HttpResponse.json(mockResponse)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Config />)
|
||||
|
||||
// Ant Design Spin component renders a .ant-spin element
|
||||
const spinner = document.querySelector('.ant-spin')
|
||||
expect(spinner).toBeTruthy()
|
||||
|
||||
// Wait for loading to complete so afterEach cleanup is clean
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('general.app_name')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('shows error state on API failure', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/config/items', () => {
|
||||
return HttpResponse.json(
|
||||
{ error: 'internal_error', message: '服务器内部错误' },
|
||||
{ status: 500 },
|
||||
)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Config />)
|
||||
|
||||
// Config page does not have a dedicated ErrorState; the ProTable simply
|
||||
// renders empty when the query fails. We verify the page header is still
|
||||
// rendered and the table body has no data rows (shows "暂无数据").
|
||||
await waitFor(() => {
|
||||
const emptyElements = screen.queryAllByText('暂无数据')
|
||||
expect(emptyElements.length).toBeGreaterThanOrEqual(1)
|
||||
})
|
||||
// Page header is still present even on error
|
||||
expect(screen.getByText('系统配置')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders config key_path and current_value columns', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/config/items', () => {
|
||||
return HttpResponse.json(mockResponse)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Config />)
|
||||
|
||||
// key_path values are rendered in <code> elements
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('general.app_name')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('general.debug_mode')).toBeInTheDocument()
|
||||
|
||||
// current_value "ZCLAW" appears in both the current_value column and default_value column
|
||||
const zclawElements = screen.getAllByText('ZCLAW')
|
||||
expect(zclawElements.length).toBeGreaterThanOrEqual(1)
|
||||
})
|
||||
|
||||
it('renders requires_restart column with tags', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/config/items', () => {
|
||||
return HttpResponse.json(mockResponse)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Config />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('general.app_name')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// requires_restart=true renders "是" (orange tag)
|
||||
expect(screen.getByText('是')).toBeInTheDocument()
|
||||
// requires_restart=false renders "否" (may appear multiple times for two items)
|
||||
const noTags = screen.getAllByText('否')
|
||||
expect(noTags.length).toBeGreaterThanOrEqual(1)
|
||||
})
|
||||
|
||||
it('renders category tabs', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/config/items', () => {
|
||||
return HttpResponse.json(mockResponse)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Config />)
|
||||
|
||||
expect(screen.getByText('通用')).toBeInTheDocument()
|
||||
expect(screen.getByText('认证')).toBeInTheDocument()
|
||||
expect(screen.getByText('中转')).toBeInTheDocument()
|
||||
expect(screen.getByText('模型')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
242
admin-v2/tests/pages/Dashboard.test.tsx
Normal file
242
admin-v2/tests/pages/Dashboard.test.tsx
Normal file
@@ -0,0 +1,242 @@
|
||||
// ============================================================
|
||||
// Dashboard 页面测试
|
||||
// ============================================================
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { http, HttpResponse } from 'msw'
|
||||
import { setupServer } from 'msw/node'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
import Dashboard from '@/pages/Dashboard'
|
||||
|
||||
// ── Mock data ────────────────────────────────────────────────
|
||||
|
||||
const mockStats = {
|
||||
total_accounts: 12,
|
||||
active_accounts: 8,
|
||||
tasks_today: 156,
|
||||
active_providers: 3,
|
||||
active_models: 7,
|
||||
tokens_today_input: 24000,
|
||||
tokens_today_output: 8500,
|
||||
}
|
||||
|
||||
const mockLogs = {
|
||||
items: [
|
||||
{
|
||||
id: 1,
|
||||
account_id: 'acc-001',
|
||||
action: 'login',
|
||||
target_type: 'account',
|
||||
target_id: 'acc-001',
|
||||
details: null,
|
||||
ip_address: '192.168.1.1',
|
||||
created_at: '2026-03-30T10:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 2,
|
||||
account_id: 'acc-002',
|
||||
action: 'create_provider',
|
||||
target_type: 'provider',
|
||||
target_id: 'prov-001',
|
||||
details: { name: 'OpenAI' },
|
||||
ip_address: '10.0.0.1',
|
||||
created_at: '2026-03-30T09:30:00Z',
|
||||
},
|
||||
],
|
||||
total: 2,
|
||||
page: 1,
|
||||
page_size: 10,
|
||||
}
|
||||
|
||||
// ── MSW server ───────────────────────────────────────────────
|
||||
|
||||
const server = setupServer()
|
||||
|
||||
beforeEach(() => {
|
||||
server.listen({ onUnhandledRequest: 'bypass' })
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
server.close()
|
||||
})
|
||||
|
||||
// ── Helper: render with QueryClient ──────────────────────────
|
||||
|
||||
function renderWithProviders(ui: React.ReactElement) {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
})
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────
|
||||
|
||||
describe('Dashboard page', () => {
|
||||
it('renders page header', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/stats/dashboard', () => {
|
||||
return HttpResponse.json(mockStats)
|
||||
}),
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Dashboard />)
|
||||
|
||||
expect(screen.getByText('仪表盘')).toBeInTheDocument()
|
||||
expect(screen.getByText('系统概览与最近活动')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders stat cards with correct values', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/stats/dashboard', () => {
|
||||
return HttpResponse.json(mockStats)
|
||||
}),
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Dashboard />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('12')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Stat titles
|
||||
expect(screen.getByText('总账号')).toBeInTheDocument()
|
||||
expect(screen.getByText('活跃服务商')).toBeInTheDocument()
|
||||
expect(screen.getByText('活跃模型')).toBeInTheDocument()
|
||||
expect(screen.getByText('今日请求')).toBeInTheDocument()
|
||||
expect(screen.getByText('今日 Token')).toBeInTheDocument()
|
||||
|
||||
// Token total: 24000 + 8500 = 32500
|
||||
expect(screen.getByText('32,500')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders recent logs table with action labels', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/stats/dashboard', () => {
|
||||
return HttpResponse.json(mockStats)
|
||||
}),
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Dashboard />)
|
||||
|
||||
// Wait for action labels from constants/status.ts
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('登录')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('创建服务商')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders target types in logs table', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/stats/dashboard', () => {
|
||||
return HttpResponse.json(mockStats)
|
||||
}),
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Dashboard />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('登录')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('account')).toBeInTheDocument()
|
||||
expect(screen.getByText('provider')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows loading spinner before stats load', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/stats/dashboard', async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
return HttpResponse.json(mockStats)
|
||||
}),
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Dashboard />)
|
||||
|
||||
// Ant Design Spin component renders a .ant-spin element
|
||||
const spinner = document.querySelector('.ant-spin')
|
||||
expect(spinner).toBeTruthy()
|
||||
|
||||
// Wait for loading to complete so afterEach cleanup is clean
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('总账号')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('shows error state when stats request fails', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/stats/dashboard', () => {
|
||||
return HttpResponse.json(
|
||||
{ error: 'internal_error', message: '服务器内部错误' },
|
||||
{ status: 500 },
|
||||
)
|
||||
}),
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Dashboard />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('服务器内部错误')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('renders stat cards with zero values when stats are null', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/stats/dashboard', () => {
|
||||
return HttpResponse.json({})
|
||||
}),
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json({ items: [], total: 0, page: 1, page_size: 10 })
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Dashboard />)
|
||||
|
||||
// All stats should fallback to 0
|
||||
await waitFor(() => {
|
||||
const zeros = screen.getAllByText('0')
|
||||
expect(zeros.length).toBeGreaterThanOrEqual(2)
|
||||
})
|
||||
})
|
||||
|
||||
it('renders recent logs section header', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/stats/dashboard', () => {
|
||||
return HttpResponse.json(mockStats)
|
||||
}),
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Dashboard />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('最近操作日志')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
219
admin-v2/tests/pages/Login.test.tsx
Normal file
219
admin-v2/tests/pages/Login.test.tsx
Normal file
@@ -0,0 +1,219 @@
|
||||
// ============================================================
|
||||
// Login 页面测试
|
||||
// ============================================================
|
||||
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import userEvent from '@testing-library/user-event'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
import { MemoryRouter } from 'react-router-dom'
|
||||
|
||||
import Login from '@/pages/Login'
|
||||
|
||||
// ── Mock data ────────────────────────────────────────────────
|
||||
|
||||
const mockLoginResponse = {
|
||||
token: 'jwt-token-123',
|
||||
refresh_token: 'refresh-token-456',
|
||||
account: {
|
||||
id: 'acc-001',
|
||||
username: 'testadmin',
|
||||
email: 'admin@zclaw.ai',
|
||||
display_name: 'Admin',
|
||||
role: 'super_admin',
|
||||
status: 'active',
|
||||
totp_enabled: false,
|
||||
last_login_at: null,
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
llm_routing: 'relay',
|
||||
},
|
||||
}
|
||||
|
||||
const mockAccount = {
|
||||
id: 'acc-001',
|
||||
username: 'testadmin',
|
||||
email: 'admin@zclaw.ai',
|
||||
display_name: 'Admin',
|
||||
role: 'super_admin',
|
||||
status: 'active',
|
||||
totp_enabled: false,
|
||||
last_login_at: null,
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
llm_routing: 'relay',
|
||||
}
|
||||
|
||||
// ── Hoisted mocks ────────────────────────────────────────────
|
||||
|
||||
const { mockLogin, mockNavigate, mockAuthServiceLogin } = vi.hoisted(() => ({
|
||||
mockLogin: vi.fn(),
|
||||
mockNavigate: vi.fn(),
|
||||
mockAuthServiceLogin: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/stores/authStore', () => ({
|
||||
useAuthStore: Object.assign(
|
||||
vi.fn((selector: (s: Record<string, unknown>) => unknown) =>
|
||||
selector({ login: mockLogin }),
|
||||
),
|
||||
{ getState: () => ({ token: null, refreshToken: null, logout: vi.fn() }) },
|
||||
),
|
||||
}))
|
||||
|
||||
vi.mock('@/services/auth', () => ({
|
||||
authService: {
|
||||
login: mockAuthServiceLogin,
|
||||
},
|
||||
}))
|
||||
|
||||
vi.mock('react-router-dom', async () => {
|
||||
const actual = await vi.importActual<typeof import('react-router-dom')>('react-router-dom')
|
||||
return {
|
||||
...actual,
|
||||
useNavigate: () => mockNavigate,
|
||||
}
|
||||
})
|
||||
|
||||
beforeEach(() => {
|
||||
mockLogin.mockClear()
|
||||
mockNavigate.mockClear()
|
||||
mockAuthServiceLogin.mockClear()
|
||||
})
|
||||
|
||||
// ── Helper: render with providers ────────────────────────────
|
||||
|
||||
function renderLogin(initialEntries = ['/login']) {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
})
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
<MemoryRouter initialEntries={initialEntries}>
|
||||
<Login />
|
||||
</MemoryRouter>
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
/** Click the LoginForm submit button (Ant Design renders "登 录" with a space) */
|
||||
function getSubmitButton(): HTMLElement {
|
||||
const btn = document.querySelector<HTMLButtonElement>(
|
||||
'button.ant-btn-primary[type="button"]',
|
||||
)
|
||||
if (!btn) throw new Error('Submit button not found')
|
||||
return btn
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────
|
||||
|
||||
describe('Login page', () => {
|
||||
it('renders the login form with username and password fields', () => {
|
||||
renderLogin()
|
||||
|
||||
expect(screen.getByText('登录到 ZCLAW')).toBeInTheDocument()
|
||||
expect(screen.getByPlaceholderText('请输入用户名')).toBeInTheDocument()
|
||||
expect(screen.getByPlaceholderText('请输入密码')).toBeInTheDocument()
|
||||
const submitButton = getSubmitButton()
|
||||
expect(submitButton).toBeTruthy()
|
||||
})
|
||||
|
||||
it('shows the ZCLAW brand logo', () => {
|
||||
renderLogin()
|
||||
|
||||
expect(screen.getByText('Z')).toBeInTheDocument()
|
||||
expect(screen.getByText(/ZCLAW Admin/)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('successful login calls authStore.login and navigates to /', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockAuthServiceLogin.mockResolvedValue(mockLoginResponse)
|
||||
|
||||
renderLogin()
|
||||
|
||||
await user.type(screen.getByPlaceholderText('请输入用户名'), 'testadmin')
|
||||
await user.type(screen.getByPlaceholderText('请输入密码'), 'password123')
|
||||
await user.click(getSubmitButton())
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockLogin).toHaveBeenCalledWith(
|
||||
'jwt-token-123',
|
||||
'refresh-token-456',
|
||||
mockAccount,
|
||||
)
|
||||
})
|
||||
|
||||
expect(mockNavigate).toHaveBeenCalledWith('/', { replace: true })
|
||||
})
|
||||
|
||||
it('navigates to redirect path after login', async () => {
|
||||
const user = userEvent.setup()
|
||||
mockAuthServiceLogin.mockResolvedValue(mockLoginResponse)
|
||||
|
||||
renderLogin(['/login?from=/settings'])
|
||||
|
||||
await user.type(screen.getByPlaceholderText('请输入用户名'), 'testadmin')
|
||||
await user.type(screen.getByPlaceholderText('请输入密码'), 'password123')
|
||||
await user.click(getSubmitButton())
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockNavigate).toHaveBeenCalledWith('/settings', { replace: true })
|
||||
})
|
||||
})
|
||||
|
||||
it('shows TOTP field when server returns TOTP-related error', async () => {
|
||||
const user = userEvent.setup()
|
||||
const error = new Error('请输入两步验证码 (TOTP)')
|
||||
Object.assign(error, { status: 403 })
|
||||
mockAuthServiceLogin.mockRejectedValue(error)
|
||||
|
||||
renderLogin()
|
||||
|
||||
// Initially no TOTP field
|
||||
expect(screen.queryByPlaceholderText('请输入 6 位验证码')).not.toBeInTheDocument()
|
||||
|
||||
await user.type(screen.getByPlaceholderText('请输入用户名'), 'testadmin')
|
||||
await user.type(screen.getByPlaceholderText('请输入密码'), 'password123')
|
||||
await user.click(getSubmitButton())
|
||||
|
||||
// After TOTP error, TOTP field appears
|
||||
await waitFor(() => {
|
||||
expect(screen.getByPlaceholderText('请输入 6 位验证码')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('shows error message on invalid credentials', async () => {
|
||||
const user = userEvent.setup()
|
||||
const error = new Error('用户名或密码错误')
|
||||
mockAuthServiceLogin.mockRejectedValue(error)
|
||||
|
||||
renderLogin()
|
||||
|
||||
await user.type(screen.getByPlaceholderText('请输入用户名'), 'wrong')
|
||||
await user.type(screen.getByPlaceholderText('请输入密码'), 'wrong')
|
||||
await user.click(getSubmitButton())
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('用户名或密码错误')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('does not call authStore.login on failed login', async () => {
|
||||
const user = userEvent.setup()
|
||||
const error = new Error('用户名或密码错误')
|
||||
mockAuthServiceLogin.mockRejectedValue(error)
|
||||
|
||||
renderLogin()
|
||||
|
||||
await user.type(screen.getByPlaceholderText('请输入用户名'), 'wrong')
|
||||
await user.type(screen.getByPlaceholderText('请输入密码'), 'wrong')
|
||||
await user.click(getSubmitButton())
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('用户名或密码错误')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(mockLogin).not.toHaveBeenCalled()
|
||||
expect(mockNavigate).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
210
admin-v2/tests/pages/Logs.test.tsx
Normal file
210
admin-v2/tests/pages/Logs.test.tsx
Normal file
@@ -0,0 +1,210 @@
|
||||
// ============================================================
|
||||
// Logs 页面测试
|
||||
// ============================================================
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { http, HttpResponse } from 'msw'
|
||||
import { setupServer } from 'msw/node'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
import Logs from '@/pages/Logs'
|
||||
|
||||
// ── Mock data ────────────────────────────────────────────────
|
||||
|
||||
const mockLogs = {
|
||||
items: [
|
||||
{
|
||||
id: 1,
|
||||
account_id: 'acc-001',
|
||||
action: 'login',
|
||||
target_type: 'account',
|
||||
target_id: 'acc-001',
|
||||
details: null,
|
||||
ip_address: '192.168.1.1',
|
||||
created_at: '2026-03-30T10:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 2,
|
||||
account_id: 'acc-002',
|
||||
action: 'create_provider',
|
||||
target_type: 'provider',
|
||||
target_id: 'prov-001',
|
||||
details: { name: 'OpenAI' },
|
||||
ip_address: '10.0.0.1',
|
||||
created_at: '2026-03-30T09:30:00Z',
|
||||
},
|
||||
{
|
||||
id: 3,
|
||||
account_id: 'acc-001',
|
||||
action: 'delete_model',
|
||||
target_type: 'model',
|
||||
target_id: 'mdl-001',
|
||||
details: null,
|
||||
ip_address: '192.168.1.1',
|
||||
created_at: '2026-03-29T14:00:00Z',
|
||||
},
|
||||
],
|
||||
total: 3,
|
||||
page: 1,
|
||||
page_size: 20,
|
||||
}
|
||||
|
||||
// ── MSW server ───────────────────────────────────────────────
|
||||
|
||||
const server = setupServer()
|
||||
|
||||
beforeEach(() => {
|
||||
server.listen({ onUnhandledRequest: 'bypass' })
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
server.close()
|
||||
})
|
||||
|
||||
// ── Helper: render with QueryClient ──────────────────────────
|
||||
|
||||
function renderWithProviders(ui: React.ReactElement) {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
})
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────
|
||||
|
||||
describe('Logs page', () => {
|
||||
it('renders page header', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Logs />)
|
||||
|
||||
expect(screen.getByText('操作日志')).toBeInTheDocument()
|
||||
expect(screen.getByText('系统审计与操作记录')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('fetches and displays log entries', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Logs />)
|
||||
|
||||
// Wait for action labels rendered from constants/status.ts
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('登录')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('创建服务商')).toBeInTheDocument()
|
||||
expect(screen.getByText('删除模型')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows loading spinner while fetching', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/logs/operations', async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Logs />)
|
||||
|
||||
// Ant Design Spin component renders a .ant-spin element
|
||||
const spinner = document.querySelector('.ant-spin')
|
||||
expect(spinner).toBeTruthy()
|
||||
|
||||
// Wait for loading to complete so afterEach cleanup is clean
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('登录')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('shows ErrorState on API failure with retry button', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(
|
||||
{ error: 'internal_error', message: '服务器内部错误' },
|
||||
{ status: 500 },
|
||||
)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Logs />)
|
||||
|
||||
// ErrorState renders the error message
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('服务器内部错误')).toBeInTheDocument()
|
||||
})
|
||||
// Ant Design Button splits two-character text with a space: "重 试"
|
||||
const retryButton = screen.getByRole('button', { name: /重.?试/ })
|
||||
expect(retryButton).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders action as a colored tag', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Logs />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('登录')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Verify the action tags have the correct Ant Design color classes
|
||||
const loginTag = screen.getByText('登录').closest('.ant-tag')
|
||||
expect(loginTag).toBeTruthy()
|
||||
// actionColors.login = 'green' → Ant Design renders ant-tag-green or ant-tag-color-green
|
||||
expect(loginTag?.className).toMatch(/green/)
|
||||
})
|
||||
|
||||
it('renders IP address column', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Logs />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('登录')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// 192.168.1.1 appears twice (two log entries from the same IP)
|
||||
const ip1Elements = screen.getAllByText('192.168.1.1')
|
||||
expect(ip1Elements.length).toBeGreaterThanOrEqual(1)
|
||||
expect(screen.getByText('10.0.0.1')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders target_type column', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/logs/operations', () => {
|
||||
return HttpResponse.json(mockLogs)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Logs />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('登录')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(screen.getByText('account')).toBeInTheDocument()
|
||||
expect(screen.getByText('provider')).toBeInTheDocument()
|
||||
expect(screen.getByText('model')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
184
admin-v2/tests/pages/ModelServices.test.tsx
Normal file
184
admin-v2/tests/pages/ModelServices.test.tsx
Normal file
@@ -0,0 +1,184 @@
|
||||
// ============================================================
|
||||
// ModelServices 页面测试
|
||||
// ============================================================
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { http, HttpResponse } from 'msw'
|
||||
import { setupServer } from 'msw/node'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
import ModelServices from '@/pages/ModelServices'
|
||||
|
||||
// ── Mock data ────────────────────────────────────────────────
|
||||
|
||||
const mockProviders = {
|
||||
items: [
|
||||
{
|
||||
id: 'prov-001',
|
||||
name: 'openai',
|
||||
display_name: 'OpenAI',
|
||||
base_url: 'https://api.openai.com/v1',
|
||||
api_protocol: 'openai',
|
||||
enabled: true,
|
||||
rate_limit_rpm: 500,
|
||||
rate_limit_tpm: null,
|
||||
created_at: '2026-01-01T00:00:00Z',
|
||||
updated_at: '2026-03-15T10:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 'prov-002',
|
||||
name: 'anthropic',
|
||||
display_name: 'Anthropic',
|
||||
base_url: 'https://api.anthropic.com',
|
||||
api_protocol: 'anthropic',
|
||||
enabled: false,
|
||||
rate_limit_rpm: 200,
|
||||
rate_limit_tpm: null,
|
||||
created_at: '2026-02-01T00:00:00Z',
|
||||
updated_at: '2026-03-01T00:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 'prov-003',
|
||||
name: 'deepseek',
|
||||
display_name: 'DeepSeek',
|
||||
base_url: 'https://api.deepseek.com/v1',
|
||||
api_protocol: 'openai',
|
||||
enabled: true,
|
||||
rate_limit_rpm: null,
|
||||
rate_limit_tpm: null,
|
||||
created_at: '2026-03-01T00:00:00Z',
|
||||
updated_at: '2026-03-01T00:00:00Z',
|
||||
},
|
||||
],
|
||||
total: 3,
|
||||
page: 1,
|
||||
page_size: 20,
|
||||
}
|
||||
|
||||
// ── MSW server ───────────────────────────────────────────────
|
||||
|
||||
const server = setupServer()
|
||||
|
||||
beforeEach(() => {
|
||||
server.listen({ onUnhandledRequest: 'bypass' })
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
server.close()
|
||||
})
|
||||
|
||||
// ── Helper: render with QueryClient ──────────────────────────
|
||||
|
||||
function renderWithProviders(ui: React.ReactElement) {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
})
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────
|
||||
|
||||
describe('ModelServices page', () => {
|
||||
it('renders page header', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/providers', () => {
|
||||
return HttpResponse.json(mockProviders)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<ModelServices />)
|
||||
|
||||
expect(screen.getByText('模型服务')).toBeInTheDocument()
|
||||
expect(screen.getByText('管理 AI 服务商、模型配置和 Key 池')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('fetches and displays providers', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/providers', () => {
|
||||
return HttpResponse.json(mockProviders)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<ModelServices />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('OpenAI')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(screen.getByText('Anthropic')).toBeInTheDocument()
|
||||
expect(screen.getByText('DeepSeek')).toBeInTheDocument()
|
||||
|
||||
// Provider identifiers rendered as code
|
||||
// openai also appears in base_url, so use getAllByText
|
||||
expect(screen.getAllByText('openai').length).toBeGreaterThanOrEqual(1)
|
||||
expect(screen.getAllByText('anthropic').length).toBeGreaterThanOrEqual(1)
|
||||
expect(screen.getAllByText('deepseek').length).toBeGreaterThanOrEqual(1)
|
||||
})
|
||||
|
||||
it('shows loading spinner before data arrives', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/providers', async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
return HttpResponse.json(mockProviders)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<ModelServices />)
|
||||
|
||||
const spinner = document.querySelector('.ant-spin')
|
||||
expect(spinner).toBeTruthy()
|
||||
|
||||
// Wait for loading to complete so afterEach cleanup is clean
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('OpenAI')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('renders provider status as tag', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/providers', () => {
|
||||
return HttpResponse.json(mockProviders)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<ModelServices />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('OpenAI')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// enabled: true -> "启用" tag, enabled: false -> "禁用" tag
|
||||
const enabledTags = screen.getAllByText('启用')
|
||||
expect(enabledTags.length).toBe(2) // openai + deepseek
|
||||
|
||||
expect(screen.getByText('禁用')).toBeInTheDocument() // anthropic
|
||||
})
|
||||
|
||||
it('shows empty table on API failure', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/providers', () => {
|
||||
return HttpResponse.json(
|
||||
{ error: 'internal_error', message: '获取服务商列表失败' },
|
||||
{ status: 500 },
|
||||
)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<ModelServices />)
|
||||
|
||||
// Page header should still render
|
||||
expect(screen.getByText('模型服务')).toBeInTheDocument()
|
||||
|
||||
// Provider names should NOT be rendered
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText('OpenAI')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
178
admin-v2/tests/pages/Prompts.test.tsx
Normal file
178
admin-v2/tests/pages/Prompts.test.tsx
Normal file
@@ -0,0 +1,178 @@
|
||||
// ============================================================
|
||||
// Prompts 页面测试
|
||||
// ============================================================
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { http, HttpResponse } from 'msw'
|
||||
import { setupServer } from 'msw/node'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
import Prompts from '@/pages/Prompts'
|
||||
|
||||
// ── Mock data ────────────────────────────────────────────────
|
||||
|
||||
const mockPrompts = {
|
||||
items: [
|
||||
{
|
||||
id: 'pt-001',
|
||||
name: 'system-default',
|
||||
category: 'system',
|
||||
description: 'Default system prompt for all agents',
|
||||
source: 'builtin' as const,
|
||||
current_version: 3,
|
||||
status: 'active' as const,
|
||||
created_at: '2026-01-15T08:00:00Z',
|
||||
updated_at: '2026-03-20T12:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 'pt-002',
|
||||
name: 'custom-research',
|
||||
category: 'tool',
|
||||
description: 'Custom research prompt template',
|
||||
source: 'custom' as const,
|
||||
current_version: 1,
|
||||
status: 'active' as const,
|
||||
created_at: '2026-03-01T10:00:00Z',
|
||||
updated_at: '2026-03-01T10:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 'pt-003',
|
||||
name: 'legacy-summary',
|
||||
category: 'system',
|
||||
description: 'Legacy summary prompt',
|
||||
source: 'builtin' as const,
|
||||
current_version: 5,
|
||||
status: 'archived' as const,
|
||||
created_at: '2025-06-01T00:00:00Z',
|
||||
updated_at: '2026-02-28T00:00:00Z',
|
||||
},
|
||||
],
|
||||
total: 3,
|
||||
page: 1,
|
||||
page_size: 20,
|
||||
}
|
||||
|
||||
// ── MSW server ───────────────────────────────────────────────
|
||||
|
||||
const server = setupServer()
|
||||
|
||||
beforeEach(() => {
|
||||
server.listen({ onUnhandledRequest: 'bypass' })
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
server.close()
|
||||
})
|
||||
|
||||
// ── Helper: render with QueryClient ──────────────────────────
|
||||
|
||||
function renderWithProviders(ui: React.ReactElement) {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
})
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────
|
||||
|
||||
describe('Prompts page', () => {
|
||||
it('renders page title and create button', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/prompts', () => {
|
||||
return HttpResponse.json(mockPrompts)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Prompts />)
|
||||
|
||||
expect(screen.getByText('提示词管理')).toBeInTheDocument()
|
||||
expect(screen.getByText('管理系统提示词模板和版本历史')).toBeInTheDocument()
|
||||
expect(screen.getByText('新建提示词')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('fetches and displays prompt templates', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/prompts', () => {
|
||||
return HttpResponse.json(mockPrompts)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Prompts />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('system-default')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(screen.getByText('custom-research')).toBeInTheDocument()
|
||||
expect(screen.getByText('legacy-summary')).toBeInTheDocument()
|
||||
|
||||
// Category "tool" appears once in data
|
||||
expect(screen.getByText('tool')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows loading spinner before data arrives', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/prompts', async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
return HttpResponse.json(mockPrompts)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Prompts />)
|
||||
|
||||
const spinner = document.querySelector('.ant-spin')
|
||||
expect(spinner).toBeTruthy()
|
||||
|
||||
// Wait for loading to complete so afterEach cleanup is clean
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('system-default')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('renders source as tag with correct labels', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/prompts', () => {
|
||||
return HttpResponse.json(mockPrompts)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Prompts />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('system-default')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// sourceLabels: { builtin: '内置', custom: '自定义' }
|
||||
// '内置' appears twice (2 builtin items), '自定义' appears once
|
||||
const builtinTags = screen.getAllByText('内置')
|
||||
expect(builtinTags.length).toBe(2)
|
||||
expect(screen.getByText('自定义')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows error state on API failure', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/prompts', () => {
|
||||
return HttpResponse.json(
|
||||
{ error: 'internal_error', message: '获取提示词列表失败' },
|
||||
{ status: 500 },
|
||||
)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Prompts />)
|
||||
|
||||
// React Query error propagation: ProTable receives empty data
|
||||
// but the query error should be visible via the table state
|
||||
// Check that no prompt names are rendered
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText('system-default')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
})
|
||||
234
admin-v2/tests/pages/Relay.test.tsx
Normal file
234
admin-v2/tests/pages/Relay.test.tsx
Normal file
@@ -0,0 +1,234 @@
|
||||
// ============================================================
|
||||
// Relay 页面测试
|
||||
// ============================================================
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { http, HttpResponse } from 'msw'
|
||||
import { setupServer } from 'msw/node'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
import Relay from '@/pages/Relay'
|
||||
|
||||
// ── Mock data ────────────────────────────────────────────────
|
||||
|
||||
const mockRelayTasks = {
|
||||
items: [
|
||||
{
|
||||
id: 'task-001-abcdef',
|
||||
account_id: 'acc-001',
|
||||
provider_id: 'prov-001',
|
||||
model_id: 'gpt-4o',
|
||||
status: 'completed',
|
||||
priority: 0,
|
||||
attempt_count: 1,
|
||||
max_attempts: 3,
|
||||
input_tokens: 1500,
|
||||
output_tokens: 800,
|
||||
error_message: null,
|
||||
queued_at: '2026-03-30T10:00:00Z',
|
||||
started_at: '2026-03-30T10:00:01Z',
|
||||
completed_at: '2026-03-30T10:00:05Z',
|
||||
created_at: '2026-03-30T10:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 'task-002-ghijkl',
|
||||
account_id: 'acc-002',
|
||||
provider_id: 'prov-002',
|
||||
model_id: 'claude-3.5-sonnet',
|
||||
status: 'failed',
|
||||
priority: 0,
|
||||
attempt_count: 3,
|
||||
max_attempts: 3,
|
||||
input_tokens: 2000,
|
||||
output_tokens: 0,
|
||||
error_message: 'Rate limit exceeded',
|
||||
queued_at: '2026-03-30T09:00:00Z',
|
||||
started_at: '2026-03-30T09:00:01Z',
|
||||
completed_at: '2026-03-30T09:01:00Z',
|
||||
created_at: '2026-03-30T09:00:00Z',
|
||||
},
|
||||
{
|
||||
id: 'task-003-mnopqr',
|
||||
account_id: 'acc-001',
|
||||
provider_id: 'prov-001',
|
||||
model_id: 'gpt-4o-mini',
|
||||
status: 'queued',
|
||||
priority: 1,
|
||||
attempt_count: 0,
|
||||
max_attempts: 3,
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
error_message: null,
|
||||
queued_at: '2026-03-30T11:00:00Z',
|
||||
started_at: null,
|
||||
completed_at: null,
|
||||
created_at: '2026-03-30T11:00:00Z',
|
||||
},
|
||||
],
|
||||
total: 3,
|
||||
page: 1,
|
||||
page_size: 20,
|
||||
}
|
||||
|
||||
// ── MSW server ───────────────────────────────────────────────
|
||||
|
||||
const server = setupServer()
|
||||
|
||||
beforeEach(() => {
|
||||
server.listen({ onUnhandledRequest: 'bypass' })
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
server.close()
|
||||
})
|
||||
|
||||
// ── Helper: render with QueryClient ──────────────────────────
|
||||
|
||||
function renderWithProviders(ui: React.ReactElement) {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
})
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────
|
||||
|
||||
describe('Relay page', () => {
|
||||
it('renders page header', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/relay/tasks', () => {
|
||||
return HttpResponse.json(mockRelayTasks)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Relay />)
|
||||
|
||||
expect(screen.getByText('中转任务')).toBeInTheDocument()
|
||||
expect(screen.getByText('查看和管理 AI 模型中转请求')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('fetches and displays relay tasks', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/relay/tasks', () => {
|
||||
return HttpResponse.json(mockRelayTasks)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Relay />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('已完成')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.getByText('失败')).toBeInTheDocument()
|
||||
expect(screen.getByText('排队中')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows loading spinner while fetching', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/relay/tasks', async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
return HttpResponse.json(mockRelayTasks)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Relay />)
|
||||
|
||||
// Ant Design Spin component renders a .ant-spin element
|
||||
const spinner = document.querySelector('.ant-spin')
|
||||
expect(spinner).toBeTruthy()
|
||||
|
||||
// Wait for loading to complete so afterEach cleanup is clean
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('已完成')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('shows ErrorState on API failure with retry button', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/relay/tasks', () => {
|
||||
return HttpResponse.json(
|
||||
{ error: 'internal_error', message: '服务器内部错误' },
|
||||
{ status: 500 },
|
||||
)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Relay />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('服务器内部错误')).toBeInTheDocument()
|
||||
})
|
||||
// Ant Design Button splits two-character text with a space: "重 试"
|
||||
const retryButton = screen.getByRole('button', { name: /重.?试/ })
|
||||
expect(retryButton).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders status as colored tag', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/relay/tasks', () => {
|
||||
return HttpResponse.json(mockRelayTasks)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Relay />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('已完成')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Verify the status tags have correct Ant Design color classes
|
||||
const completedTag = screen.getByText('已完成').closest('.ant-tag')
|
||||
expect(completedTag).toBeTruthy()
|
||||
// statusColors.completed = 'green'
|
||||
expect(completedTag?.className).toMatch(/green/)
|
||||
|
||||
const failedTag = screen.getByText('失败').closest('.ant-tag')
|
||||
expect(failedTag).toBeTruthy()
|
||||
// statusColors.failed = 'red'
|
||||
expect(failedTag?.className).toMatch(/red/)
|
||||
})
|
||||
|
||||
it('renders model_id column', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/relay/tasks', () => {
|
||||
return HttpResponse.json(mockRelayTasks)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Relay />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('已完成')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(screen.getByText('gpt-4o')).toBeInTheDocument()
|
||||
expect(screen.getByText('claude-3.5-sonnet')).toBeInTheDocument()
|
||||
expect(screen.getByText('gpt-4o-mini')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders token count column', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/relay/tasks', () => {
|
||||
return HttpResponse.json(mockRelayTasks)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Relay />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('已完成')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Token (入/出): 1,500 / 800
|
||||
expect(screen.getByText(/1,500 \/ 800/)).toBeInTheDocument()
|
||||
// 2,000 / 0
|
||||
expect(screen.getByText(/2,000 \/ 0/)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
248
admin-v2/tests/pages/Usage.test.tsx
Normal file
248
admin-v2/tests/pages/Usage.test.tsx
Normal file
@@ -0,0 +1,248 @@
|
||||
// ============================================================
|
||||
// Usage 页面测试
|
||||
// ============================================================
|
||||
|
||||
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
|
||||
import { render, screen, waitFor } from '@testing-library/react'
|
||||
import { http, HttpResponse } from 'msw'
|
||||
import { setupServer } from 'msw/node'
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
|
||||
|
||||
import Usage from '@/pages/Usage'
|
||||
|
||||
// ── Mock data ────────────────────────────────────────────────
|
||||
|
||||
const mockDailyStats = [
|
||||
{
|
||||
day: '2026-03-28',
|
||||
request_count: 120,
|
||||
input_tokens: 24000,
|
||||
output_tokens: 8000,
|
||||
unique_devices: 5,
|
||||
},
|
||||
{
|
||||
day: '2026-03-29',
|
||||
request_count: 80,
|
||||
input_tokens: 16000,
|
||||
output_tokens: 5000,
|
||||
unique_devices: 3,
|
||||
},
|
||||
{
|
||||
day: '2026-03-30',
|
||||
request_count: 200,
|
||||
input_tokens: 40000,
|
||||
output_tokens: 12000,
|
||||
unique_devices: 7,
|
||||
},
|
||||
]
|
||||
|
||||
const mockModelStats = [
|
||||
{
|
||||
model_id: 'gpt-4o',
|
||||
request_count: 300,
|
||||
input_tokens: 60000,
|
||||
output_tokens: 18000,
|
||||
avg_latency_ms: 450.3,
|
||||
success_rate: 0.98,
|
||||
},
|
||||
{
|
||||
model_id: 'claude-sonnet-4-20250514',
|
||||
request_count: 100,
|
||||
input_tokens: 20000,
|
||||
output_tokens: 7000,
|
||||
avg_latency_ms: 620.7,
|
||||
success_rate: 0.95,
|
||||
},
|
||||
]
|
||||
|
||||
// ── MSW server ───────────────────────────────────────────────
|
||||
|
||||
const server = setupServer()
|
||||
|
||||
beforeEach(() => {
|
||||
server.listen({ onUnhandledRequest: 'bypass' })
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
server.close()
|
||||
})
|
||||
|
||||
// ── Helper: render with QueryClient ──────────────────────────
|
||||
|
||||
function renderWithProviders(ui: React.ReactElement) {
|
||||
const queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: { retry: false },
|
||||
},
|
||||
})
|
||||
return render(
|
||||
<QueryClientProvider client={queryClient}>
|
||||
{ui}
|
||||
</QueryClientProvider>,
|
||||
)
|
||||
}
|
||||
|
||||
// ── Tests ────────────────────────────────────────────────────
|
||||
|
||||
describe('Usage page', () => {
|
||||
it('renders page title and summary cards', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/telemetry/daily', () => {
|
||||
return HttpResponse.json(mockDailyStats)
|
||||
}),
|
||||
http.get('*/api/v1/telemetry/stats', () => {
|
||||
return HttpResponse.json(mockModelStats)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Usage />)
|
||||
|
||||
expect(screen.getByText('用量统计')).toBeInTheDocument()
|
||||
expect(screen.getByText('查看模型使用情况和 Token 消耗')).toBeInTheDocument()
|
||||
|
||||
// Summary card titles
|
||||
expect(screen.getByText('总请求数')).toBeInTheDocument()
|
||||
expect(screen.getByText('总 Token 数')).toBeInTheDocument()
|
||||
|
||||
// Total requests: 120 + 80 + 200 = 400
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('400')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Total tokens: (24000+8000) + (16000+5000) + (40000+12000) = 105,000
|
||||
expect(screen.getByText('105,000')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('fetches and displays daily stats table', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/telemetry/daily', () => {
|
||||
return HttpResponse.json(mockDailyStats)
|
||||
}),
|
||||
http.get('*/api/v1/telemetry/stats', () => {
|
||||
return HttpResponse.json(mockModelStats)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Usage />)
|
||||
|
||||
// Table column headers
|
||||
expect(screen.getByText('每日统计')).toBeInTheDocument()
|
||||
|
||||
// Wait for data rows to render
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('2026-03-28')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Formatted request counts
|
||||
expect(screen.getByText('120')).toBeInTheDocument()
|
||||
expect(screen.getByText('80')).toBeInTheDocument()
|
||||
expect(screen.getByText('200')).toBeInTheDocument()
|
||||
|
||||
// Device counts
|
||||
expect(screen.getByText('5')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('fetches and displays model stats table', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/telemetry/daily', () => {
|
||||
return HttpResponse.json(mockDailyStats)
|
||||
}),
|
||||
http.get('*/api/v1/telemetry/stats', () => {
|
||||
return HttpResponse.json(mockModelStats)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Usage />)
|
||||
|
||||
expect(screen.getByText('按模型统计')).toBeInTheDocument()
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('gpt-4o')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
expect(screen.getByText('claude-sonnet-4-20250514')).toBeInTheDocument()
|
||||
|
||||
// Success rate: 0.98 -> "98.0%"
|
||||
expect(screen.getByText('98.0%')).toBeInTheDocument()
|
||||
|
||||
// Avg latency: 450.3 -> "450ms"
|
||||
expect(screen.getByText('450ms')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows loading spinner before data loads', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/telemetry/daily', async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
return HttpResponse.json(mockDailyStats)
|
||||
}),
|
||||
http.get('*/api/v1/telemetry/stats', async () => {
|
||||
await new Promise((resolve) => setTimeout(resolve, 500))
|
||||
return HttpResponse.json(mockModelStats)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Usage />)
|
||||
|
||||
// Ant Design Spin component renders a .ant-spin element
|
||||
const spinner = document.querySelector('.ant-spin')
|
||||
expect(spinner).toBeTruthy()
|
||||
|
||||
// Wait for loading to complete so afterEach cleanup is clean
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('用量统计')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('shows ErrorState when daily stats request fails', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/telemetry/daily', () => {
|
||||
return HttpResponse.json(
|
||||
{ error: 'internal_error', message: '服务器内部错误' },
|
||||
{ status: 500 },
|
||||
)
|
||||
}),
|
||||
http.get('*/api/v1/telemetry/stats', () => {
|
||||
return HttpResponse.json(mockModelStats)
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Usage />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('服务器内部错误')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// ErrorState renders a retry button (antd v6 may split Chinese characters)
|
||||
expect(screen.getByRole('button', { name: /重.*试/ })).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('calculates totals correctly from daily data', async () => {
|
||||
server.use(
|
||||
http.get('*/api/v1/telemetry/daily', () => {
|
||||
return HttpResponse.json([
|
||||
{
|
||||
day: '2026-03-30',
|
||||
request_count: 1500,
|
||||
input_tokens: 10000,
|
||||
output_tokens: 3000,
|
||||
unique_devices: 2,
|
||||
},
|
||||
])
|
||||
}),
|
||||
http.get('*/api/v1/telemetry/stats', () => {
|
||||
return HttpResponse.json([])
|
||||
}),
|
||||
)
|
||||
|
||||
renderWithProviders(<Usage />)
|
||||
|
||||
// Total requests: 1500 (formatted as "1,500" by Statistic)
|
||||
await waitFor(() => {
|
||||
const elements = screen.getAllByText('1,500')
|
||||
expect(elements.length).toBeGreaterThanOrEqual(1)
|
||||
})
|
||||
|
||||
// Total tokens: 10000 + 3000 = 13,000
|
||||
expect(screen.getAllByText('13,000').length).toBeGreaterThanOrEqual(1)
|
||||
})
|
||||
})
|
||||
@@ -31,4 +31,5 @@ jobs = [
|
||||
{ name = "cleanup_rate_limit", interval = "5m", task = "cleanup_rate_limit", run_on_start = false },
|
||||
{ name = "cleanup_refresh_tokens", interval = "1h", task = "cleanup_refresh_tokens", run_on_start = false },
|
||||
{ name = "cleanup_devices", interval = "24h", task = "cleanup_devices", run_on_start = false },
|
||||
{ name = "aggregate_usage", interval = "1h", task = "aggregate_usage", run_on_start = true, args = { account_id = null } },
|
||||
]
|
||||
|
||||
@@ -216,9 +216,10 @@ impl QueryAnalyzer {
|
||||
expansions
|
||||
}
|
||||
|
||||
/// Get synonyms for a keyword (simplified)
|
||||
/// Get synonyms for a keyword (simplified, English + Chinese)
|
||||
fn get_synonyms(&self, keyword: &str) -> Option<Vec<String>> {
|
||||
let synonyms: &[&str] = match keyword {
|
||||
// English synonyms
|
||||
"code" => &["program", "script", "source"],
|
||||
"error" => &["bug", "issue", "problem", "exception"],
|
||||
"fix" => &["solve", "resolve", "repair", "patch"],
|
||||
@@ -226,6 +227,20 @@ impl QueryAnalyzer {
|
||||
"slow" => &["performance", "optimize", "speed"],
|
||||
"help" => &["assist", "support", "guide", "aid"],
|
||||
"learn" => &["study", "understand", "know", "grasp"],
|
||||
// Chinese synonyms — critical for Chinese-language queries
|
||||
"错误" => &["问题", "bug", "异常", "故障"],
|
||||
"修复" => &["解决", "修正", "处理", "fix"],
|
||||
"优化" => &["改进", "提升", "加速", "improve"],
|
||||
"配置" => &["设置", "参数", "选项", "config"],
|
||||
"性能" => &["速度", "效率", "performance"],
|
||||
"问题" => &["错误", "故障", "issue", "problem"],
|
||||
"帮助" => &["协助", "支持", "help"],
|
||||
"学习" => &["了解", "掌握", "learn"],
|
||||
"代码" => &["程序", "脚本", "code"],
|
||||
"数据库" => &["DB", "database", "存储"],
|
||||
"部署" => &["发布", "上线", "deploy"],
|
||||
"测试" => &["验证", "检验", "test"],
|
||||
"安全" => &["防护", "加密", "security"],
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
|
||||
@@ -111,7 +111,7 @@ impl HandResult {
|
||||
}
|
||||
|
||||
/// Hand execution status
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum HandStatus {
|
||||
Idle,
|
||||
|
||||
@@ -134,7 +134,7 @@ impl BrowserHand {
|
||||
id: "browser".to_string(),
|
||||
name: "浏览器".to_string(),
|
||||
description: "网页浏览器自动化,支持导航、交互和数据采集".to_string(),
|
||||
needs_approval: false,
|
||||
needs_approval: true,
|
||||
dependencies: vec!["webdriver".to_string()],
|
||||
input_schema: Some(serde_json::json!({
|
||||
"type": "object",
|
||||
@@ -420,8 +420,211 @@ impl BrowserSequence {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set whether to stop on error
|
||||
pub fn stop_on_error(mut self, stop: bool) -> Self {
|
||||
self.stop_on_error = stop;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the sequence
|
||||
pub fn build(self) -> Vec<BrowserAction> {
|
||||
self.steps
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Hand;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn fresh_context() -> HandContext {
|
||||
HandContext {
|
||||
agent_id: zclaw_types::AgentId::new(),
|
||||
working_dir: None,
|
||||
env: HashMap::new(),
|
||||
timeout_secs: 30,
|
||||
callback_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_browser_config() {
|
||||
let hand = BrowserHand::new();
|
||||
let config = hand.config();
|
||||
assert_eq!(config.id, "browser");
|
||||
assert!(config.enabled);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_browser_config_needs_approval() {
|
||||
let hand = BrowserHand::new();
|
||||
assert!(hand.config().needs_approval, "Browser hand should require approval per TOML config");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_deserialize_navigate() {
|
||||
let json = serde_json::json!({
|
||||
"action": "navigate",
|
||||
"url": "https://example.com",
|
||||
"wait_for": "body"
|
||||
});
|
||||
let action: BrowserAction = serde_json::from_value(json).expect("deserialize navigate");
|
||||
match action {
|
||||
BrowserAction::Navigate { url, wait_for } => {
|
||||
assert_eq!(url, "https://example.com");
|
||||
assert_eq!(wait_for, Some("body".to_string()));
|
||||
}
|
||||
_ => panic!("Expected Navigate action, got {:?}", action),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_deserialize_click() {
|
||||
let json = serde_json::json!({
|
||||
"action": "click",
|
||||
"selector": "#submit-btn",
|
||||
"wait_ms": 500
|
||||
});
|
||||
let action: BrowserAction = serde_json::from_value(json).expect("deserialize click");
|
||||
match action {
|
||||
BrowserAction::Click { selector, wait_ms } => {
|
||||
assert_eq!(selector, "#submit-btn");
|
||||
assert_eq!(wait_ms, Some(500));
|
||||
}
|
||||
_ => panic!("Expected Click action, got {:?}", action),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_deserialize_type() {
|
||||
let json = serde_json::json!({
|
||||
"action": "type",
|
||||
"selector": "#search",
|
||||
"text": "hello world",
|
||||
"clear_first": true
|
||||
});
|
||||
let action: BrowserAction = serde_json::from_value(json).expect("deserialize type");
|
||||
match action {
|
||||
BrowserAction::Type { selector, text, clear_first } => {
|
||||
assert_eq!(selector, "#search");
|
||||
assert_eq!(text, "hello world");
|
||||
assert!(clear_first);
|
||||
}
|
||||
_ => panic!("Expected Type action, got {:?}", action),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_deserialize_scrape() {
|
||||
let json = serde_json::json!({
|
||||
"action": "scrape",
|
||||
"selectors": ["h1", ".content", "#price"]
|
||||
});
|
||||
let action: BrowserAction = serde_json::from_value(json).expect("deserialize scrape");
|
||||
match action {
|
||||
BrowserAction::Scrape { selectors, wait_for } => {
|
||||
assert_eq!(selectors, vec!["h1", ".content", "#price"]);
|
||||
assert!(wait_for.is_none());
|
||||
}
|
||||
_ => panic!("Expected Scrape action, got {:?}", action),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_deserialize_screenshot() {
|
||||
let json = serde_json::json!({
|
||||
"action": "screenshot",
|
||||
"full_page": true
|
||||
});
|
||||
let action: BrowserAction = serde_json::from_value(json).expect("deserialize screenshot");
|
||||
match action {
|
||||
BrowserAction::Screenshot { selector, full_page } => {
|
||||
assert!(selector.is_none());
|
||||
assert!(full_page);
|
||||
}
|
||||
_ => panic!("Expected Screenshot action, got {:?}", action),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_major_actions_roundtrip() {
|
||||
let actions = vec![
|
||||
BrowserAction::Navigate { url: "https://example.com".into(), wait_for: None },
|
||||
BrowserAction::Click { selector: "#btn".into(), wait_ms: None },
|
||||
BrowserAction::Type { selector: "#input".into(), text: "test".into(), clear_first: false },
|
||||
BrowserAction::Scrape { selectors: vec!["h1".into()], wait_for: None },
|
||||
BrowserAction::Screenshot { selector: None, full_page: false },
|
||||
BrowserAction::Wait { selector: "#loaded".into(), timeout_ms: 5000 },
|
||||
BrowserAction::Execute { script: "return 1".into(), args: vec![] },
|
||||
BrowserAction::FillForm {
|
||||
fields: vec![FormField { selector: "#name".into(), value: "Alice".into() }],
|
||||
submit_selector: Some("#submit".into()),
|
||||
},
|
||||
];
|
||||
|
||||
for original in actions {
|
||||
let json = serde_json::to_value(&original).expect("serialize action");
|
||||
let roundtripped: BrowserAction = serde_json::from_value(json).expect("deserialize action");
|
||||
assert_eq!(
|
||||
serde_json::to_value(&original).unwrap(),
|
||||
serde_json::to_value(&roundtripped).unwrap(),
|
||||
"Roundtrip failed for {:?}",
|
||||
original
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_browser_sequence_builder() {
|
||||
let ctx = fresh_context();
|
||||
let hand = BrowserHand::new();
|
||||
|
||||
let sequence = BrowserSequence::new("test_sequence")
|
||||
.navigate("https://example.com")
|
||||
.stop_on_error(false);
|
||||
|
||||
assert_eq!(sequence.name, "test_sequence");
|
||||
assert!(!sequence.stop_on_error);
|
||||
assert_eq!(sequence.steps.len(), 1);
|
||||
|
||||
// Execute the navigate step
|
||||
let action_json = serde_json::to_value(&sequence.steps[0]).expect("serialize step");
|
||||
let result = hand.execute(&ctx, action_json).await.expect("execute");
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output["action"], "navigate");
|
||||
assert_eq!(result.output["url"], "https://example.com");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_browser_sequence_multiple_steps() {
|
||||
let ctx = fresh_context();
|
||||
let hand = BrowserHand::new();
|
||||
|
||||
let sequence = BrowserSequence::new("multi_step")
|
||||
.navigate("https://example.com")
|
||||
.click("#login-btn")
|
||||
.type_text("#username", "admin")
|
||||
.screenshot();
|
||||
|
||||
assert_eq!(sequence.steps.len(), 4);
|
||||
|
||||
// Verify each step can execute
|
||||
for (i, step) in sequence.steps.iter().enumerate() {
|
||||
let action_json = serde_json::to_value(step).expect("serialize step");
|
||||
let result = hand.execute(&ctx, action_json).await.expect("execute step");
|
||||
assert!(result.success, "Step {} failed: {:?}", i, result.error);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_form_field_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"selector": "#email",
|
||||
"value": "user@example.com"
|
||||
});
|
||||
let field: FormField = serde_json::from_value(json).expect("deserialize form field");
|
||||
assert_eq!(field.selector, "#email");
|
||||
assert_eq!(field.value, "user@example.com");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -640,3 +640,390 @@ impl Hand for ClipHand {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// === Config & Defaults ===
|
||||
|
||||
#[test]
|
||||
fn test_hand_config() {
|
||||
let hand = ClipHand::new();
|
||||
assert_eq!(hand.config().id, "clip");
|
||||
assert_eq!(hand.config().name, "视频剪辑");
|
||||
assert!(!hand.config().needs_approval);
|
||||
assert!(hand.config().enabled);
|
||||
assert!(hand.config().tags.contains(&"video".to_string()));
|
||||
assert!(hand.config().input_schema.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_impl() {
|
||||
let hand = ClipHand::default();
|
||||
assert_eq!(hand.config().id, "clip");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_needs_approval() {
|
||||
let hand = ClipHand::new();
|
||||
assert!(!hand.needs_approval());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_dependencies() {
|
||||
let hand = ClipHand::new();
|
||||
let deps = hand.check_dependencies().unwrap();
|
||||
// May or may not find ffmpeg depending on test environment
|
||||
// Just verify it doesn't panic
|
||||
let _ = deps;
|
||||
}
|
||||
|
||||
// === VideoFormat ===
|
||||
|
||||
#[test]
|
||||
fn test_video_format_default() {
|
||||
assert!(matches!(VideoFormat::default(), VideoFormat::Mp4));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_video_format_deserialize() {
|
||||
let fmt: VideoFormat = serde_json::from_value(json!("mp4")).unwrap();
|
||||
assert!(matches!(fmt, VideoFormat::Mp4));
|
||||
|
||||
let fmt: VideoFormat = serde_json::from_value(json!("webm")).unwrap();
|
||||
assert!(matches!(fmt, VideoFormat::Webm));
|
||||
|
||||
let fmt: VideoFormat = serde_json::from_value(json!("gif")).unwrap();
|
||||
assert!(matches!(fmt, VideoFormat::Gif));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_video_format_serialize() {
|
||||
assert_eq!(serde_json::to_value(&VideoFormat::Mp4).unwrap(), "mp4");
|
||||
assert_eq!(serde_json::to_value(&VideoFormat::Webm).unwrap(), "webm");
|
||||
}
|
||||
|
||||
// === Resolution ===
|
||||
|
||||
#[test]
|
||||
fn test_resolution_default() {
|
||||
assert!(matches!(Resolution::default(), Resolution::Original));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolution_presets() {
|
||||
let r: Resolution = serde_json::from_value(json!("p720")).unwrap();
|
||||
assert!(matches!(r, Resolution::P720));
|
||||
|
||||
let r: Resolution = serde_json::from_value(json!("p1080")).unwrap();
|
||||
assert!(matches!(r, Resolution::P1080));
|
||||
|
||||
let r: Resolution = serde_json::from_value(json!("p4k")).unwrap();
|
||||
assert!(matches!(r, Resolution::P4k));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolution_custom() {
|
||||
let r: Resolution = serde_json::from_value(json!({"custom": {"width": 800, "height": 600}})).unwrap();
|
||||
match r {
|
||||
Resolution::Custom { width, height } => {
|
||||
assert_eq!(width, 800);
|
||||
assert_eq!(height, 600);
|
||||
}
|
||||
_ => panic!("Expected Custom"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolution_serialize() {
|
||||
assert_eq!(serde_json::to_value(&Resolution::P720).unwrap(), "p720");
|
||||
assert_eq!(serde_json::to_value(&Resolution::Original).unwrap(), "original");
|
||||
}
|
||||
|
||||
// === TrimConfig ===
|
||||
|
||||
#[test]
|
||||
fn test_trim_config_deserialize() {
|
||||
let config: TrimConfig = serde_json::from_value(json!({
|
||||
"inputPath": "/input.mp4",
|
||||
"outputPath": "/output.mp4",
|
||||
"startTime": 5.0,
|
||||
"duration": 10.0
|
||||
})).unwrap();
|
||||
assert_eq!(config.input_path, "/input.mp4");
|
||||
assert_eq!(config.output_path, "/output.mp4");
|
||||
assert_eq!(config.start_time, Some(5.0));
|
||||
assert_eq!(config.duration, Some(10.0));
|
||||
assert!(config.end_time.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trim_config_minimal() {
|
||||
let config: TrimConfig = serde_json::from_value(json!({
|
||||
"inputPath": "/in.mp4",
|
||||
"outputPath": "/out.mp4"
|
||||
})).unwrap();
|
||||
assert!(config.start_time.is_none());
|
||||
assert!(config.end_time.is_none());
|
||||
assert!(config.duration.is_none());
|
||||
}
|
||||
|
||||
// === ConvertConfig ===
|
||||
|
||||
#[test]
|
||||
fn test_convert_config_deserialize() {
|
||||
let config: ConvertConfig = serde_json::from_value(json!({
|
||||
"inputPath": "/input.avi",
|
||||
"outputPath": "/output.mp4",
|
||||
"format": "mp4",
|
||||
"resolution": "p1080",
|
||||
"videoBitrate": "4M",
|
||||
"audioBitrate": "192k"
|
||||
})).unwrap();
|
||||
assert_eq!(config.input_path, "/input.avi");
|
||||
assert!(matches!(config.format, VideoFormat::Mp4));
|
||||
assert!(matches!(config.resolution, Resolution::P1080));
|
||||
assert_eq!(config.video_bitrate, Some("4M".to_string()));
|
||||
assert_eq!(config.audio_bitrate, Some("192k".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_config_defaults() {
|
||||
let config: ConvertConfig = serde_json::from_value(json!({
|
||||
"inputPath": "/in.mp4",
|
||||
"outputPath": "/out.mp4"
|
||||
})).unwrap();
|
||||
assert!(matches!(config.format, VideoFormat::Mp4));
|
||||
assert!(matches!(config.resolution, Resolution::Original));
|
||||
assert!(config.video_bitrate.is_none());
|
||||
assert!(config.audio_bitrate.is_none());
|
||||
}
|
||||
|
||||
// === ThumbnailConfig ===
|
||||
|
||||
#[test]
|
||||
fn test_thumbnail_config_deserialize() {
|
||||
let config: ThumbnailConfig = serde_json::from_value(json!({
|
||||
"inputPath": "/video.mp4",
|
||||
"outputPath": "/thumb.jpg",
|
||||
"time": 5.0,
|
||||
"width": 320,
|
||||
"height": 240
|
||||
})).unwrap();
|
||||
assert_eq!(config.input_path, "/video.mp4");
|
||||
assert_eq!(config.time, 5.0);
|
||||
assert_eq!(config.width, Some(320));
|
||||
assert_eq!(config.height, Some(240));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thumbnail_config_defaults() {
|
||||
let config: ThumbnailConfig = serde_json::from_value(json!({
|
||||
"inputPath": "/v.mp4",
|
||||
"outputPath": "/t.jpg"
|
||||
})).unwrap();
|
||||
assert_eq!(config.time, 0.0);
|
||||
assert!(config.width.is_none());
|
||||
assert!(config.height.is_none());
|
||||
}
|
||||
|
||||
// === ConcatConfig ===
|
||||
|
||||
#[test]
|
||||
fn test_concat_config_deserialize() {
|
||||
let config: ConcatConfig = serde_json::from_value(json!({
|
||||
"inputPaths": ["/a.mp4", "/b.mp4"],
|
||||
"outputPath": "/merged.mp4"
|
||||
})).unwrap();
|
||||
assert_eq!(config.input_paths.len(), 2);
|
||||
assert_eq!(config.output_path, "/merged.mp4");
|
||||
}
|
||||
|
||||
// === VideoInfo ===
|
||||
|
||||
#[test]
|
||||
fn test_video_info_deserialize() {
|
||||
let info: VideoInfo = serde_json::from_value(json!({
|
||||
"path": "/test.mp4",
|
||||
"durationSecs": 120.5,
|
||||
"width": 1920,
|
||||
"height": 1080,
|
||||
"fps": 30.0,
|
||||
"format": "mp4",
|
||||
"videoCodec": "h264",
|
||||
"audioCodec": "aac",
|
||||
"bitrateKbps": 5000,
|
||||
"fileSizeBytes": 75_000_000
|
||||
})).unwrap();
|
||||
assert_eq!(info.path, "/test.mp4");
|
||||
assert_eq!(info.duration_secs, 120.5);
|
||||
assert_eq!(info.width, 1920);
|
||||
assert_eq!(info.fps, 30.0);
|
||||
assert_eq!(info.video_codec, "h264");
|
||||
assert_eq!(info.audio_codec, Some("aac".to_string()));
|
||||
assert_eq!(info.bitrate_kbps, Some(5000));
|
||||
assert_eq!(info.file_size_bytes, 75_000_000);
|
||||
}
|
||||
|
||||
// === ClipAction Deserialization ===
|
||||
|
||||
#[test]
|
||||
fn test_action_trim() {
|
||||
let action: ClipAction = serde_json::from_value(json!({
|
||||
"action": "trim",
|
||||
"config": {
|
||||
"inputPath": "/in.mp4",
|
||||
"outputPath": "/out.mp4",
|
||||
"startTime": 1.0,
|
||||
"endTime": 5.0
|
||||
}
|
||||
})).unwrap();
|
||||
match action {
|
||||
ClipAction::Trim { config } => {
|
||||
assert_eq!(config.input_path, "/in.mp4");
|
||||
assert_eq!(config.start_time, Some(1.0));
|
||||
}
|
||||
_ => panic!("Expected Trim"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_convert() {
|
||||
let action: ClipAction = serde_json::from_value(json!({
|
||||
"action": "convert",
|
||||
"config": {
|
||||
"inputPath": "/in.avi",
|
||||
"outputPath": "/out.mp4"
|
||||
}
|
||||
})).unwrap();
|
||||
assert!(matches!(action, ClipAction::Convert { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_resize() {
|
||||
let action: ClipAction = serde_json::from_value(json!({
|
||||
"action": "resize",
|
||||
"input_path": "/in.mp4",
|
||||
"output_path": "/out.mp4",
|
||||
"resolution": "p720"
|
||||
})).unwrap();
|
||||
match action {
|
||||
ClipAction::Resize { input_path, resolution, .. } => {
|
||||
assert_eq!(input_path, "/in.mp4");
|
||||
assert!(matches!(resolution, Resolution::P720));
|
||||
}
|
||||
_ => panic!("Expected Resize"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_thumbnail() {
|
||||
let action: ClipAction = serde_json::from_value(json!({
|
||||
"action": "thumbnail",
|
||||
"config": {
|
||||
"inputPath": "/in.mp4",
|
||||
"outputPath": "/thumb.jpg"
|
||||
}
|
||||
})).unwrap();
|
||||
assert!(matches!(action, ClipAction::Thumbnail { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_concat() {
|
||||
let action: ClipAction = serde_json::from_value(json!({
|
||||
"action": "concat",
|
||||
"config": {
|
||||
"inputPaths": ["/a.mp4", "/b.mp4"],
|
||||
"outputPath": "/out.mp4"
|
||||
}
|
||||
})).unwrap();
|
||||
assert!(matches!(action, ClipAction::Concat { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_info() {
|
||||
let action: ClipAction = serde_json::from_value(json!({
|
||||
"action": "info",
|
||||
"path": "/video.mp4"
|
||||
})).unwrap();
|
||||
match action {
|
||||
ClipAction::Info { path } => assert_eq!(path, "/video.mp4"),
|
||||
_ => panic!("Expected Info"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_check_ffmpeg() {
|
||||
let action: ClipAction = serde_json::from_value(json!({"action": "check_ffmpeg"})).unwrap();
|
||||
assert!(matches!(action, ClipAction::CheckFfmpeg));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_invalid() {
|
||||
let result = serde_json::from_value::<ClipAction>(json!({"action": "nonexistent"}));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// === Hand execute dispatch ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_check_ffmpeg() {
|
||||
let hand = ClipHand::new();
|
||||
let ctx = HandContext::default();
|
||||
let result = hand.execute(&ctx, json!({"action": "check_ffmpeg"})).await.unwrap();
|
||||
// Just verify it doesn't crash and returns a valid result
|
||||
assert!(result.output.is_object());
|
||||
// "available" field should exist
|
||||
assert!(result.output["available"].is_boolean());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_invalid_action() {
|
||||
let hand = ClipHand::new();
|
||||
let ctx = HandContext::default();
|
||||
let result = hand.execute(&ctx, json!({"action": "bogus"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// === Status ===
|
||||
|
||||
#[test]
|
||||
fn test_status() {
|
||||
let hand = ClipHand::new();
|
||||
let status = hand.status();
|
||||
// Either Idle (ffmpeg found) or Failed (not found) — just verify it doesn't panic
|
||||
assert!(matches!(status, crate::HandStatus::Idle | crate::HandStatus::Failed));
|
||||
}
|
||||
|
||||
// === Roundtrip ===
|
||||
|
||||
#[test]
|
||||
fn test_trim_action_roundtrip() {
|
||||
let json = json!({
|
||||
"action": "trim",
|
||||
"config": {
|
||||
"inputPath": "/in.mp4",
|
||||
"outputPath": "/out.mp4",
|
||||
"startTime": 2.0,
|
||||
"duration": 5.0
|
||||
}
|
||||
});
|
||||
let action: ClipAction = serde_json::from_value(json).unwrap();
|
||||
let serialized = serde_json::to_value(&action).unwrap();
|
||||
assert_eq!(serialized["action"], "trim");
|
||||
assert_eq!(serialized["config"]["inputPath"], "/in.mp4");
|
||||
assert_eq!(serialized["config"]["startTime"], 2.0);
|
||||
assert_eq!(serialized["config"]["duration"], 5.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_info_action_roundtrip() {
|
||||
let json = json!({"action": "info", "path": "/video.mp4"});
|
||||
let action: ClipAction = serde_json::from_value(json).unwrap();
|
||||
let serialized = serde_json::to_value(&action).unwrap();
|
||||
assert_eq!(serialized["action"], "info");
|
||||
assert_eq!(serialized["path"], "/video.mp4");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use zclaw_types::Result;
|
||||
use crate::{Hand, HandConfig, HandContext, HandResult};
|
||||
|
||||
/// Output format options
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OutputFormat {
|
||||
Json,
|
||||
@@ -234,16 +234,37 @@ impl CollectorHand {
|
||||
self.extract_visible_text(html)
|
||||
}
|
||||
|
||||
/// Extract visible text from HTML
|
||||
/// Extract visible text from HTML, stripping scripts and styles
|
||||
fn extract_visible_text(&self, html: &str) -> String {
|
||||
let html_lower = html.to_lowercase();
|
||||
let mut text = String::new();
|
||||
let mut in_tag = false;
|
||||
let mut in_script = false;
|
||||
let mut in_style = false;
|
||||
let mut pos: usize = 0;
|
||||
|
||||
for c in html.chars() {
|
||||
let char_len = c.len_utf8();
|
||||
match c {
|
||||
'<' => in_tag = true,
|
||||
'>' => in_tag = false,
|
||||
'<' => {
|
||||
let remaining = &html_lower[pos..];
|
||||
if remaining.starts_with("</script") {
|
||||
in_script = false;
|
||||
} else if remaining.starts_with("</style") {
|
||||
in_style = false;
|
||||
}
|
||||
if remaining.starts_with("<script") {
|
||||
in_script = true;
|
||||
} else if remaining.starts_with("<style") {
|
||||
in_style = true;
|
||||
}
|
||||
in_tag = true;
|
||||
}
|
||||
'>' => {
|
||||
in_tag = false;
|
||||
}
|
||||
_ if in_tag => {}
|
||||
_ if in_script || in_style => {}
|
||||
' ' | '\n' | '\t' | '\r' => {
|
||||
if !text.ends_with(' ') && !text.is_empty() {
|
||||
text.push(' ');
|
||||
@@ -251,11 +272,11 @@ impl CollectorHand {
|
||||
}
|
||||
_ => text.push(c),
|
||||
}
|
||||
pos += char_len;
|
||||
}
|
||||
|
||||
// Limit length
|
||||
if text.len() > 500 {
|
||||
text.truncate(500);
|
||||
if text.len() > 10000 {
|
||||
text.truncate(10000);
|
||||
text.push_str("...");
|
||||
}
|
||||
|
||||
@@ -407,3 +428,166 @@ impl Hand for CollectorHand {
|
||||
crate::HandStatus::Idle
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_collector_config() {
|
||||
let hand = CollectorHand::new();
|
||||
assert_eq!(hand.config().id, "collector");
|
||||
assert_eq!(hand.config().name, "数据采集器");
|
||||
assert!(hand.config().enabled);
|
||||
assert!(!hand.config().needs_approval);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_serialize() {
|
||||
let formats = vec![
|
||||
(OutputFormat::Csv, "\"csv\""),
|
||||
(OutputFormat::Markdown, "\"markdown\""),
|
||||
(OutputFormat::Json, "\"json\""),
|
||||
(OutputFormat::Text, "\"text\""),
|
||||
];
|
||||
|
||||
for (fmt, expected) in formats {
|
||||
let serialized = serde_json::to_string(&fmt).unwrap();
|
||||
assert_eq!(serialized, expected);
|
||||
}
|
||||
|
||||
// Verify round-trip deserialization
|
||||
for json_str in &["\"csv\"", "\"markdown\"", "\"json\"", "\"text\""] {
|
||||
let deserialized: OutputFormat = serde_json::from_str(json_str).unwrap();
|
||||
let re_serialized = serde_json::to_string(&deserialized).unwrap();
|
||||
assert_eq!(&re_serialized, json_str);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_visible_text_basic() {
|
||||
let hand = CollectorHand::new();
|
||||
let html = "<html><body><h1>Title</h1><p>Content here</p></body></html>";
|
||||
let text = hand.extract_visible_text(html);
|
||||
assert!(text.contains("Title"), "should contain 'Title', got: {}", text);
|
||||
assert!(text.contains("Content here"), "should contain 'Content here', got: {}", text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_visible_text_strips_scripts() {
|
||||
let hand = CollectorHand::new();
|
||||
let html = "<html><body><script>alert('xss')</script><p>Safe content</p></body></html>";
|
||||
let text = hand.extract_visible_text(html);
|
||||
assert!(!text.contains("alert"), "script content should be removed, got: {}", text);
|
||||
assert!(!text.contains("xss"), "script content should be removed, got: {}", text);
|
||||
assert!(text.contains("Safe content"), "visible content should remain, got: {}", text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_visible_text_strips_styles() {
|
||||
let hand = CollectorHand::new();
|
||||
let html = "<html><head><style>body { color: red; }</style></head><body><p>Text</p></body></html>";
|
||||
let text = hand.extract_visible_text(html);
|
||||
assert!(!text.contains("color"), "style content should be removed, got: {}", text);
|
||||
assert!(!text.contains("red"), "style content should be removed, got: {}", text);
|
||||
assert!(text.contains("Text"), "visible content should remain, got: {}", text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_visible_text_empty() {
|
||||
let hand = CollectorHand::new();
|
||||
let text = hand.extract_visible_text("");
|
||||
assert!(text.is_empty(), "empty HTML should produce empty text, got: '{}'", text);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_aggregate_action_empty_urls() {
|
||||
let hand = CollectorHand::new();
|
||||
let config = AggregationConfig {
|
||||
urls: vec![],
|
||||
aggregate_fields: vec![],
|
||||
};
|
||||
|
||||
let result = hand.execute_aggregate(&config).await.unwrap();
|
||||
let results = result.get("results").unwrap().as_array().unwrap();
|
||||
assert_eq!(results.len(), 0, "empty URLs should produce empty results");
|
||||
assert_eq!(result.get("source_count").unwrap().as_u64().unwrap(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collector_action_deserialize() {
|
||||
// Collect action
|
||||
let collect_json = json!({
|
||||
"action": "collect",
|
||||
"target": {
|
||||
"url": "https://example.com",
|
||||
"selector": ".article",
|
||||
"fields": { "title": "h1" },
|
||||
"maxItems": 10
|
||||
},
|
||||
"format": "markdown"
|
||||
});
|
||||
let action: CollectorAction = serde_json::from_value(collect_json).unwrap();
|
||||
match action {
|
||||
CollectorAction::Collect { target, format } => {
|
||||
assert_eq!(target.url, "https://example.com");
|
||||
assert_eq!(target.selector.as_deref(), Some(".article"));
|
||||
assert_eq!(target.max_items, 10);
|
||||
assert!(format.is_some());
|
||||
assert_eq!(format.unwrap(), OutputFormat::Markdown);
|
||||
}
|
||||
_ => panic!("Expected Collect action"),
|
||||
}
|
||||
|
||||
// Aggregate action
|
||||
let aggregate_json = json!({
|
||||
"action": "aggregate",
|
||||
"config": {
|
||||
"urls": ["https://a.com", "https://b.com"],
|
||||
"aggregateFields": ["title", "content"]
|
||||
}
|
||||
});
|
||||
let action: CollectorAction = serde_json::from_value(aggregate_json).unwrap();
|
||||
match action {
|
||||
CollectorAction::Aggregate { config } => {
|
||||
assert_eq!(config.urls.len(), 2);
|
||||
assert_eq!(config.aggregate_fields.len(), 2);
|
||||
}
|
||||
_ => panic!("Expected Aggregate action"),
|
||||
}
|
||||
|
||||
// Extract action
|
||||
let extract_json = json!({
|
||||
"action": "extract",
|
||||
"url": "https://example.com",
|
||||
"selectors": { "title": "h1", "body": "p" }
|
||||
});
|
||||
let action: CollectorAction = serde_json::from_value(extract_json).unwrap();
|
||||
match action {
|
||||
CollectorAction::Extract { url, selectors } => {
|
||||
assert_eq!(url, "https://example.com");
|
||||
assert_eq!(selectors.len(), 2);
|
||||
}
|
||||
_ => panic!("Expected Extract action"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_collection_target_deserialize() {
|
||||
let json = json!({
|
||||
"url": "https://example.com/page",
|
||||
"selector": ".content",
|
||||
"fields": {
|
||||
"title": "h1",
|
||||
"author": ".author-name"
|
||||
},
|
||||
"maxItems": 50
|
||||
});
|
||||
|
||||
let target: CollectionTarget = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(target.url, "https://example.com/page");
|
||||
assert_eq!(target.selector.as_deref(), Some(".content"));
|
||||
assert_eq!(target.fields.len(), 2);
|
||||
assert_eq!(target.max_items, 50);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -344,31 +344,34 @@ impl ResearcherHand {
|
||||
|
||||
/// Extract readable text from HTML
|
||||
fn extract_text_from_html(&self, html: &str) -> String {
|
||||
// Simple text extraction - remove HTML tags
|
||||
let html_lower = html.to_lowercase();
|
||||
let mut text = String::new();
|
||||
let mut in_tag = false;
|
||||
let mut in_script = false;
|
||||
let mut in_style = false;
|
||||
let mut pos: usize = 0;
|
||||
|
||||
for c in html.chars() {
|
||||
let char_len = c.len_utf8();
|
||||
match c {
|
||||
'<' => {
|
||||
in_tag = true;
|
||||
let remaining = html[text.len()..].to_lowercase();
|
||||
// Check for closing tags before entering tag mode
|
||||
let remaining = &html_lower[pos..];
|
||||
if remaining.starts_with("</script") {
|
||||
in_script = false;
|
||||
} else if remaining.starts_with("</style") {
|
||||
in_style = false;
|
||||
}
|
||||
// Check for opening tags
|
||||
if remaining.starts_with("<script") {
|
||||
in_script = true;
|
||||
} else if remaining.starts_with("<style") {
|
||||
in_style = true;
|
||||
}
|
||||
in_tag = true;
|
||||
}
|
||||
'>' => {
|
||||
in_tag = false;
|
||||
let remaining = html[text.len()..].to_lowercase();
|
||||
if remaining.starts_with("</script>") {
|
||||
in_script = false;
|
||||
} else if remaining.starts_with("</style>") {
|
||||
in_style = false;
|
||||
}
|
||||
}
|
||||
_ if in_tag => {}
|
||||
_ if in_script || in_style => {}
|
||||
@@ -379,9 +382,9 @@ impl ResearcherHand {
|
||||
}
|
||||
_ => text.push(c),
|
||||
}
|
||||
pos += char_len;
|
||||
}
|
||||
|
||||
// Limit length
|
||||
if text.len() > 10000 {
|
||||
text.truncate(10000);
|
||||
text.push_str("...");
|
||||
@@ -445,10 +448,33 @@ impl ResearcherHand {
|
||||
|
||||
let duration = start.elapsed().as_millis() as u64;
|
||||
|
||||
// Generate summary from top results
|
||||
let summary = if results.is_empty() {
|
||||
"未找到相关结果,建议调整搜索关键词后重试".to_string()
|
||||
} else {
|
||||
let top_snippets: Vec<&str> = results
|
||||
.iter()
|
||||
.take(3)
|
||||
.filter_map(|r| {
|
||||
let s = r.snippet.trim();
|
||||
if s.is_empty() { None } else { Some(s) }
|
||||
})
|
||||
.collect();
|
||||
if top_snippets.is_empty() {
|
||||
format!("找到 {} 条相关结果,但无摘要信息", results.len())
|
||||
} else {
|
||||
format!(
|
||||
"基于 {} 条搜索结果:{}",
|
||||
results.len(),
|
||||
top_snippets.join(";")
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ResearchReport {
|
||||
query: query.query.clone(),
|
||||
results,
|
||||
summary: None, // Would require LLM integration
|
||||
summary: Some(summary),
|
||||
key_findings,
|
||||
related_topics,
|
||||
researched_at: chrono::Utc::now().to_rfc3339(),
|
||||
@@ -543,3 +569,276 @@ fn url_encode(s: &str) -> String {
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_hand() -> ResearcherHand {
|
||||
ResearcherHand::new()
|
||||
}
|
||||
|
||||
fn test_context() -> HandContext {
|
||||
HandContext::default()
|
||||
}
|
||||
|
||||
// --- Config & Type Tests ---
|
||||
|
||||
#[test]
|
||||
fn test_config_id() {
|
||||
let hand = create_test_hand();
|
||||
assert_eq!(hand.config().id, "researcher");
|
||||
assert_eq!(hand.config().name, "研究员");
|
||||
assert!(hand.config().enabled);
|
||||
assert!(!hand.config().needs_approval);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_engine_default_is_auto() {
|
||||
let engine = SearchEngine::default();
|
||||
assert!(matches!(engine, SearchEngine::Auto));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_depth_default_is_standard() {
|
||||
let depth = ResearchDepth::default();
|
||||
assert!(matches!(depth, ResearchDepth::Standard));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_depth_serialize() {
|
||||
let json = serde_json::to_string(&ResearchDepth::Deep).unwrap();
|
||||
assert_eq!(json, "\"deep\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_depth_deserialize() {
|
||||
let depth: ResearchDepth = serde_json::from_str("\"quick\"").unwrap();
|
||||
assert!(matches!(depth, ResearchDepth::Quick));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_engine_serialize_roundtrip() {
|
||||
for engine in [SearchEngine::Google, SearchEngine::Bing, SearchEngine::DuckDuckGo, SearchEngine::Auto] {
|
||||
let json = serde_json::to_string(&engine).unwrap();
|
||||
let back: SearchEngine = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(json, serde_json::to_string(&back).unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
// --- Action Deserialization Tests ---
|
||||
|
||||
#[test]
|
||||
fn test_action_search_deserialize() {
|
||||
let json = json!({
|
||||
"action": "search",
|
||||
"query": {
|
||||
"query": "Rust programming",
|
||||
"engine": "duckduckgo",
|
||||
"depth": "quick",
|
||||
"maxResults": 5
|
||||
}
|
||||
});
|
||||
let action: ResearcherAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
ResearcherAction::Search { query } => {
|
||||
assert_eq!(query.query, "Rust programming");
|
||||
assert!(matches!(query.engine, SearchEngine::DuckDuckGo));
|
||||
assert!(matches!(query.depth, ResearchDepth::Quick));
|
||||
assert_eq!(query.max_results, 5);
|
||||
}
|
||||
_ => panic!("Expected Search action"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_fetch_deserialize() {
|
||||
let json = json!({
|
||||
"action": "fetch",
|
||||
"url": "https://example.com/page"
|
||||
});
|
||||
let action: ResearcherAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
ResearcherAction::Fetch { url } => {
|
||||
assert_eq!(url, "https://example.com/page");
|
||||
}
|
||||
_ => panic!("Expected Fetch action"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_report_deserialize() {
|
||||
let json = json!({
|
||||
"action": "report",
|
||||
"query": {
|
||||
"query": "AI trends 2026",
|
||||
"depth": "deep"
|
||||
}
|
||||
});
|
||||
let action: ResearcherAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
ResearcherAction::Report { query } => {
|
||||
assert_eq!(query.query, "AI trends 2026");
|
||||
assert!(matches!(query.depth, ResearchDepth::Deep));
|
||||
}
|
||||
_ => panic!("Expected Report action"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_action_invalid_rejected() {
|
||||
let json = json!({
|
||||
"action": "unknown_action",
|
||||
"data": "whatever"
|
||||
});
|
||||
let result: std::result::Result<ResearcherAction, _> = serde_json::from_value(json);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// --- URL Encoding Tests ---
|
||||
|
||||
#[test]
|
||||
fn test_url_encode_ascii() {
|
||||
assert_eq!(url_encode("hello world"), "hello%20world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_encode_chinese() {
|
||||
let encoded = url_encode("中文搜索");
|
||||
assert!(encoded.contains("%"));
|
||||
// Chinese chars should be percent-encoded
|
||||
assert!(!encoded.contains("中文"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_encode_safe_chars() {
|
||||
assert_eq!(url_encode("abc123-_."), "abc123-_.".to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_url_encode_empty() {
|
||||
assert_eq!(url_encode(""), "");
|
||||
}
|
||||
|
||||
// --- HTML Text Extraction Tests ---
|
||||
|
||||
#[test]
|
||||
fn test_extract_text_basic() {
|
||||
let hand = create_test_hand();
|
||||
let html = "<html><body><h1>Title</h1><p>Content here</p></body></html>";
|
||||
let text = hand.extract_text_from_html(html);
|
||||
assert!(text.contains("Title"));
|
||||
assert!(text.contains("Content here"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_text_strips_scripts() {
|
||||
let hand = create_test_hand();
|
||||
let html = "<html><body><script>alert('xss')</script><p>Safe text</p></body></html>";
|
||||
let text = hand.extract_text_from_html(html);
|
||||
assert!(!text.contains("alert"));
|
||||
assert!(text.contains("Safe text"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_text_strips_styles() {
|
||||
let hand = create_test_hand();
|
||||
let html = "<html><body><style>.class{color:red}</style><p>Visible</p></body></html>";
|
||||
let text = hand.extract_text_from_html(html);
|
||||
assert!(!text.contains("color"));
|
||||
assert!(text.contains("Visible"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_text_truncates_long_content() {
|
||||
let hand = create_test_hand();
|
||||
let long_body: String = "x".repeat(20000);
|
||||
let html = format!("<html><body><p>{}</p></body></html>", long_body);
|
||||
let text = hand.extract_text_from_html(&html);
|
||||
assert!(text.len() <= 10003); // 10000 + "..."
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_text_empty_body() {
|
||||
let hand = create_test_hand();
|
||||
let html = "<html><body></body></html>";
|
||||
let text = hand.extract_text_from_html(html);
|
||||
assert!(text.is_empty());
|
||||
}
|
||||
|
||||
// --- Hand Trait Tests ---
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_needs_approval_is_false() {
|
||||
let hand = create_test_hand();
|
||||
assert!(!hand.needs_approval());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_status_is_idle() {
|
||||
let hand = create_test_hand();
|
||||
assert!(matches!(hand.status(), crate::HandStatus::Idle));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_dependencies_ok() {
|
||||
let hand = create_test_hand();
|
||||
let missing = hand.check_dependencies().unwrap();
|
||||
// Default is_dependency_available returns true for all
|
||||
assert!(missing.is_empty());
|
||||
}
|
||||
|
||||
// --- Default Values Tests ---
|
||||
|
||||
#[test]
|
||||
fn test_research_query_defaults() {
|
||||
let json = json!({ "query": "test" });
|
||||
let query: ResearchQuery = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(query.query, "test");
|
||||
assert!(matches!(query.engine, SearchEngine::Auto));
|
||||
assert!(matches!(query.depth, ResearchDepth::Standard));
|
||||
assert_eq!(query.max_results, 10);
|
||||
assert_eq!(query.time_limit_secs, 60);
|
||||
assert!(!query.include_related);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_result_serialization() {
|
||||
let result = SearchResult {
|
||||
title: "Test".to_string(),
|
||||
url: "https://example.com".to_string(),
|
||||
snippet: "A snippet".to_string(),
|
||||
source: "TestSource".to_string(),
|
||||
relevance: 90,
|
||||
content: None,
|
||||
fetched_at: None,
|
||||
};
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("Test"));
|
||||
assert!(json.contains("https://example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_report_summary_is_some_when_results() {
|
||||
// Verify the struct allows Some value
|
||||
let report = ResearchReport {
|
||||
query: "test".to_string(),
|
||||
results: vec![SearchResult {
|
||||
title: "R".to_string(),
|
||||
url: "https://r.co".to_string(),
|
||||
snippet: "snippet text".to_string(),
|
||||
source: "S".to_string(),
|
||||
relevance: 80,
|
||||
content: None,
|
||||
fetched_at: None,
|
||||
}],
|
||||
summary: Some("基于 1 条搜索结果:snippet text".to_string()),
|
||||
key_findings: vec![],
|
||||
related_topics: vec![],
|
||||
researched_at: "2026-01-01T00:00:00Z".to_string(),
|
||||
duration_ms: 100,
|
||||
};
|
||||
assert!(report.summary.is_some());
|
||||
assert!(report.summary.unwrap().contains("snippet text"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -346,13 +346,50 @@ impl Hand for SlideshowHand {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::json;
|
||||
|
||||
// === Config & Defaults ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_slideshow_creation() {
|
||||
let hand = SlideshowHand::new();
|
||||
assert_eq!(hand.config().id, "slideshow");
|
||||
assert_eq!(hand.config().name, "幻灯片");
|
||||
assert!(!hand.config().needs_approval);
|
||||
assert!(hand.config().enabled);
|
||||
assert!(hand.config().tags.contains(&"presentation".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_impl() {
|
||||
let hand = SlideshowHand::default();
|
||||
assert_eq!(hand.config().id, "slideshow");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_needs_approval() {
|
||||
let hand = SlideshowHand::new();
|
||||
assert!(!hand.needs_approval());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_status() {
|
||||
let hand = SlideshowHand::new();
|
||||
assert_eq!(hand.status(), HandStatus::Idle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_state() {
|
||||
let state = SlideshowState::default();
|
||||
assert_eq!(state.current_slide, 0);
|
||||
assert_eq!(state.total_slides, 0);
|
||||
assert!(!state.is_playing);
|
||||
assert_eq!(state.auto_play_interval_ms, 5000);
|
||||
assert!(state.slides.is_empty());
|
||||
}
|
||||
|
||||
// === Navigation ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_navigation() {
|
||||
let hand = SlideshowHand::with_slides_async(vec![
|
||||
@@ -374,6 +411,53 @@ mod tests {
|
||||
assert_eq!(hand.get_state().await.current_slide, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_next_slide_at_end() {
|
||||
let hand = SlideshowHand::with_slides_async(vec![
|
||||
SlideContent { title: "Only Slide".to_string(), subtitle: None, content: vec![], notes: None, background: None },
|
||||
]).await;
|
||||
|
||||
// At slide 0, should not advance past last slide
|
||||
hand.execute_action(SlideshowAction::NextSlide).await.unwrap();
|
||||
assert_eq!(hand.get_state().await.current_slide, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_prev_slide_at_beginning() {
|
||||
let hand = SlideshowHand::with_slides_async(vec![
|
||||
SlideContent { title: "Slide 1".to_string(), subtitle: None, content: vec![], notes: None, background: None },
|
||||
SlideContent { title: "Slide 2".to_string(), subtitle: None, content: vec![], notes: None, background: None },
|
||||
]).await;
|
||||
|
||||
// At slide 0, should not go below 0
|
||||
hand.execute_action(SlideshowAction::PrevSlide).await.unwrap();
|
||||
assert_eq!(hand.get_state().await.current_slide, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_goto_slide_out_of_range() {
|
||||
let hand = SlideshowHand::with_slides_async(vec![
|
||||
SlideContent { title: "Slide 1".to_string(), subtitle: None, content: vec![], notes: None, background: None },
|
||||
]).await;
|
||||
|
||||
let result = hand.execute_action(SlideshowAction::GotoSlide { slide_number: 5 }).await.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_goto_slide_returns_content() {
|
||||
let hand = SlideshowHand::with_slides_async(vec![
|
||||
SlideContent { title: "First".to_string(), subtitle: None, content: vec![], notes: None, background: None },
|
||||
SlideContent { title: "Second".to_string(), subtitle: None, content: vec![], notes: None, background: None },
|
||||
]).await;
|
||||
|
||||
let result = hand.execute_action(SlideshowAction::GotoSlide { slide_number: 1 }).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output["slide_content"]["title"], "Second");
|
||||
}
|
||||
|
||||
// === Spotlight & Laser & Highlight ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_spotlight() {
|
||||
let hand = SlideshowHand::new();
|
||||
@@ -384,6 +468,20 @@ mod tests {
|
||||
|
||||
let result = hand.execute_action(action).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output["element_id"], "title");
|
||||
assert_eq!(result.output["duration_ms"], 2000);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_spotlight_default_duration() {
|
||||
let hand = SlideshowHand::new();
|
||||
let action = SlideshowAction::Spotlight {
|
||||
element_id: "elem".to_string(),
|
||||
duration_ms: default_spotlight_duration(),
|
||||
};
|
||||
|
||||
let result = hand.execute_action(action).await.unwrap();
|
||||
assert_eq!(result.output["duration_ms"], 2000);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -397,8 +495,96 @@ mod tests {
|
||||
|
||||
let result = hand.execute_action(action).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output["x"], 100.0);
|
||||
assert_eq!(result.output["y"], 200.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_highlight_default_color() {
|
||||
let hand = SlideshowHand::new();
|
||||
let action = SlideshowAction::Highlight {
|
||||
x: 10.0, y: 20.0, width: 100.0, height: 50.0,
|
||||
color: None, duration_ms: 2000,
|
||||
};
|
||||
|
||||
let result = hand.execute_action(action).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output["color"], "#ffcc00");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_highlight_custom_color() {
|
||||
let hand = SlideshowHand::new();
|
||||
let action = SlideshowAction::Highlight {
|
||||
x: 0.0, y: 0.0, width: 50.0, height: 50.0,
|
||||
color: Some("#ff0000".to_string()), duration_ms: 1000,
|
||||
};
|
||||
|
||||
let result = hand.execute_action(action).await.unwrap();
|
||||
assert_eq!(result.output["color"], "#ff0000");
|
||||
}
|
||||
|
||||
// === AutoPlay / Pause / Resume ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_autoplay_pause_resume() {
|
||||
let hand = SlideshowHand::new();
|
||||
|
||||
// AutoPlay
|
||||
let result = hand.execute_action(SlideshowAction::AutoPlay { interval_ms: 3000 }).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(hand.get_state().await.is_playing);
|
||||
assert_eq!(hand.get_state().await.auto_play_interval_ms, 3000);
|
||||
|
||||
// Pause
|
||||
hand.execute_action(SlideshowAction::Pause).await.unwrap();
|
||||
assert!(!hand.get_state().await.is_playing);
|
||||
|
||||
// Resume
|
||||
hand.execute_action(SlideshowAction::Resume).await.unwrap();
|
||||
assert!(hand.get_state().await.is_playing);
|
||||
|
||||
// Stop
|
||||
hand.execute_action(SlideshowAction::StopAutoPlay).await.unwrap();
|
||||
assert!(!hand.get_state().await.is_playing);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_autoplay_default_interval() {
|
||||
let hand = SlideshowHand::new();
|
||||
hand.execute_action(SlideshowAction::AutoPlay { interval_ms: default_interval() }).await.unwrap();
|
||||
assert_eq!(hand.get_state().await.auto_play_interval_ms, 5000);
|
||||
}
|
||||
|
||||
// === PlayAnimation ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_play_animation() {
|
||||
let hand = SlideshowHand::new();
|
||||
let result = hand.execute_action(SlideshowAction::PlayAnimation {
|
||||
animation_id: "fade_in".to_string(),
|
||||
}).await.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output["animation_id"], "fade_in");
|
||||
}
|
||||
|
||||
// === GetState ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_state() {
|
||||
let hand = SlideshowHand::with_slides_async(vec![
|
||||
SlideContent { title: "A".to_string(), subtitle: None, content: vec![], notes: None, background: None },
|
||||
]).await;
|
||||
|
||||
let result = hand.execute_action(SlideshowAction::GetState).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output["total_slides"], 1);
|
||||
assert_eq!(result.output["current_slide"], 0);
|
||||
}
|
||||
|
||||
// === SetContent ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_content() {
|
||||
let hand = SlideshowHand::new();
|
||||
@@ -421,5 +607,188 @@ mod tests {
|
||||
|
||||
assert!(result.success);
|
||||
assert_eq!(hand.get_state().await.total_slides, 1);
|
||||
assert_eq!(hand.get_state().await.slides[0].title, "Test Slide");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_content_append() {
|
||||
let hand = SlideshowHand::with_slides_async(vec![
|
||||
SlideContent { title: "First".to_string(), subtitle: None, content: vec![], notes: None, background: None },
|
||||
]).await;
|
||||
|
||||
let content = SlideContent {
|
||||
title: "Appended".to_string(), subtitle: None, content: vec![], notes: None, background: None,
|
||||
};
|
||||
|
||||
let result = hand.execute_action(SlideshowAction::SetContent {
|
||||
slide_number: 1,
|
||||
content,
|
||||
}).await.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output["status"], "slide_added");
|
||||
assert_eq!(hand.get_state().await.total_slides, 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_content_invalid_index() {
|
||||
let hand = SlideshowHand::new();
|
||||
|
||||
let content = SlideContent {
|
||||
title: "Gap".to_string(), subtitle: None, content: vec![], notes: None, background: None,
|
||||
};
|
||||
|
||||
let result = hand.execute_action(SlideshowAction::SetContent {
|
||||
slide_number: 5,
|
||||
content,
|
||||
}).await.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
// === Action Deserialization ===
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_next_slide() {
|
||||
let action: SlideshowAction = serde_json::from_value(json!({"action": "next_slide"})).unwrap();
|
||||
assert!(matches!(action, SlideshowAction::NextSlide));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_goto_slide() {
|
||||
let action: SlideshowAction = serde_json::from_value(json!({"action": "goto_slide", "slide_number": 3})).unwrap();
|
||||
match action {
|
||||
SlideshowAction::GotoSlide { slide_number } => assert_eq!(slide_number, 3),
|
||||
_ => panic!("Expected GotoSlide"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_laser() {
|
||||
let action: SlideshowAction = serde_json::from_value(json!({
|
||||
"action": "laser", "x": 50.0, "y": 75.0
|
||||
})).unwrap();
|
||||
match action {
|
||||
SlideshowAction::Laser { x, y, .. } => {
|
||||
assert_eq!(x, 50.0);
|
||||
assert_eq!(y, 75.0);
|
||||
}
|
||||
_ => panic!("Expected Laser"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_autoplay() {
|
||||
let action: SlideshowAction = serde_json::from_value(json!({"action": "auto_play"})).unwrap();
|
||||
match action {
|
||||
SlideshowAction::AutoPlay { interval_ms } => assert_eq!(interval_ms, 5000),
|
||||
_ => panic!("Expected AutoPlay"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deserialize_invalid_action() {
|
||||
let result = serde_json::from_value::<SlideshowAction>(json!({"action": "nonexistent"}));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// === ContentBlock Deserialization ===
|
||||
|
||||
#[test]
|
||||
fn test_content_block_text() {
|
||||
let block: ContentBlock = serde_json::from_value(json!({
|
||||
"type": "text", "text": "Hello"
|
||||
})).unwrap();
|
||||
match block {
|
||||
ContentBlock::Text { text, style } => {
|
||||
assert_eq!(text, "Hello");
|
||||
assert!(style.is_none());
|
||||
}
|
||||
_ => panic!("Expected Text"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_block_list() {
|
||||
let block: ContentBlock = serde_json::from_value(json!({
|
||||
"type": "list", "items": ["A", "B"], "ordered": true
|
||||
})).unwrap();
|
||||
match block {
|
||||
ContentBlock::List { items, ordered } => {
|
||||
assert_eq!(items, vec!["A", "B"]);
|
||||
assert!(ordered);
|
||||
}
|
||||
_ => panic!("Expected List"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_block_code() {
|
||||
let block: ContentBlock = serde_json::from_value(json!({
|
||||
"type": "code", "code": "fn main() {}", "language": "rust"
|
||||
})).unwrap();
|
||||
match block {
|
||||
ContentBlock::Code { code, language } => {
|
||||
assert_eq!(code, "fn main() {}");
|
||||
assert_eq!(language, Some("rust".to_string()));
|
||||
}
|
||||
_ => panic!("Expected Code"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_content_block_table() {
|
||||
let block: ContentBlock = serde_json::from_value(json!({
|
||||
"type": "table",
|
||||
"headers": ["Name", "Age"],
|
||||
"rows": [["Alice", "30"]]
|
||||
})).unwrap();
|
||||
match block {
|
||||
ContentBlock::Table { headers, rows } => {
|
||||
assert_eq!(headers, vec!["Name", "Age"]);
|
||||
assert_eq!(rows, vec![vec!["Alice", "30"]]);
|
||||
}
|
||||
_ => panic!("Expected Table"),
|
||||
}
|
||||
}
|
||||
|
||||
// === Hand trait via execute ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hand_execute_dispatch() {
|
||||
let hand = SlideshowHand::with_slides_async(vec![
|
||||
SlideContent { title: "S1".to_string(), subtitle: None, content: vec![], notes: None, background: None },
|
||||
SlideContent { title: "S2".to_string(), subtitle: None, content: vec![], notes: None, background: None },
|
||||
]).await;
|
||||
|
||||
let ctx = HandContext::default();
|
||||
let result = hand.execute(&ctx, json!({"action": "next_slide"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output["current_slide"], 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hand_execute_invalid_action() {
|
||||
let hand = SlideshowHand::new();
|
||||
let ctx = HandContext::default();
|
||||
let result = hand.execute(&ctx, json!({"action": "invalid"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
}
|
||||
|
||||
// === add_slide helper ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_add_slide() {
|
||||
let hand = SlideshowHand::new();
|
||||
hand.add_slide(SlideContent {
|
||||
title: "Dynamic".to_string(), subtitle: None, content: vec![], notes: None, background: None,
|
||||
}).await;
|
||||
hand.add_slide(SlideContent {
|
||||
title: "Dynamic 2".to_string(), subtitle: None, content: vec![], notes: None, background: None,
|
||||
}).await;
|
||||
|
||||
let state = hand.get_state().await;
|
||||
assert_eq!(state.total_slides, 2);
|
||||
assert_eq!(state.slides.len(), 2);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -823,3 +823,417 @@ impl Hand for TwitterHand {
|
||||
crate::HandStatus::Idle
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Hand;
|
||||
use zclaw_types::id::AgentId;
|
||||
|
||||
fn make_context() -> HandContext {
|
||||
HandContext {
|
||||
agent_id: AgentId::new(),
|
||||
working_dir: None,
|
||||
env: std::collections::HashMap::new(),
|
||||
timeout_secs: 30,
|
||||
callback_url: None,
|
||||
}
|
||||
}
|
||||
|
||||
// === Config & Defaults ===
|
||||
|
||||
#[test]
|
||||
fn test_hand_config() {
|
||||
let hand = TwitterHand::new();
|
||||
assert_eq!(hand.config().id, "twitter");
|
||||
assert_eq!(hand.config().name, "Twitter 自动化");
|
||||
assert!(hand.config().needs_approval);
|
||||
assert!(hand.config().enabled);
|
||||
assert!(hand.config().tags.contains(&"twitter".to_string()));
|
||||
assert!(hand.config().input_schema.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_impl() {
|
||||
let hand = TwitterHand::default();
|
||||
assert_eq!(hand.config().id, "twitter");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_needs_approval() {
|
||||
let hand = TwitterHand::new();
|
||||
assert!(hand.needs_approval());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_status() {
|
||||
let hand = TwitterHand::new();
|
||||
assert_eq!(hand.status(), crate::HandStatus::Idle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_dependencies() {
|
||||
let hand = TwitterHand::new();
|
||||
let deps = hand.check_dependencies().unwrap();
|
||||
assert!(!deps.is_empty());
|
||||
}
|
||||
|
||||
// === Action Deserialization ===
|
||||
|
||||
#[test]
|
||||
fn test_tweet_action_deserialize() {
|
||||
let json = json!({
|
||||
"action": "tweet",
|
||||
"config": {
|
||||
"text": "Hello world!"
|
||||
}
|
||||
});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::Tweet { config } => {
|
||||
assert_eq!(config.text, "Hello world!");
|
||||
assert!(config.media_urls.is_empty());
|
||||
assert!(config.reply_to.is_none());
|
||||
assert!(config.quote_tweet.is_none());
|
||||
assert!(config.poll.is_none());
|
||||
}
|
||||
_ => panic!("Expected Tweet action"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tweet_action_with_reply() {
|
||||
let json = json!({
|
||||
"action": "tweet",
|
||||
"config": {
|
||||
"text": "@user reply",
|
||||
"replyTo": "123456"
|
||||
}
|
||||
});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::Tweet { config } => {
|
||||
assert_eq!(config.reply_to.as_deref(), Some("123456"));
|
||||
}
|
||||
_ => panic!("Expected Tweet action"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tweet_action_with_poll() {
|
||||
let json = json!({
|
||||
"action": "tweet",
|
||||
"config": {
|
||||
"text": "Vote!",
|
||||
"poll": {
|
||||
"options": ["A", "B", "C"],
|
||||
"durationMinutes": 60
|
||||
}
|
||||
}
|
||||
});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::Tweet { config } => {
|
||||
let poll = config.poll.unwrap();
|
||||
assert_eq!(poll.options, vec!["A", "B", "C"]);
|
||||
assert_eq!(poll.duration_minutes, 60);
|
||||
}
|
||||
_ => panic!("Expected Tweet action"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_tweet_action() {
|
||||
let json = json!({"action": "delete_tweet", "tweet_id": "789"});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::DeleteTweet { tweet_id } => assert_eq!(tweet_id, "789"),
|
||||
_ => panic!("Expected DeleteTweet"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_like_unlike_actions() {
|
||||
let like: TwitterAction = serde_json::from_value(json!({"action": "like", "tweet_id": "111"})).unwrap();
|
||||
match like {
|
||||
TwitterAction::Like { tweet_id } => assert_eq!(tweet_id, "111"),
|
||||
_ => panic!("Expected Like"),
|
||||
}
|
||||
|
||||
let unlike: TwitterAction = serde_json::from_value(json!({"action": "unlike", "tweet_id": "111"})).unwrap();
|
||||
match unlike {
|
||||
TwitterAction::Unlike { tweet_id } => assert_eq!(tweet_id, "111"),
|
||||
_ => panic!("Expected Unlike"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retweet_unretweet_actions() {
|
||||
let rt: TwitterAction = serde_json::from_value(json!({"action": "retweet", "tweet_id": "222"})).unwrap();
|
||||
match rt {
|
||||
TwitterAction::Retweet { tweet_id } => assert_eq!(tweet_id, "222"),
|
||||
_ => panic!("Expected Retweet"),
|
||||
}
|
||||
|
||||
let unrt: TwitterAction = serde_json::from_value(json!({"action": "unretweet", "tweet_id": "222"})).unwrap();
|
||||
match unrt {
|
||||
TwitterAction::Unretweet { tweet_id } => assert_eq!(tweet_id, "222"),
|
||||
_ => panic!("Expected Unretweet"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_action_defaults() {
|
||||
let json = json!({"action": "search", "config": {"query": "rust lang"}});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::Search { config } => {
|
||||
assert_eq!(config.query, "rust lang");
|
||||
assert_eq!(config.max_results, 10); // default
|
||||
assert!(config.next_token.is_none());
|
||||
}
|
||||
_ => panic!("Expected Search"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_action_custom_max() {
|
||||
let json = json!({"action": "search", "config": {"query": "test", "maxResults": 50}});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::Search { config } => assert_eq!(config.max_results, 50),
|
||||
_ => panic!("Expected Search"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timeline_action_defaults() {
|
||||
let json = json!({"action": "timeline", "config": {}});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::Timeline { config } => {
|
||||
assert!(config.user_id.is_none());
|
||||
assert_eq!(config.max_results, 10); // default
|
||||
assert!(!config.exclude_replies);
|
||||
assert!(config.include_retweets);
|
||||
}
|
||||
_ => panic!("Expected Timeline"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_tweet_action() {
|
||||
let json = json!({"action": "get_tweet", "tweet_id": "999"});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::GetTweet { tweet_id } => assert_eq!(tweet_id, "999"),
|
||||
_ => panic!("Expected GetTweet"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_user_action() {
|
||||
let json = json!({"action": "get_user", "username": "elonmusk"});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::GetUser { username } => assert_eq!(username, "elonmusk"),
|
||||
_ => panic!("Expected GetUser"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_followers_action() {
|
||||
let json = json!({"action": "followers", "user_id": "u1", "max_results": 50});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::Followers { user_id, max_results } => {
|
||||
assert_eq!(user_id, "u1");
|
||||
assert_eq!(max_results, Some(50));
|
||||
}
|
||||
_ => panic!("Expected Followers"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_following_action_no_max() {
|
||||
let json = json!({"action": "following", "user_id": "u2"});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::Following { user_id, max_results } => {
|
||||
assert_eq!(user_id, "u2");
|
||||
assert!(max_results.is_none());
|
||||
}
|
||||
_ => panic!("Expected Following"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_check_credentials_action() {
|
||||
let json = json!({"action": "check_credentials"});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::CheckCredentials => {}
|
||||
_ => panic!("Expected CheckCredentials"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_action() {
|
||||
let json = json!({"action": "invalid_action"});
|
||||
let result = serde_json::from_value::<TwitterAction>(json);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
// === Serialization Roundtrip ===
|
||||
|
||||
#[test]
|
||||
fn test_tweet_action_roundtrip() {
|
||||
let json = json!({
|
||||
"action": "tweet",
|
||||
"config": {
|
||||
"text": "Test tweet",
|
||||
"mediaUrls": ["https://example.com/img.jpg"],
|
||||
"replyTo": "123",
|
||||
"quoteTweet": "456"
|
||||
}
|
||||
});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
let serialized = serde_json::to_value(&action).unwrap();
|
||||
// Verify core fields survive roundtrip (camelCase via serde rename)
|
||||
assert_eq!(serialized["action"], "tweet");
|
||||
assert_eq!(serialized["config"]["text"], "Test tweet");
|
||||
assert_eq!(serialized["config"]["mediaUrls"][0], "https://example.com/img.jpg");
|
||||
assert_eq!(serialized["config"]["replyTo"], "123");
|
||||
assert_eq!(serialized["config"]["quoteTweet"], "456");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_search_action_roundtrip() {
|
||||
let json = json!({
|
||||
"action": "search",
|
||||
"config": {
|
||||
"query": "hello world",
|
||||
"maxResults": 25
|
||||
}
|
||||
});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
let serialized = serde_json::to_value(&action).unwrap();
|
||||
assert_eq!(serialized["action"], "search");
|
||||
assert_eq!(serialized["config"]["query"], "hello world");
|
||||
assert_eq!(serialized["config"]["maxResults"], 25);
|
||||
}
|
||||
|
||||
// === Credentials ===
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_and_get_credentials() {
|
||||
let hand = TwitterHand::new();
|
||||
|
||||
// Initially no credentials
|
||||
assert!(hand.get_credentials().await.is_none());
|
||||
|
||||
// Set credentials
|
||||
hand.set_credentials(TwitterCredentials {
|
||||
api_key: "key".into(),
|
||||
api_secret: "secret".into(),
|
||||
access_token: "token".into(),
|
||||
access_token_secret: "token_secret".into(),
|
||||
bearer_token: Some("bearer".into()),
|
||||
}).await;
|
||||
|
||||
let creds = hand.get_credentials().await.unwrap();
|
||||
assert_eq!(creds.api_key, "key");
|
||||
assert_eq!(creds.bearer_token.as_deref(), Some("bearer"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_credentials_without_config() {
|
||||
let hand = TwitterHand::new();
|
||||
let ctx = make_context();
|
||||
let result = hand.execute(&ctx, json!({"action": "check_credentials"})).await.unwrap();
|
||||
|
||||
// No "success" field in output → HandResult.success defaults to false
|
||||
assert!(!result.success);
|
||||
assert_eq!(result.output["configured"], false);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_check_credentials_with_config() {
|
||||
let hand = TwitterHand::new();
|
||||
hand.set_credentials(TwitterCredentials {
|
||||
api_key: "key".into(),
|
||||
api_secret: "secret".into(),
|
||||
access_token: "token".into(),
|
||||
access_token_secret: "token_secret".into(),
|
||||
bearer_token: Some("bearer".into()),
|
||||
}).await;
|
||||
|
||||
let ctx = make_context();
|
||||
let result = hand.execute(&ctx, json!({"action": "check_credentials"})).await.unwrap();
|
||||
|
||||
// execute_check_credentials returns {"configured": true, ...} without "success" field
|
||||
// HandResult.success = result["success"].as_bool().unwrap_or(false) = false
|
||||
// But the actual data is in output
|
||||
assert_eq!(result.output["configured"], true);
|
||||
assert_eq!(result.output["has_bearer_token"], true);
|
||||
}
|
||||
|
||||
// === Tweet Data Types ===
|
||||
|
||||
#[test]
|
||||
fn test_tweet_deserialize() {
|
||||
let json = json!({
|
||||
"id": "t123",
|
||||
"text": "Hello!",
|
||||
"authorId": "a456",
|
||||
"authorName": "Test User",
|
||||
"authorUsername": "testuser",
|
||||
"createdAt": "2026-01-01T00:00:00Z",
|
||||
"publicMetrics": {
|
||||
"retweetCount": 5,
|
||||
"replyCount": 2,
|
||||
"likeCount": 10,
|
||||
"quoteCount": 1,
|
||||
"impressionCount": 1000
|
||||
}
|
||||
});
|
||||
let tweet: Tweet = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(tweet.id, "t123");
|
||||
assert_eq!(tweet.public_metrics.like_count, 10);
|
||||
assert!(tweet.media.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_twitter_user_deserialize() {
|
||||
let json = json!({
|
||||
"id": "u1",
|
||||
"name": "Alice",
|
||||
"username": "alice",
|
||||
"verified": true,
|
||||
"publicMetrics": {
|
||||
"followersCount": 100,
|
||||
"followingCount": 50,
|
||||
"tweetCount": 1000,
|
||||
"listedCount": 5
|
||||
}
|
||||
});
|
||||
let user: TwitterUser = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(user.username, "alice");
|
||||
assert!(user.verified);
|
||||
assert_eq!(user.public_metrics.followers_count, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_media_info_deserialize() {
|
||||
let json = json!({
|
||||
"mediaKey": "mk1",
|
||||
"mediaType": "photo",
|
||||
"url": "https://pbs.example.com/photo.jpg",
|
||||
"width": 1200,
|
||||
"height": 800
|
||||
});
|
||||
let media: MediaInfo = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(media.media_type, "photo");
|
||||
assert_eq!(media.width, 1200);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -71,4 +71,19 @@ CREATE INDEX IF NOT EXISTS idx_kv_agent ON kv_store(agent_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_hand_runs_hand ON hand_runs(hand_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_hand_runs_status ON hand_runs(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_hand_runs_created ON hand_runs(created_at);
|
||||
|
||||
-- Structured facts table (extracted from conversations)
|
||||
CREATE TABLE IF NOT EXISTS facts (
|
||||
id TEXT PRIMARY KEY,
|
||||
agent_id TEXT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
category TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
source_session TEXT,
|
||||
created_at INTEGER NOT NULL,
|
||||
FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_agent ON facts(agent_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_category ON facts(agent_id, category);
|
||||
CREATE INDEX IF NOT EXISTS idx_facts_confidence ON facts(agent_id, confidence DESC);
|
||||
"#;
|
||||
|
||||
@@ -482,6 +482,76 @@ impl MemoryStore {
|
||||
Ok(count as u32)
|
||||
}
|
||||
|
||||
// === Fact CRUD ===
|
||||
|
||||
/// Store extracted facts for an agent (upsert by id).
|
||||
pub async fn store_facts(&self, agent_id: &str, facts: &[crate::fact::Fact]) -> Result<()> {
|
||||
for fact in facts {
|
||||
let category_str = serde_json::to_string(&fact.category)
|
||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||
// Trim the JSON quotes from serialized enum variant
|
||||
let category_clean = category_str.trim_matches('"');
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO facts (id, agent_id, content, category, confidence, source_session, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
confidence = excluded.confidence,
|
||||
source_session = excluded.source_session
|
||||
"#,
|
||||
)
|
||||
.bind(&fact.id)
|
||||
.bind(agent_id)
|
||||
.bind(&fact.content)
|
||||
.bind(category_clean)
|
||||
.bind(fact.confidence)
|
||||
.bind(&fact.source)
|
||||
.bind(fact.created_at as i64)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get top facts for an agent, ordered by confidence descending.
|
||||
pub async fn get_top_facts(&self, agent_id: &str, limit: usize) -> Result<Vec<crate::fact::Fact>> {
|
||||
let rows = sqlx::query_as::<_, (String, String, String, f64, Option<String>, i64)>(
|
||||
r#"
|
||||
SELECT id, content, category, confidence, source_session, created_at
|
||||
FROM facts
|
||||
WHERE agent_id = ?
|
||||
ORDER BY confidence DESC
|
||||
LIMIT ?
|
||||
"#,
|
||||
)
|
||||
.bind(agent_id)
|
||||
.bind(limit as i64)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||
|
||||
let mut facts = Vec::with_capacity(rows.len());
|
||||
for (id, content, category_str, confidence, source, created_at) in rows {
|
||||
let category: crate::fact::FactCategory = serde_json::from_value(
|
||||
serde_json::Value::String(category_str)
|
||||
).map_err(|e| ZclawError::StorageError(format!("Invalid category: {}", e)))?;
|
||||
|
||||
facts.push(crate::fact::Fact {
|
||||
id,
|
||||
content,
|
||||
category,
|
||||
confidence,
|
||||
created_at: created_at as u64,
|
||||
source,
|
||||
});
|
||||
}
|
||||
Ok(facts)
|
||||
}
|
||||
|
||||
fn row_to_hand_run(
|
||||
row: (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>),
|
||||
) -> Result<HandRun> {
|
||||
@@ -527,10 +597,13 @@ mod tests {
|
||||
description: None,
|
||||
model: ModelConfig::default(),
|
||||
system_prompt: None,
|
||||
soul: None,
|
||||
capabilities: vec![],
|
||||
tools: vec![],
|
||||
max_tokens: None,
|
||||
temperature: None,
|
||||
workspace: None,
|
||||
compaction_threshold: None,
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,6 +60,22 @@ impl AgentMiddleware for MemoryMiddleware {
|
||||
fn priority(&self) -> i32 { 150 }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
// Skip memory injection for very short queries.
|
||||
// Short queries (e.g., "1+6", "hi", "好") don't benefit from memory context.
|
||||
// Worse, the retriever's scope-based fallback may return high-importance but
|
||||
// irrelevant old memories, causing the model to think about past conversations
|
||||
// instead of answering the current question.
|
||||
// Use char count (not byte count) so CJK queries are handled correctly:
|
||||
// a single Chinese char is 3 UTF-8 bytes but 1 meaningful character.
|
||||
let query = ctx.user_input.trim();
|
||||
if query.chars().count() < 2 {
|
||||
tracing::debug!(
|
||||
"[MemoryMiddleware] Skipping enhancement for short query ({:?}): no memory context needed",
|
||||
query
|
||||
);
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
match self.growth.enhance_prompt(
|
||||
&ctx.agent_id,
|
||||
&ctx.system_prompt,
|
||||
@@ -92,21 +108,27 @@ impl AgentMiddleware for MemoryMiddleware {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
match self.growth.process_conversation(
|
||||
// Combined extraction: single LLM call produces both memories and structured facts.
|
||||
// Avoids double LLM extraction ( process_conversation + extract_structured_facts).
|
||||
match self.growth.extract_combined(
|
||||
&ctx.agent_id,
|
||||
&ctx.messages,
|
||||
ctx.session_id.clone(),
|
||||
&ctx.session_id,
|
||||
).await {
|
||||
Ok(count) => {
|
||||
Ok(Some((mem_count, facts))) => {
|
||||
tracing::info!(
|
||||
"[MemoryMiddleware] Extracted {} memories for agent {}",
|
||||
count,
|
||||
"[MemoryMiddleware] Extracted {} memories + {} structured facts for agent {}",
|
||||
mem_count,
|
||||
facts.len(),
|
||||
agent_key
|
||||
);
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::debug!("[MemoryMiddleware] No memories or facts extracted");
|
||||
}
|
||||
Err(e) => {
|
||||
// Non-fatal: extraction failure should not affect the response
|
||||
tracing::warn!("[MemoryMiddleware] Memory extraction failed: {}", e);
|
||||
tracing::warn!("[MemoryMiddleware] Combined extraction failed: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -26,14 +26,17 @@ chrono = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
pgvector = { version = "0.4", features = ["sqlx"] }
|
||||
reqwest = { workspace = true }
|
||||
secrecy = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
rsa = { workspace = true, features = ["sha2"] }
|
||||
base64 = { workspace = true }
|
||||
socket2 = { workspace = true }
|
||||
url = "2"
|
||||
url = { workspace = true }
|
||||
|
||||
axum = { workspace = true }
|
||||
axum-extra = { workspace = true }
|
||||
@@ -47,6 +50,7 @@ data-encoding = "2"
|
||||
regex = { workspace = true }
|
||||
aes-gcm = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
async-stream = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
-- Add is_embedding column to models table
|
||||
-- Distinguishes embedding models from chat/completion models
|
||||
ALTER TABLE models ADD COLUMN IF NOT EXISTS is_embedding BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
|
||||
-- Add model_type column for future extensibility (chat, embedding, image, audio, etc.)
|
||||
ALTER TABLE models ADD COLUMN IF NOT EXISTS model_type TEXT NOT NULL DEFAULT 'chat';
|
||||
|
||||
-- Index for quick filtering of embedding models
|
||||
CREATE INDEX IF NOT EXISTS idx_models_is_embedding ON models(is_embedding) WHERE is_embedding = TRUE;
|
||||
CREATE INDEX IF NOT EXISTS idx_models_model_type ON models(model_type);
|
||||
133
crates/zclaw-saas/migrations/20260402000001_billing_tables.sql
Normal file
133
crates/zclaw-saas/migrations/20260402000001_billing_tables.sql
Normal file
@@ -0,0 +1,133 @@
|
||||
-- Migration: Billing tables for subscription management
|
||||
-- Supports: Free/Pro/Team plans, Alipay + WeChat Pay, usage quotas
|
||||
|
||||
-- Plan definitions (Free/Pro/Team)
|
||||
CREATE TABLE IF NOT EXISTS billing_plans (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
display_name TEXT NOT NULL,
|
||||
description TEXT,
|
||||
price_cents INTEGER NOT NULL DEFAULT 0,
|
||||
currency TEXT NOT NULL DEFAULT 'CNY',
|
||||
interval TEXT NOT NULL DEFAULT 'month',
|
||||
features JSONB NOT NULL DEFAULT '{}',
|
||||
limits JSONB NOT NULL DEFAULT '{}',
|
||||
is_default BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
sort_order INTEGER NOT NULL DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_plans_status ON billing_plans(status);
|
||||
|
||||
-- Account subscriptions
|
||||
CREATE TABLE IF NOT EXISTS billing_subscriptions (
|
||||
id TEXT PRIMARY KEY,
|
||||
account_id TEXT NOT NULL,
|
||||
plan_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
current_period_start TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
current_period_end TIMESTAMPTZ NOT NULL,
|
||||
trial_end TIMESTAMPTZ,
|
||||
canceled_at TIMESTAMPTZ,
|
||||
cancel_at_period_end BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (plan_id) REFERENCES billing_plans(id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_sub_account ON billing_subscriptions(account_id);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_billing_sub_active
|
||||
ON billing_subscriptions(account_id)
|
||||
WHERE status IN ('trial', 'active', 'past_due');
|
||||
|
||||
-- Invoices
|
||||
CREATE TABLE IF NOT EXISTS billing_invoices (
|
||||
id TEXT PRIMARY KEY,
|
||||
account_id TEXT NOT NULL,
|
||||
subscription_id TEXT,
|
||||
plan_id TEXT,
|
||||
amount_cents INTEGER NOT NULL,
|
||||
currency TEXT NOT NULL DEFAULT 'CNY',
|
||||
description TEXT,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
due_at TIMESTAMPTZ,
|
||||
paid_at TIMESTAMPTZ,
|
||||
voided_at TIMESTAMPTZ,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (subscription_id) REFERENCES billing_subscriptions(id) ON DELETE SET NULL,
|
||||
FOREIGN KEY (plan_id) REFERENCES billing_plans(id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_inv_account ON billing_invoices(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_inv_status ON billing_invoices(status);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_inv_time ON billing_invoices(created_at);
|
||||
|
||||
-- Payment records (Alipay / WeChat Pay)
|
||||
CREATE TABLE IF NOT EXISTS billing_payments (
|
||||
id TEXT PRIMARY KEY,
|
||||
invoice_id TEXT NOT NULL,
|
||||
account_id TEXT NOT NULL,
|
||||
amount_cents INTEGER NOT NULL,
|
||||
currency TEXT NOT NULL DEFAULT 'CNY',
|
||||
method TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
external_trade_no TEXT,
|
||||
paid_at TIMESTAMPTZ,
|
||||
refunded_at TIMESTAMPTZ,
|
||||
failure_reason TEXT,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
FOREIGN KEY (invoice_id) REFERENCES billing_invoices(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_pay_invoice ON billing_payments(invoice_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_pay_account ON billing_payments(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_pay_trade_no ON billing_payments(external_trade_no);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_pay_status ON billing_payments(status);
|
||||
|
||||
-- Monthly usage quotas (per account per billing period)
|
||||
CREATE TABLE IF NOT EXISTS billing_usage_quotas (
|
||||
id TEXT PRIMARY KEY,
|
||||
account_id TEXT NOT NULL,
|
||||
period_start TIMESTAMPTZ NOT NULL,
|
||||
period_end TIMESTAMPTZ NOT NULL,
|
||||
input_tokens BIGINT NOT NULL DEFAULT 0,
|
||||
output_tokens BIGINT NOT NULL DEFAULT 0,
|
||||
relay_requests INTEGER NOT NULL DEFAULT 0,
|
||||
hand_executions INTEGER NOT NULL DEFAULT 0,
|
||||
pipeline_runs INTEGER NOT NULL DEFAULT 0,
|
||||
max_input_tokens BIGINT,
|
||||
max_output_tokens BIGINT,
|
||||
max_relay_requests INTEGER,
|
||||
max_hand_executions INTEGER,
|
||||
max_pipeline_runs INTEGER,
|
||||
metadata JSONB NOT NULL DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
|
||||
UNIQUE(account_id, period_start)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_usage_account ON billing_usage_quotas(account_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_billing_usage_period ON billing_usage_quotas(period_start, period_end);
|
||||
|
||||
-- Seed: default plans
|
||||
INSERT INTO billing_plans (id, name, display_name, description, price_cents, interval, features, limits, is_default, sort_order)
|
||||
VALUES
|
||||
('plan-free', 'free', '免费版', '基础功能,适合个人体验', 0, 'month',
|
||||
'{"hands": ["browser", "collector", "researcher"], "chat_modes": ["flash", "thinking"], "pipelines": 3, "support": "community"}'::jsonb,
|
||||
'{"max_input_tokens_monthly": 500000, "max_output_tokens_monthly": 500000, "max_relay_requests_monthly": 100, "max_hand_executions_monthly": 20, "max_pipeline_runs_monthly": 5}'::jsonb,
|
||||
TRUE, 0),
|
||||
('plan-pro', 'pro', '专业版', '全功能解锁,适合知识工作者', 4900, 'month',
|
||||
'{"hands": "all", "chat_modes": "all", "pipelines": -1, "support": "priority", "memory": true, "export": true}'::jsonb,
|
||||
'{"max_input_tokens_monthly": 5000000, "max_output_tokens_monthly": 5000000, "max_relay_requests_monthly": 2000, "max_hand_executions_monthly": 200, "max_pipeline_runs_monthly": 100}'::jsonb,
|
||||
FALSE, 1),
|
||||
('plan-team', 'team', '团队版', '多席位协作,适合企业团队', 19900, 'month',
|
||||
'{"hands": "all", "chat_modes": "all", "pipelines": -1, "support": "dedicated", "memory": true, "export": true, "sharing": true, "admin": true}'::jsonb,
|
||||
'{"max_input_tokens_monthly": 50000000, "max_output_tokens_monthly": 50000000, "max_relay_requests_monthly": 20000, "max_hand_executions_monthly": 1000, "max_pipeline_runs_monthly": 500}'::jsonb,
|
||||
FALSE, 2)
|
||||
ON CONFLICT (name) DO NOTHING;
|
||||
123
crates/zclaw-saas/migrations/20260402000002_knowledge_base.sql
Normal file
123
crates/zclaw-saas/migrations/20260402000002_knowledge_base.sql
Normal file
@@ -0,0 +1,123 @@
|
||||
-- Migration: Knowledge Base tables with pgvector support
|
||||
-- 5 tables: knowledge_categories, knowledge_items, knowledge_chunks,
|
||||
-- knowledge_versions, knowledge_usage
|
||||
|
||||
-- Enable pgvector extension
|
||||
CREATE EXTENSION IF NOT EXISTS vector;
|
||||
|
||||
-- 行业分类树
|
||||
CREATE TABLE IF NOT EXISTS knowledge_categories (
|
||||
id TEXT PRIMARY KEY,
|
||||
name VARCHAR(100) NOT NULL,
|
||||
description TEXT,
|
||||
parent_id TEXT REFERENCES knowledge_categories(id) ON DELETE RESTRICT,
|
||||
icon VARCHAR(50),
|
||||
sort_order INT DEFAULT 0,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
CHECK (id != parent_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_kc_parent ON knowledge_categories(parent_id);
|
||||
|
||||
-- 知识条目
|
||||
CREATE TABLE IF NOT EXISTS knowledge_items (
|
||||
id TEXT PRIMARY KEY,
|
||||
category_id TEXT NOT NULL REFERENCES knowledge_categories(id) ON DELETE RESTRICT,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
keywords TEXT[] DEFAULT '{}',
|
||||
related_questions TEXT[] DEFAULT '{}',
|
||||
priority INT DEFAULT 0,
|
||||
status VARCHAR(20) DEFAULT 'active' CHECK (status IN ('active', 'archived', 'deprecated', 'draft')),
|
||||
version INT DEFAULT 1,
|
||||
source VARCHAR(50) DEFAULT 'manual',
|
||||
tags TEXT[] DEFAULT '{}',
|
||||
created_by TEXT NOT NULL REFERENCES accounts(id),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
CHECK (length(content) <= 100000)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_ki_category ON knowledge_items(category_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_ki_status_updated ON knowledge_items(status, updated_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_ki_keywords ON knowledge_items USING GIN(keywords);
|
||||
|
||||
-- 知识分块(RAG 检索核心)
|
||||
CREATE TABLE IF NOT EXISTS knowledge_chunks (
|
||||
id TEXT PRIMARY KEY,
|
||||
item_id TEXT NOT NULL REFERENCES knowledge_items(id) ON DELETE CASCADE,
|
||||
chunk_index INT NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
embedding vector(1536),
|
||||
keywords TEXT[] DEFAULT '{}',
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_kchunks_item_idx ON knowledge_chunks(item_id, chunk_index);
|
||||
CREATE INDEX IF NOT EXISTS idx_kchunks_item ON knowledge_chunks(item_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_kchunks_keywords ON knowledge_chunks USING GIN(keywords);
|
||||
|
||||
-- 向量相似度索引(HNSW,无需预填充数据)
|
||||
-- 仅在有数据后创建此索引可提升性能,这里预创建
|
||||
CREATE INDEX IF NOT EXISTS idx_kchunks_embedding ON knowledge_chunks
|
||||
USING hnsw (embedding vector_cosine_ops)
|
||||
WITH (m = 16, ef_construction = 128);
|
||||
|
||||
-- 版本快照
|
||||
CREATE TABLE IF NOT EXISTS knowledge_versions (
|
||||
id TEXT PRIMARY KEY,
|
||||
item_id TEXT NOT NULL REFERENCES knowledge_items(id) ON DELETE CASCADE,
|
||||
version INT NOT NULL,
|
||||
title VARCHAR(255) NOT NULL,
|
||||
content TEXT NOT NULL,
|
||||
keywords TEXT[] DEFAULT '{}',
|
||||
related_questions TEXT[] DEFAULT '{}',
|
||||
change_summary TEXT,
|
||||
created_by TEXT NOT NULL REFERENCES accounts(id),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_kv_item ON knowledge_versions(item_id);
|
||||
|
||||
-- 使用追踪
|
||||
CREATE TABLE IF NOT EXISTS knowledge_usage (
|
||||
id TEXT PRIMARY KEY,
|
||||
item_id TEXT REFERENCES knowledge_items(id) ON DELETE SET NULL,
|
||||
chunk_id TEXT REFERENCES knowledge_chunks(id) ON DELETE SET NULL,
|
||||
session_id VARCHAR(100),
|
||||
query_text TEXT,
|
||||
relevance_score FLOAT,
|
||||
was_injected BOOLEAN DEFAULT FALSE,
|
||||
agent_feedback VARCHAR(20) CHECK (agent_feedback IN ('positive', 'negative')),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW()
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_ku_item ON knowledge_usage(item_id) WHERE item_id IS NOT NULL;
|
||||
-- BRIN 索引:追加写入的时间序列数据比 B-tree 更高效
|
||||
CREATE INDEX IF NOT EXISTS idx_ku_created_brin ON knowledge_usage USING brin(created_at);
|
||||
|
||||
-- 权限种子数据(使用 jsonb 操作避免 REPLACE 脆弱性)
|
||||
UPDATE roles
|
||||
SET permissions = (
|
||||
SELECT '[' || string_agg('"' || elem || '"', ', ') || ']'
|
||||
FROM (
|
||||
SELECT DISTINCT elem
|
||||
FROM json_array_elements_text(permissions::json) AS elem
|
||||
UNION ALL SELECT 'knowledge:read'
|
||||
UNION ALL SELECT 'knowledge:write'
|
||||
UNION ALL SELECT 'knowledge:admin'
|
||||
UNION ALL SELECT 'knowledge:search'
|
||||
) sub
|
||||
)
|
||||
WHERE id = 'super_admin'
|
||||
AND permissions NOT LIKE '%knowledge:read%';
|
||||
|
||||
UPDATE roles
|
||||
SET permissions = (
|
||||
SELECT '[' || string_agg('"' || elem || '"', ', ') || ']'
|
||||
FROM (
|
||||
SELECT DISTINCT elem
|
||||
FROM json_array_elements_text(permissions::json) AS elem
|
||||
UNION ALL SELECT 'knowledge:read'
|
||||
UNION ALL SELECT 'knowledge:write'
|
||||
UNION ALL SELECT 'knowledge:search'
|
||||
) sub
|
||||
)
|
||||
WHERE id = 'admin'
|
||||
AND permissions NOT LIKE '%knowledge:read%';
|
||||
@@ -0,0 +1,5 @@
|
||||
-- Add execution result columns to scheduled_tasks
|
||||
-- Tracks the output and duration of each task execution for observability
|
||||
|
||||
ALTER TABLE scheduled_tasks ADD COLUMN IF NOT EXISTS last_result TEXT;
|
||||
ALTER TABLE scheduled_tasks ADD COLUMN IF NOT EXISTS last_duration_ms INTEGER;
|
||||
@@ -67,14 +67,17 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
||||
}
|
||||
}
|
||||
|
||||
// 异步更新 last_used_at(不阻塞请求)
|
||||
let db = state.db.clone();
|
||||
tokio::spawn(async move {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
|
||||
.bind(&now).bind(&token_hash)
|
||||
.execute(&db).await;
|
||||
});
|
||||
// 异步更新 last_used_at — 通过 Worker 通道派发,受 SpawnLimiter 门控
|
||||
// 替换原来的 tokio::spawn(DB UPDATE),消除每请求无限制 spawn
|
||||
{
|
||||
use crate::workers::update_last_used::UpdateLastUsedArgs;
|
||||
let args = UpdateLastUsedArgs {
|
||||
token_hash: token_hash.to_string(),
|
||||
};
|
||||
if let Err(e) = state.worker_dispatcher.dispatch("update_last_used", args).await {
|
||||
tracing::debug!("Failed to dispatch update_last_used: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(AuthContext {
|
||||
account_id,
|
||||
@@ -84,23 +87,43 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
|
||||
})
|
||||
}
|
||||
|
||||
/// 从请求中提取客户端 IP
|
||||
fn extract_client_ip(req: &Request) -> Option<String> {
|
||||
// 优先从 ConnectInfo 获取
|
||||
if let Some(ConnectInfo(addr)) = req.extensions().get::<ConnectInfo<SocketAddr>>() {
|
||||
return Some(addr.ip().to_string());
|
||||
/// 从请求中提取客户端 IP(安全版:仅对 trusted_proxies 解析 XFF)
|
||||
fn extract_client_ip(req: &Request, trusted_proxies: &[String]) -> Option<String> {
|
||||
// 优先从 ConnectInfo 获取直接连接 IP
|
||||
let connect_ip = req.extensions()
|
||||
.get::<ConnectInfo<SocketAddr>>()
|
||||
.map(|ConnectInfo(addr)| addr.ip().to_string());
|
||||
|
||||
// 仅当直接连接 IP 在 trusted_proxies 中时,才信任 XFF/X-Real-IP
|
||||
if let Some(ref ip) = connect_ip {
|
||||
if trusted_proxies.iter().any(|p| p == ip) {
|
||||
// 受信代理 → 从 XFF 取真实客户端 IP
|
||||
if let Some(forwarded) = req.headers()
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
if let Some(client) = forwarded.split(',').next() {
|
||||
let trimmed = client.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
// 尝试 X-Real-IP
|
||||
if let Some(real_ip) = req.headers()
|
||||
.get("x-real-ip")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
let trimmed = real_ip.trim();
|
||||
if !trimmed.is_empty() {
|
||||
return Some(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 回退到 X-Forwarded-For / X-Real-IP
|
||||
if let Some(forwarded) = req.headers()
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
{
|
||||
return Some(forwarded.split(',').next()?.trim().to_string());
|
||||
}
|
||||
req.headers()
|
||||
.get("x-real-ip")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
|
||||
// 非受信来源或无代理头 → 返回直接连接 IP
|
||||
connect_ip
|
||||
}
|
||||
|
||||
/// 认证中间件: 从 JWT Cookie / Authorization Header / API Token 提取身份
|
||||
@@ -110,7 +133,10 @@ pub async fn auth_middleware(
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let client_ip = extract_client_ip(&req);
|
||||
let client_ip = {
|
||||
let config = state.config.read().await;
|
||||
extract_client_ip(&req, &config.server.trusted_proxies)
|
||||
};
|
||||
let auth_header = req.headers()
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
395
crates/zclaw-saas/src/billing/handlers.rs
Normal file
395
crates/zclaw-saas/src/billing/handlers.rs
Normal file
@@ -0,0 +1,395 @@
|
||||
//! 计费 HTTP 处理器
|
||||
|
||||
use axum::{
|
||||
extract::{Extension, Form, Path, Query, State},
|
||||
Json,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::auth::types::AuthContext;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::state::AppState;
|
||||
use super::service;
|
||||
use super::types::*;
|
||||
|
||||
/// GET /api/v1/billing/plans — 列出所有活跃计划
|
||||
pub async fn list_plans(
|
||||
State(state): State<AppState>,
|
||||
) -> SaasResult<Json<Vec<BillingPlan>>> {
|
||||
let plans = service::list_plans(&state.db).await?;
|
||||
Ok(Json(plans))
|
||||
}
|
||||
|
||||
/// GET /api/v1/billing/plans/:id — 获取单个计划详情
|
||||
pub async fn get_plan(
|
||||
State(state): State<AppState>,
|
||||
Path(plan_id): Path<String>,
|
||||
) -> SaasResult<Json<BillingPlan>> {
|
||||
let plan = service::get_plan(&state.db, &plan_id).await?
|
||||
.ok_or_else(|| crate::error::SaasError::NotFound("计划不存在".into()))?;
|
||||
Ok(Json(plan))
|
||||
}
|
||||
|
||||
/// GET /api/v1/billing/subscription — 获取当前订阅
|
||||
pub async fn get_subscription(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
let plan = service::get_account_plan(&state.db, &ctx.account_id).await?;
|
||||
let sub = service::get_active_subscription(&state.db, &ctx.account_id).await?;
|
||||
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"plan": plan,
|
||||
"subscription": sub,
|
||||
"usage": usage,
|
||||
})))
|
||||
}
|
||||
|
||||
/// GET /api/v1/billing/usage — 获取当月用量
|
||||
pub async fn get_usage(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<UsageQuota>> {
|
||||
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
|
||||
Ok(Json(usage))
|
||||
}
|
||||
|
||||
/// POST /api/v1/billing/usage/increment — 客户端上报用量(Hand/Pipeline 执行后调用)
|
||||
///
|
||||
/// 请求体: `{ "dimension": "hand_executions" | "pipeline_runs" | "relay_requests", "count": 1 }`
|
||||
/// 需要认证 — account_id 从 JWT 提取。
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct IncrementUsageRequest {
|
||||
/// 用量维度:hand_executions / pipeline_runs / relay_requests
|
||||
pub dimension: String,
|
||||
/// 递增数量,默认 1
|
||||
#[serde(default = "default_count")]
|
||||
pub count: i32,
|
||||
}
|
||||
|
||||
fn default_count() -> i32 { 1 }
|
||||
|
||||
pub async fn increment_usage_dimension(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<IncrementUsageRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
// 验证维度白名单
|
||||
if !["hand_executions", "pipeline_runs", "relay_requests"].contains(&req.dimension.as_str()) {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("无效的用量维度: {},支持: hand_executions / pipeline_runs / relay_requests", req.dimension)
|
||||
));
|
||||
}
|
||||
|
||||
// 限制单次递增上限(防滥用)
|
||||
if req.count < 1 || req.count > 100 {
|
||||
return Err(SaasError::InvalidInput(
|
||||
format!("count 必须在 1~100 范围内,得到: {}", req.count)
|
||||
));
|
||||
}
|
||||
|
||||
// 单次原子更新,避免循环 N 次数据库查询
|
||||
service::increment_dimension_by(&state.db, &ctx.account_id, &req.dimension, req.count).await?;
|
||||
|
||||
// 返回更新后的用量
|
||||
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
|
||||
Ok(Json(serde_json::json!({
|
||||
"dimension": req.dimension,
|
||||
"incremented": req.count,
|
||||
"usage": usage,
|
||||
})))
|
||||
}
|
||||
|
||||
/// POST /api/v1/billing/payments — 创建支付订单
|
||||
pub async fn create_payment(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<CreatePaymentRequest>,
|
||||
) -> SaasResult<Json<PaymentResult>> {
|
||||
let config = state.config.read().await;
|
||||
let result = super::payment::create_payment(
|
||||
&state.db,
|
||||
&ctx.account_id,
|
||||
&req,
|
||||
&config.payment,
|
||||
).await?;
|
||||
Ok(Json(result))
|
||||
}
|
||||
|
||||
/// GET /api/v1/billing/payments/:id — 查询支付状态
|
||||
pub async fn get_payment_status(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(payment_id): Path<String>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
let status = super::payment::query_payment_status(
|
||||
&state.db,
|
||||
&payment_id,
|
||||
&ctx.account_id,
|
||||
).await?;
|
||||
Ok(Json(status))
|
||||
}
|
||||
|
||||
/// POST /api/v1/billing/callback/:method — 支付回调(支付宝/微信异步通知)
|
||||
pub async fn payment_callback(
|
||||
State(state): State<AppState>,
|
||||
Path(method): Path<String>,
|
||||
body: axum::body::Bytes,
|
||||
) -> SaasResult<String> {
|
||||
tracing::info!("Payment callback received: method={}, body_len={}", method, body.len());
|
||||
|
||||
let body_str = String::from_utf8_lossy(&body);
|
||||
let config = state.config.read().await;
|
||||
|
||||
let (trade_no, status, callback_amount) = if method == "alipay" {
|
||||
parse_alipay_callback(&body_str, &config.payment)?
|
||||
} else if method == "wechat" {
|
||||
parse_wechat_callback(&body_str, &config.payment)?
|
||||
} else {
|
||||
tracing::warn!("Unknown payment callback method: {}", method);
|
||||
return Ok("fail".into());
|
||||
};
|
||||
|
||||
// trade_no 是必填字段,缺失说明回调格式异常
|
||||
let trade_no = trade_no.ok_or_else(|| {
|
||||
tracing::warn!("Payment callback missing out_trade_no: method={}", method);
|
||||
SaasError::InvalidInput("回调缺少交易号".into())
|
||||
})?;
|
||||
|
||||
if let Err(e) = super::payment::handle_payment_callback(&state.db, &trade_no, &status, callback_amount).await {
|
||||
// 对外返回通用错误,不泄露内部细节
|
||||
tracing::error!("Payment callback processing failed: method={}, error={}", method, e);
|
||||
return Ok("fail".into());
|
||||
}
|
||||
|
||||
// 支付宝期望 "success",微信期望 JSON
|
||||
if method == "alipay" {
|
||||
Ok("success".into())
|
||||
} else {
|
||||
Ok(r#"{"code":"SUCCESS","message":"OK"}"#.into())
|
||||
}
|
||||
}
|
||||
|
||||
// === Mock 支付(开发模式) ===
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct MockPayQuery {
|
||||
trade_no: String,
|
||||
amount: i32,
|
||||
subject: String,
|
||||
}
|
||||
|
||||
/// GET /api/v1/billing/mock-pay — 开发模式 Mock 支付页面
|
||||
pub async fn mock_pay_page(
|
||||
Query(params): Query<MockPayQuery>,
|
||||
) -> axum::response::Html<String> {
|
||||
// HTML 转义防止 XSS
|
||||
let safe_subject = html_escape(¶ms.subject);
|
||||
let safe_trade_no = html_escape(¶ms.trade_no);
|
||||
let amount_yuan = params.amount as f64 / 100.0;
|
||||
|
||||
axum::response::Html(format!(r#"
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh">
|
||||
<head><meta charset="utf-8"><title>Mock 支付</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20px; }}
|
||||
.card {{ background: #fff; border-radius: 12px; padding: 24px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }}
|
||||
.amount {{ font-size: 32px; font-weight: 700; color: #333; text-align: center; margin: 20px 0; }}
|
||||
.btn {{ display: block; width: 100%; padding: 12px; border: none; border-radius: 8px; font-size: 16px; cursor: pointer; margin-top: 12px; }}
|
||||
.btn-pay {{ background: #1677ff; color: #fff; }}
|
||||
.btn-pay:hover {{ background: #0958d9; }}
|
||||
.btn-fail {{ background: #f5f5f5; color: #999; }}
|
||||
.subject {{ text-align: center; color: #666; font-size: 14px; }}
|
||||
</style></head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<div class="subject">{safe_subject}</div>
|
||||
<div class="amount">¥{amount_yuan}</div>
|
||||
<div style="text-align:center;color:#999;font-size:12px;margin-bottom:16px;">
|
||||
订单号: {safe_trade_no}
|
||||
</div>
|
||||
<form action="/api/v1/billing/mock-pay/confirm" method="POST">
|
||||
<input type="hidden" name="trade_no" value="{safe_trade_no}" />
|
||||
<button type="submit" name="action" value="success" class="btn btn-pay">确认支付 ¥{amount_yuan}</button>
|
||||
<button type="submit" name="action" value="fail" class="btn btn-fail">模拟失败</button>
|
||||
</form>
|
||||
</div>
|
||||
</body></html>
|
||||
"#))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct MockPayConfirm {
|
||||
trade_no: String,
|
||||
action: String,
|
||||
}
|
||||
|
||||
/// POST /api/v1/billing/mock-pay/confirm — Mock 支付确认
|
||||
pub async fn mock_pay_confirm(
|
||||
State(state): State<AppState>,
|
||||
Form(form): Form<MockPayConfirm>,
|
||||
) -> SaasResult<axum::response::Html<String>> {
|
||||
let status = if form.action == "success" { "success" } else { "failed" };
|
||||
|
||||
if let Err(e) = super::payment::handle_payment_callback(&state.db, &form.trade_no, status, None).await {
|
||||
tracing::error!("Mock payment callback failed: {}", e);
|
||||
}
|
||||
|
||||
let msg = if status == "success" {
|
||||
"支付成功!您可以关闭此页面。"
|
||||
} else {
|
||||
"支付已取消。"
|
||||
};
|
||||
|
||||
Ok(axum::response::Html(format!(r#"
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh">
|
||||
<head><meta charset="utf-8"><title>支付结果</title>
|
||||
<style>
|
||||
body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20px; text-align: center; }}
|
||||
.msg {{ font-size: 18px; color: #333; margin: 40px 0; }}
|
||||
</style></head>
|
||||
<body><div class="msg">{msg}</div></body>
|
||||
</html>
|
||||
"#)))
|
||||
}
|
||||
|
||||
// === 回调解析 ===
|
||||
|
||||
/// 解析支付宝回调并验签,返回 (trade_no, status, callback_amount_cents)
|
||||
fn parse_alipay_callback(
|
||||
body: &str,
|
||||
config: &crate::config::PaymentConfig,
|
||||
) -> SaasResult<(Option<String>, String, Option<i32>)> {
|
||||
// form-urlencoded → key=value 对
|
||||
let mut params: Vec<(String, String)> = Vec::new();
|
||||
for pair in body.split('&') {
|
||||
if let Some((k, v)) = pair.split_once('=') {
|
||||
params.push((
|
||||
k.to_string(),
|
||||
urlencoding::decode(v).unwrap_or_default().to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let mut trade_no = None;
|
||||
let mut callback_amount: Option<i32> = None;
|
||||
|
||||
// 验签:生产环境强制,开发环境允许跳过
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
|
||||
if let Some(ref public_key) = config.alipay_public_key {
|
||||
match super::payment::verify_alipay_callback(¶ms, public_key) {
|
||||
Ok(true) => {}
|
||||
Ok(false) => {
|
||||
tracing::warn!("Alipay callback signature verification FAILED");
|
||||
return Err(SaasError::InvalidInput("支付宝回调验签失败".into()));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Alipay callback verification error: {}", e);
|
||||
return Err(SaasError::InvalidInput("支付宝回调验签异常".into()));
|
||||
}
|
||||
}
|
||||
} else if !is_dev {
|
||||
tracing::error!("Alipay public key not configured in production — rejecting callback");
|
||||
return Err(SaasError::InvalidInput("支付宝公钥未配置,无法验签".into()));
|
||||
} else {
|
||||
tracing::warn!("Alipay public key not configured (dev mode), skipping signature verification");
|
||||
}
|
||||
|
||||
// 提取 trade_no、trade_status 和 total_amount
|
||||
let mut trade_status = "unknown".to_string();
|
||||
for (k, v) in ¶ms {
|
||||
match k.as_str() {
|
||||
"out_trade_no" => trade_no = Some(v.clone()),
|
||||
"trade_status" => trade_status = v.clone(),
|
||||
"total_amount" => {
|
||||
// 支付宝金额为元(字符串),转为分(整数)
|
||||
if let Ok(yuan) = v.parse::<f64>() {
|
||||
callback_amount = Some((yuan * 100.0).round() as i32);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// 支付宝成功状态映射
|
||||
let status = if trade_status == "TRADE_SUCCESS" || trade_status == "TRADE_FINISHED" {
|
||||
"TRADE_SUCCESS"
|
||||
} else {
|
||||
&trade_status
|
||||
};
|
||||
|
||||
Ok((trade_no, status.to_string(), callback_amount))
|
||||
}
|
||||
|
||||
/// 解析微信支付回调,解密 resource 字段,返回 (trade_no, status, callback_amount_cents)
|
||||
fn parse_wechat_callback(
|
||||
body: &str,
|
||||
config: &crate::config::PaymentConfig,
|
||||
) -> SaasResult<(Option<String>, String, Option<i32>)> {
|
||||
let v: serde_json::Value = serde_json::from_str(body)
|
||||
.map_err(|e| SaasError::InvalidInput(format!("微信回调 JSON 解析失败: {}", e)))?;
|
||||
|
||||
let event_type = v.get("event_type")
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if event_type != "TRANSACTION.SUCCESS" {
|
||||
// 非支付成功事件(如退款等),忽略
|
||||
return Ok((None, event_type.to_string(), None));
|
||||
}
|
||||
|
||||
// 解密 resource 字段
|
||||
let resource = v.get("resource")
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信回调缺少 resource 字段".into()))?;
|
||||
|
||||
let ciphertext = resource.get("ciphertext")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 ciphertext".into()))?;
|
||||
let nonce = resource.get("nonce")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 nonce".into()))?;
|
||||
let associated_data = resource.get("associated_data")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let api_v3_key = config.wechat_api_v3_key.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信 API v3 密钥未配置,无法解密回调".into()))?;
|
||||
|
||||
let plaintext = super::payment::decrypt_wechat_resource(
|
||||
ciphertext, nonce, associated_data, api_v3_key,
|
||||
)?;
|
||||
|
||||
let decrypted: serde_json::Value = serde_json::from_str(&plaintext)
|
||||
.map_err(|e| SaasError::Internal(format!("微信回调解密内容 JSON 解析失败: {}", e)))?;
|
||||
|
||||
let trade_no = decrypted.get("out_trade_no")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
let trade_state = decrypted.get("trade_state")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("UNKNOWN");
|
||||
|
||||
// 微信金额已为分(整数)
|
||||
let callback_amount = decrypted.get("amount")
|
||||
.and_then(|a| a.get("total"))
|
||||
.and_then(|v| v.as_i64())
|
||||
.map(|v| v as i32);
|
||||
|
||||
Ok((trade_no, trade_state.to_string(), callback_amount))
|
||||
}
|
||||
|
||||
/// HTML 转义,防止 XSS 注入
|
||||
fn html_escape(s: &str) -> String {
|
||||
s.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
.replace('\'', "'")
|
||||
}
|
||||
33
crates/zclaw-saas/src/billing/mod.rs
Normal file
33
crates/zclaw-saas/src/billing/mod.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
//! 计费模块 — 计划管理、订阅、用量配额、支付
|
||||
|
||||
pub mod types;
|
||||
pub mod service;
|
||||
pub mod handlers;
|
||||
pub mod payment;
|
||||
|
||||
use axum::routing::{get, post};
|
||||
|
||||
/// 需要认证的计费路由
|
||||
pub fn routes() -> axum::Router<crate::state::AppState> {
|
||||
axum::Router::new()
|
||||
.route("/api/v1/billing/plans", get(handlers::list_plans))
|
||||
.route("/api/v1/billing/plans/{id}", get(handlers::get_plan))
|
||||
.route("/api/v1/billing/subscription", get(handlers::get_subscription))
|
||||
.route("/api/v1/billing/usage", get(handlers::get_usage))
|
||||
.route("/api/v1/billing/usage/increment", post(handlers::increment_usage_dimension))
|
||||
.route("/api/v1/billing/payments", post(handlers::create_payment))
|
||||
.route("/api/v1/billing/payments/{id}", get(handlers::get_payment_status))
|
||||
}
|
||||
|
||||
/// 支付回调路由(无需 auth — 支付宝/微信服务器回调)
|
||||
pub fn callback_routes() -> axum::Router<crate::state::AppState> {
|
||||
axum::Router::new()
|
||||
.route("/api/v1/billing/callback/{method}", post(handlers::payment_callback))
|
||||
}
|
||||
|
||||
/// mock 支付页面路由(开发模式)
|
||||
pub fn mock_routes() -> axum::Router<crate::state::AppState> {
|
||||
axum::Router::new()
|
||||
.route("/api/v1/billing/mock-pay", get(handlers::mock_pay_page))
|
||||
.route("/api/v1/billing/mock-pay/confirm", post(handlers::mock_pay_confirm))
|
||||
}
|
||||
647
crates/zclaw-saas/src/billing/payment.rs
Normal file
647
crates/zclaw-saas/src/billing/payment.rs
Normal file
@@ -0,0 +1,647 @@
|
||||
//! 支付集成 — 支付宝/微信支付(直连 HTTP 实现)
|
||||
//!
|
||||
//! 不依赖第三方 SDK,使用 `rsa` crate 做 RSA2 签名,`reqwest` 做 HTTP 调用。
|
||||
//! 开发模式(`ZCLAW_SAAS_DEV=true`)使用 mock 支付。
|
||||
|
||||
use sqlx::PgPool;
|
||||
use crate::config::PaymentConfig;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use super::types::*;
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 公开 API
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// 创建支付订单,返回支付链接/二维码 URL
|
||||
///
|
||||
/// 发票和支付记录在事务中创建,确保原子性。
|
||||
pub async fn create_payment(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
req: &CreatePaymentRequest,
|
||||
config: &PaymentConfig,
|
||||
) -> SaasResult<PaymentResult> {
|
||||
// 1. 获取计划信息
|
||||
let plan = sqlx::query_as::<_, BillingPlan>(
|
||||
"SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'"
|
||||
)
|
||||
.bind(&req.plan_id)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.ok_or_else(|| SaasError::NotFound("计划不存在或已下架".into()))?;
|
||||
|
||||
// 检查是否已有活跃订阅
|
||||
let existing = sqlx::query_scalar::<_, i64>(
|
||||
"SELECT COUNT(*) FROM billing_subscriptions \
|
||||
WHERE account_id = $1 AND status IN ('trial', 'active') AND plan_id = $2"
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(&req.plan_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if existing > 0 {
|
||||
return Err(SaasError::InvalidInput("已订阅该计划".into()));
|
||||
}
|
||||
|
||||
// 2. 在事务中创建发票和支付记录
|
||||
let mut tx = pool.begin().await
|
||||
.map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?;
|
||||
|
||||
let invoice_id = uuid::Uuid::new_v4().to_string();
|
||||
let now = chrono::Utc::now();
|
||||
let due = now + chrono::Duration::days(1);
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO billing_invoices \
|
||||
(id, account_id, plan_id, amount_cents, currency, description, status, due_at, created_at, updated_at) \
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7, $8, $8)"
|
||||
)
|
||||
.bind(&invoice_id)
|
||||
.bind(account_id)
|
||||
.bind(&req.plan_id)
|
||||
.bind(plan.price_cents)
|
||||
.bind(&plan.currency)
|
||||
.bind(format!("{} - {} ({})", plan.display_name, plan.interval, now.format("%Y-%m")))
|
||||
.bind(due.to_rfc3339())
|
||||
.bind(now.to_rfc3339())
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
let payment_id = uuid::Uuid::new_v4().to_string();
|
||||
let trade_no = format!("ZCLAW-{}-{}", chrono::Utc::now().format("%Y%m%d%H%M%S"), &payment_id[..8]);
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO billing_payments \
|
||||
(id, invoice_id, account_id, amount_cents, currency, method, status, external_trade_no, metadata, created_at, updated_at) \
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7, '{}', $8, $8)"
|
||||
)
|
||||
.bind(&payment_id)
|
||||
.bind(&invoice_id)
|
||||
.bind(account_id)
|
||||
.bind(plan.price_cents)
|
||||
.bind(&plan.currency)
|
||||
.bind(req.payment_method.to_string())
|
||||
.bind(&trade_no)
|
||||
.bind(now.to_rfc3339())
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await
|
||||
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
|
||||
|
||||
// 3. 生成支付链接
|
||||
let pay_url = generate_pay_url(
|
||||
req.payment_method,
|
||||
&trade_no,
|
||||
plan.price_cents,
|
||||
&plan.display_name,
|
||||
config,
|
||||
).await?;
|
||||
|
||||
Ok(PaymentResult {
|
||||
payment_id,
|
||||
trade_no,
|
||||
pay_url,
|
||||
amount_cents: plan.price_cents,
|
||||
})
|
||||
}
|
||||
|
||||
/// 处理支付回调(支付宝/微信异步通知)
|
||||
///
|
||||
/// `callback_amount_cents` 来自回调报文的金额(分),用于与 DB 金额交叉验证。
|
||||
/// 整个操作在数据库事务中执行,使用 SELECT FOR UPDATE 防止并发竞态。
|
||||
pub async fn handle_payment_callback(
|
||||
pool: &PgPool,
|
||||
trade_no: &str,
|
||||
status: &str,
|
||||
callback_amount_cents: Option<i32>,
|
||||
) -> SaasResult<()> {
|
||||
// 1. 在事务中锁定支付记录,防止 TOCTOU 竞态
|
||||
let mut tx = pool.begin().await
|
||||
.map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?;
|
||||
|
||||
let payment: Option<(String, String, String, i32, String)> = sqlx::query_as::<_, (String, String, String, i32, String)>(
|
||||
"SELECT id, invoice_id, account_id, amount_cents, status \
|
||||
FROM billing_payments WHERE external_trade_no = $1 FOR UPDATE"
|
||||
)
|
||||
.bind(trade_no)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?;
|
||||
|
||||
let (payment_id, invoice_id, account_id, db_amount, current_status) = match payment {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
tracing::error!("Payment callback for unknown trade: {}", sanitize_log(trade_no));
|
||||
tx.rollback().await?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// 幂等性:已处理过直接返回
|
||||
if current_status != "pending" {
|
||||
tracing::info!("Payment already processed (idempotent): trade={}, status={}", sanitize_log(trade_no), current_status);
|
||||
tx.rollback().await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 2. 金额交叉验证(防篡改)
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
|
||||
if let Some(callback_amount) = callback_amount_cents {
|
||||
if callback_amount != db_amount {
|
||||
tracing::error!(
|
||||
"Amount mismatch: trade={}, db_amount={}, callback_amount={}. Rejecting.",
|
||||
sanitize_log(trade_no), db_amount, callback_amount
|
||||
);
|
||||
tx.rollback().await?;
|
||||
return Err(SaasError::InvalidInput("回调验证失败".into()));
|
||||
}
|
||||
} else if !is_dev {
|
||||
// 非开发环境必须有金额
|
||||
tracing::error!("Callback without amount in non-dev mode: trade={}", sanitize_log(trade_no));
|
||||
tx.rollback().await?;
|
||||
return Err(SaasError::InvalidInput("回调缺少金额验证".into()));
|
||||
} else {
|
||||
tracing::warn!("DEV: Skipping amount verification for trade={}", sanitize_log(trade_no));
|
||||
}
|
||||
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
if status == "success" || status == "TRADE_SUCCESS" || status == "SUCCESS" {
|
||||
// 3. 更新支付状态
|
||||
sqlx::query(
|
||||
"UPDATE billing_payments SET status = 'succeeded', paid_at = $1, updated_at = $1 WHERE id = $2"
|
||||
)
|
||||
.bind(&now)
|
||||
.bind(&payment_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 4. 更新发票状态
|
||||
sqlx::query(
|
||||
"UPDATE billing_invoices SET status = 'paid', paid_at = $1, updated_at = $1 WHERE id = $2"
|
||||
)
|
||||
.bind(&now)
|
||||
.bind(&invoice_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 5. 获取发票关联的计划
|
||||
let plan_id: Option<String> = sqlx::query_scalar(
|
||||
"SELECT plan_id FROM billing_invoices WHERE id = $1"
|
||||
)
|
||||
.bind(&invoice_id)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?
|
||||
.flatten();
|
||||
|
||||
if let Some(plan_id) = plan_id {
|
||||
// 6. 取消旧订阅
|
||||
sqlx::query(
|
||||
"UPDATE billing_subscriptions SET status = 'canceled', canceled_at = $1, updated_at = $1 \
|
||||
WHERE account_id = $2 AND status IN ('trial', 'active')"
|
||||
)
|
||||
.bind(&now)
|
||||
.bind(&account_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 7. 创建新订阅(30 天周期)
|
||||
let sub_id = uuid::Uuid::new_v4().to_string();
|
||||
let period_end = (chrono::Utc::now() + chrono::Duration::days(30)).to_rfc3339();
|
||||
let period_start = chrono::Utc::now().to_rfc3339();
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO billing_subscriptions \
|
||||
(id, account_id, plan_id, status, current_period_start, current_period_end, created_at, updated_at) \
|
||||
VALUES ($1, $2, $3, 'active', $4, $5, $6, $6)"
|
||||
)
|
||||
.bind(&sub_id)
|
||||
.bind(&account_id)
|
||||
.bind(&plan_id)
|
||||
.bind(&period_start)
|
||||
.bind(&period_end)
|
||||
.bind(&now)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tracing::info!(
|
||||
"Payment succeeded: account={}, plan={}, subscription={}",
|
||||
account_id, plan_id, sub_id
|
||||
);
|
||||
}
|
||||
|
||||
tx.commit().await
|
||||
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
|
||||
} else {
|
||||
// 支付失败:截断 status 防止注入,更新发票为 void
|
||||
let safe_reason = truncate_str(status, 200);
|
||||
sqlx::query(
|
||||
"UPDATE billing_payments SET status = 'failed', failure_reason = $1, updated_at = $2 WHERE id = $3"
|
||||
)
|
||||
.bind(&safe_reason)
|
||||
.bind(&now)
|
||||
.bind(&payment_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 同时将发票标记为 void
|
||||
sqlx::query(
|
||||
"UPDATE billing_invoices SET status = 'void', voided_at = $1, updated_at = $1 WHERE id = $2"
|
||||
)
|
||||
.bind(&now)
|
||||
.bind(&invoice_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await
|
||||
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
|
||||
|
||||
tracing::warn!("Payment failed: trade={}, status={}", sanitize_log(trade_no), safe_reason);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 查询支付状态
|
||||
pub async fn query_payment_status(
|
||||
pool: &PgPool,
|
||||
payment_id: &str,
|
||||
account_id: &str,
|
||||
) -> SaasResult<serde_json::Value> {
|
||||
let payment: (String, String, i32, String, String) = sqlx::query_as::<_, (String, String, i32, String, String)>(
|
||||
"SELECT id, method, amount_cents, currency, status \
|
||||
FROM billing_payments WHERE id = $1 AND account_id = $2"
|
||||
)
|
||||
.bind(payment_id)
|
||||
.bind(account_id)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.ok_or_else(|| SaasError::NotFound("支付记录不存在".into()))?;
|
||||
|
||||
let (id, method, amount, currency, status) = payment;
|
||||
Ok(serde_json::json!({
|
||||
"id": id,
|
||||
"method": method,
|
||||
"amount_cents": amount,
|
||||
"currency": currency,
|
||||
"status": status,
|
||||
}))
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 支付 URL 生成
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// 生成支付 URL:根据配置决定 mock 或真实支付
|
||||
async fn generate_pay_url(
|
||||
method: PaymentMethod,
|
||||
trade_no: &str,
|
||||
amount_cents: i32,
|
||||
subject: &str,
|
||||
config: &PaymentConfig,
|
||||
) -> SaasResult<String> {
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_dev {
|
||||
return Ok(mock_pay_url(trade_no, amount_cents, subject));
|
||||
}
|
||||
|
||||
match method {
|
||||
PaymentMethod::Alipay => generate_alipay_url(trade_no, amount_cents, subject, config),
|
||||
PaymentMethod::Wechat => generate_wechat_url(trade_no, amount_cents, subject, config).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn mock_pay_url(trade_no: &str, amount_cents: i32, subject: &str) -> String {
|
||||
let base = std::env::var("ZCLAW_SAAS_URL")
|
||||
.unwrap_or_else(|_| "http://localhost:8080".into());
|
||||
format!(
|
||||
"{}/api/v1/billing/mock-pay?trade_no={}&amount={}&subject={}",
|
||||
base,
|
||||
urlencoding::encode(trade_no),
|
||||
amount_cents,
|
||||
urlencoding::encode(subject),
|
||||
)
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 支付宝 — alipay.trade.page.pay(RSA2 签名 + 证书模式)
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
fn generate_alipay_url(
|
||||
trade_no: &str,
|
||||
amount_cents: i32,
|
||||
subject: &str,
|
||||
config: &PaymentConfig,
|
||||
) -> SaasResult<String> {
|
||||
let app_id = config.alipay_app_id.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("支付宝 app_id 未配置".into()))?;
|
||||
let private_key_pem = config.alipay_private_key.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("支付宝商户私钥未配置".into()))?;
|
||||
let notify_url = config.alipay_notify_url.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("支付宝回调 URL 未配置".into()))?;
|
||||
|
||||
// 金额:分 → 元(整数运算避免浮点精度问题)
|
||||
let yuan_part = amount_cents / 100;
|
||||
let cent_part = amount_cents % 100;
|
||||
let amount_yuan = format!("{}.{:02}", yuan_part, cent_part);
|
||||
|
||||
// 构建请求参数(字典序)
|
||||
let mut params: Vec<(&str, String)> = vec![
|
||||
("app_id", app_id.to_string()),
|
||||
("method", "alipay.trade.page.pay".to_string()),
|
||||
("charset", "utf-8".to_string()),
|
||||
("sign_type", "RSA2".to_string()),
|
||||
("timestamp", chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()),
|
||||
("version", "1.0".to_string()),
|
||||
("notify_url", notify_url.to_string()),
|
||||
("biz_content", serde_json::json!({
|
||||
"out_trade_no": trade_no,
|
||||
"total_amount": amount_yuan,
|
||||
"subject": subject,
|
||||
"product_code": "FAST_INSTANT_TRADE_PAY",
|
||||
}).to_string()),
|
||||
];
|
||||
|
||||
// 按 key 字典序排列并拼接
|
||||
params.sort_by(|a, b| a.0.cmp(b.0));
|
||||
let sign_str: String = params.iter()
|
||||
.map(|(k, v)| format!("{}={}", k, v))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
|
||||
// RSA2 签名
|
||||
let sign = rsa_sign_sha256_base64(private_key_pem, sign_str.as_bytes())?;
|
||||
|
||||
// 构建 gateway URL
|
||||
params.push(("sign", sign));
|
||||
let query: String = params.iter()
|
||||
.map(|(k, v)| format!("{}={}", k, urlencoding::encode(v)))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
|
||||
Ok(format!("https://openapi.alipay.com/gateway.do?{}", query))
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 微信支付 — V3 Native Pay(QR 码模式)
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn generate_wechat_url(
|
||||
trade_no: &str,
|
||||
amount_cents: i32,
|
||||
subject: &str,
|
||||
config: &PaymentConfig,
|
||||
) -> SaasResult<String> {
|
||||
let mch_id = config.wechat_mch_id.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信支付商户号未配置".into()))?;
|
||||
let serial_no = config.wechat_serial_no.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信支付证书序列号未配置".into()))?;
|
||||
let private_key_pem = config.wechat_private_key_path.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信支付私钥路径未配置".into()))?;
|
||||
let notify_url = config.wechat_notify_url.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信支付回调 URL 未配置".into()))?;
|
||||
let app_id = config.wechat_app_id.as_deref()
|
||||
.ok_or_else(|| SaasError::InvalidInput("微信支付 App ID 未配置".into()))?;
|
||||
|
||||
// 读取私钥文件
|
||||
let private_key = std::fs::read_to_string(private_key_pem)
|
||||
.map_err(|e| SaasError::InvalidInput(format!("微信支付私钥文件读取失败: {}", e)))?;
|
||||
|
||||
let body = serde_json::json!({
|
||||
"appid": app_id,
|
||||
"mchid": mch_id,
|
||||
"description": subject,
|
||||
"out_trade_no": trade_no,
|
||||
"notify_url": notify_url,
|
||||
"amount": {
|
||||
"total": amount_cents,
|
||||
"currency": "CNY",
|
||||
},
|
||||
});
|
||||
let body_str = body.to_string();
|
||||
|
||||
// 构建签名字符串
|
||||
let timestamp = chrono::Utc::now().timestamp().to_string();
|
||||
let nonce_str = uuid::Uuid::new_v4().to_string().replace("-", "");
|
||||
let sign_message = format!(
|
||||
"POST\n/v3/pay/transactions/native\n{}\n{}\n{}\n",
|
||||
timestamp, nonce_str, body_str
|
||||
);
|
||||
|
||||
let signature = rsa_sign_sha256_base64(&private_key, sign_message.as_bytes())?;
|
||||
|
||||
// 构建 Authorization 头
|
||||
let auth_header = format!(
|
||||
"WECHATPAY2-SHA256-RSA2048 mchid=\"{}\",nonce_str=\"{}\",timestamp=\"{}\",serial_no=\"{}\",signature=\"{}\"",
|
||||
mch_id, nonce_str, timestamp, serial_no, signature
|
||||
);
|
||||
|
||||
// 发送请求
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post("https://api.mch.weixin.qq.com/v3/pay/transactions/native")
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", auth_header)
|
||||
.header("Accept", "application/json")
|
||||
.body(body_str)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| SaasError::Internal(format!("微信支付请求失败: {}", e)))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
tracing::error!("WeChat Pay API error: status={}, body={}", status, text);
|
||||
return Err(SaasError::InvalidInput(format!(
|
||||
"微信支付创建订单失败 (HTTP {})", status
|
||||
)));
|
||||
}
|
||||
|
||||
let resp_json: serde_json::Value = resp.json().await
|
||||
.map_err(|e| SaasError::Internal(format!("微信支付响应解析失败: {}", e)))?;
|
||||
|
||||
let code_url = resp_json.get("code_url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| SaasError::Internal("微信支付响应缺少 code_url".into()))?
|
||||
.to_string();
|
||||
|
||||
Ok(code_url)
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 回调验签
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// 验证支付宝回调签名
|
||||
pub fn verify_alipay_callback(
|
||||
params: &[(String, String)],
|
||||
alipay_public_key_pem: &str,
|
||||
) -> SaasResult<bool> {
|
||||
// 1. 提取 sign 和 sign_type,剩余参数字典序拼接
|
||||
let mut sign = None;
|
||||
let mut filtered: Vec<(&str, &str)> = Vec::new();
|
||||
|
||||
for (k, v) in params {
|
||||
match k.as_str() {
|
||||
"sign" => sign = Some(v.clone()),
|
||||
"sign_type" => {} // 跳过
|
||||
_ => {
|
||||
if !v.is_empty() {
|
||||
filtered.push((k.as_str(), v.as_str()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let sign = match sign {
|
||||
Some(s) => s,
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
filtered.sort_by(|a, b| a.0.cmp(b.0));
|
||||
let sign_str: String = filtered.iter()
|
||||
.map(|(k, v)| format!("{}={}", k, v))
|
||||
.collect::<Vec<_>>()
|
||||
.join("&");
|
||||
|
||||
// 2. 用支付宝公钥验签
|
||||
rsa_verify_sha256(alipay_public_key_pem, sign_str.as_bytes(), &sign)
|
||||
}
|
||||
|
||||
/// 解密微信支付回调 resource 字段(AES-256-GCM)
|
||||
pub fn decrypt_wechat_resource(
|
||||
ciphertext_b64: &str,
|
||||
nonce: &str,
|
||||
associated_data: &str,
|
||||
api_v3_key: &str,
|
||||
) -> SaasResult<String> {
|
||||
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
|
||||
use aes_gcm::aead::Aead;
|
||||
use base64::Engine;
|
||||
|
||||
let key_bytes = api_v3_key.as_bytes();
|
||||
if key_bytes.len() != 32 {
|
||||
return Err(SaasError::Internal("微信 API v3 密钥必须为 32 字节".into()));
|
||||
}
|
||||
|
||||
let nonce_bytes = nonce.as_bytes();
|
||||
if nonce_bytes.len() != 12 {
|
||||
return Err(SaasError::InvalidInput("微信回调 nonce 长度必须为 12 字节".into()));
|
||||
}
|
||||
|
||||
let ciphertext = base64::engine::general_purpose::STANDARD
|
||||
.decode(ciphertext_b64)
|
||||
.map_err(|e| SaasError::Internal(format!("base64 解码失败: {}", e)))?;
|
||||
|
||||
let cipher = Aes256Gcm::new_from_slice(key_bytes)
|
||||
.map_err(|e| SaasError::Internal(format!("AES 密钥初始化失败: {}", e)))?;
|
||||
let nonce = Nonce::from_slice(nonce_bytes);
|
||||
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, aes_gcm::aead::Payload {
|
||||
msg: &ciphertext,
|
||||
aad: associated_data.as_bytes(),
|
||||
})
|
||||
.map_err(|e| SaasError::Internal(format!("AES-GCM 解密失败: {}", e)))?;
|
||||
|
||||
String::from_utf8(plaintext)
|
||||
.map_err(|e| SaasError::Internal(format!("解密结果 UTF-8 转换失败: {}", e)))
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// RSA 工具函数
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// SHA256WithRSA 签名 + Base64 编码(PKCS#1 v1.5)
|
||||
fn rsa_sign_sha256_base64(
|
||||
private_key_pem: &str,
|
||||
message: &[u8],
|
||||
) -> SaasResult<String> {
|
||||
use rsa::pkcs8::DecodePrivateKey;
|
||||
use rsa::signature::{Signer, SignatureEncoding};
|
||||
use sha2::Sha256;
|
||||
use rsa::pkcs1v15::SigningKey;
|
||||
use base64::Engine;
|
||||
|
||||
let private_key = rsa::RsaPrivateKey::from_pkcs8_pem(private_key_pem)
|
||||
.map_err(|e| SaasError::Internal(format!("RSA 私钥解析失败: {}", e)))?;
|
||||
|
||||
let signing_key = SigningKey::<Sha256>::new(private_key);
|
||||
let signature = signing_key.sign(message);
|
||||
|
||||
Ok(base64::engine::general_purpose::STANDARD.encode(signature.to_bytes()))
|
||||
}
|
||||
|
||||
/// SHA256WithRSA 验签
|
||||
fn rsa_verify_sha256(
|
||||
public_key_pem: &str,
|
||||
message: &[u8],
|
||||
signature_b64: &str,
|
||||
) -> SaasResult<bool> {
|
||||
use rsa::pkcs8::DecodePublicKey;
|
||||
use rsa::signature::Verifier;
|
||||
use sha2::Sha256;
|
||||
use rsa::pkcs1v15::VerifyingKey;
|
||||
use base64::Engine;
|
||||
|
||||
let public_key = match rsa::RsaPublicKey::from_public_key_pem(public_key_pem) {
|
||||
Ok(k) => k,
|
||||
Err(e) => {
|
||||
tracing::error!("RSA 公钥解析失败: {}", e);
|
||||
return Ok(false);
|
||||
}
|
||||
};
|
||||
|
||||
let signature_bytes = match base64::engine::general_purpose::STANDARD.decode(signature_b64) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
tracing::error!("签名 base64 解码失败: {}", e);
|
||||
return Ok(false);
|
||||
}
|
||||
};
|
||||
|
||||
let verifying_key = VerifyingKey::<Sha256>::new(public_key);
|
||||
let signature = match rsa::pkcs1v15::Signature::try_from(signature_bytes.as_slice()) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Ok(false),
|
||||
};
|
||||
|
||||
Ok(verifying_key.verify(message, &signature).is_ok())
|
||||
}
|
||||
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
// 辅助函数
|
||||
// ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// 日志安全:只保留字母数字和 `-` `_`,防止日志注入
|
||||
fn sanitize_log(s: &str) -> String {
|
||||
s.chars()
|
||||
.filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// 截断字符串到指定长度(按字符而非字节)
|
||||
fn truncate_str(s: &str, max_len: usize) -> String {
|
||||
let chars: Vec<char> = s.chars().collect();
|
||||
if chars.len() <= max_len {
|
||||
s.to_string()
|
||||
} else {
|
||||
chars.into_iter().take(max_len).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PaymentMethod {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Alipay => write!(f, "alipay"),
|
||||
Self::Wechat => write!(f, "wechat"),
|
||||
}
|
||||
}
|
||||
}
|
||||
303
crates/zclaw-saas/src/billing/service.rs
Normal file
303
crates/zclaw-saas/src/billing/service.rs
Normal file
@@ -0,0 +1,303 @@
|
||||
//! 计费服务层 — 计划查询、订阅管理、用量检查
|
||||
|
||||
use chrono::{Datelike, Timelike};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::error::SaasResult;
|
||||
|
||||
use super::types::*;
|
||||
|
||||
/// 获取所有活跃计划
|
||||
pub async fn list_plans(pool: &PgPool) -> SaasResult<Vec<BillingPlan>> {
|
||||
let plans = sqlx::query_as::<_, BillingPlan>(
|
||||
"SELECT * FROM billing_plans WHERE status = 'active' ORDER BY sort_order"
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
Ok(plans)
|
||||
}
|
||||
|
||||
/// 获取单个计划(公开 API 只返回 active 计划)
|
||||
pub async fn get_plan(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
|
||||
let plan = sqlx::query_as::<_, BillingPlan>(
|
||||
"SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'"
|
||||
)
|
||||
.bind(plan_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
/// 获取单个计划(内部使用,不过滤 status,用于已订阅用户查看旧计划)
|
||||
pub async fn get_plan_any_status(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
|
||||
let plan = sqlx::query_as::<_, BillingPlan>(
|
||||
"SELECT * FROM billing_plans WHERE id = $1"
|
||||
)
|
||||
.bind(plan_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(plan)
|
||||
}
|
||||
|
||||
/// 获取账户当前有效订阅
|
||||
pub async fn get_active_subscription(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
) -> SaasResult<Option<Subscription>> {
|
||||
let sub = sqlx::query_as::<_, Subscription>(
|
||||
"SELECT * FROM billing_subscriptions \
|
||||
WHERE account_id = $1 AND status IN ('trial', 'active', 'past_due') \
|
||||
ORDER BY created_at DESC LIMIT 1"
|
||||
)
|
||||
.bind(account_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(sub)
|
||||
}
|
||||
|
||||
/// 获取账户当前计划(有订阅返回订阅计划,否则返回 Free)
|
||||
pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult<BillingPlan> {
|
||||
if let Some(sub) = get_active_subscription(pool, account_id).await? {
|
||||
if let Some(plan) = get_plan_any_status(pool, &sub.plan_id).await? {
|
||||
return Ok(plan);
|
||||
}
|
||||
}
|
||||
// 回退到 Free 计划
|
||||
let free = sqlx::query_as::<_, BillingPlan>(
|
||||
"SELECT * FROM billing_plans WHERE name = 'free' AND status = 'active' LIMIT 1"
|
||||
)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(free.unwrap_or_else(|| BillingPlan {
|
||||
id: "plan-free".into(),
|
||||
name: "free".into(),
|
||||
display_name: "免费版".into(),
|
||||
description: Some("基础功能".into()),
|
||||
price_cents: 0,
|
||||
currency: "CNY".into(),
|
||||
interval: "month".into(),
|
||||
features: serde_json::json!({}),
|
||||
limits: serde_json::json!({
|
||||
"max_input_tokens_monthly": 500000,
|
||||
"max_output_tokens_monthly": 500000,
|
||||
"max_relay_requests_monthly": 100,
|
||||
"max_hand_executions_monthly": 20,
|
||||
"max_pipeline_runs_monthly": 5,
|
||||
}),
|
||||
is_default: true,
|
||||
sort_order: 0,
|
||||
status: "active".into(),
|
||||
created_at: chrono::Utc::now(),
|
||||
updated_at: chrono::Utc::now(),
|
||||
}))
|
||||
}
|
||||
|
||||
/// 获取或创建当月用量记录(原子操作,使用 INSERT ON CONFLICT 防止 TOCTOU 竞态)
|
||||
pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<UsageQuota> {
|
||||
let now = chrono::Utc::now();
|
||||
let period_start = now
|
||||
.with_day(1).unwrap_or(now)
|
||||
.with_hour(0).unwrap_or(now)
|
||||
.with_minute(0).unwrap_or(now)
|
||||
.with_second(0).unwrap_or(now)
|
||||
.with_nanosecond(0).unwrap_or(now);
|
||||
|
||||
// 先尝试获取已有记录
|
||||
let existing = sqlx::query_as::<_, UsageQuota>(
|
||||
"SELECT * FROM billing_usage_quotas \
|
||||
WHERE account_id = $1 AND period_start = $2"
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(period_start)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
if let Some(usage) = existing {
|
||||
return Ok(usage);
|
||||
}
|
||||
|
||||
// 获取当前计划限额
|
||||
let plan = get_account_plan(pool, account_id).await?;
|
||||
let limits: PlanLimits = serde_json::from_value(plan.limits.clone())
|
||||
.unwrap_or_else(|_| PlanLimits::free());
|
||||
|
||||
// 计算月末
|
||||
let period_end = if now.month() == 12 {
|
||||
now.with_year(now.year() + 1).and_then(|d| d.with_month(1))
|
||||
} else {
|
||||
now.with_month(now.month() + 1)
|
||||
}.unwrap_or(now)
|
||||
.with_day(1).unwrap_or(now)
|
||||
.with_hour(0).unwrap_or(now)
|
||||
.with_minute(0).unwrap_or(now)
|
||||
.with_second(0).unwrap_or(now)
|
||||
.with_nanosecond(0).unwrap_or(now);
|
||||
|
||||
// 使用 INSERT ON CONFLICT 原子创建(防止并发重复插入)
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let inserted = sqlx::query_as::<_, UsageQuota>(
|
||||
"INSERT INTO billing_usage_quotas \
|
||||
(id, account_id, period_start, period_end, \
|
||||
max_input_tokens, max_output_tokens, max_relay_requests, \
|
||||
max_hand_executions, max_pipeline_runs) \
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \
|
||||
ON CONFLICT (account_id, period_start) DO NOTHING \
|
||||
RETURNING *"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(account_id)
|
||||
.bind(period_start)
|
||||
.bind(period_end)
|
||||
.bind(limits.max_input_tokens_monthly)
|
||||
.bind(limits.max_output_tokens_monthly)
|
||||
.bind(limits.max_relay_requests_monthly)
|
||||
.bind(limits.max_hand_executions_monthly)
|
||||
.bind(limits.max_pipeline_runs_monthly)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
if let Some(usage) = inserted {
|
||||
return Ok(usage);
|
||||
}
|
||||
|
||||
// ON CONFLICT 说明另一个并发请求已经创建了,直接查询返回
|
||||
let usage = sqlx::query_as::<_, UsageQuota>(
|
||||
"SELECT * FROM billing_usage_quotas \
|
||||
WHERE account_id = $1 AND period_start = $2"
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(period_start)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok(usage)
|
||||
}
|
||||
|
||||
/// 增加用量计数(Relay 请求:tokens + relay_requests +1)
|
||||
///
|
||||
/// 在 relay handler 响应成功后直接调用,实现实时配额更新。
|
||||
/// 聚合器 `AggregateUsageWorker` 每小时做一次对账修正。
|
||||
pub async fn increment_usage(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
input_tokens: i64,
|
||||
output_tokens: i64,
|
||||
) -> SaasResult<()> {
|
||||
let usage = get_or_create_usage(pool, account_id).await?;
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas \
|
||||
SET input_tokens = input_tokens + $1, \
|
||||
output_tokens = output_tokens + $2, \
|
||||
relay_requests = relay_requests + 1, \
|
||||
updated_at = NOW() \
|
||||
WHERE id = $3"
|
||||
)
|
||||
.bind(input_tokens)
|
||||
.bind(output_tokens)
|
||||
.bind(&usage.id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 增加单一维度用量计数(单次 +1)
|
||||
///
|
||||
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
|
||||
pub async fn increment_dimension(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
dimension: &str,
|
||||
) -> SaasResult<()> {
|
||||
let usage = get_or_create_usage(pool, account_id).await?;
|
||||
|
||||
match dimension {
|
||||
"relay_requests" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + 1, updated_at = NOW() WHERE id = $1"
|
||||
).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
"hand_executions" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + 1, updated_at = NOW() WHERE id = $1"
|
||||
).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
"pipeline_runs" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + 1, updated_at = NOW() WHERE id = $1"
|
||||
).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
_ => return Err(crate::error::SaasError::InvalidInput(
|
||||
format!("Unknown usage dimension: {}", dimension)
|
||||
)),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 增加单一维度用量计数(批量 +N,原子操作,替代循环调用)
|
||||
///
|
||||
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
|
||||
pub async fn increment_dimension_by(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
dimension: &str,
|
||||
count: i32,
|
||||
) -> SaasResult<()> {
|
||||
let usage = get_or_create_usage(pool, account_id).await?;
|
||||
|
||||
match dimension {
|
||||
"relay_requests" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + $1, updated_at = NOW() WHERE id = $2"
|
||||
).bind(count).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
"hand_executions" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + $1, updated_at = NOW() WHERE id = $2"
|
||||
).bind(count).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
"pipeline_runs" => {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + $1, updated_at = NOW() WHERE id = $2"
|
||||
).bind(count).bind(&usage.id).execute(pool).await?;
|
||||
}
|
||||
_ => return Err(crate::error::SaasError::InvalidInput(
|
||||
format!("Unknown usage dimension: {}", dimension)
|
||||
)),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 检查用量配额
|
||||
pub async fn check_quota(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
quota_type: &str,
|
||||
) -> SaasResult<QuotaCheck> {
|
||||
let usage = get_or_create_usage(pool, account_id).await?;
|
||||
|
||||
let (current, limit) = match quota_type {
|
||||
"input_tokens" => (usage.input_tokens, usage.max_input_tokens),
|
||||
"output_tokens" => (usage.output_tokens, usage.max_output_tokens),
|
||||
"relay_requests" => (usage.relay_requests as i64, usage.max_relay_requests.map(|v| v as i64)),
|
||||
"hand_executions" => (usage.hand_executions as i64, usage.max_hand_executions.map(|v| v as i64)),
|
||||
"pipeline_runs" => (usage.pipeline_runs as i64, usage.max_pipeline_runs.map(|v| v as i64)),
|
||||
_ => return Ok(QuotaCheck {
|
||||
allowed: true,
|
||||
reason: None,
|
||||
current: 0,
|
||||
limit: None,
|
||||
remaining: None,
|
||||
}),
|
||||
};
|
||||
|
||||
let allowed = limit.map_or(true, |lim| current < lim);
|
||||
let remaining = limit.map(|lim| (lim - current).max(0));
|
||||
|
||||
Ok(QuotaCheck {
|
||||
allowed,
|
||||
reason: if !allowed { Some(format!("{} 配额已用尽", quota_type)) } else { None },
|
||||
current,
|
||||
limit,
|
||||
remaining,
|
||||
})
|
||||
}
|
||||
161
crates/zclaw-saas/src/billing/types.rs
Normal file
161
crates/zclaw-saas/src/billing/types.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
//! 计费类型定义
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// 计费计划定义 — 对应 billing_plans 表
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct BillingPlan {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub display_name: String,
|
||||
pub description: Option<String>,
|
||||
pub price_cents: i32,
|
||||
pub currency: String,
|
||||
pub interval: String,
|
||||
pub features: serde_json::Value,
|
||||
pub limits: serde_json::Value,
|
||||
pub is_default: bool,
|
||||
pub sort_order: i32,
|
||||
pub status: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// 计划限额(从 limits JSON 反序列化)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlanLimits {
|
||||
#[serde(default)]
|
||||
pub max_input_tokens_monthly: Option<i64>,
|
||||
#[serde(default)]
|
||||
pub max_output_tokens_monthly: Option<i64>,
|
||||
#[serde(default)]
|
||||
pub max_relay_requests_monthly: Option<i32>,
|
||||
#[serde(default)]
|
||||
pub max_hand_executions_monthly: Option<i32>,
|
||||
#[serde(default)]
|
||||
pub max_pipeline_runs_monthly: Option<i32>,
|
||||
}
|
||||
|
||||
impl PlanLimits {
|
||||
pub fn free() -> Self {
|
||||
Self {
|
||||
max_input_tokens_monthly: Some(500_000),
|
||||
max_output_tokens_monthly: Some(500_000),
|
||||
max_relay_requests_monthly: Some(100),
|
||||
max_hand_executions_monthly: Some(20),
|
||||
max_pipeline_runs_monthly: Some(5),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 账户订阅 — 对应 billing_subscriptions 表
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct Subscription {
|
||||
pub id: String,
|
||||
pub account_id: String,
|
||||
pub plan_id: String,
|
||||
pub status: String,
|
||||
pub current_period_start: DateTime<Utc>,
|
||||
pub current_period_end: DateTime<Utc>,
|
||||
pub trial_end: Option<DateTime<Utc>>,
|
||||
pub canceled_at: Option<DateTime<Utc>>,
|
||||
pub cancel_at_period_end: bool,
|
||||
pub metadata: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// 发票 — 对应 billing_invoices 表
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct Invoice {
|
||||
pub id: String,
|
||||
pub account_id: String,
|
||||
pub subscription_id: Option<String>,
|
||||
pub plan_id: Option<String>,
|
||||
pub amount_cents: i32,
|
||||
pub currency: String,
|
||||
pub description: Option<String>,
|
||||
pub status: String,
|
||||
pub due_at: Option<DateTime<Utc>>,
|
||||
pub paid_at: Option<DateTime<Utc>>,
|
||||
pub voided_at: Option<DateTime<Utc>>,
|
||||
pub metadata: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// 支付记录 — 对应 billing_payments 表
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct Payment {
|
||||
pub id: String,
|
||||
pub invoice_id: String,
|
||||
pub account_id: String,
|
||||
pub amount_cents: i32,
|
||||
pub currency: String,
|
||||
pub method: String,
|
||||
pub status: String,
|
||||
pub external_trade_no: Option<String>,
|
||||
pub paid_at: Option<DateTime<Utc>>,
|
||||
pub refunded_at: Option<DateTime<Utc>>,
|
||||
pub failure_reason: Option<String>,
|
||||
pub metadata: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// 月度用量配额 — 对应 billing_usage_quotas 表
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct UsageQuota {
|
||||
pub id: String,
|
||||
pub account_id: String,
|
||||
pub period_start: DateTime<Utc>,
|
||||
pub period_end: DateTime<Utc>,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
pub relay_requests: i32,
|
||||
pub hand_executions: i32,
|
||||
pub pipeline_runs: i32,
|
||||
pub max_input_tokens: Option<i64>,
|
||||
pub max_output_tokens: Option<i64>,
|
||||
pub max_relay_requests: Option<i32>,
|
||||
pub max_hand_executions: Option<i32>,
|
||||
pub max_pipeline_runs: Option<i32>,
|
||||
pub metadata: serde_json::Value,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// 用量检查结果
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct QuotaCheck {
|
||||
pub allowed: bool,
|
||||
pub reason: Option<String>,
|
||||
pub current: i64,
|
||||
pub limit: Option<i64>,
|
||||
pub remaining: Option<i64>,
|
||||
}
|
||||
|
||||
/// 支付方式
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PaymentMethod {
|
||||
Alipay,
|
||||
Wechat,
|
||||
}
|
||||
|
||||
/// 创建支付请求
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreatePaymentRequest {
|
||||
pub plan_id: String,
|
||||
pub payment_method: PaymentMethod,
|
||||
}
|
||||
|
||||
/// 支付结果
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct PaymentResult {
|
||||
pub payment_id: String,
|
||||
pub trade_no: String,
|
||||
pub pay_url: String,
|
||||
pub amount_cents: i32,
|
||||
}
|
||||
@@ -167,6 +167,22 @@ impl AppCache {
|
||||
self.relay_queue_counts.retain(|k, _| db_keys.contains(k));
|
||||
}
|
||||
|
||||
// ============ 快捷查找(Phase 2: 减少关键路径 DB 查询) ============
|
||||
|
||||
/// 按 model_id 查找已启用的模型。O(1) DashMap 查找。
|
||||
pub fn get_model(&self, model_id: &str) -> Option<CachedModel> {
|
||||
self.models.get(model_id)
|
||||
.filter(|m| m.enabled)
|
||||
.map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
/// 按 provider id 查找已启用的 Provider。O(1) DashMap 查找。
|
||||
pub fn get_provider(&self, provider_id: &str) -> Option<CachedProvider> {
|
||||
self.providers.get(provider_id)
|
||||
.filter(|p| p.enabled)
|
||||
.map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
// ============ 缓存失效 ============
|
||||
|
||||
/// 清除 model 缓存中的指定条目(Admin CRUD 后调用)
|
||||
|
||||
@@ -4,9 +4,15 @@ use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use secrecy::SecretString;
|
||||
|
||||
/// 当前期望的配置版本
|
||||
const CURRENT_CONFIG_VERSION: u32 = 1;
|
||||
|
||||
/// SaaS 服务器完整配置
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SaaSConfig {
|
||||
/// Configuration schema version
|
||||
#[serde(default = "default_config_version")]
|
||||
pub config_version: u32,
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub auth: AuthConfig,
|
||||
@@ -15,6 +21,8 @@ pub struct SaaSConfig {
|
||||
pub rate_limit: RateLimitConfig,
|
||||
#[serde(default)]
|
||||
pub scheduler: SchedulerConfig,
|
||||
#[serde(default)]
|
||||
pub payment: PaymentConfig,
|
||||
}
|
||||
|
||||
/// Scheduler 定时任务配置
|
||||
@@ -66,6 +74,30 @@ pub struct ServerConfig {
|
||||
pub struct DatabaseConfig {
|
||||
#[serde(default = "default_db_url")]
|
||||
pub url: String,
|
||||
/// 连接池最大连接数
|
||||
#[serde(default = "default_max_connections")]
|
||||
pub max_connections: u32,
|
||||
/// 连接池最小连接数
|
||||
#[serde(default = "default_min_connections")]
|
||||
pub min_connections: u32,
|
||||
/// 获取连接超时 (秒)
|
||||
#[serde(default = "default_acquire_timeout")]
|
||||
pub acquire_timeout_secs: u64,
|
||||
/// 空闲连接回收超时 (秒)
|
||||
#[serde(default = "default_idle_timeout")]
|
||||
pub idle_timeout_secs: u64,
|
||||
/// 连接最大生命周期 (秒)
|
||||
#[serde(default = "default_max_lifetime")]
|
||||
pub max_lifetime_secs: u64,
|
||||
/// Worker 并发上限 (Semaphore permits)
|
||||
#[serde(default = "default_worker_concurrency")]
|
||||
pub worker_concurrency: usize,
|
||||
/// 限流事件批量 flush 间隔 (秒)
|
||||
#[serde(default = "default_rate_limit_batch_interval")]
|
||||
pub rate_limit_batch_interval_secs: u64,
|
||||
/// 限流事件批量 flush 最大条目数
|
||||
#[serde(default = "default_rate_limit_batch_max")]
|
||||
pub rate_limit_batch_max_size: usize,
|
||||
}
|
||||
|
||||
/// 认证配置
|
||||
@@ -97,12 +129,21 @@ pub struct RelayConfig {
|
||||
pub max_attempts: u32,
|
||||
}
|
||||
|
||||
fn default_config_version() -> u32 { 1 }
|
||||
fn default_host() -> String { "0.0.0.0".into() }
|
||||
fn default_port() -> u16 { 8080 }
|
||||
fn default_db_url() -> String { "postgres://localhost:5432/zclaw".into() }
|
||||
fn default_jwt_hours() -> i64 { 24 }
|
||||
fn default_totp_issuer() -> String { "ZCLAW SaaS".into() }
|
||||
fn default_refresh_hours() -> i64 { 168 }
|
||||
fn default_max_connections() -> u32 { 100 }
|
||||
fn default_min_connections() -> u32 { 5 }
|
||||
fn default_acquire_timeout() -> u64 { 8 }
|
||||
fn default_idle_timeout() -> u64 { 180 }
|
||||
fn default_max_lifetime() -> u64 { 900 }
|
||||
fn default_worker_concurrency() -> usize { 20 }
|
||||
fn default_rate_limit_batch_interval() -> u64 { 5 }
|
||||
fn default_rate_limit_batch_max() -> usize { 500 }
|
||||
fn default_max_queue() -> usize { 1000 }
|
||||
fn default_max_concurrent() -> usize { 5 }
|
||||
fn default_batch_window() -> u64 { 50 }
|
||||
@@ -132,15 +173,115 @@ impl Default for RateLimitConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// 支付配置
|
||||
///
|
||||
/// 支付宝和微信支付商户配置。所有字段通过环境变量传入(不写入 TOML 文件)。
|
||||
/// 字段缺失时自动降级为 mock 支付模式。
|
||||
///
|
||||
/// 注意:自定义 Debug 和 Serialize 实现会隐藏敏感字段。
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct PaymentConfig {
|
||||
/// 支付宝 App ID(来自支付宝开放平台)
|
||||
#[serde(default)]
|
||||
pub alipay_app_id: Option<String>,
|
||||
/// 支付宝商户私钥(RSA2)— 敏感,不序列化
|
||||
#[serde(default, skip_serializing)]
|
||||
pub alipay_private_key: Option<String>,
|
||||
/// 支付宝公钥证书路径(用于验签)
|
||||
#[serde(default)]
|
||||
pub alipay_cert_path: Option<String>,
|
||||
/// 支付宝回调通知 URL
|
||||
#[serde(default)]
|
||||
pub alipay_notify_url: Option<String>,
|
||||
/// 支付宝公钥(用于回调验签,PEM 格式)— 敏感,不序列化
|
||||
#[serde(default, skip_serializing)]
|
||||
pub alipay_public_key: Option<String>,
|
||||
|
||||
/// 微信支付商户号
|
||||
#[serde(default)]
|
||||
pub wechat_mch_id: Option<String>,
|
||||
/// 微信支付商户证书序列号
|
||||
#[serde(default)]
|
||||
pub wechat_serial_no: Option<String>,
|
||||
/// 微信支付商户私钥路径
|
||||
#[serde(default)]
|
||||
pub wechat_private_key_path: Option<String>,
|
||||
/// 微信支付 API v3 密钥 — 敏感,不序列化
|
||||
#[serde(default, skip_serializing)]
|
||||
pub wechat_api_v3_key: Option<String>,
|
||||
/// 微信支付回调通知 URL
|
||||
#[serde(default)]
|
||||
pub wechat_notify_url: Option<String>,
|
||||
/// 微信支付 App ID(公众号/小程序)
|
||||
#[serde(default)]
|
||||
pub wechat_app_id: Option<String>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PaymentConfig {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PaymentConfig")
|
||||
.field("alipay_app_id", &self.alipay_app_id)
|
||||
.field("alipay_private_key", &self.alipay_private_key.as_ref().map(|_| "***REDACTED***"))
|
||||
.field("alipay_cert_path", &self.alipay_cert_path)
|
||||
.field("alipay_notify_url", &self.alipay_notify_url)
|
||||
.field("alipay_public_key", &self.alipay_public_key.as_ref().map(|_| "***REDACTED***"))
|
||||
.field("wechat_mch_id", &self.wechat_mch_id)
|
||||
.field("wechat_serial_no", &self.wechat_serial_no)
|
||||
.field("wechat_private_key_path", &self.wechat_private_key_path)
|
||||
.field("wechat_api_v3_key", &self.wechat_api_v3_key.as_ref().map(|_| "***REDACTED***"))
|
||||
.field("wechat_notify_url", &self.wechat_notify_url)
|
||||
.field("wechat_app_id", &self.wechat_app_id)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PaymentConfig {
|
||||
fn default() -> Self {
|
||||
// 优先从环境变量读取,未配置则降级 mock
|
||||
Self {
|
||||
alipay_app_id: std::env::var("ALIPAY_APP_ID").ok(),
|
||||
alipay_private_key: std::env::var("ALIPAY_PRIVATE_KEY").ok(),
|
||||
alipay_cert_path: std::env::var("ALIPAY_CERT_PATH").ok(),
|
||||
alipay_notify_url: std::env::var("ALIPAY_NOTIFY_URL").ok(),
|
||||
alipay_public_key: std::env::var("ALIPAY_PUBLIC_KEY").ok(),
|
||||
wechat_mch_id: std::env::var("WECHAT_PAY_MCH_ID").ok(),
|
||||
wechat_serial_no: std::env::var("WECHAT_PAY_SERIAL_NO").ok(),
|
||||
wechat_private_key_path: std::env::var("WECHAT_PAY_PRIVATE_KEY_PATH").ok(),
|
||||
wechat_api_v3_key: std::env::var("WECHAT_PAY_API_V3_KEY").ok(),
|
||||
wechat_notify_url: std::env::var("WECHAT_PAY_NOTIFY_URL").ok(),
|
||||
wechat_app_id: std::env::var("WECHAT_PAY_APP_ID").ok(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PaymentConfig {
|
||||
/// 支付宝是否已完整配置
|
||||
pub fn alipay_configured(&self) -> bool {
|
||||
self.alipay_app_id.is_some()
|
||||
&& self.alipay_private_key.is_some()
|
||||
&& self.alipay_notify_url.is_some()
|
||||
}
|
||||
|
||||
/// 微信支付是否已完整配置
|
||||
pub fn wechat_configured(&self) -> bool {
|
||||
self.wechat_mch_id.is_some()
|
||||
&& self.wechat_serial_no.is_some()
|
||||
&& self.wechat_private_key_path.is_some()
|
||||
&& self.wechat_notify_url.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SaaSConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
config_version: 1,
|
||||
server: ServerConfig::default(),
|
||||
database: DatabaseConfig::default(),
|
||||
auth: AuthConfig::default(),
|
||||
relay: RelayConfig::default(),
|
||||
rate_limit: RateLimitConfig::default(),
|
||||
scheduler: SchedulerConfig::default(),
|
||||
payment: PaymentConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -158,7 +299,17 @@ impl Default for ServerConfig {
|
||||
|
||||
impl Default for DatabaseConfig {
|
||||
fn default() -> Self {
|
||||
Self { url: default_db_url() }
|
||||
Self {
|
||||
url: default_db_url(),
|
||||
max_connections: default_max_connections(),
|
||||
min_connections: default_min_connections(),
|
||||
acquire_timeout_secs: default_acquire_timeout(),
|
||||
idle_timeout_secs: default_idle_timeout(),
|
||||
max_lifetime_secs: default_max_lifetime(),
|
||||
worker_concurrency: default_worker_concurrency(),
|
||||
rate_limit_batch_interval_secs: default_rate_limit_batch_interval(),
|
||||
rate_limit_batch_max_size: default_rate_limit_batch_max(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,6 +371,26 @@ impl SaaSConfig {
|
||||
SaaSConfig::default()
|
||||
};
|
||||
|
||||
// 配置版本兼容性检查
|
||||
if config.config_version < CURRENT_CONFIG_VERSION {
|
||||
tracing::warn!(
|
||||
"[Config] config_version ({}) is below current version ({}). \
|
||||
Some features may not work correctly. \
|
||||
Please update your saas-config.toml. \
|
||||
See docs for migration guide.",
|
||||
config.config_version,
|
||||
CURRENT_CONFIG_VERSION
|
||||
);
|
||||
} else if config.config_version > CURRENT_CONFIG_VERSION {
|
||||
tracing::error!(
|
||||
"[Config] config_version ({}) is ahead of supported version ({}). \
|
||||
This server version may not support all configured features. \
|
||||
Consider upgrading the server.",
|
||||
config.config_version,
|
||||
CURRENT_CONFIG_VERSION
|
||||
);
|
||||
}
|
||||
|
||||
// 环境变量覆盖数据库 URL (避免在配置文件中存储密码)
|
||||
if let Ok(db_url) = std::env::var("ZCLAW_DATABASE_URL") {
|
||||
config.database.url = db_url;
|
||||
|
||||
@@ -2,34 +2,44 @@
|
||||
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use sqlx::PgPool;
|
||||
use crate::config::DatabaseConfig;
|
||||
use crate::error::SaasResult;
|
||||
|
||||
const SCHEMA_VERSION: i32 = 11;
|
||||
const SCHEMA_VERSION: i32 = 13;
|
||||
|
||||
/// 初始化数据库
|
||||
pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
|
||||
// 连接池大小可通过环境变量配置,默认 100(relay 请求每次 10+ 串行查询,50 偏紧)
|
||||
pub async fn init_db(config: &DatabaseConfig) -> SaasResult<PgPool> {
|
||||
// 环境变量覆盖 URL(避免在配置文件中存储密码)
|
||||
let database_url = std::env::var("ZCLAW_DATABASE_URL")
|
||||
.unwrap_or_else(|_| config.url.clone());
|
||||
|
||||
// 环境变量覆盖连接数(向后兼容)
|
||||
let max_connections: u32 = std::env::var("ZCLAW_DB_MAX_CONNECTIONS")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(100);
|
||||
.unwrap_or(config.max_connections);
|
||||
let min_connections: u32 = std::env::var("ZCLAW_DB_MIN_CONNECTIONS")
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(5);
|
||||
.unwrap_or(config.min_connections);
|
||||
|
||||
tracing::info!("Database pool: max={}, min={}", max_connections, min_connections);
|
||||
tracing::info!(
|
||||
"Database pool: max={}, min={}, acquire_timeout={}s, idle_timeout={}s, max_lifetime={}s",
|
||||
max_connections, min_connections,
|
||||
config.acquire_timeout_secs, config.idle_timeout_secs, config.max_lifetime_secs
|
||||
);
|
||||
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(max_connections)
|
||||
.min_connections(min_connections)
|
||||
.acquire_timeout(std::time::Duration::from_secs(8))
|
||||
.idle_timeout(std::time::Duration::from_secs(180))
|
||||
.max_lifetime(std::time::Duration::from_secs(900))
|
||||
.connect(database_url)
|
||||
.acquire_timeout(std::time::Duration::from_secs(config.acquire_timeout_secs))
|
||||
.idle_timeout(std::time::Duration::from_secs(config.idle_timeout_secs))
|
||||
.max_lifetime(std::time::Duration::from_secs(config.max_lifetime_secs))
|
||||
.connect(&database_url)
|
||||
.await?;
|
||||
|
||||
run_migrations(&pool).await?;
|
||||
ensure_security_columns(&pool).await?;
|
||||
seed_admin_account(&pool).await?;
|
||||
seed_builtin_prompts(&pool).await?;
|
||||
seed_demo_data(&pool).await?;
|
||||
@@ -884,6 +894,56 @@ async fn fix_seed_data(pool: &PgPool) -> SaasResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 防御性检查:确保安全审计新增的列存在(即使 schema_version 显示已是最新)
|
||||
///
|
||||
/// 场景:旧数据库的 schema_version 已被手动更新但迁移文件未实际执行,
|
||||
/// 或者迁移文件在 version check 时被跳过。
|
||||
async fn ensure_security_columns(pool: &PgPool) -> SaasResult<()> {
|
||||
// 检查 password_version 列是否存在
|
||||
let col_exists: bool = sqlx::query_scalar(
|
||||
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'accounts' AND column_name = 'password_version')"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if !col_exists {
|
||||
tracing::warn!("[DB] 'password_version' column missing — applying security fix migration");
|
||||
sqlx::query("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS password_version INTEGER NOT NULL DEFAULT 1")
|
||||
.execute(pool).await?;
|
||||
sqlx::query("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS failed_login_count INTEGER NOT NULL DEFAULT 0")
|
||||
.execute(pool).await?;
|
||||
sqlx::query("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS locked_until TIMESTAMPTZ")
|
||||
.execute(pool).await?;
|
||||
tracing::info!("[DB] Security columns (password_version, failed_login_count, locked_until) applied");
|
||||
}
|
||||
|
||||
// 检查 rate_limit_events 表是否存在
|
||||
let table_exists: bool = sqlx::query_scalar(
|
||||
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'rate_limit_events')"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if !table_exists {
|
||||
tracing::warn!("[DB] 'rate_limit_events' table missing — applying rate limit migration");
|
||||
if let Err(e) = sqlx::query(
|
||||
"CREATE TABLE IF NOT EXISTS rate_limit_events (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
key TEXT NOT NULL,
|
||||
count BIGINT NOT NULL DEFAULT 1,
|
||||
window_start TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)"
|
||||
).execute(pool).await {
|
||||
tracing::warn!("[DB] Failed to create rate_limit_events: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容
|
||||
|
||||
591
crates/zclaw-saas/src/knowledge/handlers.rs
Normal file
591
crates/zclaw-saas/src/knowledge/handlers.rs
Normal file
@@ -0,0 +1,591 @@
|
||||
//! 知识库 HTTP 处理器
|
||||
|
||||
use axum::{
|
||||
extract::{Extension, Path, Query, State},
|
||||
Json,
|
||||
};
|
||||
|
||||
use crate::auth::types::AuthContext;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::state::AppState;
|
||||
use super::service;
|
||||
use super::types::*;
|
||||
|
||||
// === 分类管理 ===
|
||||
|
||||
/// GET /api/v1/knowledge/categories — 树形分类列表
|
||||
pub async fn list_categories(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<Vec<CategoryResponse>>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
let tree = service::list_categories_tree(&state.db).await?;
|
||||
Ok(Json(tree))
|
||||
}
|
||||
|
||||
/// POST /api/v1/knowledge/categories — 创建分类
|
||||
pub async fn create_category(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<CreateCategoryRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:write")?;
|
||||
|
||||
if req.name.trim().is_empty() {
|
||||
return Err(SaasError::InvalidInput("分类名称不能为空".into()));
|
||||
}
|
||||
|
||||
let cat = service::create_category(
|
||||
&state.db,
|
||||
req.name.trim(),
|
||||
req.description.as_deref(),
|
||||
req.parent_id.as_deref(),
|
||||
req.icon.as_deref(),
|
||||
).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"id": cat.id,
|
||||
"name": cat.name,
|
||||
})))
|
||||
}
|
||||
|
||||
/// PUT /api/v1/knowledge/categories/:id — 更新分类
|
||||
pub async fn update_category(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<UpdateCategoryRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:write")?;
|
||||
|
||||
if let Some(ref name) = req.name {
|
||||
if name.trim().is_empty() {
|
||||
return Err(SaasError::InvalidInput("分类名称不能为空".into()));
|
||||
}
|
||||
}
|
||||
|
||||
let cat = service::update_category(
|
||||
&state.db,
|
||||
&id,
|
||||
req.name.as_deref().map(|n| n.trim()),
|
||||
req.description.as_deref(),
|
||||
req.parent_id.as_deref(),
|
||||
req.icon.as_deref(),
|
||||
).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"id": cat.id,
|
||||
"name": cat.name,
|
||||
"updated": true,
|
||||
})))
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/knowledge/categories/:id — 删除分类
|
||||
pub async fn delete_category(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(id): Path<String>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:admin")?;
|
||||
service::delete_category(&state.db, &id).await?;
|
||||
Ok(Json(serde_json::json!({"deleted": true})))
|
||||
}
|
||||
|
||||
/// GET /api/v1/knowledge/categories/:id/items — 分类下条目列表
|
||||
pub async fn list_category_items(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(id): Path<String>,
|
||||
Query(query): Query<ListItemsQuery>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
let page = query.page.unwrap_or(1).max(1);
|
||||
let page_size = query.page_size.unwrap_or(20).max(1).min(100);
|
||||
let status_filter = query.status.as_deref().unwrap_or("active");
|
||||
|
||||
let (items, total) = service::list_items_by_category(
|
||||
&state.db,
|
||||
&id,
|
||||
status_filter,
|
||||
page,
|
||||
page_size,
|
||||
).await?;
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"items": items,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
})))
|
||||
}
|
||||
|
||||
// === 知识条目 CRUD ===
|
||||
|
||||
/// GET /api/v1/knowledge/items — 分页列表
|
||||
pub async fn list_items(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Query(query): Query<ListItemsQuery>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
let page = query.page.unwrap_or(1).max(1).min(10000);
|
||||
let page_size = query.page_size.unwrap_or(20).max(1).min(100);
|
||||
let offset = (page - 1) * page_size;
|
||||
|
||||
// 转义 ILIKE 通配符,防止用户输入的 % 和 _ 被当作通配符
|
||||
let keyword = query.keyword.as_ref().map(|k| {
|
||||
k.replace('\\', "\\\\").replace('%', "\\%").replace('_', "\\_")
|
||||
});
|
||||
|
||||
let items: Vec<KnowledgeItem> = sqlx::query_as(
|
||||
"SELECT ki.* FROM knowledge_items ki \
|
||||
JOIN knowledge_categories kc ON ki.category_id = kc.id \
|
||||
WHERE ($1::text IS NULL OR ki.category_id = $1) \
|
||||
AND ($2::text IS NULL OR ki.status = $2) \
|
||||
AND ($3::text IS NULL OR ki.title ILIKE '%' || $3 || '%') \
|
||||
ORDER BY ki.priority DESC, ki.updated_at DESC \
|
||||
LIMIT $4 OFFSET $5"
|
||||
)
|
||||
.bind(&query.category_id)
|
||||
.bind(&query.status)
|
||||
.bind(&keyword)
|
||||
.bind(page_size)
|
||||
.bind(offset)
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
|
||||
let total: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_items ki \
|
||||
WHERE ($1::text IS NULL OR ki.category_id = $1) \
|
||||
AND ($2::text IS NULL OR ki.status = $2) \
|
||||
AND ($3::text IS NULL OR ki.title ILIKE '%' || $3 || '%')"
|
||||
)
|
||||
.bind(&query.category_id)
|
||||
.bind(&query.status)
|
||||
.bind(&keyword)
|
||||
.fetch_one(&state.db)
|
||||
.await?;
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"items": items,
|
||||
"total": total.0,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
})))
|
||||
}
|
||||
|
||||
/// POST /api/v1/knowledge/items — 创建条目
|
||||
pub async fn create_item(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<CreateItemRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:write")?;
|
||||
|
||||
if req.title.trim().is_empty() || req.content.trim().is_empty() {
|
||||
return Err(SaasError::InvalidInput("标题和内容不能为空".into()));
|
||||
}
|
||||
|
||||
if req.content.len() > 100_000 {
|
||||
return Err(SaasError::InvalidInput("内容不能超过 100KB".into()));
|
||||
}
|
||||
|
||||
let item = service::create_item(&state.db, &ctx.account_id, &req).await?;
|
||||
|
||||
// 异步触发 embedding 生成
|
||||
if let Err(e) = state.worker_dispatcher.dispatch(
|
||||
"generate_embedding",
|
||||
serde_json::json!({ "item_id": item.id }),
|
||||
).await {
|
||||
tracing::warn!("Failed to dispatch embedding generation: {}", e);
|
||||
}
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"id": item.id,
|
||||
"title": item.title,
|
||||
"version": item.version,
|
||||
})))
|
||||
}
|
||||
|
||||
/// POST /api/v1/knowledge/items/batch — 批量创建
|
||||
pub async fn batch_create_items(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(items): Json<Vec<CreateItemRequest>>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:write")?;
|
||||
|
||||
if items.len() > 50 {
|
||||
return Err(SaasError::InvalidInput("单次批量创建不能超过 50 条".into()));
|
||||
}
|
||||
|
||||
let mut created = Vec::new();
|
||||
for req in &items {
|
||||
if req.title.trim().is_empty() || req.content.trim().is_empty() {
|
||||
tracing::warn!("Batch create: skipping item with empty title or content");
|
||||
continue;
|
||||
}
|
||||
if req.content.len() > 100_000 {
|
||||
tracing::warn!("Batch create: skipping item '{}' (content too long)", req.title);
|
||||
continue;
|
||||
}
|
||||
match service::create_item(&state.db, &ctx.account_id, req).await {
|
||||
Ok(item) => {
|
||||
let _ = state.worker_dispatcher.dispatch(
|
||||
"generate_embedding",
|
||||
serde_json::json!({ "item_id": item.id }),
|
||||
).await.map_err(|e| {
|
||||
tracing::warn!("[Knowledge] Failed to dispatch embedding for item {}: {}", item.id, e);
|
||||
e
|
||||
});
|
||||
created.push(item.id);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Batch create item failed: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"created_count": created.len(),
|
||||
"ids": created,
|
||||
})))
|
||||
}
|
||||
|
||||
/// GET /api/v1/knowledge/items/:id — 条目详情
|
||||
pub async fn get_item(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(id): Path<String>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
let item = service::get_item(&state.db, &id).await?
|
||||
.ok_or_else(|| SaasError::NotFound("知识条目不存在".into()))?;
|
||||
Ok(Json(serde_json::to_value(item).unwrap_or_default()))
|
||||
}
|
||||
|
||||
/// PUT /api/v1/knowledge/items/:id — 更新条目
|
||||
pub async fn update_item(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(id): Path<String>,
|
||||
Json(req): Json<UpdateItemRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:write")?;
|
||||
|
||||
if let Some(ref content) = req.content {
|
||||
if content.len() > 100_000 {
|
||||
return Err(SaasError::InvalidInput("内容不能超过 100KB".into()));
|
||||
}
|
||||
}
|
||||
|
||||
let updated = service::update_item(&state.db, &id, &ctx.account_id, &req).await?;
|
||||
|
||||
// 触发 re-embedding
|
||||
if let Err(e) = state.worker_dispatcher.dispatch(
|
||||
"generate_embedding",
|
||||
serde_json::json!({ "item_id": id }),
|
||||
).await {
|
||||
tracing::warn!("[Knowledge] Failed to dispatch re-embedding for item {}: {}", id, e);
|
||||
}
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"id": updated.id,
|
||||
"version": updated.version,
|
||||
})))
|
||||
}
|
||||
|
||||
/// DELETE /api/v1/knowledge/items/:id — 删除条目
|
||||
pub async fn delete_item(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(id): Path<String>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:admin")?;
|
||||
service::delete_item(&state.db, &id).await?;
|
||||
Ok(Json(serde_json::json!({"deleted": true})))
|
||||
}
|
||||
|
||||
// === 版本控制 ===
|
||||
|
||||
/// GET /api/v1/knowledge/items/:id/versions
|
||||
pub async fn list_versions(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path(id): Path<String>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
let versions: Vec<KnowledgeVersion> = sqlx::query_as(
|
||||
"SELECT * FROM knowledge_versions WHERE item_id = $1 ORDER BY version DESC"
|
||||
)
|
||||
.bind(&id)
|
||||
.fetch_all(&state.db)
|
||||
.await?;
|
||||
Ok(Json(serde_json::json!({"versions": versions})))
|
||||
}
|
||||
|
||||
/// GET /api/v1/knowledge/items/:id/versions/:v
|
||||
pub async fn get_version(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path((id, v)): Path<(String, i32)>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
let version: KnowledgeVersion = sqlx::query_as(
|
||||
"SELECT * FROM knowledge_versions WHERE item_id = $1 AND version = $2"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(v)
|
||||
.fetch_optional(&state.db)
|
||||
.await?
|
||||
.ok_or_else(|| SaasError::NotFound("版本不存在".into()))?;
|
||||
Ok(Json(serde_json::to_value(version).unwrap_or_default()))
|
||||
}
|
||||
|
||||
/// POST /api/v1/knowledge/items/:id/rollback/:v
|
||||
pub async fn rollback_version(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Path((id, v)): Path<(String, i32)>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:admin")?;
|
||||
|
||||
let updated = service::rollback_version(&state.db, &id, v, &ctx.account_id).await?;
|
||||
|
||||
// 触发 re-embedding
|
||||
if let Err(e) = state.worker_dispatcher.dispatch(
|
||||
"generate_embedding",
|
||||
serde_json::json!({ "item_id": id }),
|
||||
).await {
|
||||
tracing::warn!("[Knowledge] Failed to dispatch re-embedding after rollback for item {}: {}", id, e);
|
||||
}
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"id": updated.id,
|
||||
"version": updated.version,
|
||||
"rolled_back_to": v,
|
||||
})))
|
||||
}
|
||||
|
||||
// === 检索 ===
|
||||
|
||||
/// POST /api/v1/knowledge/search — 语义搜索
|
||||
pub async fn search(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<SearchRequest>,
|
||||
) -> SaasResult<Json<Vec<SearchResult>>> {
|
||||
check_permission(&ctx, "knowledge:search")?;
|
||||
let limit = req.limit.unwrap_or(5).min(10);
|
||||
let min_score = req.min_score.unwrap_or(0.5);
|
||||
let results = service::search(
|
||||
&state.db,
|
||||
&req.query,
|
||||
req.category_id.as_deref(),
|
||||
limit,
|
||||
min_score,
|
||||
).await?;
|
||||
Ok(Json(results))
|
||||
}
|
||||
|
||||
/// POST /api/v1/knowledge/recommend — 关联推荐
|
||||
pub async fn recommend(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<SearchRequest>,
|
||||
) -> SaasResult<Json<Vec<SearchResult>>> {
|
||||
check_permission(&ctx, "knowledge:search")?;
|
||||
let limit = req.limit.unwrap_or(5).min(10);
|
||||
let results = service::search(
|
||||
&state.db,
|
||||
&req.query,
|
||||
req.category_id.as_deref(),
|
||||
limit,
|
||||
0.3,
|
||||
).await?;
|
||||
Ok(Json(results))
|
||||
}
|
||||
|
||||
// === 分析看板 ===
|
||||
|
||||
/// GET /api/v1/knowledge/analytics/overview
|
||||
pub async fn analytics_overview(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<AnalyticsOverview>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
let overview = service::analytics_overview(&state.db).await?;
|
||||
Ok(Json(overview))
|
||||
}
|
||||
|
||||
/// GET /api/v1/knowledge/analytics/trends
|
||||
pub async fn analytics_trends(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
// 使用 serde_json::Value 行来避免 PgRow 序列化
|
||||
let trends: Vec<(serde_json::Value,)> = sqlx::query_as(
|
||||
"SELECT json_build_object(
|
||||
'date', DATE(created_at),
|
||||
'count', COUNT(*),
|
||||
'injected_count', SUM(CASE WHEN was_injected THEN 1 ELSE 0 END)
|
||||
) as row \
|
||||
FROM knowledge_usage \
|
||||
WHERE created_at >= NOW() - interval '30 days' \
|
||||
GROUP BY DATE(created_at) ORDER BY DATE(created_at)"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let trends: Vec<serde_json::Value> = trends.into_iter().map(|(v,)| v).collect();
|
||||
Ok(Json(serde_json::json!({"trends": trends})))
|
||||
}
|
||||
|
||||
/// GET /api/v1/knowledge/analytics/top-items
|
||||
pub async fn analytics_top_items(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
let items: Vec<(serde_json::Value,)> = sqlx::query_as(
|
||||
"SELECT json_build_object(
|
||||
'id', ki.id,
|
||||
'title', ki.title,
|
||||
'category', kc.name,
|
||||
'ref_count', COUNT(ku.id)
|
||||
) as row \
|
||||
FROM knowledge_items ki \
|
||||
JOIN knowledge_categories kc ON ki.category_id = kc.id \
|
||||
LEFT JOIN knowledge_usage ku ON ku.item_id = ki.id \
|
||||
WHERE ki.status = 'active' \
|
||||
GROUP BY ki.id, ki.title, kc.name \
|
||||
ORDER BY COUNT(ku.id) DESC LIMIT 20"
|
||||
)
|
||||
.fetch_all(&state.db)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
let items: Vec<serde_json::Value> = items.into_iter().map(|(v,)| v).collect();
|
||||
Ok(Json(serde_json::json!({"items": items})))
|
||||
}
|
||||
|
||||
/// GET /api/v1/knowledge/analytics/quality
|
||||
pub async fn analytics_quality(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
let quality = service::analytics_quality(&state.db).await?;
|
||||
Ok(Json(quality))
|
||||
}
|
||||
|
||||
/// GET /api/v1/knowledge/analytics/gaps
|
||||
pub async fn analytics_gaps(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:read")?;
|
||||
let gaps = service::analytics_gaps(&state.db).await?;
|
||||
Ok(Json(gaps))
|
||||
}
|
||||
|
||||
// === 批量操作 ===
|
||||
|
||||
/// PATCH /api/v1/knowledge/categories/reorder — 批量排序
|
||||
pub async fn reorder_categories(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(items): Json<Vec<ReorderItem>>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:write")?;
|
||||
|
||||
if items.is_empty() {
|
||||
return Ok(Json(serde_json::json!({"reordered": false, "count": 0})));
|
||||
}
|
||||
if items.len() > 100 {
|
||||
return Err(SaasError::InvalidInput("单次排序不能超过 100 个".into()));
|
||||
}
|
||||
|
||||
// 使用事务保证原子性
|
||||
let mut tx = state.db.begin().await?;
|
||||
for item in &items {
|
||||
sqlx::query("UPDATE knowledge_categories SET sort_order = $1, updated_at = NOW() WHERE id = $2")
|
||||
.bind(item.sort_order)
|
||||
.bind(&item.id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(Json(serde_json::json!({"reordered": true, "count": items.len()})))
|
||||
}
|
||||
|
||||
/// POST /api/v1/knowledge/items/import — Markdown 文件导入
|
||||
pub async fn import_items(
|
||||
State(state): State<AppState>,
|
||||
Extension(ctx): Extension<AuthContext>,
|
||||
Json(req): Json<ImportRequest>,
|
||||
) -> SaasResult<Json<serde_json::Value>> {
|
||||
check_permission(&ctx, "knowledge:write")?;
|
||||
|
||||
if req.files.len() > 20 {
|
||||
return Err(SaasError::InvalidInput("单次导入不能超过 20 个文件".into()));
|
||||
}
|
||||
|
||||
let mut created = Vec::new();
|
||||
for file in &req.files {
|
||||
// 内容长度检查(数据库限制 100KB)
|
||||
if file.content.len() > 100_000 {
|
||||
tracing::warn!("跳过文件 '{}': 内容超长 ({} bytes)", file.title.as_deref().unwrap_or("未命名"), file.content.len());
|
||||
continue;
|
||||
}
|
||||
// 空内容检查
|
||||
if file.content.trim().is_empty() {
|
||||
tracing::warn!("跳过空文件: '{}'", file.title.as_deref().unwrap_or("未命名"));
|
||||
continue;
|
||||
}
|
||||
|
||||
let title = file.title.clone().unwrap_or_else(|| {
|
||||
file.content.lines().next()
|
||||
.map(|l| l.trim_start_matches('#').trim().to_string())
|
||||
.unwrap_or_else(|| format!("导入条目 {}", created.len() + 1))
|
||||
});
|
||||
|
||||
let item_req = CreateItemRequest {
|
||||
category_id: req.category_id.clone(),
|
||||
title,
|
||||
content: file.content.clone(),
|
||||
keywords: file.keywords.clone(),
|
||||
related_questions: None,
|
||||
priority: None,
|
||||
tags: file.tags.clone(),
|
||||
};
|
||||
|
||||
match service::create_item(&state.db, &ctx.account_id, &item_req).await {
|
||||
Ok(item) => {
|
||||
let _ = state.worker_dispatcher.dispatch(
|
||||
"generate_embedding",
|
||||
serde_json::json!({ "item_id": item.id }),
|
||||
).await.map_err(|e| {
|
||||
tracing::warn!("[Knowledge] Failed to dispatch embedding for item {}: {}", item.id, e);
|
||||
e
|
||||
});
|
||||
created.push(item.id);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Import item '{}' failed: {}", item_req.title, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(serde_json::json!({
|
||||
"created_count": created.len(),
|
||||
"ids": created,
|
||||
})))
|
||||
}
|
||||
|
||||
// === 辅助函数 ===
|
||||
|
||||
fn check_permission(ctx: &AuthContext, permission: &str) -> SaasResult<()> {
|
||||
crate::auth::handlers::check_permission(ctx, permission)
|
||||
}
|
||||
39
crates/zclaw-saas/src/knowledge/mod.rs
Normal file
39
crates/zclaw-saas/src/knowledge/mod.rs
Normal file
@@ -0,0 +1,39 @@
|
||||
//! 知识库模块 — 行业知识管理、RAG 检索、版本控制
|
||||
|
||||
pub mod types;
|
||||
pub mod service;
|
||||
pub mod handlers;
|
||||
|
||||
use axum::routing::{delete, get, patch, post, put};
|
||||
|
||||
pub fn routes() -> axum::Router<crate::state::AppState> {
|
||||
axum::Router::new()
|
||||
// 分类管理
|
||||
.route("/api/v1/knowledge/categories", get(handlers::list_categories))
|
||||
.route("/api/v1/knowledge/categories", post(handlers::create_category))
|
||||
.route("/api/v1/knowledge/categories/{id}", put(handlers::update_category))
|
||||
.route("/api/v1/knowledge/categories/{id}", delete(handlers::delete_category))
|
||||
.route("/api/v1/knowledge/categories/{id}/items", get(handlers::list_category_items))
|
||||
.route("/api/v1/knowledge/categories/reorder", patch(handlers::reorder_categories))
|
||||
// 知识条目 CRUD
|
||||
.route("/api/v1/knowledge/items", get(handlers::list_items))
|
||||
.route("/api/v1/knowledge/items", post(handlers::create_item))
|
||||
.route("/api/v1/knowledge/items/batch", post(handlers::batch_create_items))
|
||||
.route("/api/v1/knowledge/items/import", post(handlers::import_items))
|
||||
.route("/api/v1/knowledge/items/{id}", get(handlers::get_item))
|
||||
.route("/api/v1/knowledge/items/{id}", put(handlers::update_item))
|
||||
.route("/api/v1/knowledge/items/{id}", delete(handlers::delete_item))
|
||||
// 版本控制
|
||||
.route("/api/v1/knowledge/items/{id}/versions", get(handlers::list_versions))
|
||||
.route("/api/v1/knowledge/items/{id}/versions/{v}", get(handlers::get_version))
|
||||
.route("/api/v1/knowledge/items/{id}/rollback/{v}", post(handlers::rollback_version))
|
||||
// 检索
|
||||
.route("/api/v1/knowledge/search", post(handlers::search))
|
||||
.route("/api/v1/knowledge/recommend", post(handlers::recommend))
|
||||
// 分析看板
|
||||
.route("/api/v1/knowledge/analytics/overview", get(handlers::analytics_overview))
|
||||
.route("/api/v1/knowledge/analytics/trends", get(handlers::analytics_trends))
|
||||
.route("/api/v1/knowledge/analytics/top-items", get(handlers::analytics_top_items))
|
||||
.route("/api/v1/knowledge/analytics/quality", get(handlers::analytics_quality))
|
||||
.route("/api/v1/knowledge/analytics/gaps", get(handlers::analytics_gaps))
|
||||
}
|
||||
783
crates/zclaw-saas/src/knowledge/service.rs
Normal file
783
crates/zclaw-saas/src/knowledge/service.rs
Normal file
@@ -0,0 +1,783 @@
|
||||
//! 知识库服务层 — CRUD、检索、分析
|
||||
|
||||
use sqlx::PgPool;
|
||||
use crate::error::SaasResult;
|
||||
use super::types::*;
|
||||
|
||||
// === 分类管理 ===
|
||||
|
||||
/// 获取分类树(带条目计数)
|
||||
pub async fn list_categories_tree(pool: &PgPool) -> SaasResult<Vec<CategoryResponse>> {
|
||||
let categories: Vec<KnowledgeCategory> = sqlx::query_as(
|
||||
"SELECT * FROM knowledge_categories ORDER BY sort_order, name"
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
// 获取每个分类的条目计数
|
||||
let counts: Vec<(String, i64)> = sqlx::query_as(
|
||||
"SELECT category_id, COUNT(*) FROM knowledge_items WHERE status = 'active' GROUP BY category_id"
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let count_map: std::collections::HashMap<String, i64> = counts.into_iter().collect();
|
||||
|
||||
// 构建树形结构
|
||||
let mut roots = Vec::new();
|
||||
let mut all: Vec<CategoryResponse> = categories.into_iter().map(|c| {
|
||||
let count = *count_map.get(&c.id).unwrap_or(&0);
|
||||
CategoryResponse {
|
||||
id: c.id,
|
||||
name: c.name,
|
||||
description: c.description,
|
||||
parent_id: c.parent_id,
|
||||
icon: c.icon,
|
||||
sort_order: c.sort_order,
|
||||
item_count: count,
|
||||
children: Vec::new(),
|
||||
created_at: c.created_at.to_rfc3339(),
|
||||
updated_at: c.updated_at.to_rfc3339(),
|
||||
}
|
||||
}).collect();
|
||||
|
||||
// 构建子节点映射
|
||||
let mut children_map: std::collections::HashMap<String, Vec<CategoryResponse>> =
|
||||
std::collections::HashMap::new();
|
||||
|
||||
for cat in all.drain(..) {
|
||||
if let Some(ref parent_id) = cat.parent_id {
|
||||
children_map.entry(parent_id.clone()).or_default().push(cat);
|
||||
} else {
|
||||
roots.push(cat);
|
||||
}
|
||||
}
|
||||
|
||||
// 递归填充子节点
|
||||
fn fill_children(
|
||||
cats: &mut Vec<CategoryResponse>,
|
||||
children_map: &mut std::collections::HashMap<String, Vec<CategoryResponse>>,
|
||||
) {
|
||||
for cat in cats.iter_mut() {
|
||||
if let Some(children) = children_map.remove(&cat.id) {
|
||||
cat.children = children;
|
||||
fill_children(&mut cat.children, children_map);
|
||||
}
|
||||
// 累加子节点条目数到父节点
|
||||
let child_count: i64 = cat.children.iter().map(|c| c.item_count).sum();
|
||||
cat.item_count += child_count;
|
||||
}
|
||||
}
|
||||
|
||||
fill_children(&mut roots, &mut children_map);
|
||||
Ok(roots)
|
||||
}
|
||||
|
||||
/// 创建分类
|
||||
pub async fn create_category(
|
||||
pool: &PgPool,
|
||||
name: &str,
|
||||
description: Option<&str>,
|
||||
parent_id: Option<&str>,
|
||||
icon: Option<&str>,
|
||||
) -> SaasResult<KnowledgeCategory> {
|
||||
// 验证 parent_id 存在性
|
||||
if let Some(pid) = parent_id {
|
||||
let exists: bool = sqlx::query_scalar(
|
||||
"SELECT EXISTS(SELECT 1 FROM knowledge_categories WHERE id = $1)"
|
||||
)
|
||||
.bind(pid)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
if !exists {
|
||||
return Err(crate::error::SaasError::InvalidInput(
|
||||
format!("父分类 '{}' 不存在", pid),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let category = sqlx::query_as::<_, KnowledgeCategory>(
|
||||
"INSERT INTO knowledge_categories (id, name, description, parent_id, icon) \
|
||||
VALUES ($1, $2, $3, $4, $5) RETURNING *"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(name)
|
||||
.bind(description)
|
||||
.bind(parent_id)
|
||||
.bind(icon)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
Ok(category)
|
||||
}
|
||||
|
||||
/// 删除分类(有子分类或条目时拒绝)
|
||||
pub async fn delete_category(pool: &PgPool, category_id: &str) -> SaasResult<()> {
|
||||
// 检查子分类
|
||||
let child_count: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_categories WHERE parent_id = $1"
|
||||
)
|
||||
.bind(category_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if child_count.0 > 0 {
|
||||
return Err(crate::error::SaasError::InvalidInput(
|
||||
"该分类下有子分类,无法删除".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// 检查条目
|
||||
let item_count: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_items WHERE category_id = $1"
|
||||
)
|
||||
.bind(category_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
if item_count.0 > 0 {
|
||||
return Err(crate::error::SaasError::InvalidInput(
|
||||
"该分类下有知识条目,无法删除".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let result = sqlx::query("DELETE FROM knowledge_categories WHERE id = $1")
|
||||
.bind(category_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(crate::error::SaasError::NotFound("分类不存在".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 更新分类(含循环引用检测 + 深度限制)
|
||||
pub async fn update_category(
|
||||
pool: &PgPool,
|
||||
category_id: &str,
|
||||
name: Option<&str>,
|
||||
description: Option<&str>,
|
||||
parent_id: Option<&str>,
|
||||
icon: Option<&str>,
|
||||
) -> SaasResult<KnowledgeCategory> {
|
||||
if let Some(pid) = parent_id {
|
||||
if pid == category_id {
|
||||
return Err(crate::error::SaasError::InvalidInput(
|
||||
"分类不能成为自身的子分类".into(),
|
||||
));
|
||||
}
|
||||
// 检查新的父级不是当前分类的后代(循环检测)
|
||||
let mut check_id = pid.to_string();
|
||||
let mut depth = 0;
|
||||
loop {
|
||||
if check_id == category_id {
|
||||
return Err(crate::error::SaasError::InvalidInput(
|
||||
"循环引用:父级分类不能是当前分类的后代".into(),
|
||||
));
|
||||
}
|
||||
let parent: Option<(Option<String>,)> = sqlx::query_as(
|
||||
"SELECT parent_id FROM knowledge_categories WHERE id = $1"
|
||||
)
|
||||
.bind(&check_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
|
||||
match parent {
|
||||
Some((Some(gp),)) => {
|
||||
check_id = gp;
|
||||
depth += 1;
|
||||
if depth > 10 { break; }
|
||||
}
|
||||
_ => break,
|
||||
}
|
||||
}
|
||||
|
||||
// 检查深度限制(最多 3 层)
|
||||
let mut current_depth = 0;
|
||||
let mut check = pid.to_string();
|
||||
while let Some((Some(p),)) = sqlx::query_as::<_, (Option<String>,)>(
|
||||
"SELECT parent_id FROM knowledge_categories WHERE id = $1"
|
||||
)
|
||||
.bind(&check)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
{
|
||||
check = p;
|
||||
current_depth += 1;
|
||||
if current_depth > 10 { break; }
|
||||
}
|
||||
if current_depth >= 3 {
|
||||
return Err(crate::error::SaasError::InvalidInput(
|
||||
"分类层级不能超过 3 层".into(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let category = sqlx::query_as::<_, KnowledgeCategory>(
|
||||
"UPDATE knowledge_categories SET \
|
||||
name = COALESCE($1, name), \
|
||||
description = COALESCE($2, description), \
|
||||
parent_id = COALESCE($3, parent_id), \
|
||||
icon = COALESCE($4, icon), \
|
||||
updated_at = NOW() \
|
||||
WHERE id = $5 RETURNING *"
|
||||
)
|
||||
.bind(name)
|
||||
.bind(description)
|
||||
.bind(parent_id)
|
||||
.bind(icon)
|
||||
.bind(category_id)
|
||||
.fetch_optional(pool)
|
||||
.await?
|
||||
.ok_or_else(|| crate::error::SaasError::NotFound("分类不存在".into()))?;
|
||||
|
||||
Ok(category)
|
||||
}
|
||||
|
||||
// === 知识条目 CRUD ===
|
||||
|
||||
/// 按分类分页查询条目列表
|
||||
pub async fn list_items_by_category(
|
||||
pool: &PgPool,
|
||||
category_id: &str,
|
||||
status_filter: &str,
|
||||
page: i64,
|
||||
page_size: i64,
|
||||
) -> SaasResult<(Vec<KnowledgeItem>, i64)> {
|
||||
let offset = (page - 1) * page_size;
|
||||
|
||||
let items: Vec<KnowledgeItem> = sqlx::query_as(
|
||||
"SELECT * FROM knowledge_items \
|
||||
WHERE category_id = $1 AND status = $2 \
|
||||
ORDER BY priority DESC, updated_at DESC \
|
||||
LIMIT $3 OFFSET $4"
|
||||
)
|
||||
.bind(category_id)
|
||||
.bind(status_filter)
|
||||
.bind(page_size)
|
||||
.bind(offset)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
|
||||
let total: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_items WHERE category_id = $1 AND status = $2"
|
||||
)
|
||||
.bind(category_id)
|
||||
.bind(status_filter)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
Ok((items, total.0))
|
||||
}
|
||||
|
||||
/// 创建知识条目
|
||||
pub async fn create_item(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
req: &CreateItemRequest,
|
||||
) -> SaasResult<KnowledgeItem> {
|
||||
let id = uuid::Uuid::new_v4().to_string();
|
||||
let keywords = req.keywords.as_deref().unwrap_or(&[]);
|
||||
let related_questions = req.related_questions.as_deref().unwrap_or(&[]);
|
||||
let priority = req.priority.unwrap_or(0);
|
||||
let tags = req.tags.as_deref().unwrap_or(&[]);
|
||||
|
||||
// 验证 category_id 存在性
|
||||
let cat_exists: bool = sqlx::query_scalar(
|
||||
"SELECT EXISTS(SELECT 1 FROM knowledge_categories WHERE id = $1)"
|
||||
)
|
||||
.bind(&req.category_id)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
if !cat_exists {
|
||||
return Err(crate::error::SaasError::InvalidInput(
|
||||
format!("分类 '{}' 不存在", req.category_id),
|
||||
));
|
||||
}
|
||||
|
||||
// 使用事务保证 item + version 原子性
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
let item = sqlx::query_as::<_, KnowledgeItem>(
|
||||
"INSERT INTO knowledge_items \
|
||||
(id, category_id, title, content, keywords, related_questions, priority, tags, created_by) \
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \
|
||||
RETURNING *"
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(&req.category_id)
|
||||
.bind(&req.title)
|
||||
.bind(&req.content)
|
||||
.bind(keywords)
|
||||
.bind(related_questions)
|
||||
.bind(priority)
|
||||
.bind(tags)
|
||||
.bind(account_id)
|
||||
.fetch_one(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 创建初始版本快照
|
||||
let version_id = uuid::Uuid::new_v4().to_string();
|
||||
sqlx::query(
|
||||
"INSERT INTO knowledge_versions \
|
||||
(id, item_id, version, title, content, keywords, related_questions, created_by) \
|
||||
VALUES ($1, $2, 1, $3, $4, $5, $6, $7)"
|
||||
)
|
||||
.bind(&version_id)
|
||||
.bind(&id)
|
||||
.bind(&req.title)
|
||||
.bind(&req.content)
|
||||
.bind(keywords)
|
||||
.bind(related_questions)
|
||||
.bind(account_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(item)
|
||||
}
|
||||
|
||||
/// 获取条目详情
|
||||
pub async fn get_item(pool: &PgPool, item_id: &str) -> SaasResult<Option<KnowledgeItem>> {
|
||||
let item = sqlx::query_as::<_, KnowledgeItem>(
|
||||
"SELECT * FROM knowledge_items WHERE id = $1"
|
||||
)
|
||||
.bind(item_id)
|
||||
.fetch_optional(pool)
|
||||
.await?;
|
||||
Ok(item)
|
||||
}
|
||||
|
||||
/// 更新条目(含版本快照)— 事务保护防止并发竞态
|
||||
pub async fn update_item(
|
||||
pool: &PgPool,
|
||||
item_id: &str,
|
||||
account_id: &str,
|
||||
req: &UpdateItemRequest,
|
||||
) -> SaasResult<KnowledgeItem> {
|
||||
// status 验证在事务之前,避免无谓锁占用
|
||||
const VALID_STATUSES: &[&str] = &["active", "draft", "archived", "deprecated"];
|
||||
if let Some(ref status) = &req.status {
|
||||
if !VALID_STATUSES.contains(&status.as_str()) {
|
||||
return Err(crate::error::SaasError::InvalidInput(
|
||||
format!("无效的状态值: {},有效值: {}", status, VALID_STATUSES.join(", "))
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
// 获取当前条目并锁定行防止并发修改
|
||||
let current = sqlx::query_as::<_, KnowledgeItem>(
|
||||
"SELECT * FROM knowledge_items WHERE id = $1 FOR UPDATE"
|
||||
)
|
||||
.bind(item_id)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?
|
||||
.ok_or_else(|| crate::error::SaasError::NotFound("知识条目不存在".into()))?;
|
||||
|
||||
// 合并更新
|
||||
let title = req.title.as_deref().unwrap_or(¤t.title);
|
||||
let content = req.content.as_deref().unwrap_or(¤t.content);
|
||||
let keywords: Vec<String> = req.keywords.as_ref()
|
||||
.or(Some(¤t.keywords))
|
||||
.unwrap_or(&vec![])
|
||||
.clone();
|
||||
let related_questions: Vec<String> = req.related_questions.as_ref()
|
||||
.or(Some(¤t.related_questions))
|
||||
.unwrap_or(&vec![])
|
||||
.clone();
|
||||
let priority = req.priority.unwrap_or(current.priority);
|
||||
let tags: Vec<String> = req.tags.as_ref()
|
||||
.or(Some(¤t.tags))
|
||||
.unwrap_or(&vec![])
|
||||
.clone();
|
||||
|
||||
|
||||
// 更新条目
|
||||
let updated = sqlx::query_as::<_, KnowledgeItem>(
|
||||
"UPDATE knowledge_items SET \
|
||||
title = $1, content = $2, keywords = $3, related_questions = $4, \
|
||||
priority = $5, tags = $6, status = COALESCE($7, status), \
|
||||
version = version + 1, updated_at = NOW() \
|
||||
WHERE id = $8 RETURNING *"
|
||||
)
|
||||
.bind(title)
|
||||
.bind(content)
|
||||
.bind(&keywords)
|
||||
.bind(&related_questions)
|
||||
.bind(priority)
|
||||
.bind(&tags)
|
||||
.bind(req.status.as_deref())
|
||||
.bind(item_id)
|
||||
.fetch_one(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 创建版本快照
|
||||
let version_id = uuid::Uuid::new_v4().to_string();
|
||||
sqlx::query(
|
||||
"INSERT INTO knowledge_versions \
|
||||
(id, item_id, version, title, content, keywords, related_questions, \
|
||||
change_summary, created_by) \
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
)
|
||||
.bind(&version_id)
|
||||
.bind(item_id)
|
||||
.bind(updated.version)
|
||||
.bind(title)
|
||||
.bind(content)
|
||||
.bind(&keywords)
|
||||
.bind(&related_questions)
|
||||
.bind(req.change_summary.as_deref())
|
||||
.bind(account_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// 删除条目(级联删除 chunks + versions)
|
||||
pub async fn delete_item(pool: &PgPool, item_id: &str) -> SaasResult<()> {
|
||||
let result = sqlx::query("DELETE FROM knowledge_items WHERE id = $1")
|
||||
.bind(item_id)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err(crate::error::SaasError::NotFound("知识条目不存在".into()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// === 分块 ===
|
||||
|
||||
/// 将内容按 Markdown 标题 + 固定长度分块
|
||||
pub fn chunk_content(content: &str, max_tokens: usize, overlap: usize) -> Vec<String> {
|
||||
// 先按 Markdown 标题分段
|
||||
let sections: Vec<&str> = content.split("\n# ").collect();
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
for (i, section) in sections.iter().enumerate() {
|
||||
// 第一个片段保留原始内容,其余片段重新添加标题标记
|
||||
let section_content = if i == 0 {
|
||||
section.to_string()
|
||||
} else {
|
||||
format!("# {}", section)
|
||||
};
|
||||
|
||||
// 磁盘估算 token(中文约 1.5 字符/token)
|
||||
let estimated_tokens = section_content.len() / 2;
|
||||
|
||||
if estimated_tokens <= max_tokens {
|
||||
if !section_content.trim().is_empty() {
|
||||
chunks.push(section_content.trim().to_string());
|
||||
}
|
||||
} else {
|
||||
// 超长段落按固定长度切分
|
||||
let chars: Vec<char> = section_content.chars().collect();
|
||||
let chunk_chars = max_tokens * 2; // 近似字符数
|
||||
let overlap_chars = overlap * 2;
|
||||
|
||||
let mut pos = 0;
|
||||
while pos < chars.len() {
|
||||
let end = (pos + chunk_chars).min(chars.len());
|
||||
let chunk_str: String = chars[pos..end].iter().collect();
|
||||
if !chunk_str.trim().is_empty() {
|
||||
chunks.push(chunk_str.trim().to_string());
|
||||
}
|
||||
pos = if end >= chars.len() { end} else { end.saturating_sub(overlap_chars) };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
chunks}
|
||||
|
||||
// === 搜索 ===
|
||||
|
||||
/// 语义搜索(向量 + 关键词混合)
|
||||
pub async fn search(
|
||||
pool: &PgPool,
|
||||
query: &str,
|
||||
category_id: Option<&str>,
|
||||
limit: i64,
|
||||
min_score: f64,
|
||||
) -> SaasResult<Vec<SearchResult>> {
|
||||
// 暂时使用关键词匹配(向量搜索需要 embedding 生成)
|
||||
let pattern = format!("%{}%", query.replace('\\', "\\\\").replace('%', "\\%").replace('_', "\\_"));
|
||||
|
||||
let results = if let Some(cat_id) = category_id {
|
||||
sqlx::query_as::<_, (String, String, String, String, String, Vec<String>)>(
|
||||
"SELECT kc.id, kc.item_id, ki.title, kcat.name, kc.content, kc.keywords \
|
||||
FROM knowledge_chunks kc \
|
||||
JOIN knowledge_items ki ON kc.item_id = ki.id \
|
||||
JOIN knowledge_categories kcat ON ki.category_id = kcat.id \
|
||||
WHERE ki.status = 'active' \
|
||||
AND ki.category_id = $1 \
|
||||
AND (kc.content ILIKE $2 OR $3 = ANY(kc.keywords)) \
|
||||
ORDER BY ki.priority DESC \
|
||||
LIMIT $4"
|
||||
)
|
||||
.bind(cat_id)
|
||||
.bind(&pattern)
|
||||
.bind(query)
|
||||
.bind(limit)
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
} else {
|
||||
sqlx::query_as::<_, (String, String, String, String, String, Vec<String>)>(
|
||||
"SELECT kc.id, kc.item_id, ki.title, kcat.name, kc.content, kc.keywords \
|
||||
FROM knowledge_chunks kc \
|
||||
JOIN knowledge_items ki ON kc.item_id = ki.id \
|
||||
JOIN knowledge_categories kcat ON ki.category_id = kcat.id \
|
||||
WHERE ki.status = 'active' \
|
||||
AND (kc.content ILIKE $1 OR $2 = ANY(kc.keywords)) \
|
||||
ORDER BY ki.priority DESC \
|
||||
LIMIT $3"
|
||||
)
|
||||
.bind(&pattern)
|
||||
.bind(query)
|
||||
.bind(limit)
|
||||
.fetch_all(pool)
|
||||
.await?
|
||||
};
|
||||
|
||||
Ok(results.into_iter().map(|(chunk_id, item_id, title, cat_name, content, keywords)| {
|
||||
// 基于关键词匹配数计算分数:匹配数 / 总查询关键词数
|
||||
let query_keywords: Vec<&str> = query.split_whitespace().collect();
|
||||
let matched_count = keywords.iter()
|
||||
.filter(|k| query_keywords.iter().any(|qk| k.to_lowercase().contains(&qk.to_lowercase())))
|
||||
.count();
|
||||
let score = if keywords.is_empty() || query_keywords.is_empty() {
|
||||
0.5
|
||||
} else {
|
||||
(matched_count as f64 / keywords.len().max(query_keywords.len()) as f64).min(1.0)
|
||||
};
|
||||
|
||||
SearchResult {
|
||||
chunk_id,
|
||||
item_id,
|
||||
item_title: title,
|
||||
category_name: cat_name,
|
||||
content,
|
||||
score,
|
||||
keywords,
|
||||
}
|
||||
}).filter(|r| r.score >= min_score).collect())
|
||||
}
|
||||
|
||||
// === 分析 ===
|
||||
|
||||
/// 分析总览
|
||||
pub async fn analytics_overview(pool: &PgPool) -> SaasResult<AnalyticsOverview> {
|
||||
let total_items: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_items"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let active_items: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_items WHERE status = 'active'"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let total_categories: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_categories"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let weekly_new: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_items WHERE created_at >= NOW() - interval '7 days'"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let total_refs: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_usage"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let injected: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_usage WHERE was_injected = true"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let positive: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_usage WHERE agent_feedback = 'positive'"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let with_feedback: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_usage WHERE agent_feedback IS NOT NULL"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let stale: (i64,) = sqlx::query_as(
|
||||
"SELECT COUNT(*) FROM knowledge_items ki \
|
||||
WHERE ki.status = 'active' \
|
||||
AND NOT EXISTS (SELECT 1 FROM knowledge_usage ku WHERE ku.item_id = ki.id AND ku.created_at >= NOW() - interval '90 days')"
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?;
|
||||
|
||||
let hit_rate = if total_refs.0 > 0 { with_feedback.0 as f64 / total_refs.0 as f64 } else { 0.0 };
|
||||
let injection_rate = if total_refs.0 > 0 { injected.0 as f64 / total_refs.0 as f64 } else { 0.0 };
|
||||
let positive_rate = if total_refs.0 > 0 { positive.0 as f64 / total_refs.0 as f64 } else { 0.0 };
|
||||
|
||||
Ok(AnalyticsOverview {
|
||||
total_items: total_items.0,
|
||||
active_items: active_items.0,
|
||||
total_categories: total_categories.0,
|
||||
weekly_new_items: weekly_new.0,
|
||||
total_references: total_refs.0,
|
||||
avg_reference_per_item: if total_items.0 > 0 { total_refs.0 as f64 / total_items.0 as f64 } else { 0.0 },
|
||||
hit_rate,
|
||||
injection_rate,
|
||||
positive_feedback_rate: positive_rate,
|
||||
stale_items_count: stale.0,
|
||||
})
|
||||
}
|
||||
|
||||
/// 回滚到指定版本(创建新版本快照)
|
||||
pub async fn rollback_version(
|
||||
pool: &PgPool,
|
||||
item_id: &str,
|
||||
target_version: i32,
|
||||
account_id: &str,
|
||||
) -> SaasResult<KnowledgeItem> {
|
||||
// 使用事务保证原子性,防止并发回滚冲突
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
// 获取目标版本
|
||||
let version: KnowledgeVersion = sqlx::query_as(
|
||||
"SELECT * FROM knowledge_versions WHERE item_id = $1 AND version = $2"
|
||||
)
|
||||
.bind(item_id)
|
||||
.bind(target_version)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?
|
||||
.ok_or_else(|| crate::error::SaasError::NotFound("版本不存在".into()))?;
|
||||
|
||||
// 锁定当前条目行防止并发修改(SELECT FOR UPDATE)
|
||||
let current: Option<(i32,)> = sqlx::query_as(
|
||||
"SELECT version FROM knowledge_items WHERE id = $1 FOR UPDATE"
|
||||
)
|
||||
.bind(item_id)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?;
|
||||
|
||||
let current_version = current
|
||||
.ok_or_else(|| crate::error::SaasError::NotFound("知识条目不存在".into()))?
|
||||
.0;
|
||||
|
||||
// 防止版本无限递增: 最多 100 个版本
|
||||
if current_version >= 100 {
|
||||
return Err(crate::error::SaasError::InvalidInput(
|
||||
"版本数已达上限(100),请考虑合并历史版本".into(),
|
||||
));
|
||||
}
|
||||
|
||||
let new_version = current_version + 1;
|
||||
|
||||
// 更新条目为该版本内容
|
||||
let updated = sqlx::query_as::<_, KnowledgeItem>(
|
||||
"UPDATE knowledge_items SET \
|
||||
title = $1, content = $2, keywords = $3, related_questions = $4, \
|
||||
version = $5, updated_at = NOW() \
|
||||
WHERE id = $6 RETURNING *"
|
||||
)
|
||||
.bind(&version.title)
|
||||
.bind(&version.content)
|
||||
.bind(&version.keywords)
|
||||
.bind(&version.related_questions)
|
||||
.bind(new_version)
|
||||
.bind(item_id)
|
||||
.fetch_one(&mut *tx)
|
||||
.await?;
|
||||
|
||||
// 创建新版本快照(记录回滚来源)
|
||||
let version_id = uuid::Uuid::new_v4().to_string();
|
||||
let summary = format!("回滚到版本 {}(当前版本 {} → 新版本 {})", target_version, current_version, new_version);
|
||||
sqlx::query(
|
||||
"INSERT INTO knowledge_versions \
|
||||
(id, item_id, version, title, content, keywords, related_questions, \
|
||||
change_summary, created_by) \
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
)
|
||||
.bind(&version_id)
|
||||
.bind(item_id)
|
||||
.bind(new_version)
|
||||
.bind(&updated.title)
|
||||
.bind(&updated.content)
|
||||
.bind(&updated.keywords)
|
||||
.bind(&updated.related_questions)
|
||||
.bind(&summary)
|
||||
.bind(account_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
Ok(updated)
|
||||
}
|
||||
|
||||
/// 质量指标(按分类分组)
|
||||
pub async fn analytics_quality(pool: &PgPool) -> SaasResult<serde_json::Value> {
|
||||
let quality: Vec<(serde_json::Value,)> = sqlx::query_as(
|
||||
"SELECT json_build_object(
|
||||
'category', kc.name,
|
||||
'total', COUNT(ki.id),
|
||||
'active', COUNT(CASE WHEN ki.status = 'active' THEN 1 END),
|
||||
'with_keywords', COUNT(CASE WHEN array_length(ki.keywords, 1) > 0 THEN 1 END),
|
||||
'avg_priority', COALESCE(AVG(ki.priority), 0)
|
||||
) as row \
|
||||
FROM knowledge_categories kc \
|
||||
LEFT JOIN knowledge_items ki ON ki.category_id = kc.id \
|
||||
GROUP BY kc.id, kc.name \
|
||||
ORDER BY COUNT(ki.id) DESC"
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!("analytics_quality query failed: {}", e);
|
||||
vec![]
|
||||
});
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"categories": quality.into_iter().map(|(v,)| v).collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
|
||||
/// 知识缺口检测(低分查询聚类)
|
||||
pub async fn analytics_gaps(pool: &PgPool) -> SaasResult<serde_json::Value> {
|
||||
let gaps: Vec<(serde_json::Value,)> = sqlx::query_as(
|
||||
"SELECT json_build_object(
|
||||
'query', ku.query_text,
|
||||
'count', COUNT(*),
|
||||
'avg_score', COALESCE(AVG(ku.relevance_score), 0)
|
||||
) as row \
|
||||
FROM knowledge_usage ku \
|
||||
WHERE ku.created_at >= NOW() - interval '30 days' \
|
||||
AND (ku.relevance_score IS NULL OR ku.relevance_score < 0.5) \
|
||||
AND ku.query_text IS NOT NULL \
|
||||
GROUP BY ku.query_text \
|
||||
ORDER BY COUNT(*) DESC \
|
||||
LIMIT 20"
|
||||
)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.unwrap_or_else(|e| {
|
||||
tracing::warn!("analytics_gaps query failed: {}", e);
|
||||
vec![]
|
||||
});
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"gaps": gaps.into_iter().map(|(v,)| v).collect::<Vec<_>>()
|
||||
}))
|
||||
}
|
||||
225
crates/zclaw-saas/src/knowledge/types.rs
Normal file
225
crates/zclaw-saas/src/knowledge/types.rs
Normal file
@@ -0,0 +1,225 @@
|
||||
//! 知识库类型定义
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// === 分类 ===
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct KnowledgeCategory {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub parent_id: Option<String>,
|
||||
pub icon: Option<String>,
|
||||
pub sort_order: i32,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateCategoryRequest {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub parent_id: Option<String>,
|
||||
pub icon: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateCategoryRequest {
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub parent_id: Option<String>,
|
||||
pub icon: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct CategoryResponse {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub parent_id: Option<String>,
|
||||
pub icon: Option<String>,
|
||||
pub sort_order: i32,
|
||||
pub item_count: i64,
|
||||
pub children: Vec<CategoryResponse>,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
// === 知识条目 ===
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct KnowledgeItem {
|
||||
pub id: String,
|
||||
pub category_id: String,
|
||||
pub title: String,
|
||||
pub content: String,
|
||||
pub keywords: Vec<String>,
|
||||
pub related_questions: Vec<String>,
|
||||
pub priority: i32,
|
||||
pub status: String,
|
||||
pub version: i32,
|
||||
pub source: String,
|
||||
pub tags: Vec<String>,
|
||||
pub created_by: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct CreateItemRequest {
|
||||
pub category_id: String,
|
||||
pub title: String,
|
||||
pub content: String,
|
||||
pub keywords: Option<Vec<String>>,
|
||||
pub related_questions: Option<Vec<String>>,
|
||||
pub priority: Option<i32>,
|
||||
pub tags: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct UpdateItemRequest {
|
||||
pub category_id: Option<String>,
|
||||
pub title: Option<String>,
|
||||
pub content: Option<String>,
|
||||
pub keywords: Option<Vec<String>>,
|
||||
pub related_questions: Option<Vec<String>>,
|
||||
pub priority: Option<i32>,
|
||||
pub status: Option<String>,
|
||||
pub tags: Option<Vec<String>>,
|
||||
pub change_summary: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListItemsQuery {
|
||||
pub page: Option<i64>,
|
||||
pub page_size: Option<i64>,
|
||||
pub category_id: Option<String>,
|
||||
pub status: Option<String>,
|
||||
pub keyword: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ItemResponse {
|
||||
pub id: String,
|
||||
pub category_id: String,
|
||||
pub category_name: String,
|
||||
pub title: String,
|
||||
pub content: String,
|
||||
pub keywords: Vec<String>,
|
||||
pub related_questions: Vec<String>,
|
||||
pub priority: i32,
|
||||
pub status: String,
|
||||
pub version: i32,
|
||||
pub source: String,
|
||||
pub tags: Vec<String>,
|
||||
pub created_by: String,
|
||||
pub reference_count: i64,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
// === 知识分块 ===
|
||||
// 注意:DB 表含 embedding vector(1536) 列,但当前所有查询均显式指定列,
|
||||
// 故 struct 暂不映射该字段。若未来使用 SELECT * 需添加 embedding: Option<pgvector::Vector>。
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct KnowledgeChunk {
|
||||
pub id: String,
|
||||
pub item_id: String,
|
||||
pub chunk_index: i32,
|
||||
pub content: String,
|
||||
pub keywords: Vec<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
// === 版本快照 ===
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct KnowledgeVersion {
|
||||
pub id: String,
|
||||
pub item_id: String,
|
||||
pub version: i32,
|
||||
pub title: String,
|
||||
pub content: String,
|
||||
pub keywords: Vec<String>,
|
||||
pub related_questions: Vec<String>,
|
||||
pub change_summary: Option<String>,
|
||||
pub created_by: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
// === 使用追踪 ===
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
|
||||
pub struct KnowledgeUsage {
|
||||
pub id: String,
|
||||
pub item_id: String,
|
||||
pub chunk_id: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
pub query_text: Option<String>,
|
||||
pub relevance_score: Option<f64>,
|
||||
pub was_injected: bool,
|
||||
pub agent_feedback: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
// === 搜索 ===
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SearchRequest {
|
||||
pub query: String,
|
||||
pub category_id: Option<String>,
|
||||
pub limit: Option<i64>,
|
||||
pub min_score: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SearchResult {
|
||||
pub chunk_id: String,
|
||||
pub item_id: String,
|
||||
pub item_title: String,
|
||||
pub category_name: String,
|
||||
pub content: String,
|
||||
pub score: f64,
|
||||
pub keywords: Vec<String>,
|
||||
}
|
||||
|
||||
// === 分析 ===
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct AnalyticsOverview {
|
||||
pub total_items: i64,
|
||||
pub active_items: i64,
|
||||
pub total_categories: i64,
|
||||
pub weekly_new_items: i64,
|
||||
pub total_references: i64,
|
||||
pub avg_reference_per_item: f64,
|
||||
pub hit_rate: f64,
|
||||
pub injection_rate: f64,
|
||||
pub positive_feedback_rate: f64,
|
||||
pub stale_items_count: i64,
|
||||
}
|
||||
|
||||
// === 批量操作 ===
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ReorderItem {
|
||||
pub id: String,
|
||||
pub sort_order: i32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ImportFile {
|
||||
pub content: String,
|
||||
pub title: Option<String>,
|
||||
pub keywords: Option<Vec<String>>,
|
||||
pub tags: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ImportRequest {
|
||||
pub category_id: String,
|
||||
pub files: Vec<ImportFile>,
|
||||
}
|
||||
@@ -25,3 +25,5 @@ pub mod prompt;
|
||||
pub mod agent_template;
|
||||
pub mod scheduled_task;
|
||||
pub mod telemetry;
|
||||
pub mod billing;
|
||||
pub mod knowledge;
|
||||
|
||||
@@ -11,9 +11,14 @@ use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
||||
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
|
||||
use zclaw_saas::workers::record_usage::RecordUsageWorker;
|
||||
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
|
||||
use zclaw_saas::workers::aggregate_usage::AggregateUsageWorker;
|
||||
use zclaw_saas::workers::generate_embedding::GenerateEmbeddingWorker;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// Load .env file from project root (walk up from current dir)
|
||||
load_dotenv();
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
@@ -24,26 +29,36 @@ async fn main() -> anyhow::Result<()> {
|
||||
let config = SaaSConfig::load()?;
|
||||
info!("SaaS config loaded: {}:{}", config.server.host, config.server.port);
|
||||
|
||||
let db = init_db(&config.database.url).await?;
|
||||
let db = init_db(&config.database).await?;
|
||||
info!("Database initialized");
|
||||
|
||||
// 创建 Worker spawn 限制器(门控并发 DB 操作数量)
|
||||
let worker_limiter = zclaw_saas::state::SpawnLimiter::new(
|
||||
"worker",
|
||||
config.database.worker_concurrency,
|
||||
);
|
||||
info!("Worker spawn limiter: {} permits", config.database.worker_concurrency);
|
||||
|
||||
// 初始化 Worker 调度器 + 注册所有 Worker
|
||||
let mut dispatcher = WorkerDispatcher::new(db.clone());
|
||||
let mut dispatcher = WorkerDispatcher::new(db.clone(), worker_limiter.clone());
|
||||
dispatcher.register(LogOperationWorker);
|
||||
dispatcher.register(CleanupRefreshTokensWorker);
|
||||
dispatcher.register(CleanupRateLimitWorker);
|
||||
dispatcher.register(RecordUsageWorker);
|
||||
dispatcher.register(UpdateLastUsedWorker);
|
||||
info!("Worker dispatcher initialized (5 workers registered)");
|
||||
dispatcher.register(AggregateUsageWorker);
|
||||
dispatcher.register(GenerateEmbeddingWorker);
|
||||
info!("Worker dispatcher initialized (7 workers registered)");
|
||||
|
||||
// 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止
|
||||
let shutdown_token = CancellationToken::new();
|
||||
let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone())?;
|
||||
let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone(), worker_limiter.clone())?;
|
||||
|
||||
// Restore rate limit counts from DB so limits survive server restarts
|
||||
// 仅恢复最近 60s 的计数(与 middleware 的 60s 滑动窗口一致),避免过于保守的限流
|
||||
{
|
||||
let rows: Vec<(String, i64)> = sqlx::query_as(
|
||||
"SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '1 hour' GROUP BY key"
|
||||
"SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '60 seconds' GROUP BY key"
|
||||
)
|
||||
.fetch_all(&db)
|
||||
.await
|
||||
@@ -51,18 +66,17 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
let mut restored_count = 0usize;
|
||||
for (key, count) in rows {
|
||||
let mut entries = Vec::new();
|
||||
// Approximate: insert count timestamps at "now" — the DashMap will
|
||||
// expire them naturally via the retain() call in the middleware.
|
||||
// This is intentionally approximate; exact window alignment is not
|
||||
// required for rate limiting correctness.
|
||||
for _ in 0..count as usize {
|
||||
// 限制恢复计数不超过 RPM 配额,避免重启后过于保守
|
||||
let rpm = state.rate_limit_rpm() as usize;
|
||||
let capped = (count as usize).min(rpm);
|
||||
let mut entries = Vec::with_capacity(capped);
|
||||
for _ in 0..capped {
|
||||
entries.push(std::time::Instant::now());
|
||||
}
|
||||
state.rate_limit_entries.insert(key, entries);
|
||||
restored_count += 1;
|
||||
}
|
||||
info!("Restored rate limit state from DB: {} keys", restored_count);
|
||||
info!("Restored rate limit state from DB: {} keys (60s window, capped at RPM)", restored_count);
|
||||
}
|
||||
|
||||
// 迁移旧格式 TOTP secret(明文 → 加密 enc: 格式)
|
||||
@@ -117,20 +131,64 @@ async fn main() -> anyhow::Result<()> {
|
||||
});
|
||||
}
|
||||
|
||||
let app = build_router(state).await;
|
||||
// 限流事件批量 flush (可配置间隔,默认 5s)
|
||||
{
|
||||
let flush_state = state.clone();
|
||||
let batch_interval = config.database.rate_limit_batch_interval_secs;
|
||||
let batch_max = config.database.rate_limit_batch_max_size;
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(batch_interval));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
flush_state.flush_rate_limit_batch(batch_max).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 连接池可观测性 (30s 指标日志)
|
||||
{
|
||||
let metrics_db = db.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let pool = &metrics_db;
|
||||
let total = pool.options().get_max_connections() as usize;
|
||||
let idle = pool.num_idle() as usize;
|
||||
let used = total.saturating_sub(idle);
|
||||
let usage_pct = if total > 0 { used * 100 / total } else { 0 };
|
||||
tracing::info!(
|
||||
"[PoolMetrics] total={} idle={} used={} usage_pct={}%",
|
||||
total, idle, used, usage_pct,
|
||||
);
|
||||
if usage_pct >= 80 {
|
||||
tracing::warn!(
|
||||
"[PoolMetrics] HIGH USAGE: {}% of connections in use!",
|
||||
usage_pct,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let app = build_router(state.clone()).await;
|
||||
|
||||
// 配置 TCP keepalive + 短 SO_LINGER,防止 CLOSE_WAIT 累积
|
||||
let listener = create_listener(&config.server.host, config.server.port)?;
|
||||
info!("SaaS server listening on {}:{}", config.server.host, config.server.port);
|
||||
|
||||
// 优雅停机: Ctrl+C → 取消 CancellationToken → SSE 流终止 → 连接排空
|
||||
// 优雅停机: Ctrl+C → 最终批量 flush → 取消 CancellationToken → SSE 流终止 → 连接排空
|
||||
let token = shutdown_token.clone();
|
||||
let flush_state = state;
|
||||
let batch_max = config.database.rate_limit_batch_max_size;
|
||||
axum::serve(listener, app.into_make_service_with_connect_info::<std::net::SocketAddr>())
|
||||
.with_graceful_shutdown(async move {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("Failed to install Ctrl+C handler");
|
||||
info!("Received shutdown signal, cancelling SSE streams and draining connections...");
|
||||
info!("Received shutdown signal, flushing pending rate limit batch...");
|
||||
flush_state.flush_rate_limit_batch(batch_max).await;
|
||||
info!("Cancelling SSE streams and draining connections...");
|
||||
token.cancel();
|
||||
})
|
||||
.await?;
|
||||
@@ -265,6 +323,7 @@ async fn build_router(state: AppState) -> axum::Router {
|
||||
|
||||
let public_routes = zclaw_saas::auth::routes()
|
||||
.route("/api/health", axum::routing::get(health_handler))
|
||||
.merge(zclaw_saas::billing::callback_routes())
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::public_rate_limit_middleware,
|
||||
@@ -280,6 +339,8 @@ async fn build_router(state: AppState) -> axum::Router {
|
||||
.merge(zclaw_saas::agent_template::routes())
|
||||
.merge(zclaw_saas::scheduled_task::routes())
|
||||
.merge(zclaw_saas::telemetry::routes())
|
||||
.merge(zclaw_saas::billing::routes())
|
||||
.merge(zclaw_saas::knowledge::routes())
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::api_version_middleware,
|
||||
@@ -313,6 +374,10 @@ async fn build_router(state: AppState) -> axum::Router {
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::request_id_middleware,
|
||||
))
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::quota_check_middleware,
|
||||
))
|
||||
.layer(middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
zclaw_saas::middleware::rate_limit_middleware,
|
||||
@@ -322,10 +387,55 @@ async fn build_router(state: AppState) -> axum::Router {
|
||||
zclaw_saas::auth::auth_middleware,
|
||||
));
|
||||
|
||||
axum::Router::new()
|
||||
let mut router = axum::Router::new()
|
||||
.merge(non_streaming_routes)
|
||||
.merge(relay_routes)
|
||||
.merge(relay_routes);
|
||||
|
||||
// 开发模式挂载 mock 支付页面
|
||||
{
|
||||
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
|
||||
.map(|v| v == "true" || v == "1")
|
||||
.unwrap_or(false);
|
||||
if is_dev {
|
||||
router = router.merge(zclaw_saas::billing::mock_routes());
|
||||
info!("Mock payment routes mounted (dev mode)");
|
||||
}
|
||||
}
|
||||
|
||||
router
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(cors)
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
/// Load `.env` file from project root by walking up from current directory.
|
||||
/// Sets environment variables that are not already set (does not override).
|
||||
fn load_dotenv() {
|
||||
let mut dir = std::env::current_dir().unwrap_or_default();
|
||||
loop {
|
||||
let env_path = dir.join(".env");
|
||||
if env_path.is_file() {
|
||||
if let Ok(content) = std::fs::read_to_string(&env_path) {
|
||||
for line in content.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
if let Some((key, value)) = line.split_once('=') {
|
||||
let key = key.trim();
|
||||
let value = value.trim();
|
||||
// Only set if not already defined in environment
|
||||
if std::env::var(key).is_err() {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
tracing::debug!("Loaded .env from {}", env_path.display());
|
||||
}
|
||||
return;
|
||||
}
|
||||
if !dir.pop() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,17 +93,56 @@ pub async fn rate_limit_middleware(
|
||||
)).into_response();
|
||||
}
|
||||
|
||||
// Write-through to DB for persistence across restarts (fire-and-forget)
|
||||
// Write-through to batch accumulator (memory-only, flushed periodically by background task)
|
||||
// 替换原来的 fire-and-forget tokio::spawn(DB INSERT),消除每请求 1 个 DB 连接消耗
|
||||
if should_persist {
|
||||
let db = state.db.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO rate_limit_events (key, window_start, count) VALUES ($1, NOW(), 1)"
|
||||
)
|
||||
.bind(&key)
|
||||
.execute(&db)
|
||||
.await;
|
||||
});
|
||||
let mut entry = state.rate_limit_batch.entry(key).or_insert(0);
|
||||
*entry += 1;
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
/// 配额检查中间件
|
||||
/// 在 Relay 请求前检查账户月度用量配额
|
||||
/// 仅对 /api/v1/relay/chat/completions 生效
|
||||
pub async fn quota_check_middleware(
|
||||
State(state): State<AppState>,
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response<Body> {
|
||||
let path = req.uri().path();
|
||||
|
||||
// 仅对 relay 请求检查配额
|
||||
if !path.starts_with("/api/v1/relay/") {
|
||||
return next.run(req).await;
|
||||
}
|
||||
|
||||
// 从扩展中获取认证上下文
|
||||
let account_id = match req.extensions().get::<AuthContext>() {
|
||||
Some(ctx) => ctx.account_id.clone(),
|
||||
None => return next.run(req).await,
|
||||
};
|
||||
|
||||
// 检查 relay_requests 配额
|
||||
match crate::billing::service::check_quota(&state.db, &account_id, "relay_requests").await {
|
||||
Ok(check) if !check.allowed => {
|
||||
tracing::warn!(
|
||||
"Quota exceeded for account {}: {} ({}/{})",
|
||||
account_id,
|
||||
check.reason.as_deref().unwrap_or("配额已用尽"),
|
||||
check.current,
|
||||
check.limit.map(|l| l.to_string()).unwrap_or_else(|| "∞".into()),
|
||||
);
|
||||
return SaasError::RateLimited(
|
||||
check.reason.unwrap_or_else(|| "月度配额已用尽".into()),
|
||||
).into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
// 配额检查失败不阻断请求(降级策略)
|
||||
tracing::warn!("Quota check failed for account {}: {}", account_id, e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
@@ -192,17 +231,10 @@ pub async fn public_rate_limit_middleware(
|
||||
return SaasError::RateLimited(error_msg.into()).into_response();
|
||||
}
|
||||
|
||||
// Write-through to DB for persistence across restarts (fire-and-forget)
|
||||
// Write-through to batch accumulator (memory-only, flushed periodically)
|
||||
if should_persist {
|
||||
let db = state.db.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = sqlx::query(
|
||||
"INSERT INTO rate_limit_events (key, window_start, count) VALUES ($1, NOW(), 1)"
|
||||
)
|
||||
.bind(&key)
|
||||
.execute(&db)
|
||||
.await;
|
||||
});
|
||||
let mut entry = state.rate_limit_batch.entry(key).or_insert(0);
|
||||
*entry += 1;
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
|
||||
@@ -82,6 +82,10 @@ pub async fn create_provider(
|
||||
let provider = service::create_provider(&state.db, &req, &enc_key).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
|
||||
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
|
||||
// Admin mutation 后立即刷新缓存,消除 60s 陈旧窗口
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await {
|
||||
tracing::warn!("Cache reload failed after provider.create: {}", e);
|
||||
}
|
||||
Ok((StatusCode::CREATED, Json(provider)))
|
||||
}
|
||||
|
||||
@@ -102,6 +106,9 @@ pub async fn update_provider(
|
||||
drop(config);
|
||||
let provider = service::update_provider(&state.db, &id, &req, &enc_key).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await{
|
||||
tracing::warn!("Cache reload failed after provider.update: {}", e);
|
||||
}
|
||||
Ok(Json(provider))
|
||||
}
|
||||
|
||||
@@ -114,6 +121,9 @@ pub async fn delete_provider(
|
||||
check_permission(&ctx, "provider:manage")?;
|
||||
service::delete_provider(&state.db, &id).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "provider.delete", "provider", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await{
|
||||
tracing::warn!("Cache reload failed after provider.delete: {}", e);
|
||||
}
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
@@ -150,6 +160,9 @@ pub async fn create_model(
|
||||
let model = service::create_model(&state.db, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model.create", "model", &model.id,
|
||||
Some(serde_json::json!({"model_id": &req.model_id, "provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await{
|
||||
tracing::warn!("Cache reload failed after model.create: {}", e);
|
||||
}
|
||||
Ok((StatusCode::CREATED, Json(model)))
|
||||
}
|
||||
|
||||
@@ -163,6 +176,9 @@ pub async fn update_model(
|
||||
check_permission(&ctx, "model:manage")?;
|
||||
let model = service::update_model(&state.db, &id, &req).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model.update", "model", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await{
|
||||
tracing::warn!("Cache reload failed after model.update: {}", e);
|
||||
}
|
||||
Ok(Json(model))
|
||||
}
|
||||
|
||||
@@ -175,6 +191,9 @@ pub async fn delete_model(
|
||||
check_permission(&ctx, "model:manage")?;
|
||||
service::delete_model(&state.db, &id).await?;
|
||||
log_operation(&state.db, &ctx.account_id, "model.delete", "model", &id, None, ctx.client_ip.as_deref()).await?;
|
||||
if let Err(e) = state.cache.load_from_db(&state.db).await{
|
||||
tracing::warn!("Cache reload failed after model.delete: {}", e);
|
||||
}
|
||||
Ok(Json(serde_json::json!({"ok": true})))
|
||||
}
|
||||
|
||||
|
||||
@@ -29,3 +29,12 @@ pub struct PromptVersionRow {
|
||||
pub min_app_version: Option<String>,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// prompt_sync_status 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct PromptSyncStatusRow {
|
||||
pub device_id: String,
|
||||
pub template_id: String,
|
||||
pub synced_version: i32,
|
||||
pub synced_at: String,
|
||||
}
|
||||
|
||||
@@ -2,6 +2,24 @@
|
||||
|
||||
use sqlx::FromRow;
|
||||
|
||||
/// telemetry_reports 表行
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct TelemetryReportRow {
|
||||
pub id: String,
|
||||
pub account_id: String,
|
||||
pub device_id: String,
|
||||
pub app_version: Option<String>,
|
||||
pub model_id: String,
|
||||
pub input_tokens: i64,
|
||||
pub output_tokens: i64,
|
||||
pub latency_ms: Option<i32>,
|
||||
pub success: bool,
|
||||
pub error_type: Option<String>,
|
||||
pub connection_mode: Option<String>,
|
||||
pub reported_at: String,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
/// telemetry 按 model 分组统计
|
||||
#[derive(Debug, FromRow)]
|
||||
pub struct TelemetryModelStatsRow {
|
||||
|
||||
@@ -4,7 +4,7 @@ use sqlx::PgPool;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::common::PaginatedResponse;
|
||||
use crate::common::normalize_pagination;
|
||||
use crate::models::{PromptTemplateRow, PromptVersionRow};
|
||||
use crate::models::{PromptTemplateRow, PromptVersionRow, PromptSyncStatusRow};
|
||||
use super::types::*;
|
||||
|
||||
/// 创建提示词模板 + 初始版本
|
||||
@@ -310,3 +310,21 @@ pub async fn check_updates(
|
||||
server_time: chrono::Utc::now().to_rfc3339(),
|
||||
})
|
||||
}
|
||||
|
||||
/// 查询设备的提示词同步状态
|
||||
pub async fn get_sync_status(
|
||||
db: &PgPool,
|
||||
device_id: &str,
|
||||
) -> SaasResult<Vec<PromptSyncStatusRow>> {
|
||||
let rows = sqlx::query_as::<_, PromptSyncStatusRow>(
|
||||
"SELECT device_id, template_id, synced_version, synced_at \
|
||||
FROM prompt_sync_status \
|
||||
WHERE device_id = $1 \
|
||||
ORDER BY synced_at DESC \
|
||||
LIMIT 50"
|
||||
)
|
||||
.bind(device_id)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
@@ -23,18 +23,12 @@ pub async fn chat_completions(
|
||||
) -> SaasResult<Response> {
|
||||
check_permission(&ctx, "relay:use")?;
|
||||
|
||||
// 队列容量检查:防止过载(立即释放读锁)
|
||||
// 队列容量检查:使用内存 AtomicI64 计数器,消除 DB COUNT 查询
|
||||
let max_queue_size = {
|
||||
let config = state.config.read().await;
|
||||
config.relay.max_queue_size
|
||||
};
|
||||
let queued_count: i64 = sqlx::query_scalar(
|
||||
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
|
||||
)
|
||||
.bind(&ctx.account_id)
|
||||
.fetch_one(&state.db)
|
||||
.await
|
||||
.unwrap_or(0);
|
||||
let queued_count = state.cache.relay_queue_count(&ctx.account_id);
|
||||
|
||||
if queued_count >= max_queue_size as i64 {
|
||||
return Err(SaasError::RateLimited(
|
||||
@@ -128,18 +122,8 @@ pub async fn chat_completions(
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// 查找 model 对应的 provider — 使用精准查询避免全量加载
|
||||
let target_model: Option<crate::models::ModelRow> = sqlx::query_as(
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
|
||||
supports_streaming, supports_vision, enabled, pricing_input, pricing_output,
|
||||
created_at, updated_at
|
||||
FROM models WHERE model_id = $1 AND enabled = true LIMIT 1"
|
||||
)
|
||||
.bind(&model_name)
|
||||
.fetch_optional(&state.db)
|
||||
.await?;
|
||||
|
||||
let target_model = target_model
|
||||
// 查找 model — 使用内存缓存(O(1) DashMap),消除关键路径 DB 查询
|
||||
let target_model = state.cache.get_model(model_name)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
// Stream compatibility check: reject stream requests for non-streaming models
|
||||
@@ -149,8 +133,9 @@ pub async fn chat_completions(
|
||||
));
|
||||
}
|
||||
|
||||
// 获取 provider 信息
|
||||
let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?;
|
||||
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
|
||||
let provider = state.cache.get_provider(&target_model.provider_id)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
|
||||
if !provider.enabled {
|
||||
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
|
||||
}
|
||||
@@ -171,6 +156,9 @@ pub async fn chat_completions(
|
||||
max_attempts,
|
||||
).await?;
|
||||
|
||||
// 递增内存队列计数器(替代 DB COUNT 查询)
|
||||
state.cache.relay_enqueue(&ctx.account_id);
|
||||
|
||||
// 异步派发操作日志(非阻塞,不占用关键路径 DB 连接)
|
||||
state.dispatch_log_operation(
|
||||
&ctx.account_id, "relay.request", "relay_task", &task.id,
|
||||
@@ -186,8 +174,7 @@ pub async fn chat_completions(
|
||||
&enc_key,
|
||||
).await;
|
||||
|
||||
// 克隆用于异步 usage 记录
|
||||
let db_usage = state.db.clone();
|
||||
// 克隆用于 Worker dispatch usage 记录(受 SpawnLimiter 门控,不再直接 spawn)
|
||||
let account_id_usage = ctx.account_id.clone();
|
||||
let provider_id_usage = target_model.provider_id.clone();
|
||||
let model_id_usage = target_model.model_id.clone();
|
||||
@@ -195,30 +182,62 @@ pub async fn chat_completions(
|
||||
match response {
|
||||
Ok(service::RelayResponse::Json(body)) => {
|
||||
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
|
||||
// 异步记录 usage(不阻塞响应)
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = model_service::record_usage(
|
||||
&db_usage, &account_id_usage, &provider_id_usage,
|
||||
&model_id_usage, input_tokens, output_tokens,
|
||||
None, "success", None,
|
||||
).await {
|
||||
tracing::warn!("Failed to record relay usage: {}", e);
|
||||
// 通过 Worker dispatch 记录 usage(受 SpawnLimiter 门控,不阻塞响应)
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
account_id: account_id_usage.clone(),
|
||||
provider_id: provider_id_usage.clone(),
|
||||
model_id: model_id_usage.clone(),
|
||||
input_tokens: input_tokens as i32,
|
||||
output_tokens: output_tokens as i32,
|
||||
latency_ms: None,
|
||||
status: "success".to_string(),
|
||||
error_message: None,
|
||||
};
|
||||
if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||||
tracing::warn!("Failed to dispatch record_usage: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 实时更新计费配额(relay_requests + tokens 同步递增)
|
||||
if let Err(e) = crate::billing::service::increment_usage(
|
||||
&state.db, &account_id_usage, input_tokens as i64, output_tokens as i64,
|
||||
).await {
|
||||
tracing::warn!("Failed to increment billing usage for {}: {}", account_id_usage, e);
|
||||
}
|
||||
|
||||
// 任务完成,递减队列计数器
|
||||
state.cache.relay_dequeue(&account_id_usage);
|
||||
|
||||
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
|
||||
}
|
||||
Ok(service::RelayResponse::Sse(body)) => {
|
||||
// 异步记录 SSE 占位 usage
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = model_service::record_usage(
|
||||
&db_usage, &account_id_usage, &provider_id_usage,
|
||||
&model_id_usage, 0, 0,
|
||||
None, "streaming", None,
|
||||
).await {
|
||||
tracing::warn!("Failed to record SSE usage placeholder: {}", e);
|
||||
// 通过 Worker dispatch 记录 SSE 占位 usage
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
account_id: account_id_usage.clone(),
|
||||
provider_id: provider_id_usage.clone(),
|
||||
model_id: model_id_usage.clone(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
latency_ms: None,
|
||||
status: "streaming".to_string(),
|
||||
error_message: None,
|
||||
};
|
||||
if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||||
tracing::warn!("Failed to dispatch SSE usage: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// SSE: relay_requests 实时递增(tokens 由 AggregateUsageWorker 对账修正)
|
||||
if let Err(e) = crate::billing::service::increment_dimension(
|
||||
&state.db, &account_id_usage, "relay_requests",
|
||||
).await {
|
||||
tracing::warn!("Failed to increment billing relay_requests for {}: {}", account_id_usage, e);
|
||||
}
|
||||
|
||||
// SSE 流已返回,递减队列计数器(流式任务开始处理)
|
||||
state.cache.relay_dequeue(&account_id_usage);
|
||||
|
||||
let response = axum::response::Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
@@ -230,17 +249,25 @@ pub async fn chat_completions(
|
||||
Ok(response)
|
||||
}
|
||||
Err(e) => {
|
||||
// 异步记录失败 usage(不阻塞错误响应)
|
||||
// 通过 Worker dispatch 记录失败 usage
|
||||
let error_msg = e.to_string();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e2) = model_service::record_usage(
|
||||
&db_usage, &account_id_usage, &provider_id_usage,
|
||||
&model_id_usage, 0, 0,
|
||||
None, "failed", Some(&error_msg),
|
||||
).await {
|
||||
tracing::warn!("Failed to record relay failure usage: {}", e2);
|
||||
{
|
||||
let args = crate::workers::record_usage::RecordUsageArgs {
|
||||
account_id: account_id_usage.clone(),
|
||||
provider_id: provider_id_usage.clone(),
|
||||
model_id: model_id_usage.clone(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
latency_ms: None,
|
||||
status: "failed".to_string(),
|
||||
error_message: Some(error_msg),
|
||||
};
|
||||
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
|
||||
tracing::warn!("Failed to dispatch failure usage: {}", e2);
|
||||
}
|
||||
});
|
||||
}
|
||||
// 任务失败,递减队列计数器(失败请求不计费)
|
||||
state.cache.relay_dequeue(&account_id_usage);
|
||||
Err(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,6 +281,39 @@ pub async fn delete_provider_key(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Key 使用窗口统计
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KeyUsageStats {
|
||||
pub key_id: String,
|
||||
pub window_minute: String,
|
||||
pub request_count: i32,
|
||||
pub token_count: i64,
|
||||
}
|
||||
|
||||
/// 查询指定 Key 的最近使用窗口统计
|
||||
pub async fn get_key_usage_stats(
|
||||
db: &PgPool,
|
||||
key_id: &str,
|
||||
limit: i64,
|
||||
) -> SaasResult<Vec<KeyUsageStats>> {
|
||||
let limit = limit.min(60).max(1);
|
||||
let rows: Vec<(String, String, i32, i64)> = sqlx::query_as(
|
||||
"SELECT key_id, window_minute, request_count, token_count \
|
||||
FROM key_usage_window \
|
||||
WHERE key_id = $1 \
|
||||
ORDER BY window_minute DESC \
|
||||
LIMIT $2"
|
||||
)
|
||||
.bind(key_id)
|
||||
.bind(limit)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
Ok(rows.into_iter().map(|(key_id, window_minute, request_count, token_count)| {
|
||||
KeyUsageStats { key_id, window_minute, request_count, token_count }
|
||||
}).collect())
|
||||
}
|
||||
|
||||
/// 解析冷却剩余时间(秒)
|
||||
fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 {
|
||||
let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until);
|
||||
|
||||
@@ -2,11 +2,23 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use crate::error::{SaasError, SaasResult};
|
||||
use crate::models::RelayTaskRow;
|
||||
use super::types::*;
|
||||
|
||||
// ============ StreamBridge 背压常量 ============
|
||||
|
||||
/// 上游无数据时,发送 SSE 心跳注释行的间隔
|
||||
const STREAMBRIDGE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
|
||||
|
||||
/// 上游无数据时,丢弃连接的超时阈值
|
||||
const STREAMBRIDGE_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
|
||||
/// 流结束后延迟清理的时间窗口
|
||||
const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(60);
|
||||
|
||||
/// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429)
|
||||
fn is_retryable_status(status: u16) -> bool {
|
||||
status == 429 || (500..600).contains(&status)
|
||||
@@ -33,15 +45,24 @@ pub async fn create_relay_task(
|
||||
let request_hash = hash_request(request_body);
|
||||
let max_attempts = max_attempts.max(1).min(5);
|
||||
|
||||
sqlx::query(
|
||||
// INSERT ... RETURNING 合并两次 DB 往返为一次
|
||||
let row: RelayTaskRow = sqlx::query_as(
|
||||
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)"
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)
|
||||
RETURNING id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at"
|
||||
)
|
||||
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
|
||||
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
|
||||
.execute(db).await?;
|
||||
.fetch_one(db)
|
||||
.await?;
|
||||
|
||||
get_relay_task(db, &id).await
|
||||
Ok(RelayTaskInfo {
|
||||
id: row.id, account_id: row.account_id, provider_id: row.provider_id, model_id: row.model_id,
|
||||
status: row.status, priority: row.priority, attempt_count: row.attempt_count,
|
||||
max_attempts: row.max_attempts, input_tokens: row.input_tokens, output_tokens: row.output_tokens,
|
||||
error_message: row.error_message, queued_at: row.queued_at, started_at: row.started_at,
|
||||
completed_at: row.completed_at, created_at: row.created_at,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
|
||||
@@ -295,9 +316,9 @@ pub async fn execute_relay(
|
||||
}
|
||||
});
|
||||
|
||||
// Convert mpsc::Receiver into a Body stream
|
||||
let body_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
let body = axum::body::Body::from_stream(body_stream);
|
||||
// Build StreamBridge: wraps the bounded receiver with heartbeat,
|
||||
// timeout, and delayed cleanup (DeerFlow-inspired backpressure).
|
||||
let body = build_stream_bridge(rx, task_id.to_string());
|
||||
|
||||
// SSE 流结束后异步记录 usage + Key 使用量
|
||||
// 使用全局 Arc<Semaphore> 限制并发 spawned tasks,防止高并发时耗尽连接池
|
||||
@@ -335,6 +356,14 @@ pub async fn execute_relay(
|
||||
if tokio::time::timeout(std::time::Duration::from_secs(5), db_op).await.is_err() {
|
||||
tracing::warn!("SSE usage recording timed out for task {}", task_id_clone);
|
||||
}
|
||||
|
||||
// StreamBridge 延迟清理:流结束 60s 后释放残留资源
|
||||
// (主要是 Arc<SseUsageCapture> 等,通过 drop(_permit) 归还信号量)
|
||||
tokio::time::sleep(STREAMBRIDGE_CLEANUP_DELAY).await;
|
||||
tracing::debug!(
|
||||
"[StreamBridge] Cleanup delay elapsed for task {}",
|
||||
task_id_clone
|
||||
);
|
||||
});
|
||||
|
||||
return Ok(RelayResponse::Sse(body));
|
||||
@@ -346,7 +375,9 @@ pub async fn execute_relay(
|
||||
// 记录 Key 使用量
|
||||
let _ = super::key_pool::record_key_usage(
|
||||
db, &key_id, Some(input_tokens + output_tokens),
|
||||
).await;
|
||||
).await.map_err(|e| {
|
||||
tracing::warn!("[Relay] Failed to record key usage for billing: {}", e);
|
||||
});
|
||||
return Ok(RelayResponse::Json(body));
|
||||
}
|
||||
}
|
||||
@@ -423,6 +454,98 @@ pub enum RelayResponse {
|
||||
Sse(axum::body::Body),
|
||||
}
|
||||
|
||||
// ============ StreamBridge ============
|
||||
|
||||
/// 构建 StreamBridge:将 mpsc::Receiver 包装为带心跳、超时的 axum Body。
|
||||
///
|
||||
/// 借鉴 DeerFlow StreamBridge 背压机制:
|
||||
/// - 15s 心跳:上游长时间无输出时,发送 SSE 注释行 `: heartbeat\n\n` 保持连接活跃
|
||||
/// - 30s 超时:上游连续 30s 无真实数据时,发送超时事件并关闭流
|
||||
/// - 60s 延迟清理:由调用方的 spawned task 在流结束后延迟释放资源
|
||||
fn build_stream_bridge(
|
||||
mut rx: tokio::sync::mpsc::Receiver<Result<bytes::Bytes, std::io::Error>>,
|
||||
task_id: String,
|
||||
) -> axum::body::Body {
|
||||
// SSE heartbeat comment bytes: `: heartbeat\n\n`
|
||||
// SSE spec: lines starting with `:` are comments and ignored by clients
|
||||
const HEARTBEAT_BYTES: &[u8] = b": heartbeat\n\n";
|
||||
// SSE timeout error event
|
||||
const TIMEOUT_EVENT: &[u8] = b"data: {\"error\":\"stream_timeout\",\"message\":\"upstream timed out\"}\n\n";
|
||||
|
||||
let stream = async_stream::stream! {
|
||||
// Track how many consecutive heartbeat-only cycles have elapsed.
|
||||
// Real data resets this counter; after 2 heartbeats (30s) without
|
||||
// real data, we terminate the stream.
|
||||
let mut idle_heartbeats: u32 = 0;
|
||||
|
||||
loop {
|
||||
// tokio::select! races the next data chunk against a heartbeat timer.
|
||||
// The timer resets on every iteration, ensuring heartbeats only fire
|
||||
// during genuine idle periods.
|
||||
tokio::select! {
|
||||
biased; // prioritize data over heartbeat
|
||||
|
||||
chunk = rx.recv() => {
|
||||
match chunk {
|
||||
Some(Ok(data)) => {
|
||||
// Real data received — reset idle counter
|
||||
idle_heartbeats = 0;
|
||||
yield Ok::<bytes::Bytes, std::io::Error>(data);
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
tracing::warn!(
|
||||
"[StreamBridge] Upstream error for task {}: {}",
|
||||
task_id, e
|
||||
);
|
||||
yield Err(e);
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
// Channel closed = upstream finished normally
|
||||
tracing::debug!(
|
||||
"[StreamBridge] Upstream completed for task {}",
|
||||
task_id
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Heartbeat: send SSE comment if no data for 15s
|
||||
_ = tokio::time::sleep(STREAMBRIDGE_HEARTBEAT_INTERVAL) => {
|
||||
idle_heartbeats += 1;
|
||||
tracing::trace!(
|
||||
"[StreamBridge] Heartbeat #{} for task {} (idle {}s)",
|
||||
idle_heartbeats,
|
||||
task_id,
|
||||
idle_heartbeats as u64 * STREAMBRIDGE_HEARTBEAT_INTERVAL.as_secs(),
|
||||
);
|
||||
|
||||
// After 2 consecutive heartbeats without real data (30s),
|
||||
// terminate the stream to prevent connection leaks.
|
||||
if idle_heartbeats >= 2 {
|
||||
tracing::warn!(
|
||||
"[StreamBridge] Timeout ({:?}) no real data, closing stream for task {}",
|
||||
STREAMBRIDGE_TIMEOUT,
|
||||
task_id,
|
||||
);
|
||||
yield Ok(bytes::Bytes::from_static(TIMEOUT_EVENT));
|
||||
break;
|
||||
}
|
||||
|
||||
yield Ok(bytes::Bytes::from_static(HEARTBEAT_BYTES));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Pin the stream to a Box<dyn Stream + Send> to satisfy Body::from_stream
|
||||
let boxed: std::pin::Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes, std::io::Error>> + Send>> =
|
||||
Box::pin(stream);
|
||||
|
||||
axum::body::Body::from_stream(boxed)
|
||||
}
|
||||
|
||||
// ============ Helpers ============
|
||||
|
||||
fn hash_request(body: &str) -> String {
|
||||
|
||||
@@ -20,7 +20,9 @@ struct ScheduledTaskRow {
|
||||
last_run_at: Option<String>,
|
||||
next_run_at: Option<String>,
|
||||
run_count: i32,
|
||||
last_result: Option<String>,
|
||||
last_error: Option<String>,
|
||||
last_duration_ms: Option<i64>,
|
||||
input_payload: Option<serde_json::Value>,
|
||||
created_at: String,
|
||||
}
|
||||
@@ -41,7 +43,9 @@ impl ScheduledTaskRow {
|
||||
last_run: self.last_run_at.clone(),
|
||||
next_run: self.next_run_at.clone(),
|
||||
run_count: self.run_count,
|
||||
last_result: self.last_result.clone(),
|
||||
last_error: self.last_error.clone(),
|
||||
last_duration_ms: self.last_duration_ms,
|
||||
created_at: self.created_at.clone(),
|
||||
}
|
||||
}
|
||||
@@ -86,7 +90,9 @@ pub async fn create_task(
|
||||
last_run: None,
|
||||
next_run: None,
|
||||
run_count: 0,
|
||||
last_result: None,
|
||||
last_error: None,
|
||||
last_duration_ms: None,
|
||||
created_at: now,
|
||||
})
|
||||
}
|
||||
@@ -99,7 +105,7 @@ pub async fn list_tasks(
|
||||
let rows: Vec<ScheduledTaskRow> = sqlx::query_as(
|
||||
"SELECT id, account_id, name, description, schedule, schedule_type,
|
||||
target_type, target_id, enabled, last_run_at, next_run_at,
|
||||
run_count, last_error, input_payload, created_at
|
||||
run_count, last_result, last_error, last_duration_ms, input_payload, created_at
|
||||
FROM scheduled_tasks WHERE account_id = $1 ORDER BY created_at DESC"
|
||||
)
|
||||
.bind(account_id)
|
||||
@@ -118,7 +124,7 @@ pub async fn get_task(
|
||||
let row: Option<ScheduledTaskRow> = sqlx::query_as(
|
||||
"SELECT id, account_id, name, description, schedule, schedule_type,
|
||||
target_type, target_id, enabled, last_run_at, next_run_at,
|
||||
run_count, last_error, input_payload, created_at
|
||||
run_count, last_result, last_error, last_duration_ms, input_payload, created_at
|
||||
FROM scheduled_tasks WHERE id = $1 AND account_id = $2"
|
||||
)
|
||||
.bind(task_id)
|
||||
|
||||
@@ -58,6 +58,8 @@ pub struct ScheduledTaskResponse {
|
||||
pub last_run: Option<String>,
|
||||
pub next_run: Option<String>,
|
||||
pub run_count: i32,
|
||||
pub last_result: Option<String>,
|
||||
pub last_error: Option<String>,
|
||||
pub last_duration_ms: Option<i64>,
|
||||
pub created_at: String,
|
||||
}
|
||||
|
||||
@@ -3,11 +3,18 @@
|
||||
//! 通过 TOML 配置定时任务,无需改代码调整调度时间。
|
||||
//! 配置格式在 config.rs 的 SchedulerConfig / JobConfig 中定义。
|
||||
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
use sqlx::PgPool;
|
||||
use crate::config::SchedulerConfig;
|
||||
use crate::workers::WorkerDispatcher;
|
||||
|
||||
/// 单次任务执行的产出
|
||||
struct TaskExecution {
|
||||
result: Option<String>,
|
||||
error: Option<String>,
|
||||
duration_ms: i64,
|
||||
}
|
||||
|
||||
/// 解析时间间隔字符串为 Duration
|
||||
pub fn parse_duration(s: &str) -> Result<Duration, String> {
|
||||
let s = s.trim().to_lowercase();
|
||||
@@ -143,23 +150,42 @@ pub fn start_user_task_scheduler(db: PgPool) {
|
||||
});
|
||||
}
|
||||
|
||||
/// 执行单个调度任务
|
||||
/// 执行单个调度任务,返回执行产出(结果/错误/耗时)
|
||||
async fn execute_scheduled_task(
|
||||
db: &PgPool,
|
||||
task_id: &str,
|
||||
target_type: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let task_info: Option<(String, Option<String>)> = sqlx::query_as(
|
||||
) -> TaskExecution {
|
||||
let start = Instant::now();
|
||||
|
||||
let task_info: Option<(String, Option<String>)> = match sqlx::query_as(
|
||||
"SELECT name, config_json FROM scheduled_tasks WHERE id = $1"
|
||||
)
|
||||
.bind(task_id)
|
||||
.fetch_optional(db)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to fetch task {}: {}", task_id, e))?;
|
||||
{
|
||||
Ok(info) => info,
|
||||
Err(e) => {
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
return TaskExecution {
|
||||
result: None,
|
||||
error: Some(format!("Failed to fetch task {}: {}", task_id, e)),
|
||||
duration_ms: elapsed,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let (task_name, _config_json) = match task_info {
|
||||
Some(info) => info,
|
||||
None => return Err(format!("Task {} not found", task_id).into()),
|
||||
None => {
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
return TaskExecution {
|
||||
result: None,
|
||||
error: Some(format!("Task {} not found", task_id)),
|
||||
duration_ms: elapsed,
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
@@ -167,22 +193,39 @@ async fn execute_scheduled_task(
|
||||
task_name, target_type
|
||||
);
|
||||
|
||||
match target_type {
|
||||
let exec_result = match target_type {
|
||||
t if t == "agent" => {
|
||||
tracing::info!("[UserScheduler] Agent task '{}' queued for execution", task_name);
|
||||
Ok("agent_dispatched".to_string())
|
||||
}
|
||||
t if t == "hand" => {
|
||||
tracing::info!("[UserScheduler] Hand task '{}' queued for execution", task_name);
|
||||
Ok("hand_dispatched".to_string())
|
||||
}
|
||||
t if t == "workflow" => {
|
||||
tracing::info!("[UserScheduler] Workflow task '{}' queued for execution", task_name);
|
||||
Ok("workflow_dispatched".to_string())
|
||||
}
|
||||
other => {
|
||||
tracing::warn!("[UserScheduler] Unknown target_type '{}' for task '{}'", other, task_name);
|
||||
Err(format!("Unknown target_type: {}", other))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Ok(())
|
||||
let elapsed = start.elapsed().as_millis() as i64;
|
||||
|
||||
match exec_result {
|
||||
Ok(msg) => TaskExecution {
|
||||
result: Some(msg),
|
||||
error: None,
|
||||
duration_ms: elapsed,
|
||||
},
|
||||
Err(err) => TaskExecution {
|
||||
result: None,
|
||||
error: Some(err),
|
||||
duration_ms: elapsed,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
||||
@@ -206,17 +249,19 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
||||
task_id, target_type, schedule_type
|
||||
);
|
||||
|
||||
// 执行任务
|
||||
match execute_scheduled_task(db, &task_id, &target_type).await {
|
||||
Ok(()) => {
|
||||
tracing::info!("[UserScheduler] task {} executed successfully", task_id);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, e);
|
||||
}
|
||||
// 执行任务并收集产出
|
||||
let exec = execute_scheduled_task(db, &task_id, &target_type).await;
|
||||
|
||||
if let Some(ref err) = exec.error {
|
||||
tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, err);
|
||||
} else {
|
||||
tracing::info!(
|
||||
"[UserScheduler] task {} executed successfully ({}ms)",
|
||||
task_id, exec.duration_ms
|
||||
);
|
||||
}
|
||||
|
||||
// 更新任务状态
|
||||
// 更新任务状态(含执行产出)
|
||||
let result = sqlx::query(
|
||||
"UPDATE scheduled_tasks
|
||||
SET last_run_at = NOW(),
|
||||
@@ -228,10 +273,16 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
|
||||
WHEN schedule_type = 'interval' AND interval_seconds IS NOT NULL
|
||||
THEN NOW() + (interval_seconds || ' seconds')::INTERVAL
|
||||
ELSE NULL
|
||||
END
|
||||
END,
|
||||
last_result = $2,
|
||||
last_error = $3,
|
||||
last_duration_ms = $4
|
||||
WHERE id = $1"
|
||||
)
|
||||
.bind(&task_id)
|
||||
.bind(&exec.result)
|
||||
.bind(&exec.error)
|
||||
.bind(exec.duration_ms)
|
||||
.execute(db)
|
||||
.await;
|
||||
|
||||
|
||||
@@ -10,6 +10,44 @@ use crate::config::SaaSConfig;
|
||||
use crate::workers::WorkerDispatcher;
|
||||
use crate::cache::AppCache;
|
||||
|
||||
// ============ SpawnLimiter ============
|
||||
|
||||
/// 可复用的并发限制器,基于 Arc<Semaphore>。
|
||||
/// 复用 SSE_SPAWN_SEMAPHORE 模式,为 Worker、中间件等场景提供统一门控。
|
||||
#[derive(Clone)]
|
||||
pub struct SpawnLimiter {
|
||||
semaphore: Arc<tokio::sync::Semaphore>,
|
||||
name: &'static str,
|
||||
}
|
||||
|
||||
impl SpawnLimiter {
|
||||
pub fn new(name: &'static str, max_permits: usize) -> Self {
|
||||
Self {
|
||||
semaphore: Arc::new(tokio::sync::Semaphore::new(max_permits)),
|
||||
name,
|
||||
}
|
||||
}
|
||||
|
||||
/// 尝试获取 permit,满时返回 None(适用于可丢弃的操作如 usage 记录)
|
||||
pub fn try_acquire(&self) -> Option<tokio::sync::OwnedSemaphorePermit> {
|
||||
self.semaphore.clone().try_acquire_owned().ok()
|
||||
}
|
||||
|
||||
/// 异步等待 permit(适用于不可丢弃的操作如 Worker 任务)
|
||||
pub async fn acquire(&self) -> tokio::sync::OwnedSemaphorePermit {
|
||||
self.semaphore
|
||||
.clone()
|
||||
.acquire_owned()
|
||||
.await
|
||||
.expect("SpawnLimiter semaphore closed unexpectedly")
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &'static str { self.name }
|
||||
pub fn available(&self) -> usize { self.semaphore.available_permits() }
|
||||
}
|
||||
|
||||
// ============ AppState ============
|
||||
|
||||
/// 全局应用状态,通过 Axum State 共享
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
@@ -33,10 +71,20 @@ pub struct AppState {
|
||||
pub shutdown_token: CancellationToken,
|
||||
/// 应用缓存: Model/Provider/队列计数器
|
||||
pub cache: AppCache,
|
||||
/// Worker spawn 并发限制器
|
||||
pub worker_limiter: SpawnLimiter,
|
||||
/// 限流事件批量累加器: key → 待写入计数
|
||||
pub rate_limit_batch: Arc<dashmap::DashMap<String, i64>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher, shutdown_token: CancellationToken) -> anyhow::Result<Self> {
|
||||
pub fn new(
|
||||
db: PgPool,
|
||||
config: SaaSConfig,
|
||||
worker_dispatcher: WorkerDispatcher,
|
||||
shutdown_token: CancellationToken,
|
||||
worker_limiter: SpawnLimiter,
|
||||
) -> anyhow::Result<Self> {
|
||||
let jwt_secret = config.jwt_secret()?;
|
||||
let rpm = config.rate_limit.requests_per_minute;
|
||||
Ok(Self {
|
||||
@@ -50,6 +98,8 @@ impl AppState {
|
||||
worker_dispatcher,
|
||||
shutdown_token,
|
||||
cache: AppCache::new(),
|
||||
worker_limiter,
|
||||
rate_limit_batch: Arc::new(dashmap::DashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -96,4 +146,60 @@ impl AppState {
|
||||
tracing::warn!("Failed to dispatch log_operation: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
/// 限流事件批量 flush 到 DB
|
||||
///
|
||||
/// 使用 swap-to-zero 模式:先将计数器原子归零,DB 写入成功后删除条目。
|
||||
/// 如果 DB 写入失败,归零的计数会在下次 flush 时重新累加(因 middleware 持续写入)。
|
||||
pub async fn flush_rate_limit_batch(&self, max_batch: usize) {
|
||||
// 阶段1: 收集非零 key,将计数器原子归零(而非删除)
|
||||
// 这样如果 DB 写入失败,middleware 的新累加会在已有 key 上继续
|
||||
let mut batch: Vec<(String, i64)> = Vec::with_capacity(max_batch.min(64));
|
||||
|
||||
let keys: Vec<String> = self.rate_limit_batch.iter()
|
||||
.filter(|e| *e.value() > 0)
|
||||
.take(max_batch)
|
||||
.map(|e| e.key().clone())
|
||||
.collect();
|
||||
|
||||
for key in &keys {
|
||||
// 原子交换为 0,取走当前值
|
||||
if let Some(mut entry) = self.rate_limit_batch.get_mut(key) {
|
||||
if *entry > 0 {
|
||||
batch.push((key.clone(), *entry));
|
||||
*entry = 0; // 归零而非删除
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if batch.is_empty() { return; }
|
||||
|
||||
let keys_buf: Vec<String> = batch.iter().map(|(k, _)| k.clone()).collect();
|
||||
let counts: Vec<i64> = batch.iter().map(|(_, c)| *c).collect();
|
||||
|
||||
let result = sqlx::query(
|
||||
"INSERT INTO rate_limit_events (key, window_start, count)
|
||||
SELECT u.key, NOW(), u.cnt FROM UNNEST($1::text[], $2::bigint[]) AS u(key, cnt)"
|
||||
)
|
||||
.bind(&keys_buf)
|
||||
.bind(&counts)
|
||||
.execute(&self.db)
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
// DB 写入失败:将归零的计数加回去,避免数据丢失
|
||||
tracing::warn!("[RateLimitBatch] flush failed ({} entries), restoring counts: {}", batch.len(), e);
|
||||
for (key, count) in &batch {
|
||||
if let Some(mut entry) = self.rate_limit_batch.get_mut(key) {
|
||||
*entry += *count;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// DB 写入成功:删除已归零的条目
|
||||
for (key, _) in &batch {
|
||||
self.rate_limit_batch.remove_if(key, |_, v| *v == 0);
|
||||
}
|
||||
tracing::debug!("[RateLimitBatch] flushed {} entries", batch.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use sqlx::PgPool;
|
||||
use crate::error::SaasResult;
|
||||
use crate::models::{TelemetryModelStatsRow, TelemetryDailyStatsRow};
|
||||
use crate::models::{TelemetryModelStatsRow, TelemetryDailyStatsRow, TelemetryReportRow};
|
||||
use super::types::*;
|
||||
|
||||
const CHUNK_SIZE: usize = 100;
|
||||
@@ -270,3 +270,27 @@ pub async fn get_daily_stats(
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
/// 查询账号最近的遥测报告
|
||||
pub async fn get_recent_reports(
|
||||
db: &PgPool,
|
||||
account_id: &str,
|
||||
limit: i64,
|
||||
) -> SaasResult<Vec<TelemetryReportRow>> {
|
||||
let limit = limit.min(100).max(1);
|
||||
let rows = sqlx::query_as::<_, TelemetryReportRow>(
|
||||
"SELECT id, account_id, device_id, app_version, model_id, \
|
||||
input_tokens, output_tokens, latency_ms, success, \
|
||||
error_type, connection_mode, \
|
||||
reported_at::text, created_at::text \
|
||||
FROM telemetry_reports \
|
||||
WHERE account_id = $1 \
|
||||
ORDER BY reported_at DESC \
|
||||
LIMIT $2"
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(limit)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
123
crates/zclaw-saas/src/workers/aggregate_usage.rs
Normal file
123
crates/zclaw-saas/src/workers/aggregate_usage.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
//! 计费用量聚合 Worker
|
||||
//!
|
||||
//! 从 usage_records 聚合当月用量到 billing_usage_quotas 表。
|
||||
//! 由 Scheduler 每小时触发,或在 relay 请求完成时直接派发。
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{Datelike, Timelike};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::error::SaasResult;
|
||||
use super::Worker;
|
||||
|
||||
/// 用量聚合参数
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AggregateUsageArgs {
|
||||
/// 聚合的目标账户 ID(None = 聚合所有活跃账户)
|
||||
pub account_id: Option<String>,
|
||||
}
|
||||
|
||||
pub struct AggregateUsageWorker;
|
||||
|
||||
#[async_trait]
|
||||
impl Worker for AggregateUsageWorker {
|
||||
type Args = AggregateUsageArgs;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"aggregate_usage"
|
||||
}
|
||||
|
||||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
|
||||
match args.account_id {
|
||||
Some(account_id) => {
|
||||
aggregate_single_account(db, &account_id).await?;
|
||||
}
|
||||
None => {
|
||||
aggregate_all_accounts(db).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// 聚合单个账户的当月用量
|
||||
async fn aggregate_single_account(db: &PgPool, account_id: &str) -> SaasResult<()> {
|
||||
// 获取或创建用量记录(确保存在)
|
||||
let usage = crate::billing::service::get_or_create_usage(db, account_id).await?;
|
||||
|
||||
// 从 usage_records 聚合当月实际 token 用量
|
||||
let now = chrono::Utc::now();
|
||||
let period_start = now
|
||||
.with_day(1).unwrap_or(now)
|
||||
.with_hour(0).unwrap_or(now)
|
||||
.with_minute(0).unwrap_or(now)
|
||||
.with_second(0).unwrap_or(now)
|
||||
.with_nanosecond(0).unwrap_or(now);
|
||||
|
||||
let aggregated: Option<(i64, i64, i64)> = sqlx::query_as(
|
||||
"SELECT COALESCE(SUM(input_tokens), 0), \
|
||||
COALESCE(SUM(output_tokens), 0), \
|
||||
COUNT(*) \
|
||||
FROM usage_records \
|
||||
WHERE account_id = $1 AND created_at >= $2 AND status = 'success'"
|
||||
)
|
||||
.bind(account_id)
|
||||
.bind(period_start)
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
if let Some((input_tokens, output_tokens, request_count)) = aggregated {
|
||||
sqlx::query(
|
||||
"UPDATE billing_usage_quotas \
|
||||
SET input_tokens = $1, \
|
||||
output_tokens = $2, \
|
||||
relay_requests = GREATEST(relay_requests, $3::int), \
|
||||
updated_at = NOW() \
|
||||
WHERE id = $4"
|
||||
)
|
||||
.bind(input_tokens)
|
||||
.bind(output_tokens)
|
||||
.bind(request_count as i32)
|
||||
.bind(&usage.id)
|
||||
.execute(db)
|
||||
.await?;
|
||||
|
||||
tracing::debug!(
|
||||
"Aggregated usage for account {}: in={}, out={}, reqs={}",
|
||||
account_id, input_tokens, output_tokens, request_count
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 聚合所有活跃账户
|
||||
async fn aggregate_all_accounts(db: &PgPool) -> SaasResult<()> {
|
||||
let account_ids: Vec<String> = sqlx::query_scalar(
|
||||
"SELECT DISTINCT account_id FROM billing_subscriptions \
|
||||
WHERE status IN ('trial', 'active', 'past_due') \
|
||||
UNION \
|
||||
SELECT DISTINCT account_id FROM billing_usage_quotas \
|
||||
WHERE period_start >= date_trunc('month', NOW())"
|
||||
)
|
||||
.fetch_all(db)
|
||||
.await?;
|
||||
|
||||
let total = account_ids.len();
|
||||
let mut errors = 0;
|
||||
|
||||
for account_id in &account_ids {
|
||||
if let Err(e) = aggregate_single_account(db, account_id).await {
|
||||
tracing::warn!("Failed to aggregate usage for {}: {}", account_id, e);
|
||||
errors += 1;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"Usage aggregation complete: {} accounts, {} errors",
|
||||
total, errors
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
168
crates/zclaw-saas/src/workers/generate_embedding.rs
Normal file
168
crates/zclaw-saas/src/workers/generate_embedding.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
//! 知识条目分块 + Embedding 生成 Worker
|
||||
//!
|
||||
//! 当知识条目创建/更新时触发:
|
||||
//! 1. 按 Markdown 标题 + 固定长度分块
|
||||
//! 2. 提取关键词(从 item 的 keywords 字段继承 + 内容提取)
|
||||
//! 3. 写入 knowledge_chunks 表
|
||||
//! 4. 如果配置了 embedding provider,生成向量 embedding(Phase 2)
|
||||
|
||||
use async_trait::async_trait;
|
||||
use sqlx::PgPool;
|
||||
use serde::{Serialize, Deserialize};
|
||||
use crate::error::SaasResult;
|
||||
use super::Worker;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GenerateEmbeddingArgs {
|
||||
pub item_id: String,
|
||||
}
|
||||
|
||||
pub struct GenerateEmbeddingWorker;
|
||||
|
||||
#[async_trait]
|
||||
impl Worker for GenerateEmbeddingWorker {
|
||||
type Args = GenerateEmbeddingArgs;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"generate_embedding"
|
||||
}
|
||||
|
||||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
|
||||
// 1. 加载条目
|
||||
let item: Option<(String, String, Vec<String>)> = sqlx::query_as(
|
||||
"SELECT content, title, keywords FROM knowledge_items WHERE id = $1"
|
||||
)
|
||||
.bind(&args.item_id)
|
||||
.fetch_optional(db)
|
||||
.await?;
|
||||
|
||||
let (content, title, keywords) = match item {
|
||||
Some(row) => row,
|
||||
None => {
|
||||
tracing::warn!("GenerateEmbedding: item {} not found, skipping", args.item_id);
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
// 2. 分块
|
||||
let chunks = crate::knowledge::service::chunk_content(&content, 512, 64);
|
||||
|
||||
if chunks.is_empty() {
|
||||
tracing::debug!("GenerateEmbedding: item {} has no content to chunk", args.item_id);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// 3. 在事务中删除旧分块 + 插入新分块(防止并发竞争条件)
|
||||
let mut tx = db.begin().await?;
|
||||
|
||||
// 锁定条目行防止并发 worker 同时处理同一条目
|
||||
let locked: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT id FROM knowledge_items WHERE id = $1 FOR UPDATE"
|
||||
)
|
||||
.bind(&args.item_id)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?;
|
||||
|
||||
if locked.is_none() {
|
||||
tx.rollback().await?;
|
||||
tracing::warn!("GenerateEmbedding: item {} was deleted during processing", args.item_id);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
sqlx::query("DELETE FROM knowledge_chunks WHERE item_id = $1")
|
||||
.bind(&args.item_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
for (idx, chunk) in chunks.iter().enumerate() {
|
||||
let chunk_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let mut chunk_keywords = keywords.clone();
|
||||
extract_chunk_keywords(chunk, &mut chunk_keywords);
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO knowledge_chunks (id, item_id, chunk_index, content, keywords, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW())"
|
||||
)
|
||||
.bind(&chunk_id)
|
||||
.bind(&args.item_id)
|
||||
.bind(idx as i32)
|
||||
.bind(chunk)
|
||||
.bind(&chunk_keywords)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
tracing::info!(
|
||||
"GenerateEmbedding: item '{}' → {} chunks (keywords: {})",
|
||||
title,
|
||||
chunks.len(),
|
||||
keywords.len(),
|
||||
);
|
||||
|
||||
// Phase 2: 如果配置了 embedding provider,在此处调用 embedding API
|
||||
// 并更新 chunks 的 embedding 列
|
||||
// TODO: let _ = generate_vectors(db, &args.item_id, &chunks).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// 从 chunk 内容中提取高频中文词组作为补充关键词
|
||||
///
|
||||
/// 简单策略:提取 2-4 字的连续中文字符段,取出现频率 > 1 的
|
||||
fn extract_chunk_keywords(content: &str, keywords: &mut Vec<String>) {
|
||||
let chars: Vec<char> = content.chars().collect();
|
||||
let mut i = 0;
|
||||
|
||||
while i < chars.len() {
|
||||
// 寻找连续中文字符段
|
||||
if is_cjk(chars[i]) {
|
||||
let start = i;
|
||||
while i < chars.len() && is_cjk(chars[i]) {
|
||||
i += 1;
|
||||
}
|
||||
let segment: String = chars[start..i].iter().collect();
|
||||
|
||||
// 提取 2-4 字的子串
|
||||
let seg_chars: Vec<char> = segment.chars().collect();
|
||||
if seg_chars.len() >= 2 {
|
||||
// 只取前 2-4 字的短语(避免过长无意义词组)
|
||||
for len in 2..=4.min(seg_chars.len()) {
|
||||
let phrase: String = seg_chars[..len].iter().collect();
|
||||
// 过滤常见停用词(简单版)
|
||||
if !is_stop_word(&phrase) && !keywords.contains(&phrase) {
|
||||
keywords.push(phrase);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// 限制关键词总数
|
||||
keywords.truncate(50);
|
||||
}
|
||||
|
||||
/// 判断是否为 CJK 字符
|
||||
fn is_cjk(c: char) -> bool {
|
||||
matches!(c,
|
||||
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
|
||||
'\u{3400}'..='\u{4DBF}' | // CJK Unified Ideographs Extension A
|
||||
'\u{F900}'..='\u{FAFF}' // CJK Compatibility Ideographs
|
||||
)
|
||||
}
|
||||
|
||||
/// 简单停用词表
|
||||
fn is_stop_word(s: &str) -> bool {
|
||||
matches!(s,
|
||||
"的" | "了" | "是" | "在" | "我" | "有" | "和" | "就" | "不" | "人" |
|
||||
"都" | "一" | "一个" | "上" | "也" | "很" | "到" | "说" | "要" | "去" |
|
||||
"你" | "会" | "着" | "没有" | "看" | "好" | "自己" | "这" | "他" | "她" |
|
||||
"它" | "们" | "那" | "些" | "什么" | "为" | "所以" | "但是" | "因为" |
|
||||
"如果" | "可以" | "能够" | "需要" | "应该" | "已经" | "还是" | "或者"
|
||||
)
|
||||
}
|
||||
@@ -42,10 +42,12 @@ struct TaskMessage {
|
||||
/// Worker 调度器 — 管理所有 Worker 的注册和派发
|
||||
///
|
||||
/// 使用 Arc 包装,可安全跨任务共享。
|
||||
/// 通过 SpawnLimiter 限制并发执行的任务数,防止连接池耗尽。
|
||||
pub struct WorkerDispatcher {
|
||||
db: PgPool,
|
||||
sender: mpsc::Sender<TaskMessage>,
|
||||
handlers: HashMap<String, Arc<dyn DynWorker>>,
|
||||
spawn_limiter: crate::state::SpawnLimiter,
|
||||
}
|
||||
|
||||
impl Clone for WorkerDispatcher {
|
||||
@@ -54,6 +56,7 @@ impl Clone for WorkerDispatcher {
|
||||
db: self.db.clone(),
|
||||
sender: self.sender.clone(),
|
||||
handlers: self.handlers.clone(),
|
||||
spawn_limiter: self.spawn_limiter.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,7 +93,7 @@ where
|
||||
|
||||
impl WorkerDispatcher {
|
||||
/// 创建新的调度器
|
||||
pub fn new(db: PgPool) -> Self {
|
||||
pub fn new(db: PgPool, spawn_limiter: crate::state::SpawnLimiter) -> Self {
|
||||
// channel 容量 1024,足够缓冲高峰期任务
|
||||
let (sender, receiver) = mpsc::channel(1024);
|
||||
|
||||
@@ -98,6 +101,7 @@ impl WorkerDispatcher {
|
||||
db,
|
||||
sender,
|
||||
handlers: HashMap::new(),
|
||||
spawn_limiter,
|
||||
};
|
||||
|
||||
// 启动消费循环
|
||||
@@ -152,10 +156,15 @@ impl WorkerDispatcher {
|
||||
}
|
||||
|
||||
/// 启动消费循环
|
||||
///
|
||||
/// 通过 SpawnLimiter 门控并发:消费者循环在 spawn 之前获取 permit,
|
||||
/// 信号量满时阻塞消费者循环(而非 spawn 无限任务),提供真正的背压。
|
||||
/// 重试时先 drop permit 再 sleep,避免浪费 permit 在等待期间。
|
||||
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
|
||||
let db = self.db.clone();
|
||||
let handlers = self.handlers.clone();
|
||||
let sender = self.sender.clone();
|
||||
let limiter = self.spawn_limiter.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
while let Some(msg) = receiver.recv().await {
|
||||
@@ -171,21 +180,34 @@ impl WorkerDispatcher {
|
||||
let max_retries = handler.max_retries();
|
||||
let db = db.clone();
|
||||
let sender = sender.clone();
|
||||
let limiter = limiter.clone();
|
||||
|
||||
// 关键:在 spawn 之前获取 permit
|
||||
// 信号量满时阻塞消费者循环,限制 tokio::spawn 调用数量
|
||||
let permit = limiter.acquire().await;
|
||||
tracing::trace!(
|
||||
"Worker '{}' acquired permit ({} available), spawning task",
|
||||
worker_name, limiter.available()
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
// permit 已预获取,任务立即执行
|
||||
let _permit = permit;
|
||||
|
||||
match handler.perform(&db, &msg.args_json).await {
|
||||
Ok(()) => {
|
||||
tracing::debug!("Worker {} completed successfully", worker_name);
|
||||
}
|
||||
Err(e) => {
|
||||
if msg.attempt < max_retries {
|
||||
// 先 drop permit,不占用并发配额在 sleep 期间
|
||||
drop(_permit);
|
||||
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
|
||||
tracing::warn!(
|
||||
"Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.",
|
||||
worker_name, msg.attempt, max_retries, e, delay
|
||||
);
|
||||
tokio::time::sleep(delay).await;
|
||||
// 重新入队(递增 attempt 计数)
|
||||
let retry_msg = TaskMessage {
|
||||
worker_name: msg.worker_name.clone(),
|
||||
args_json: msg.args_json.clone(),
|
||||
@@ -218,6 +240,8 @@ pub mod cleanup_rate_limit;
|
||||
pub mod cleanup_refresh_tokens;
|
||||
pub mod update_last_used;
|
||||
pub mod record_usage;
|
||||
pub mod aggregate_usage;
|
||||
pub mod generate_embedding;
|
||||
|
||||
// 便捷导出
|
||||
pub use log_operation::LogOperationWorker;
|
||||
@@ -225,3 +249,4 @@ pub use cleanup_rate_limit::CleanupRateLimitWorker;
|
||||
pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker;
|
||||
pub use update_last_used::UpdateLastUsedWorker;
|
||||
pub use record_usage::RecordUsageWorker;
|
||||
pub use aggregate_usage::AggregateUsageWorker;
|
||||
|
||||
@@ -8,7 +8,8 @@ use super::Worker;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct UpdateLastUsedArgs {
|
||||
pub token_id: String,
|
||||
/// token_hash 用于 WHERE 条件匹配
|
||||
pub token_hash: String,
|
||||
}
|
||||
|
||||
pub struct UpdateLastUsedWorker;
|
||||
@@ -23,9 +24,9 @@ impl Worker for UpdateLastUsedWorker {
|
||||
|
||||
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now().to_rfc3339();
|
||||
sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE id = $2")
|
||||
sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
|
||||
.bind(&now)
|
||||
.bind(&args.token_id)
|
||||
.bind(&args.token_hash)
|
||||
.execute(db)
|
||||
.await?;
|
||||
Ok(())
|
||||
|
||||
223
desktop/src-tauri/src/classroom_commands/chat.rs
Normal file
223
desktop/src-tauri/src/classroom_commands/chat.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
//! Classroom multi-agent chat commands
|
||||
//!
|
||||
//! - `classroom_chat` — send a message and receive multi-agent responses
|
||||
//! - `classroom_chat_history` — retrieve chat history for a classroom
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tauri::State;
|
||||
|
||||
use zclaw_kernel::generation::{
|
||||
AgentProfile, AgentRole,
|
||||
ClassroomChatMessage, ClassroomChatState,
|
||||
ClassroomChatRequest,
|
||||
build_chat_prompt, parse_chat_responses,
|
||||
};
|
||||
use zclaw_runtime::CompletionRequest;
|
||||
|
||||
use super::ClassroomStore;
|
||||
use crate::kernel_commands::KernelState;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// State
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Chat state store: classroom_id → chat state
|
||||
pub type ChatStore = Arc<Mutex<std::collections::HashMap<String, ClassroomChatState>>>;
|
||||
|
||||
pub fn create_chat_state() -> ChatStore {
|
||||
Arc::new(Mutex::new(std::collections::HashMap::new()))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Request / Response
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClassroomChatCmdRequest {
|
||||
pub classroom_id: String,
|
||||
pub user_message: String,
|
||||
pub scene_context: Option<String>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Commands
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Send a message in the classroom chat and get multi-agent responses.
|
||||
#[tauri::command]
|
||||
pub async fn classroom_chat(
|
||||
store: State<'_, ClassroomStore>,
|
||||
chat_store: State<'_, ChatStore>,
|
||||
kernel_state: State<'_, KernelState>,
|
||||
request: ClassroomChatCmdRequest,
|
||||
) -> Result<Vec<ClassroomChatMessage>, String> {
|
||||
if request.user_message.trim().is_empty() {
|
||||
return Err("Message cannot be empty".to_string());
|
||||
}
|
||||
|
||||
// Get classroom data
|
||||
let classroom = {
|
||||
let s = store.lock().await;
|
||||
s.get(&request.classroom_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| format!("Classroom '{}' not found", request.classroom_id))?
|
||||
};
|
||||
|
||||
// Create user message
|
||||
let user_msg = ClassroomChatMessage::user_message(&request.user_message);
|
||||
|
||||
// Get chat history for context
|
||||
let history: Vec<ClassroomChatMessage> = {
|
||||
let cs = chat_store.lock().await;
|
||||
cs.get(&request.classroom_id)
|
||||
.map(|s| s.messages.clone())
|
||||
.unwrap_or_default()
|
||||
};
|
||||
|
||||
// Try LLM-powered multi-agent responses, fallback to placeholder
|
||||
let agent_responses = match generate_llm_responses(&kernel_state, &classroom.agents, &request.user_message, request.scene_context.as_deref(), &history).await {
|
||||
Ok(responses) => responses,
|
||||
Err(e) => {
|
||||
tracing::warn!("LLM chat generation failed, using placeholders: {}", e);
|
||||
generate_placeholder_responses(
|
||||
&classroom.agents,
|
||||
&request.user_message,
|
||||
request.scene_context.as_deref(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Store in chat state
|
||||
{
|
||||
let mut cs = chat_store.lock().await;
|
||||
let state = cs.entry(request.classroom_id.clone())
|
||||
.or_insert_with(|| ClassroomChatState {
|
||||
messages: vec![],
|
||||
active: true,
|
||||
});
|
||||
|
||||
state.messages.push(user_msg);
|
||||
state.messages.extend(agent_responses.clone());
|
||||
}
|
||||
|
||||
Ok(agent_responses)
|
||||
}
|
||||
|
||||
/// Retrieve chat history for a classroom
|
||||
#[tauri::command]
|
||||
pub async fn classroom_chat_history(
|
||||
chat_store: State<'_, ChatStore>,
|
||||
classroom_id: String,
|
||||
) -> Result<Vec<ClassroomChatMessage>, String> {
|
||||
let cs = chat_store.lock().await;
|
||||
Ok(cs.get(&classroom_id)
|
||||
.map(|s| s.messages.clone())
|
||||
.unwrap_or_default())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Placeholder response generation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn generate_placeholder_responses(
|
||||
agents: &[AgentProfile],
|
||||
user_message: &str,
|
||||
scene_context: Option<&str>,
|
||||
) -> Vec<ClassroomChatMessage> {
|
||||
let mut responses = Vec::new();
|
||||
|
||||
// Teacher always responds
|
||||
if let Some(teacher) = agents.iter().find(|a| a.role == AgentRole::Teacher) {
|
||||
let context_hint = scene_context
|
||||
.map(|ctx| format!("关于「{}」,", ctx))
|
||||
.unwrap_or_default();
|
||||
|
||||
responses.push(ClassroomChatMessage::agent_message(
|
||||
teacher,
|
||||
&format!("{}这是一个很好的问题!让我来详细解释一下「{}」的核心概念...", context_hint, user_message),
|
||||
));
|
||||
}
|
||||
|
||||
// Assistant chimes in
|
||||
if let Some(assistant) = agents.iter().find(|a| a.role == AgentRole::Assistant) {
|
||||
responses.push(ClassroomChatMessage::agent_message(
|
||||
assistant,
|
||||
"我来补充一下要点 📌",
|
||||
));
|
||||
}
|
||||
|
||||
// One student responds
|
||||
if let Some(student) = agents.iter().find(|a| a.role == AgentRole::Student) {
|
||||
responses.push(ClassroomChatMessage::agent_message(
|
||||
student,
|
||||
&format!("谢谢老师!我大概理解了{}", user_message),
|
||||
));
|
||||
}
|
||||
|
||||
responses
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// LLM-powered response generation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn generate_llm_responses(
|
||||
kernel_state: &State<'_, KernelState>,
|
||||
agents: &[AgentProfile],
|
||||
user_message: &str,
|
||||
scene_context: Option<&str>,
|
||||
history: &[ClassroomChatMessage],
|
||||
) -> Result<Vec<ClassroomChatMessage>, String> {
|
||||
let driver = {
|
||||
let ks = kernel_state.lock().await;
|
||||
ks.as_ref()
|
||||
.map(|k| k.driver())
|
||||
.ok_or_else(|| "Kernel not initialized".to_string())?
|
||||
};
|
||||
|
||||
if !driver.is_configured() {
|
||||
return Err("LLM driver not configured".to_string());
|
||||
}
|
||||
|
||||
// Build the chat request for prompt generation (include history)
|
||||
let chat_request = ClassroomChatRequest {
|
||||
classroom_id: String::new(),
|
||||
user_message: user_message.to_string(),
|
||||
agents: agents.to_vec(),
|
||||
scene_context: scene_context.map(|s| s.to_string()),
|
||||
history: history.to_vec(),
|
||||
};
|
||||
|
||||
let prompt = build_chat_prompt(&chat_request);
|
||||
|
||||
let request = CompletionRequest {
|
||||
model: "default".to_string(),
|
||||
system: Some("你是一个课堂多智能体讨论的协调器。".to_string()),
|
||||
messages: vec![zclaw_types::Message::User {
|
||||
content: prompt,
|
||||
}],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response = driver.complete(request).await
|
||||
.map_err(|e| format!("LLM call failed: {}", e))?;
|
||||
|
||||
// Extract text from response
|
||||
let text = response.content.iter()
|
||||
.filter_map(|block| match block {
|
||||
zclaw_runtime::ContentBlock::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
|
||||
let responses = parse_chat_responses(&text, agents);
|
||||
if responses.is_empty() {
|
||||
return Err("LLM returned no parseable agent responses".to_string());
|
||||
}
|
||||
|
||||
Ok(responses)
|
||||
}
|
||||
152
desktop/src-tauri/src/classroom_commands/export.rs
Normal file
152
desktop/src-tauri/src/classroom_commands/export.rs
Normal file
@@ -0,0 +1,152 @@
|
||||
//! Classroom export commands
|
||||
//!
|
||||
//! - `classroom_export` — export classroom as HTML, Markdown, or JSON
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tauri::State;
|
||||
|
||||
use zclaw_kernel::generation::Classroom;
|
||||
|
||||
use super::ClassroomStore;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClassroomExportRequest {
|
||||
pub classroom_id: String,
|
||||
pub format: String, // "html" | "markdown" | "json"
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClassroomExportResponse {
|
||||
pub content: String,
|
||||
pub filename: String,
|
||||
pub mime_type: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Command
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tauri::command]
|
||||
pub async fn classroom_export(
|
||||
store: State<'_, ClassroomStore>,
|
||||
request: ClassroomExportRequest,
|
||||
) -> Result<ClassroomExportResponse, String> {
|
||||
let classroom = {
|
||||
let s = store.lock().await;
|
||||
s.get(&request.classroom_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| format!("Classroom '{}' not found", request.classroom_id))?
|
||||
};
|
||||
|
||||
match request.format.as_str() {
|
||||
"json" => export_json(&classroom),
|
||||
"html" => export_html(&classroom),
|
||||
"markdown" | "md" => export_markdown(&classroom),
|
||||
_ => Err(format!("Unsupported export format: '{}'. Use html, markdown, or json.", request.format)),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Exporters
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn export_json(classroom: &Classroom) -> Result<ClassroomExportResponse, String> {
|
||||
let content = serde_json::to_string_pretty(classroom)
|
||||
.map_err(|e| format!("JSON serialization failed: {}", e))?;
|
||||
|
||||
Ok(ClassroomExportResponse {
|
||||
filename: format!("{}.json", sanitize_filename(&classroom.title)),
|
||||
content,
|
||||
mime_type: "application/json".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn export_html(classroom: &Classroom) -> Result<ClassroomExportResponse, String> {
|
||||
let mut html = String::from(r#"<!DOCTYPE html><html lang="zh-CN"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width,initial-scale=1">"#);
|
||||
html.push_str(&format!("<title>{}</title>", html_escape(&classroom.title)));
|
||||
html.push_str(r#"<style>body{font-family:system-ui,sans-serif;max-width:800px;margin:0 auto;padding:2rem;color:#333}h1{color:#4F46E5}h2{color:#7C3AED;border-bottom:2px solid #E5E7EB;padding-bottom:0.5rem}.scene{margin:2rem 0;padding:1rem;border-left:4px solid #4F46E5;background:#F9FAFB}.quiz{border-left-color:#F59E0B}.discussion{border-left-color:#10B981}.agent{display:inline-flex;align-items:center;gap:0.5rem;margin:0.25rem;padding:0.25rem 0.75rem;border-radius:9999px;font-size:0.875rem;font-weight:500}</style></head><body>"#);
|
||||
|
||||
html.push_str(&format!("<h1>{}</h1>", html_escape(&classroom.title)));
|
||||
html.push_str(&format!("<p>{}</p>", html_escape(&classroom.description)));
|
||||
|
||||
// Agents
|
||||
html.push_str("<h2>课堂角色</h2><div>");
|
||||
for agent in &classroom.agents {
|
||||
html.push_str(&format!(
|
||||
r#"<span class="agent" style="background:{};color:white">{} {}</span>"#,
|
||||
agent.color, agent.avatar, html_escape(&agent.name)
|
||||
));
|
||||
}
|
||||
html.push_str("</div>");
|
||||
|
||||
// Scenes
|
||||
html.push_str("<h2>课程内容</h2>");
|
||||
for scene in &classroom.scenes {
|
||||
let type_class = match scene.content.scene_type {
|
||||
zclaw_kernel::generation::SceneType::Quiz => "quiz",
|
||||
zclaw_kernel::generation::SceneType::Discussion => "discussion",
|
||||
_ => "",
|
||||
};
|
||||
html.push_str(&format!(
|
||||
r#"<div class="scene {}"><h3>{}</h3><p>类型: {:?} | 时长: {}秒</p></div>"#,
|
||||
type_class,
|
||||
html_escape(&scene.content.title),
|
||||
scene.content.scene_type,
|
||||
scene.content.duration_seconds
|
||||
));
|
||||
}
|
||||
|
||||
html.push_str("</body></html>");
|
||||
|
||||
Ok(ClassroomExportResponse {
|
||||
filename: format!("{}.html", sanitize_filename(&classroom.title)),
|
||||
content: html,
|
||||
mime_type: "text/html".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn export_markdown(classroom: &Classroom) -> Result<ClassroomExportResponse, String> {
|
||||
let mut md = String::new();
|
||||
md.push_str(&format!("# {}\n\n", &classroom.title));
|
||||
md.push_str(&format!("{}\n\n", &classroom.description));
|
||||
|
||||
md.push_str("## 课堂角色\n\n");
|
||||
for agent in &classroom.agents {
|
||||
md.push_str(&format!("- {} **{}** ({:?})\n", agent.avatar, agent.name, agent.role));
|
||||
}
|
||||
md.push('\n');
|
||||
|
||||
md.push_str("## 课程内容\n\n");
|
||||
for (i, scene) in classroom.scenes.iter().enumerate() {
|
||||
md.push_str(&format!("### {}. {}\n\n", i + 1, scene.content.title));
|
||||
md.push_str(&format!("- 类型: `{:?}`\n", scene.content.scene_type));
|
||||
md.push_str(&format!("- 时长: {}秒\n\n", scene.content.duration_seconds));
|
||||
}
|
||||
|
||||
Ok(ClassroomExportResponse {
|
||||
filename: format!("{}.md", sanitize_filename(&classroom.title)),
|
||||
content: md,
|
||||
mime_type: "text/markdown".to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn sanitize_filename(name: &str) -> String {
|
||||
name.chars()
|
||||
.map(|c| if c.is_alphanumeric() || c == '-' || c == '_' { c } else { '_' })
|
||||
.collect::<String>()
|
||||
.trim_matches('_')
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn html_escape(s: &str) -> String {
|
||||
s.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
}
|
||||
286
desktop/src-tauri/src/classroom_commands/generate.rs
Normal file
286
desktop/src-tauri/src/classroom_commands/generate.rs
Normal file
@@ -0,0 +1,286 @@
|
||||
//! Classroom generation commands
|
||||
//!
|
||||
//! - `classroom_generate` — start 4-stage pipeline, emit progress events
|
||||
//! - `classroom_generation_progress` — query current progress
|
||||
//! - `classroom_cancel_generation` — cancel active generation
|
||||
//! - `classroom_get` — retrieve generated classroom data
|
||||
//! - `classroom_list` — list all generated classrooms
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tauri::{AppHandle, Emitter, State};
|
||||
|
||||
use zclaw_kernel::generation::{
|
||||
Classroom, GenerationPipeline, GenerationRequest as KernelGenRequest, GenerationStage,
|
||||
TeachingStyle, DifficultyLevel,
|
||||
};
|
||||
|
||||
use super::{ClassroomStore, GenerationTasks};
|
||||
use crate::kernel_commands::KernelState;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Request / Response types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClassroomGenerateRequest {
|
||||
pub topic: String,
|
||||
pub document: Option<String>,
|
||||
pub style: Option<String>,
|
||||
pub level: Option<String>,
|
||||
pub target_duration_minutes: Option<u32>,
|
||||
pub scene_count: Option<usize>,
|
||||
pub custom_instructions: Option<String>,
|
||||
pub language: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClassroomGenerateResponse {
|
||||
pub classroom_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ClassroomProgressResponse {
|
||||
pub stage: String,
|
||||
pub progress: u8,
|
||||
pub activity: String,
|
||||
pub items_progress: Option<(usize, usize)>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn parse_style(s: Option<&str>) -> TeachingStyle {
|
||||
match s.unwrap_or("lecture") {
|
||||
"discussion" => TeachingStyle::Discussion,
|
||||
"pbl" => TeachingStyle::Pbl,
|
||||
"flipped" => TeachingStyle::Flipped,
|
||||
"socratic" => TeachingStyle::Socratic,
|
||||
_ => TeachingStyle::Lecture,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_level(l: Option<&str>) -> DifficultyLevel {
|
||||
match l.unwrap_or("intermediate") {
|
||||
"beginner" => DifficultyLevel::Beginner,
|
||||
"advanced" => DifficultyLevel::Advanced,
|
||||
"expert" => DifficultyLevel::Expert,
|
||||
_ => DifficultyLevel::Intermediate,
|
||||
}
|
||||
}
|
||||
|
||||
fn stage_name(stage: &GenerationStage) -> &'static str {
|
||||
match stage {
|
||||
GenerationStage::AgentProfiles => "agent_profiles",
|
||||
GenerationStage::Outline => "outline",
|
||||
GenerationStage::Scene => "scene",
|
||||
GenerationStage::Complete => "complete",
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Commands
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Start classroom generation (4-stage pipeline).
|
||||
/// Progress events are emitted via `classroom:progress`.
|
||||
/// Supports cancellation between stages by removing the task from GenerationTasks.
|
||||
#[tauri::command]
|
||||
pub async fn classroom_generate(
|
||||
app: AppHandle,
|
||||
store: State<'_, ClassroomStore>,
|
||||
tasks: State<'_, GenerationTasks>,
|
||||
kernel_state: State<'_, KernelState>,
|
||||
request: ClassroomGenerateRequest,
|
||||
) -> Result<ClassroomGenerateResponse, String> {
|
||||
if request.topic.trim().is_empty() {
|
||||
return Err("Topic is required".to_string());
|
||||
}
|
||||
|
||||
let topic_clone = request.topic.clone();
|
||||
|
||||
let kernel_request = KernelGenRequest {
|
||||
topic: request.topic.clone(),
|
||||
document: request.document.clone(),
|
||||
style: parse_style(request.style.as_deref()),
|
||||
level: parse_level(request.level.as_deref()),
|
||||
target_duration_minutes: request.target_duration_minutes.unwrap_or(30),
|
||||
scene_count: request.scene_count,
|
||||
custom_instructions: request.custom_instructions.clone(),
|
||||
language: request.language.clone().or_else(|| Some("zh-CN".to_string())),
|
||||
};
|
||||
|
||||
// Register generation task so cancellation can check it
|
||||
{
|
||||
use zclaw_kernel::generation::GenerationProgress;
|
||||
let mut t = tasks.lock().await;
|
||||
t.insert(topic_clone.clone(), GenerationProgress {
|
||||
stage: zclaw_kernel::generation::GenerationStage::AgentProfiles,
|
||||
progress: 0,
|
||||
activity: "Starting generation...".to_string(),
|
||||
items_progress: None,
|
||||
eta_seconds: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Get LLM driver from kernel if available, otherwise use placeholder mode
|
||||
let pipeline = {
|
||||
let ks = kernel_state.lock().await;
|
||||
if let Some(kernel) = ks.as_ref() {
|
||||
GenerationPipeline::with_driver(kernel.driver())
|
||||
} else {
|
||||
GenerationPipeline::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Helper: check if cancelled
|
||||
let is_cancelled = || {
|
||||
let t = tasks.blocking_lock();
|
||||
!t.contains_key(&topic_clone)
|
||||
};
|
||||
|
||||
// Helper: emit progress event
|
||||
let emit_progress = |stage: &str, progress: u8, activity: &str| {
|
||||
let _ = app.emit("classroom:progress", serde_json::json!({
|
||||
"topic": &topic_clone,
|
||||
"stage": stage,
|
||||
"progress": progress,
|
||||
"activity": activity
|
||||
}));
|
||||
};
|
||||
|
||||
// ── Stage 0: Agent Profiles ──
|
||||
emit_progress("agent_profiles", 5, "生成课堂角色...");
|
||||
let agents = pipeline.generate_agent_profiles(&kernel_request).await;
|
||||
emit_progress("agent_profiles", 25, "角色生成完成");
|
||||
if is_cancelled() {
|
||||
return Err("Generation cancelled".to_string());
|
||||
}
|
||||
|
||||
// ── Stage 1: Outline ──
|
||||
emit_progress("outline", 30, "分析主题,生成大纲...");
|
||||
let outline = pipeline.generate_outline(&kernel_request).await
|
||||
.map_err(|e| format!("Outline generation failed: {}", e))?;
|
||||
emit_progress("outline", 50, &format!("大纲完成:{} 个场景", outline.len()));
|
||||
if is_cancelled() {
|
||||
return Err("Generation cancelled".to_string());
|
||||
}
|
||||
|
||||
// ── Stage 2: Scenes (parallel) ──
|
||||
emit_progress("scene", 55, &format!("并行生成 {} 个场景...", outline.len()));
|
||||
let scenes = pipeline.generate_scenes(&outline).await
|
||||
.map_err(|e| format!("Scene generation failed: {}", e))?;
|
||||
if is_cancelled() {
|
||||
return Err("Generation cancelled".to_string());
|
||||
}
|
||||
|
||||
// ── Stage 3: Assemble ──
|
||||
emit_progress("complete", 90, "组装课堂...");
|
||||
|
||||
// Build classroom directly (pipeline.build_classroom is private)
|
||||
let total_duration: u32 = scenes.iter().map(|s| s.content.duration_seconds).sum();
|
||||
let objectives = outline.iter()
|
||||
.take(3)
|
||||
.map(|item| format!("理解: {}", item.title))
|
||||
.collect::<Vec<_>>();
|
||||
let classroom_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let classroom = Classroom {
|
||||
id: classroom_id.clone(),
|
||||
title: format!("课堂: {}", kernel_request.topic),
|
||||
description: format!("{:?} 风格课堂 — {}", kernel_request.style, kernel_request.topic),
|
||||
topic: kernel_request.topic.clone(),
|
||||
style: kernel_request.style,
|
||||
level: kernel_request.level,
|
||||
total_duration,
|
||||
objectives,
|
||||
scenes,
|
||||
agents,
|
||||
metadata: zclaw_kernel::generation::ClassroomMetadata {
|
||||
generated_at: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis() as i64,
|
||||
source_document: kernel_request.document.map(|_| "user_document".to_string()),
|
||||
model: None,
|
||||
version: "2.0.0".to_string(),
|
||||
custom: serde_json::Map::new(),
|
||||
},
|
||||
};
|
||||
|
||||
// Store classroom
|
||||
{
|
||||
let mut s = store.lock().await;
|
||||
s.insert(classroom_id.clone(), classroom);
|
||||
}
|
||||
|
||||
// Clear generation task
|
||||
{
|
||||
let mut t = tasks.lock().await;
|
||||
t.remove(&topic_clone);
|
||||
}
|
||||
|
||||
// Emit completion
|
||||
emit_progress("complete", 100, "课堂生成完成");
|
||||
|
||||
Ok(ClassroomGenerateResponse {
|
||||
classroom_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get current generation progress for a topic
|
||||
#[tauri::command]
|
||||
pub async fn classroom_generation_progress(
|
||||
tasks: State<'_, GenerationTasks>,
|
||||
topic: String,
|
||||
) -> Result<ClassroomProgressResponse, String> {
|
||||
let t = tasks.lock().await;
|
||||
let progress = t.get(&topic);
|
||||
Ok(ClassroomProgressResponse {
|
||||
stage: progress.map(|p| stage_name(&p.stage).to_string()).unwrap_or_else(|| "none".to_string()),
|
||||
progress: progress.map(|p| p.progress).unwrap_or(0),
|
||||
activity: progress.map(|p| p.activity.clone()).unwrap_or_default(),
|
||||
items_progress: progress.and_then(|p| p.items_progress),
|
||||
})
|
||||
}
|
||||
|
||||
/// Cancel an active generation
|
||||
#[tauri::command]
|
||||
pub async fn classroom_cancel_generation(
|
||||
tasks: State<'_, GenerationTasks>,
|
||||
topic: String,
|
||||
) -> Result<(), String> {
|
||||
let mut t = tasks.lock().await;
|
||||
t.remove(&topic);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve a generated classroom by ID
|
||||
#[tauri::command]
|
||||
pub async fn classroom_get(
|
||||
store: State<'_, ClassroomStore>,
|
||||
classroom_id: String,
|
||||
) -> Result<Classroom, String> {
|
||||
let s = store.lock().await;
|
||||
s.get(&classroom_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| format!("Classroom '{}' not found", classroom_id))
|
||||
}
|
||||
|
||||
/// List all generated classrooms (id + title only)
|
||||
#[tauri::command]
|
||||
pub async fn classroom_list(
|
||||
store: State<'_, ClassroomStore>,
|
||||
) -> Result<Vec<serde_json::Value>, String> {
|
||||
let s = store.lock().await;
|
||||
Ok(s.values().map(|c| serde_json::json!({
|
||||
"id": c.id,
|
||||
"title": c.title,
|
||||
"topic": c.topic,
|
||||
"totalDuration": c.total_duration,
|
||||
"sceneCount": c.scenes.len(),
|
||||
})).collect())
|
||||
}
|
||||
41
desktop/src-tauri/src/classroom_commands/mod.rs
Normal file
41
desktop/src-tauri/src/classroom_commands/mod.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
//! Classroom generation and interaction commands
|
||||
//!
|
||||
//! Tauri commands for the OpenMAIC-style interactive classroom:
|
||||
//! - Generate classroom (4-stage pipeline with progress events)
|
||||
//! - Multi-agent chat
|
||||
//! - Export (HTML/Markdown/JSON)
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
use zclaw_kernel::generation::Classroom;
|
||||
|
||||
pub mod chat;
|
||||
pub mod export;
|
||||
pub mod generate;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared state types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// In-memory classroom store: classroom_id → Classroom
|
||||
pub type ClassroomStore = Arc<Mutex<std::collections::HashMap<String, Classroom>>>;
|
||||
|
||||
/// Active generation tasks: topic → progress
|
||||
pub type GenerationTasks = Arc<Mutex<std::collections::HashMap<String, zclaw_kernel::generation::GenerationProgress>>>;
|
||||
|
||||
// Re-export chat state type
|
||||
// Re-export chat state type — used by lib.rs to construct managed state
|
||||
#[allow(unused_imports)]
|
||||
pub use chat::ChatStore;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// State constructors
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub fn create_classroom_state() -> ClassroomStore {
|
||||
Arc::new(Mutex::new(std::collections::HashMap::new()))
|
||||
}
|
||||
|
||||
pub fn create_generation_tasks() -> GenerationTasks {
|
||||
Arc::new(Mutex::new(std::collections::HashMap::new()))
|
||||
}
|
||||
@@ -258,11 +258,18 @@ impl AgentIdentityManager {
|
||||
if !identity.instructions.is_empty() {
|
||||
sections.push(identity.instructions.clone());
|
||||
}
|
||||
if !identity.user_profile.is_empty()
|
||||
&& identity.user_profile != default_user_profile()
|
||||
{
|
||||
sections.push(format!("## 用户画像\n{}", identity.user_profile));
|
||||
}
|
||||
// NOTE: user_profile injection is intentionally disabled.
|
||||
// The reflection engine may accumulate overly specific details from past
|
||||
// conversations (e.g., "广东光华", "汕头玩具产业") into user_profile.
|
||||
// These details then leak into every new conversation's system prompt,
|
||||
// causing the model to think about old topics instead of the current query.
|
||||
// Memory injection should only happen via MemoryMiddleware with relevance
|
||||
// filtering, not unconditionally via user_profile.
|
||||
// if !identity.user_profile.is_empty()
|
||||
// && identity.user_profile != default_user_profile()
|
||||
// {
|
||||
// sections.push(format!("## 用户画像\n{}", identity.user_profile));
|
||||
// }
|
||||
if let Some(ctx) = memory_context {
|
||||
sections.push(ctx.to_string());
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ pub struct ChatResponse {
|
||||
#[serde(rename_all = "camelCase", tag = "type")]
|
||||
pub enum StreamChatEvent {
|
||||
Delta { delta: String },
|
||||
ThinkingDelta { delta: String },
|
||||
ToolStart { name: String, input: serde_json::Value },
|
||||
ToolEnd { name: String, output: serde_json::Value },
|
||||
IterationStart { iteration: usize, max_iterations: usize },
|
||||
@@ -218,6 +219,10 @@ pub async fn agent_chat_stream(
|
||||
tracing::trace!("[agent_chat_stream] Delta: {} bytes", delta.len());
|
||||
StreamChatEvent::Delta { delta: delta.clone() }
|
||||
}
|
||||
LoopEvent::ThinkingDelta(delta) => {
|
||||
tracing::trace!("[agent_chat_stream] ThinkingDelta: {} bytes", delta.len());
|
||||
StreamChatEvent::ThinkingDelta { delta: delta.clone() }
|
||||
}
|
||||
LoopEvent::ToolStart { name, input } => {
|
||||
tracing::debug!("[agent_chat_stream] ToolStart: {}", name);
|
||||
if name.starts_with("hand_") {
|
||||
|
||||
@@ -249,3 +249,130 @@ pub async fn kernel_shutdown(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Apply SaaS-synced configuration to the Kernel config file.
|
||||
///
|
||||
/// Writes relevant config values (agent, llm categories) to the TOML config file.
|
||||
/// The changes take effect on the next Kernel restart.
|
||||
#[tauri::command]
|
||||
pub async fn kernel_apply_saas_config(
|
||||
configs: Vec<SaasConfigItem>,
|
||||
) -> Result<u32, String> {
|
||||
use std::io::Write;
|
||||
|
||||
let config_path = zclaw_kernel::config::KernelConfig::find_config_path()
|
||||
.ok_or_else(|| "No config file path found".to_string())?;
|
||||
|
||||
// Read existing config or create empty
|
||||
let existing = if config_path.exists() {
|
||||
std::fs::read_to_string(&config_path).unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let mut updated = existing;
|
||||
let mut applied: u32 = 0;
|
||||
|
||||
for config in &configs {
|
||||
// Only process kernel-relevant categories
|
||||
if !matches!(config.category.as_str(), "agent" | "llm") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Write key=value to the [llm] or [agent] section
|
||||
let section = &config.category;
|
||||
let key = config.key.replace('.', "_");
|
||||
let value = &config.value;
|
||||
|
||||
// Simple TOML patching: find or create section, update key
|
||||
let section_header = format!("[{}]", section);
|
||||
let line_to_set = format!("{} = {}", key, toml_quote_value(value));
|
||||
|
||||
if let Some(section_start) = updated.find(§ion_header) {
|
||||
// Section exists, find or add the key within it
|
||||
let after_header = section_start + section_header.len();
|
||||
let next_section = updated[after_header..].find("\n[")
|
||||
.map(|i| after_header + i)
|
||||
.unwrap_or(updated.len());
|
||||
|
||||
let section_content = &updated[after_header..next_section];
|
||||
let key_prefix = format!("\n{} =", key);
|
||||
let key_prefix_alt = format!("\n{}=", key);
|
||||
|
||||
if let Some(key_pos) = section_content.find(&key_prefix)
|
||||
.or_else(|| section_content.find(&key_prefix_alt))
|
||||
{
|
||||
// Key exists, replace the line
|
||||
let line_start = after_header + key_pos + 1; // skip \n
|
||||
let line_end = updated[line_start..].find('\n')
|
||||
.map(|i| line_start + i)
|
||||
.unwrap_or(updated.len());
|
||||
updated = format!(
|
||||
"{}{}{}\n{}",
|
||||
&updated[..line_start],
|
||||
line_to_set,
|
||||
if line_end < updated.len() { "" } else { "" },
|
||||
&updated[line_end..]
|
||||
);
|
||||
// Remove the extra newline if line_end included one
|
||||
updated = updated.replace(&format!("{}\n\n", line_to_set), &format!("{}\n", line_to_set));
|
||||
} else {
|
||||
// Key doesn't exist, append to section
|
||||
updated.insert_str(next_section, format!("\n{}", line_to_set).as_str());
|
||||
}
|
||||
} else {
|
||||
// Section doesn't exist, append it
|
||||
updated = format!("{}\n{}\n{}\n", updated.trim_end(), section_header, line_to_set);
|
||||
}
|
||||
applied += 1;
|
||||
}
|
||||
|
||||
if applied > 0 {
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = config_path.parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| format!("Failed to create config dir: {}", e))?;
|
||||
}
|
||||
|
||||
let mut file = std::fs::File::create(&config_path)
|
||||
.map_err(|e| format!("Failed to write config: {}", e))?;
|
||||
file.write_all(updated.as_bytes())
|
||||
.map_err(|e| format!("Failed to write config: {}", e))?;
|
||||
|
||||
tracing::info!(
|
||||
"[kernel_apply_saas_config] Applied {} config items to {:?} (restart required)",
|
||||
applied,
|
||||
config_path
|
||||
);
|
||||
}
|
||||
|
||||
Ok(applied)
|
||||
}
|
||||
|
||||
/// Single config item from SaaS sync
|
||||
#[derive(Debug, Clone, serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SaasConfigItem {
|
||||
pub category: String,
|
||||
pub key: String,
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
/// Quote a value for TOML format
|
||||
fn toml_quote_value(value: &str) -> String {
|
||||
// Try to parse as number or boolean
|
||||
if value == "true" || value == "false" {
|
||||
return value.to_string();
|
||||
}
|
||||
if let Ok(n) = value.parse::<i64>() {
|
||||
return n.to_string();
|
||||
}
|
||||
if let Ok(n) = value.parse::<f64>() {
|
||||
return n.to_string();
|
||||
}
|
||||
// Handle multi-line strings with TOML triple-quote syntax
|
||||
if value.contains('\n') {
|
||||
return format!("\"\"\"\n{}\"\"\"", value.replace('\\', "\\\\").replace("\"\"\"", "'\"'\"'\""));
|
||||
}
|
||||
// Default: quote as string
|
||||
format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\""))
|
||||
}
|
||||
|
||||
@@ -34,6 +34,9 @@ mod kernel_commands;
|
||||
// Pipeline commands (DSL-based workflows)
|
||||
mod pipeline_commands;
|
||||
|
||||
// Classroom generation and interaction commands
|
||||
mod classroom_commands;
|
||||
|
||||
// Gateway sub-modules (runtime, config, io, commands)
|
||||
mod gateway;
|
||||
|
||||
@@ -99,6 +102,11 @@ pub fn run() {
|
||||
// Initialize Pipeline state (DSL-based workflows)
|
||||
let pipeline_state = pipeline_commands::create_pipeline_state();
|
||||
|
||||
// Initialize Classroom state (generation + chat)
|
||||
let classroom_state = classroom_commands::create_classroom_state();
|
||||
let classroom_chat_state = classroom_commands::chat::create_chat_state();
|
||||
let classroom_gen_tasks = classroom_commands::create_generation_tasks();
|
||||
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_opener::init())
|
||||
.manage(browser_state)
|
||||
@@ -110,11 +118,15 @@ pub fn run() {
|
||||
.manage(scheduler_state)
|
||||
.manage(kernel_commands::SessionStreamGuard::default())
|
||||
.manage(pipeline_state)
|
||||
.manage(classroom_state)
|
||||
.manage(classroom_chat_state)
|
||||
.manage(classroom_gen_tasks)
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
// Internal ZCLAW Kernel commands (preferred)
|
||||
kernel_commands::lifecycle::kernel_init,
|
||||
kernel_commands::lifecycle::kernel_status,
|
||||
kernel_commands::lifecycle::kernel_shutdown,
|
||||
kernel_commands::lifecycle::kernel_apply_saas_config,
|
||||
kernel_commands::agent::agent_create,
|
||||
kernel_commands::agent::agent_list,
|
||||
kernel_commands::agent::agent_get,
|
||||
@@ -300,7 +312,16 @@ pub fn run() {
|
||||
intelligence::identity::identity_get_snapshots,
|
||||
intelligence::identity::identity_restore_snapshot,
|
||||
intelligence::identity::identity_list_agents,
|
||||
intelligence::identity::identity_delete_agent
|
||||
intelligence::identity::identity_delete_agent,
|
||||
// Classroom generation and interaction commands
|
||||
classroom_commands::generate::classroom_generate,
|
||||
classroom_commands::generate::classroom_generation_progress,
|
||||
classroom_commands::generate::classroom_cancel_generation,
|
||||
classroom_commands::generate::classroom_get,
|
||||
classroom_commands::generate::classroom_list,
|
||||
classroom_commands::chat::classroom_chat,
|
||||
classroom_commands::chat::classroom_chat_history,
|
||||
classroom_commands::export::classroom_export
|
||||
])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
|
||||
@@ -29,6 +29,7 @@ import { useProposalNotifications, ProposalNotificationHandler } from './lib/use
|
||||
import { useToast } from './components/ui/Toast';
|
||||
import type { Clone } from './store/agentStore';
|
||||
import { createLogger } from './lib/logger';
|
||||
import { startOfflineMonitor } from './store/offlineStore';
|
||||
|
||||
const log = createLogger('App');
|
||||
|
||||
@@ -86,6 +87,8 @@ function App() {
|
||||
|
||||
useEffect(() => {
|
||||
document.title = 'ZCLAW';
|
||||
const stopOfflineMonitor = startOfflineMonitor();
|
||||
return () => { stopOfflineMonitor(); };
|
||||
}, []);
|
||||
|
||||
// Restore SaaS session from OS keyring on startup (before auth gate)
|
||||
@@ -152,8 +155,11 @@ function App() {
|
||||
let mounted = true;
|
||||
|
||||
const bootstrap = async () => {
|
||||
// 未登录时不启动 bootstrap
|
||||
if (!useSaaSStore.getState().isLoggedIn) return;
|
||||
// 未登录时不启动 bootstrap,直接结束 loading
|
||||
if (!useSaaSStore.getState().isLoggedIn) {
|
||||
setBootstrapping(false);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Step 1: Check and start local gateway in Tauri environment
|
||||
|
||||
@@ -2,6 +2,7 @@ import { useState, useEffect, useRef, useCallback, useMemo, type MutableRefObjec
|
||||
import { motion, AnimatePresence } from 'framer-motion';
|
||||
import { List, type ListImperativeAPI } from 'react-window';
|
||||
import { useChatStore, Message } from '../store/chatStore';
|
||||
import { useArtifactStore } from '../store/chat/artifactStore';
|
||||
import { useConnectionStore } from '../store/connectionStore';
|
||||
import { useAgentStore } from '../store/agentStore';
|
||||
import { useConfigStore } from '../store/configStore';
|
||||
@@ -12,6 +13,8 @@ import { ArtifactPanel } from './ai/ArtifactPanel';
|
||||
import { ToolCallChain } from './ai/ToolCallChain';
|
||||
import { listItemVariants, defaultTransition, fadeInVariants } from '../lib/animations';
|
||||
import { FirstConversationPrompt } from './FirstConversationPrompt';
|
||||
import { ClassroomPlayer } from './classroom_player';
|
||||
import { useClassroomStore } from '../store/classroomStore';
|
||||
// MessageSearch temporarily removed during DeerFlow redesign
|
||||
import { OfflineIndicator } from './OfflineIndicator';
|
||||
import {
|
||||
@@ -45,11 +48,14 @@ export function ChatArea() {
|
||||
messages, currentAgent, isStreaming, isLoading, currentModel,
|
||||
sendMessage: sendToGateway, setCurrentModel, initStreamListener,
|
||||
newConversation, chatMode, setChatMode, suggestions,
|
||||
artifacts, selectedArtifactId, artifactPanelOpen,
|
||||
selectArtifact, setArtifactPanelOpen,
|
||||
totalInputTokens, totalOutputTokens,
|
||||
} = useChatStore();
|
||||
const {
|
||||
artifacts, selectedArtifactId, artifactPanelOpen,
|
||||
selectArtifact, setArtifactPanelOpen,
|
||||
} = useArtifactStore();
|
||||
const connectionState = useConnectionStore((s) => s.connectionState);
|
||||
const { activeClassroom, classroomOpen, closeClassroom, generating, progressPercent, progressActivity, error: classroomError, clearError: clearClassroomError } = useClassroomStore();
|
||||
const clones = useAgentStore((s) => s.clones);
|
||||
const models = useConfigStore((s) => s.models);
|
||||
|
||||
@@ -203,9 +209,76 @@ export function ChatArea() {
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="relative h-full">
|
||||
{/* Generation progress overlay */}
|
||||
<AnimatePresence>
|
||||
{generating && (
|
||||
<motion.div
|
||||
key="generation-overlay"
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
className="absolute inset-0 z-40 bg-white/80 dark:bg-gray-900/80 backdrop-blur-sm flex items-center justify-center"
|
||||
>
|
||||
<div className="text-center space-y-4">
|
||||
<div className="w-12 h-12 border-4 border-indigo-200 border-t-indigo-500 rounded-full animate-spin mx-auto" />
|
||||
<div>
|
||||
<p className="text-lg font-medium text-gray-900 dark:text-white">
|
||||
正在生成课堂...
|
||||
</p>
|
||||
<p className="text-sm text-gray-500 dark:text-gray-400 mt-1">
|
||||
{progressActivity || '准备中...'}
|
||||
</p>
|
||||
</div>
|
||||
{progressPercent > 0 && (
|
||||
<div className="w-64 mx-auto">
|
||||
<div className="h-2 bg-gray-200 dark:bg-gray-700 rounded-full overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-indigo-500 rounded-full transition-all duration-500"
|
||||
style={{ width: `${progressPercent}%` }}
|
||||
/>
|
||||
</div>
|
||||
<p className="text-xs text-gray-400 mt-1">{progressPercent}%</p>
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={() => useClassroomStore.getState().cancelGeneration()}
|
||||
className="px-4 py-2 text-sm text-gray-500 hover:text-gray-700 dark:hover:text-gray-300 border border-gray-300 dark:border-gray-600 rounded-lg"
|
||||
>
|
||||
取消
|
||||
</button>
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
{/* ClassroomPlayer overlay */}
|
||||
<AnimatePresence>
|
||||
{classroomOpen && activeClassroom && (
|
||||
<motion.div
|
||||
key="classroom-overlay"
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
className="absolute inset-0 z-50 bg-white dark:bg-gray-900"
|
||||
>
|
||||
<ClassroomPlayer
|
||||
onClose={closeClassroom}
|
||||
/>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
|
||||
<ResizableChatLayout
|
||||
chatPanel={
|
||||
<div className="flex flex-col h-full">
|
||||
{/* Classroom generation error banner */}
|
||||
{classroomError && (
|
||||
<div className="mx-4 mt-2 px-4 py-2 bg-red-50 dark:bg-red-900/20 border border-red-200 dark:border-red-800 rounded-lg flex items-center justify-between text-sm">
|
||||
<span className="text-red-600 dark:text-red-400">课堂生成失败: {classroomError}</span>
|
||||
<button onClick={clearClassroomError} className="text-red-400 hover:text-red-600 ml-3 text-xs">关闭</button>
|
||||
</div>
|
||||
)}
|
||||
{/* Header — DeerFlow-style: minimal */}
|
||||
<div className="h-14 border-b border-transparent flex items-center justify-between px-6 flex-shrink-0 bg-white dark:bg-gray-900">
|
||||
<div className="flex items-center gap-2 text-sm text-gray-500">
|
||||
@@ -298,6 +371,7 @@ export function ChatArea() {
|
||||
getHeight={getHeight}
|
||||
onHeightChange={setHeight}
|
||||
messageRefs={messageRefs}
|
||||
setInput={setInput}
|
||||
/>
|
||||
) : (
|
||||
messages.map((message) => (
|
||||
@@ -310,7 +384,7 @@ export function ChatArea() {
|
||||
layout
|
||||
transition={defaultTransition}
|
||||
>
|
||||
<MessageBubble message={message} />
|
||||
<MessageBubble message={message} setInput={setInput} />
|
||||
</motion.div>
|
||||
))
|
||||
)}
|
||||
@@ -433,19 +507,16 @@ export function ChatArea() {
|
||||
rightPanelOpen={artifactPanelOpen}
|
||||
onRightPanelToggle={setArtifactPanelOpen}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function MessageBubble({ message }: { message: Message }) {
|
||||
// Tool messages are now absorbed into the assistant message's toolSteps chain.
|
||||
// Legacy standalone tool messages (from older sessions) still render as before.
|
||||
function MessageBubble({ message, setInput }: { message: Message; setInput: (text: string) => void }) {
|
||||
if (message.role === 'tool') {
|
||||
return null;
|
||||
}
|
||||
|
||||
const isUser = message.role === 'user';
|
||||
|
||||
// 思考中状态:streaming 且内容为空时显示思考指示器
|
||||
const isThinking = message.streaming && !message.content;
|
||||
|
||||
// Download message as Markdown file
|
||||
@@ -518,7 +589,20 @@ function MessageBubble({ message }: { message: Message }) {
|
||||
: '...'}
|
||||
</div>
|
||||
{message.error && (
|
||||
<p className="text-xs text-red-500 mt-2">{message.error}</p>
|
||||
<div className="flex items-center gap-2 mt-2">
|
||||
<p className="text-xs text-red-500">{message.error}</p>
|
||||
<button
|
||||
onClick={() => {
|
||||
const text = typeof message.content === 'string' ? message.content : '';
|
||||
if (text) {
|
||||
setInput(text);
|
||||
}
|
||||
}}
|
||||
className="text-xs px-2 py-0.5 rounded bg-red-100 dark:bg-red-900/30 text-red-600 dark:text-red-400 hover:bg-red-200 dark:hover:bg-red-900/50 transition-colors"
|
||||
>
|
||||
重试
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{/* Download button for AI messages - show on hover */}
|
||||
{!isUser && message.content && !message.streaming && (
|
||||
@@ -543,6 +627,7 @@ interface VirtualizedMessageRowProps {
|
||||
message: Message;
|
||||
onHeightChange: (height: number) => void;
|
||||
messageRefs: MutableRefObject<Map<string, HTMLDivElement>>;
|
||||
setInput: (text: string) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -553,6 +638,7 @@ function VirtualizedMessageRow({
|
||||
message,
|
||||
onHeightChange,
|
||||
messageRefs,
|
||||
setInput,
|
||||
style,
|
||||
ariaAttributes,
|
||||
}: VirtualizedMessageRowProps & {
|
||||
@@ -587,7 +673,7 @@ function VirtualizedMessageRow({
|
||||
className="py-3"
|
||||
{...ariaAttributes}
|
||||
>
|
||||
<MessageBubble message={message} />
|
||||
<MessageBubble message={message} setInput={setInput} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -598,6 +684,7 @@ interface VirtualizedMessageListProps {
|
||||
getHeight: (id: string, role: string) => number;
|
||||
onHeightChange: (id: string, height: number) => void;
|
||||
messageRefs: MutableRefObject<Map<string, HTMLDivElement>>;
|
||||
setInput: (text: string) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -610,6 +697,7 @@ function VirtualizedMessageList({
|
||||
getHeight,
|
||||
onHeightChange,
|
||||
messageRefs,
|
||||
setInput,
|
||||
}: VirtualizedMessageListProps) {
|
||||
// Row component for react-window v2
|
||||
const RowComponent = (props: {
|
||||
@@ -625,6 +713,7 @@ function VirtualizedMessageList({
|
||||
message={messages[props.index]}
|
||||
onHeightChange={(h) => onHeightChange(messages[props.index].id, h)}
|
||||
messageRefs={messageRefs}
|
||||
setInput={setInput}
|
||||
style={props.style}
|
||||
ariaAttributes={props.ariaAttributes}
|
||||
/>
|
||||
|
||||
@@ -67,6 +67,7 @@ interface ClassroomPreviewerProps {
|
||||
data: ClassroomData;
|
||||
onClose?: () => void;
|
||||
onExport?: (format: 'pptx' | 'html' | 'pdf') => void;
|
||||
onOpenFullPlayer?: () => void;
|
||||
}
|
||||
|
||||
// === Sub-Components ===
|
||||
@@ -271,6 +272,7 @@ function OutlinePanel({
|
||||
export function ClassroomPreviewer({
|
||||
data,
|
||||
onExport,
|
||||
onOpenFullPlayer,
|
||||
}: ClassroomPreviewerProps) {
|
||||
const [currentSceneIndex, setCurrentSceneIndex] = useState(0);
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
@@ -398,6 +400,15 @@ export function ClassroomPreviewer({
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{onOpenFullPlayer && (
|
||||
<button
|
||||
onClick={onOpenFullPlayer}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 text-sm bg-indigo-100 dark:bg-indigo-900/30 text-indigo-700 dark:text-indigo-300 rounded-md hover:bg-indigo-200 dark:hover:bg-indigo-900/50 transition-colors"
|
||||
>
|
||||
<Play className="w-4 h-4" />
|
||||
完整播放器
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={() => handleExport('pptx')}
|
||||
className="flex items-center gap-1.5 px-3 py-1.5 text-sm bg-orange-100 dark:bg-orange-900/30 text-orange-700 dark:text-orange-300 rounded-md hover:bg-orange-200 dark:hover:bg-orange-900/50 transition-colors"
|
||||
|
||||
@@ -22,23 +22,26 @@ import {
|
||||
} from '../lib/personality-presets';
|
||||
import type { Clone } from '../store/agentStore';
|
||||
import { useChatStore } from '../store/chatStore';
|
||||
import { useClassroomStore } from '../store/classroomStore';
|
||||
import { useHandStore } from '../store/handStore';
|
||||
|
||||
// Quick action chip definitions — DeerFlow-style colored pills
|
||||
// handId maps to actual Hand names in the runtime
|
||||
const QUICK_ACTIONS = [
|
||||
{ key: 'surprise', label: '小惊喜', icon: Sparkles, color: 'text-orange-500' },
|
||||
{ key: 'write', label: '写作', icon: PenLine, color: 'text-blue-500' },
|
||||
{ key: 'research', label: '研究', icon: Microscope, color: 'text-purple-500' },
|
||||
{ key: 'collect', label: '收集', icon: Layers, color: 'text-green-500' },
|
||||
{ key: 'research', label: '研究', icon: Microscope, color: 'text-purple-500', handId: 'researcher' },
|
||||
{ key: 'collect', label: '收集', icon: Layers, color: 'text-green-500', handId: 'collector' },
|
||||
{ key: 'learn', label: '学习', icon: GraduationCap, color: 'text-indigo-500' },
|
||||
];
|
||||
|
||||
// Pre-filled prompts for each quick action
|
||||
// Pre-filled prompts for each quick action — tailored for target industries
|
||||
const QUICK_ACTION_PROMPTS: Record<string, string> = {
|
||||
surprise: '给我一个小惊喜吧!来点创意的',
|
||||
write: '帮我写一篇文章,主题你来定',
|
||||
research: '帮我做一个深度研究分析',
|
||||
collect: '帮我收集整理一些有用的信息',
|
||||
learn: '我想学点新东西,教我一些有趣的知识',
|
||||
write: '帮我写一份关于"远程医疗行政管理优化方案"的提案大纲',
|
||||
research: '帮我深度研究"2026年教育数字化转型趋势",包括政策、技术和实践三个维度',
|
||||
collect: '帮我采集 5 个主流 AI 教育工具的产品信息,对比功能和价格',
|
||||
learn: '我想了解汕头玩具产业 2026 年出口趋势,能帮我分析一下吗?',
|
||||
};
|
||||
|
||||
interface FirstConversationPromptProps {
|
||||
@@ -69,6 +72,41 @@ export function FirstConversationPrompt({
|
||||
});
|
||||
|
||||
const handleQuickAction = (key: string) => {
|
||||
if (key === 'learn') {
|
||||
// Trigger classroom generation flow
|
||||
const classroomStore = useClassroomStore.getState();
|
||||
// Extract a clean topic from the prompt
|
||||
const prompt = QUICK_ACTION_PROMPTS[key] || '';
|
||||
const topic = prompt
|
||||
.replace(/^[你我].*?(想了解|想学|了解|学习|分析|研究|探索)\s*/g, '')
|
||||
.replace(/[,。?!].*$/g, '')
|
||||
.replace(/^(能|帮|请|可不可以).*/g, '')
|
||||
.trim() || '互动课堂';
|
||||
classroomStore.startGeneration({
|
||||
topic,
|
||||
style: 'lecture',
|
||||
level: 'intermediate',
|
||||
language: 'zh-CN',
|
||||
}).catch(() => {
|
||||
// Error is already stored in classroomStore.error and displayed in ChatArea
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if this action maps to a Hand
|
||||
const actionDef = QUICK_ACTIONS.find((a) => a.key === key);
|
||||
if (actionDef?.handId) {
|
||||
const handStore = useHandStore.getState();
|
||||
handStore.triggerHand(actionDef.handId, {
|
||||
action: key === 'research' ? 'report' : 'collect',
|
||||
query: { query: QUICK_ACTION_PROMPTS[key] || '' },
|
||||
}).catch(() => {
|
||||
// Fallback: fill prompt into input bar
|
||||
onSelectSuggestion?.(QUICK_ACTION_PROMPTS[key] || '你好!');
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const prompt = QUICK_ACTION_PROMPTS[key] || '你好!';
|
||||
onSelectSuggestion?.(prompt);
|
||||
};
|
||||
|
||||
@@ -25,6 +25,8 @@ import { PipelineRunResponse } from '../lib/pipeline-client';
|
||||
import { useToast } from './ui/Toast';
|
||||
import DOMPurify from 'dompurify';
|
||||
import { ClassroomPreviewer, type ClassroomData } from './ClassroomPreviewer';
|
||||
import { useClassroomStore } from '../store/classroomStore';
|
||||
import { adaptToClassroom } from '../lib/classroom-adapter';
|
||||
|
||||
// === Types ===
|
||||
|
||||
@@ -286,6 +288,11 @@ export function PipelineResultPreview({
|
||||
// Handle export
|
||||
handleClassroomExport(format, classroomData);
|
||||
}}
|
||||
onOpenFullPlayer={() => {
|
||||
const classroom = adaptToClassroom(classroomData);
|
||||
useClassroomStore.getState().setActiveClassroom(classroom);
|
||||
useClassroomStore.getState().openClassroom();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
Filter,
|
||||
X,
|
||||
} from 'lucide-react';
|
||||
import { PipelineResultPreview } from './PipelineResultPreview';
|
||||
import {
|
||||
PipelineClient,
|
||||
PipelineInfo,
|
||||
@@ -28,7 +29,7 @@ import {
|
||||
formatInputType,
|
||||
} from '../lib/pipeline-client';
|
||||
import { useToast } from './ui/Toast';
|
||||
import { PresentationContainer } from './presentation';
|
||||
import { saasClient } from '../lib/saas-client';
|
||||
|
||||
// === Category Badge Component ===
|
||||
|
||||
@@ -117,64 +118,6 @@ function PipelineCard({ pipeline, onRun }: PipelineCardProps) {
|
||||
);
|
||||
}
|
||||
|
||||
// === Pipeline Result Modal ===
|
||||
|
||||
interface ResultModalProps {
|
||||
result: PipelineRunResponse;
|
||||
pipeline: PipelineInfo;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
function ResultModal({ result, pipeline, onClose }: ResultModalProps) {
|
||||
return (
|
||||
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
|
||||
<div className="bg-white dark:bg-gray-800 rounded-lg shadow-xl w-[90vw] max-w-4xl h-[85vh] flex flex-col mx-4">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between p-4 border-b border-gray-200 dark:border-gray-700">
|
||||
<div className="flex items-center gap-3">
|
||||
<span className="text-2xl">{pipeline.icon}</span>
|
||||
<div>
|
||||
<h2 className="text-lg font-semibold text-gray-900 dark:text-white">
|
||||
{pipeline.displayName} - 执行结果
|
||||
</h2>
|
||||
<p className="text-sm text-gray-500 dark:text-gray-400">
|
||||
状态: {result.status === 'completed' ? '已完成' : '失败'}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
onClick={onClose}
|
||||
className="p-1 hover:bg-gray-100 dark:hover:bg-gray-700 rounded"
|
||||
>
|
||||
<X className="w-5 h-5 text-gray-500" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div className="flex-1 overflow-hidden">
|
||||
{result.outputs ? (
|
||||
<PresentationContainer
|
||||
data={result.outputs}
|
||||
pipelineId={pipeline.id}
|
||||
supportedTypes={['document', 'chart', 'quiz', 'slideshow']}
|
||||
/>
|
||||
) : result.error ? (
|
||||
<div className="p-6 text-center text-red-500">
|
||||
<XCircle className="w-8 h-8 mx-auto mb-2" />
|
||||
<p>{result.error}</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="p-6 text-center text-gray-500">
|
||||
<Package className="w-8 h-8 mx-auto mb-2" />
|
||||
<p>无输出结果</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// === Pipeline Run Modal ===
|
||||
|
||||
interface RunModalProps {
|
||||
@@ -489,6 +432,13 @@ export function PipelinesPanel() {
|
||||
if (result.status === 'completed') {
|
||||
toast('Pipeline 执行完成', 'success');
|
||||
setRunResult({ result, pipeline: selectedPipeline! });
|
||||
|
||||
// Report pipeline execution to billing (fire-and-forget)
|
||||
try {
|
||||
if (saasClient.isAuthenticated()) {
|
||||
saasClient.reportUsageFireAndForget('pipeline_runs');
|
||||
}
|
||||
} catch { /* billing reporting must never block */ }
|
||||
} else {
|
||||
toast(`Pipeline 执行失败: ${result.error}`, 'error');
|
||||
}
|
||||
@@ -602,11 +552,11 @@ export function PipelinesPanel() {
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Result Modal */}
|
||||
{/* Result Preview */}
|
||||
{runResult && (
|
||||
<ResultModal
|
||||
<PipelineResultPreview
|
||||
result={runResult.result}
|
||||
pipeline={runResult.pipeline}
|
||||
pipelineId={runResult.pipeline.id}
|
||||
onClose={() => setRunResult(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -109,7 +109,7 @@ export function Conversation({ children, className = '' }: ConversationProps) {
|
||||
<div
|
||||
ref={containerRef}
|
||||
onScroll={handleScroll}
|
||||
className={`overflow-y-auto custom-scrollbar ${className}`}
|
||||
className={`overflow-y-auto custom-scrollbar min-h-0 ${className}`}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
|
||||
@@ -62,7 +62,7 @@ export function ResizableChatLayout({
|
||||
|
||||
if (!rightPanelOpen || !rightPanel) {
|
||||
return (
|
||||
<div className="flex-1 flex flex-col overflow-hidden relative">
|
||||
<div className="h-full flex flex-col overflow-hidden relative">
|
||||
{chatPanel}
|
||||
<button
|
||||
onClick={handleToggle}
|
||||
@@ -76,7 +76,7 @@ export function ResizableChatLayout({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex-1 flex flex-col overflow-hidden">
|
||||
<div className="h-full flex flex-col overflow-hidden">
|
||||
<Group
|
||||
orientation="horizontal"
|
||||
onLayoutChanged={(layout) => savePanelSizes(layout)}
|
||||
|
||||
121
desktop/src/components/classroom_player/AgentChat.tsx
Normal file
121
desktop/src/components/classroom_player/AgentChat.tsx
Normal file
@@ -0,0 +1,121 @@
|
||||
/**
|
||||
* AgentChat — Multi-agent chat panel for classroom interaction.
|
||||
*
|
||||
* Displays chat bubbles from different agents (teacher, assistant, students)
|
||||
* with distinct colors and avatars. Users can send messages.
|
||||
*/
|
||||
|
||||
import { useState, useRef, useEffect } from 'react';
|
||||
import type { ClassroomChatMessage as ChatMessage, AgentProfile } from '../../types/classroom';
|
||||
|
||||
interface AgentChatProps {
|
||||
messages: ChatMessage[];
|
||||
agents: AgentProfile[];
|
||||
loading: boolean;
|
||||
onSend: (message: string) => Promise<void>;
|
||||
}
|
||||
|
||||
export function AgentChat({ messages, loading, onSend }: AgentChatProps) {
|
||||
const [input, setInput] = useState('');
|
||||
const scrollRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Auto-scroll to bottom
|
||||
useEffect(() => {
|
||||
if (scrollRef.current) {
|
||||
scrollRef.current.scrollTop = scrollRef.current.scrollHeight;
|
||||
}
|
||||
}, [messages]);
|
||||
|
||||
const handleSend = async () => {
|
||||
const trimmed = input.trim();
|
||||
if (!trimmed || loading) return;
|
||||
|
||||
setInput('');
|
||||
await onSend(trimmed);
|
||||
};
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
if (e.key === 'Enter' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
handleSend();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex flex-col w-80 border-l border-gray-200 dark:border-gray-700 bg-white dark:bg-gray-800">
|
||||
{/* Header */}
|
||||
<div className="px-3 py-2 border-b border-gray-200 dark:border-gray-700">
|
||||
<h3 className="text-sm font-medium text-gray-700 dark:text-gray-300">
|
||||
Classroom Chat
|
||||
</h3>
|
||||
</div>
|
||||
|
||||
{/* Messages */}
|
||||
<div ref={scrollRef} className="flex-1 overflow-auto p-3 space-y-3">
|
||||
{messages.length === 0 ? (
|
||||
<div className="text-center text-xs text-gray-400 py-8">
|
||||
Start a conversation with the classroom
|
||||
</div>
|
||||
) : (
|
||||
messages.map((msg) => {
|
||||
const isUser = msg.role === 'user';
|
||||
|
||||
return (
|
||||
<div key={msg.id} className={`flex gap-2 ${isUser ? 'justify-end' : ''}`}>
|
||||
{/* Avatar */}
|
||||
{!isUser && (
|
||||
<span
|
||||
className="flex-shrink-0 w-7 h-7 rounded-full flex items-center justify-center text-xs"
|
||||
style={{ backgroundColor: msg.color + '20' }}
|
||||
>
|
||||
{msg.agentAvatar}
|
||||
</span>
|
||||
)}
|
||||
|
||||
{/* Message bubble */}
|
||||
<div className={`max-w-[200px] ${isUser ? 'text-right' : ''}`}>
|
||||
{!isUser && (
|
||||
<span className="text-xs font-medium" style={{ color: msg.color }}>
|
||||
{msg.agentName}
|
||||
</span>
|
||||
)}
|
||||
<div
|
||||
className={`text-sm px-3 py-1.5 rounded-lg ${
|
||||
isUser
|
||||
? 'bg-indigo-500 text-white'
|
||||
: 'bg-gray-100 dark:bg-gray-700 text-gray-800 dark:text-gray-200'
|
||||
}`}
|
||||
>
|
||||
{msg.content}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Input */}
|
||||
<div className="px-3 py-2 border-t border-gray-200 dark:border-gray-700">
|
||||
<div className="flex gap-2">
|
||||
<input
|
||||
type="text"
|
||||
value={input}
|
||||
onChange={(e) => setInput(e.target.value)}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder="Ask a question..."
|
||||
disabled={loading}
|
||||
className="flex-1 px-2 py-1.5 text-sm rounded border border-gray-300 dark:border-gray-600 bg-white dark:bg-gray-700 text-gray-900 dark:text-white focus:outline-none focus:ring-1 focus:ring-indigo-400 disabled:opacity-50"
|
||||
/>
|
||||
<button
|
||||
onClick={handleSend}
|
||||
disabled={loading || !input.trim()}
|
||||
className="px-3 py-1.5 text-sm rounded bg-indigo-500 text-white disabled:opacity-50 hover:bg-indigo-600"
|
||||
>
|
||||
{loading ? '...' : 'Send'}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
231
desktop/src/components/classroom_player/ClassroomPlayer.tsx
Normal file
231
desktop/src/components/classroom_player/ClassroomPlayer.tsx
Normal file
@@ -0,0 +1,231 @@
|
||||
/**
|
||||
* ClassroomPlayer — Full-screen interactive classroom player.
|
||||
*
|
||||
* Layout: Notes sidebar | Main stage | Chat panel
|
||||
* Top: Title + Agent avatars
|
||||
* Bottom: Scene navigation + playback controls
|
||||
*/
|
||||
|
||||
import { useState, useCallback, useEffect } from 'react';
|
||||
import { invoke } from '@tauri-apps/api/core';
|
||||
import { useClassroom } from '../../hooks/useClassroom';
|
||||
import { SceneRenderer } from './SceneRenderer';
|
||||
import { AgentChat } from './AgentChat';
|
||||
import { NotesSidebar } from './NotesSidebar';
|
||||
import { TtsPlayer } from './TtsPlayer';
|
||||
import { Download } from 'lucide-react';
|
||||
|
||||
interface ClassroomPlayerProps {
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export function ClassroomPlayer({ onClose }: ClassroomPlayerProps) {
|
||||
const {
|
||||
activeClassroom,
|
||||
chatMessages,
|
||||
chatLoading,
|
||||
sendChatMessage,
|
||||
} = useClassroom();
|
||||
|
||||
const [currentSceneIndex, setCurrentSceneIndex] = useState(0);
|
||||
const [sidebarOpen, setSidebarOpen] = useState(true);
|
||||
const [chatOpen, setChatOpen] = useState(true);
|
||||
const [exporting, setExporting] = useState(false);
|
||||
|
||||
const classroom = activeClassroom;
|
||||
const scenes = classroom?.scenes ?? [];
|
||||
const agents = classroom?.agents ?? [];
|
||||
const currentScene = scenes[currentSceneIndex] ?? null;
|
||||
|
||||
// Navigate to next/prev scene
|
||||
const goNext = useCallback(() => {
|
||||
setCurrentSceneIndex((i) => Math.min(i + 1, scenes.length - 1));
|
||||
}, [scenes.length]);
|
||||
|
||||
const goPrev = useCallback(() => {
|
||||
setCurrentSceneIndex((i) => Math.max(i - 1, 0));
|
||||
}, []);
|
||||
|
||||
// Keyboard shortcuts
|
||||
useEffect(() => {
|
||||
const handler = (e: KeyboardEvent) => {
|
||||
if (e.key === 'ArrowRight') goNext();
|
||||
else if (e.key === 'ArrowLeft') goPrev();
|
||||
else if (e.key === 'Escape') onClose();
|
||||
};
|
||||
window.addEventListener('keydown', handler);
|
||||
return () => window.removeEventListener('keydown', handler);
|
||||
}, [goNext, goPrev, onClose]);
|
||||
|
||||
// Chat handler
|
||||
const handleChatSend = useCallback(async (message: string) => {
|
||||
const sceneContext = currentScene?.content.title;
|
||||
await sendChatMessage(message, sceneContext);
|
||||
}, [sendChatMessage, currentScene]);
|
||||
|
||||
// Export handler
|
||||
const handleExport = useCallback(async (format: 'html' | 'markdown' | 'json') => {
|
||||
if (!classroom) return;
|
||||
setExporting(true);
|
||||
try {
|
||||
const result = await invoke<{ content: string; filename: string; mimeType: string }>(
|
||||
'classroom_export',
|
||||
{ request: { classroomId: classroom.id, format } }
|
||||
);
|
||||
// Download the exported file
|
||||
const blob = new Blob([result.content], { type: result.mimeType });
|
||||
const url = URL.createObjectURL(blob);
|
||||
const a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = result.filename;
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
} catch (e) {
|
||||
console.error('Export failed:', e);
|
||||
} finally {
|
||||
setExporting(false);
|
||||
}
|
||||
}, [classroom]);
|
||||
|
||||
if (!classroom) {
|
||||
return (
|
||||
<div className="flex items-center justify-center h-full text-gray-500">
|
||||
No classroom loaded
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col h-full bg-gray-50 dark:bg-gray-900">
|
||||
{/* Header */}
|
||||
<header className="flex items-center justify-between px-4 py-2 border-b border-gray-200 dark:border-gray-700 bg-white dark:bg-gray-800">
|
||||
<div className="flex items-center gap-3">
|
||||
<button
|
||||
onClick={onClose}
|
||||
className="p-1 rounded hover:bg-gray-100 dark:hover:bg-gray-700"
|
||||
aria-label="Close classroom"
|
||||
>
|
||||
←
|
||||
</button>
|
||||
<h1 className="text-lg font-semibold text-gray-900 dark:text-white truncate max-w-md">
|
||||
{classroom.title}
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
{/* Agent avatars */}
|
||||
<div className="flex items-center gap-1">
|
||||
{agents.map((agent) => (
|
||||
<span
|
||||
key={agent.id}
|
||||
className="inline-flex items-center justify-center w-8 h-8 rounded-full text-sm"
|
||||
style={{ backgroundColor: agent.color + '20', color: agent.color }}
|
||||
title={agent.name}
|
||||
>
|
||||
{agent.avatar}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
onClick={() => setSidebarOpen(!sidebarOpen)}
|
||||
className={`px-2 py-1 rounded text-xs ${sidebarOpen ? 'bg-indigo-100 text-indigo-700' : 'text-gray-500'}`}
|
||||
>
|
||||
Notes
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setChatOpen(!chatOpen)}
|
||||
className={`px-2 py-1 rounded text-xs ${chatOpen ? 'bg-indigo-100 text-indigo-700' : 'text-gray-500'}`}
|
||||
>
|
||||
Chat
|
||||
</button>
|
||||
{/* Export dropdown */}
|
||||
<div className="relative group">
|
||||
<button
|
||||
disabled={exporting}
|
||||
className="px-2 py-1 rounded text-xs text-gray-500 hover:text-gray-700 flex items-center gap-1"
|
||||
title="导出课堂"
|
||||
>
|
||||
<Download className="w-3.5 h-3.5" />
|
||||
{exporting ? '...' : '导出'}
|
||||
</button>
|
||||
<div className="absolute right-0 top-full mt-1 bg-white dark:bg-gray-800 border border-gray-200 dark:border-gray-700 rounded shadow-lg hidden group-hover:block z-10">
|
||||
<button onClick={() => handleExport('html')} className="block w-full text-left px-3 py-1.5 text-xs text-gray-700 dark:text-gray-300 hover:bg-gray-100 dark:hover:bg-gray-700">HTML</button>
|
||||
<button onClick={() => handleExport('markdown')} className="block w-full text-left px-3 py-1.5 text-xs text-gray-700 dark:text-gray-300 hover:bg-gray-100 dark:hover:bg-gray-700">Markdown</button>
|
||||
<button onClick={() => handleExport('json')} className="block w-full text-left px-3 py-1.5 text-xs text-gray-700 dark:text-gray-300 hover:bg-gray-100 dark:hover:bg-gray-700">JSON</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
{/* Main content */}
|
||||
<div className="flex flex-1 overflow-hidden">
|
||||
{/* Notes sidebar */}
|
||||
{sidebarOpen && (
|
||||
<NotesSidebar
|
||||
scenes={scenes}
|
||||
currentIndex={currentSceneIndex}
|
||||
onSelectScene={setCurrentSceneIndex}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Main stage */}
|
||||
<main className="flex-1 overflow-auto p-4">
|
||||
{currentScene ? (
|
||||
<SceneRenderer key={currentScene.id} scene={currentScene} agents={agents} />
|
||||
) : (
|
||||
<div className="flex items-center justify-center h-full text-gray-400">
|
||||
No scenes available
|
||||
</div>
|
||||
)}
|
||||
</main>
|
||||
|
||||
{/* Chat panel */}
|
||||
{chatOpen && (
|
||||
<AgentChat
|
||||
messages={chatMessages}
|
||||
agents={agents}
|
||||
loading={chatLoading}
|
||||
onSend={handleChatSend}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Bottom navigation */}
|
||||
<footer className="flex items-center justify-between px-4 py-2 border-t border-gray-200 dark:border-gray-700 bg-white dark:bg-gray-800">
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
onClick={goPrev}
|
||||
disabled={currentSceneIndex === 0}
|
||||
className="px-3 py-1 rounded text-sm bg-gray-100 dark:bg-gray-700 disabled:opacity-50"
|
||||
>
|
||||
Previous
|
||||
</button>
|
||||
<span className="text-sm text-gray-500">
|
||||
{currentSceneIndex + 1} / {scenes.length}
|
||||
</span>
|
||||
<button
|
||||
onClick={goNext}
|
||||
disabled={currentSceneIndex >= scenes.length - 1}
|
||||
className="px-3 py-1 rounded text-sm bg-gray-100 dark:bg-gray-700 disabled:opacity-50"
|
||||
>
|
||||
Next
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* TTS + Scene info */}
|
||||
<div className="flex items-center gap-3">
|
||||
{currentScene?.content.notes && (
|
||||
<TtsPlayer text={currentScene.content.notes} />
|
||||
)}
|
||||
<div className="text-xs text-gray-400">
|
||||
{currentScene?.content.sceneType ?? ''}
|
||||
{currentScene?.content.durationSeconds
|
||||
? ` · ${Math.floor(currentScene.content.durationSeconds / 60)}:${String(currentScene.content.durationSeconds % 60).padStart(2, '0')}`
|
||||
: ''}
|
||||
</div>
|
||||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
71
desktop/src/components/classroom_player/NotesSidebar.tsx
Normal file
71
desktop/src/components/classroom_player/NotesSidebar.tsx
Normal file
@@ -0,0 +1,71 @@
|
||||
/**
|
||||
* NotesSidebar — Scene outline navigation + notes.
|
||||
*
|
||||
* Left panel showing all scenes as clickable items with notes.
|
||||
*/
|
||||
|
||||
import type { GeneratedScene } from '../../types/classroom';
|
||||
|
||||
interface NotesSidebarProps {
|
||||
scenes: GeneratedScene[];
|
||||
currentIndex: number;
|
||||
onSelectScene: (index: number) => void;
|
||||
}
|
||||
|
||||
export function NotesSidebar({ scenes, currentIndex, onSelectScene }: NotesSidebarProps) {
|
||||
return (
|
||||
<div className="w-64 border-r border-gray-200 dark:border-gray-700 bg-white dark:bg-gray-800 overflow-auto">
|
||||
<div className="px-3 py-2 border-b border-gray-200 dark:border-gray-700">
|
||||
<h3 className="text-xs font-semibold text-gray-500 uppercase tracking-wider">
|
||||
Outline
|
||||
</h3>
|
||||
</div>
|
||||
|
||||
<nav className="py-1">
|
||||
{scenes.map((scene, i) => {
|
||||
const isActive = i === currentIndex;
|
||||
const typeColor = getTypeColor(scene.content.sceneType);
|
||||
|
||||
return (
|
||||
<button
|
||||
key={scene.id}
|
||||
onClick={() => onSelectScene(i)}
|
||||
className={`w-full text-left px-3 py-2 text-sm border-l-2 transition-colors ${
|
||||
isActive
|
||||
? 'border-indigo-500 bg-indigo-50 dark:bg-indigo-900/20'
|
||||
: 'border-transparent hover:bg-gray-50 dark:hover:bg-gray-700/50'
|
||||
}`}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<span
|
||||
className="inline-block w-1.5 h-1.5 rounded-full"
|
||||
style={{ backgroundColor: typeColor }}
|
||||
/>
|
||||
<span className={`font-medium ${isActive ? 'text-indigo-700 dark:text-indigo-300' : 'text-gray-700 dark:text-gray-300'}`}>
|
||||
{i + 1}. {scene.content.title}
|
||||
</span>
|
||||
</div>
|
||||
{scene.content.notes && (
|
||||
<p className="text-xs text-gray-400 mt-0.5 ml-3.5 line-clamp-2">
|
||||
{scene.content.notes}
|
||||
</p>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
})}
|
||||
</nav>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function getTypeColor(type: string): string {
|
||||
switch (type) {
|
||||
case 'slide': return '#6366F1';
|
||||
case 'quiz': return '#F59E0B';
|
||||
case 'discussion': return '#10B981';
|
||||
case 'interactive': return '#8B5CF6';
|
||||
case 'pbl': return '#EF4444';
|
||||
case 'media': return '#06B6D4';
|
||||
default: return '#9CA3AF';
|
||||
}
|
||||
}
|
||||
219
desktop/src/components/classroom_player/SceneRenderer.tsx
Normal file
219
desktop/src/components/classroom_player/SceneRenderer.tsx
Normal file
@@ -0,0 +1,219 @@
|
||||
/**
|
||||
* SceneRenderer — Renders a single classroom scene.
|
||||
*
|
||||
* Supports scene types: slide, quiz, discussion, interactive, text, pbl, media.
|
||||
* Executes scene actions (speech, whiteboard, quiz, discussion).
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import type { GeneratedScene, SceneContent, SceneAction, AgentProfile } from '../../types/classroom';
|
||||
|
||||
interface SceneRendererProps {
|
||||
scene: GeneratedScene;
|
||||
agents: AgentProfile[];
|
||||
autoPlay?: boolean;
|
||||
}
|
||||
|
||||
export function SceneRenderer({ scene, agents, autoPlay = true }: SceneRendererProps) {
|
||||
const { content } = scene;
|
||||
const [actionIndex, setActionIndex] = useState(0);
|
||||
const [isPlaying, setIsPlaying] = useState(autoPlay);
|
||||
const [whiteboardItems, setWhiteboardItems] = useState<Array<{
|
||||
type: string;
|
||||
data: SceneAction;
|
||||
}>>([]);
|
||||
|
||||
const actions = content.actions ?? [];
|
||||
const currentAction = actions[actionIndex] ?? null;
|
||||
|
||||
// Auto-advance through actions
|
||||
useEffect(() => {
|
||||
if (!isPlaying || actions.length === 0) return;
|
||||
if (actionIndex >= actions.length) {
|
||||
setIsPlaying(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const delay = getActionDelay(actions[actionIndex]);
|
||||
const timer = setTimeout(() => {
|
||||
processAction(actions[actionIndex]);
|
||||
setActionIndex((i) => i + 1);
|
||||
}, delay);
|
||||
|
||||
return () => clearTimeout(timer);
|
||||
}, [actionIndex, isPlaying, actions]);
|
||||
|
||||
const processAction = useCallback((action: SceneAction) => {
|
||||
switch (action.type) {
|
||||
case 'whiteboard_draw_text':
|
||||
case 'whiteboard_draw_shape':
|
||||
case 'whiteboard_draw_chart':
|
||||
case 'whiteboard_draw_latex':
|
||||
setWhiteboardItems((prev) => [...prev, { type: action.type, data: action }]);
|
||||
break;
|
||||
case 'whiteboard_clear':
|
||||
setWhiteboardItems([]);
|
||||
break;
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Render scene based on type
|
||||
return (
|
||||
<div className="flex flex-col h-full">
|
||||
{/* Scene title */}
|
||||
<div className="mb-4">
|
||||
<h2 className="text-2xl font-bold text-gray-900 dark:text-white">
|
||||
{content.title}
|
||||
</h2>
|
||||
{content.notes && (
|
||||
<p className="text-sm text-gray-500 mt-1">{content.notes}</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Main content area */}
|
||||
<div className="flex-1 flex gap-4 overflow-hidden">
|
||||
{/* Content panel */}
|
||||
<div className="flex-1 overflow-auto">
|
||||
{renderContent(content)}
|
||||
</div>
|
||||
|
||||
{/* Whiteboard area */}
|
||||
{whiteboardItems.length > 0 && (
|
||||
<div className="w-80 border border-gray-200 dark:border-gray-700 rounded-lg bg-white dark:bg-gray-800 p-2 overflow-auto">
|
||||
<svg viewBox="0 0 800 600" className="w-full h-full">
|
||||
{whiteboardItems.map((item, i) => (
|
||||
<g key={i}>{renderWhiteboardItem(item)}</g>
|
||||
))}
|
||||
</svg>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Current action indicator */}
|
||||
{currentAction && (
|
||||
<div className="mt-4 p-3 rounded-lg bg-indigo-50 dark:bg-indigo-900/20 border border-indigo-100 dark:border-indigo-800">
|
||||
{renderCurrentAction(currentAction, agents)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Playback controls */}
|
||||
<div className="flex items-center justify-center gap-2 mt-4">
|
||||
<button
|
||||
onClick={() => { setActionIndex(0); setWhiteboardItems([]); }}
|
||||
className="px-2 py-1 text-xs rounded bg-gray-100 dark:bg-gray-700"
|
||||
>
|
||||
Restart
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setIsPlaying(!isPlaying)}
|
||||
className="px-3 py-1 text-sm rounded bg-indigo-500 text-white"
|
||||
>
|
||||
{isPlaying ? 'Pause' : 'Play'}
|
||||
</button>
|
||||
<span className="text-xs text-gray-400">
|
||||
Action {Math.min(actionIndex + 1, actions.length)} / {actions.length}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function getActionDelay(action: SceneAction): number {
|
||||
switch (action.type) {
|
||||
case 'speech': return 2000;
|
||||
case 'whiteboard_draw_text': return 800;
|
||||
case 'whiteboard_draw_shape': return 600;
|
||||
case 'quiz_show': return 5000;
|
||||
case 'discussion': return 10000;
|
||||
default: return 1000;
|
||||
}
|
||||
}
|
||||
|
||||
function renderContent(content: SceneContent) {
|
||||
const data = content.content;
|
||||
if (!data || typeof data !== 'object') return null;
|
||||
|
||||
// Handle slide content
|
||||
const keyPoints = data.key_points as string[] | undefined;
|
||||
const description = data.description as string | undefined;
|
||||
const slides = data.slides as Array<{ title: string; content: string }> | undefined;
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
{description && (
|
||||
<p className="text-gray-700 dark:text-gray-300 leading-relaxed">{description}</p>
|
||||
)}
|
||||
{keyPoints && keyPoints.length > 0 && (
|
||||
<ul className="space-y-2">
|
||||
{keyPoints.map((point, i) => (
|
||||
<li key={i} className="flex items-start gap-2">
|
||||
<span className="text-indigo-500 mt-0.5">●</span>
|
||||
<span className="text-gray-700 dark:text-gray-300">{point}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
)}
|
||||
{slides && slides.map((slide, i) => (
|
||||
<div key={i} className="p-3 rounded border border-gray-200 dark:border-gray-700">
|
||||
<h4 className="font-medium text-gray-900 dark:text-white">{slide.title}</h4>
|
||||
<p className="text-sm text-gray-600 dark:text-gray-400 mt-1">{slide.content}</p>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function renderCurrentAction(action: SceneAction, agents: AgentProfile[]) {
|
||||
switch (action.type) {
|
||||
case 'speech': {
|
||||
const agent = agents.find(a => a.role === action.agentRole);
|
||||
return (
|
||||
<div className="flex items-start gap-2">
|
||||
<span className="text-lg">{agent?.avatar ?? '💬'}</span>
|
||||
<div>
|
||||
<span className="text-xs font-medium text-gray-600">{agent?.name ?? action.agentRole}</span>
|
||||
<p className="text-sm text-gray-700 dark:text-gray-300">{action.text}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
case 'quiz_show':
|
||||
return <div className="text-sm text-amber-600">Quiz: {action.quizId}</div>;
|
||||
case 'discussion':
|
||||
return <div className="text-sm text-green-600">Discussion: {action.topic}</div>;
|
||||
default:
|
||||
return <div className="text-xs text-gray-400">{action.type}</div>;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
function renderWhiteboardItem(item: { type: string; data: Record<string, unknown> }) {
|
||||
switch (item.type) {
|
||||
case 'whiteboard_draw_text': {
|
||||
const d = item.data;
|
||||
if ('text' in d && 'x' in d && 'y' in d) {
|
||||
return (
|
||||
<text x={typeof d.x === 'number' ? d.x : 100} y={typeof d.y === 'number' ? d.y : 100} fontSize={typeof d.fontSize === 'number' ? d.fontSize : 16} fill={typeof d.color === 'string' ? d.color : '#333'}>
|
||||
{String(d.text ?? '')}
|
||||
</text>
|
||||
);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
case 'whiteboard_draw_shape': {
|
||||
const d = item.data as Record<string, unknown>;
|
||||
const x = typeof d.x === 'number' ? d.x : 0;
|
||||
const y = typeof d.y === 'number' ? d.y : 0;
|
||||
const w = typeof d.width === 'number' ? d.width : 100;
|
||||
const h = typeof d.height === 'number' ? d.height : 50;
|
||||
const fill = typeof d.fill === 'string' ? d.fill : '#e5e5e5';
|
||||
return (
|
||||
<rect x={x} y={y} width={w} height={h} fill={fill} />
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
155
desktop/src/components/classroom_player/TtsPlayer.tsx
Normal file
155
desktop/src/components/classroom_player/TtsPlayer.tsx
Normal file
@@ -0,0 +1,155 @@
|
||||
/**
|
||||
* TtsPlayer — Text-to-Speech playback controls for classroom narration.
|
||||
*
|
||||
* Uses the browser's built-in SpeechSynthesis API.
|
||||
* Provides play/pause, speed, and volume controls.
|
||||
*/
|
||||
|
||||
import { useState, useEffect, useCallback, useRef } from 'react';
|
||||
import { Volume2, VolumeX, Pause, Play, Gauge } from 'lucide-react';
|
||||
|
||||
interface TtsPlayerProps {
|
||||
text: string;
|
||||
autoPlay?: boolean;
|
||||
onEnd?: () => void;
|
||||
}
|
||||
|
||||
export function TtsPlayer({ text, autoPlay = false, onEnd }: TtsPlayerProps) {
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
const [isPaused, setIsPaused] = useState(false);
|
||||
const [rate, setRate] = useState(1.0);
|
||||
const [isMuted, setIsMuted] = useState(false);
|
||||
const utteranceRef = useRef<SpeechSynthesisUtterance | null>(null);
|
||||
|
||||
const speak = useCallback(() => {
|
||||
if (!text || typeof window === 'undefined') return;
|
||||
|
||||
window.speechSynthesis.cancel();
|
||||
|
||||
const utterance = new SpeechSynthesisUtterance(text);
|
||||
utterance.lang = 'zh-CN';
|
||||
utterance.rate = rate;
|
||||
utterance.volume = isMuted ? 0 : 1;
|
||||
|
||||
utterance.onend = () => {
|
||||
setIsPlaying(false);
|
||||
setIsPaused(false);
|
||||
onEnd?.();
|
||||
};
|
||||
utterance.onerror = () => {
|
||||
setIsPlaying(false);
|
||||
setIsPaused(false);
|
||||
};
|
||||
|
||||
utteranceRef.current = utterance;
|
||||
window.speechSynthesis.speak(utterance);
|
||||
setIsPlaying(true);
|
||||
setIsPaused(false);
|
||||
}, [text, rate, isMuted, onEnd]);
|
||||
|
||||
const togglePlay = useCallback(() => {
|
||||
if (isPlaying && !isPaused) {
|
||||
window.speechSynthesis.pause();
|
||||
setIsPaused(true);
|
||||
} else if (isPaused) {
|
||||
window.speechSynthesis.resume();
|
||||
setIsPaused(false);
|
||||
} else {
|
||||
speak();
|
||||
}
|
||||
}, [isPlaying, isPaused, speak]);
|
||||
|
||||
const stop = useCallback(() => {
|
||||
window.speechSynthesis.cancel();
|
||||
setIsPlaying(false);
|
||||
setIsPaused(false);
|
||||
}, []);
|
||||
|
||||
// Auto-play when text changes
|
||||
useEffect(() => {
|
||||
if (autoPlay && text) {
|
||||
speak();
|
||||
}
|
||||
return () => {
|
||||
if (typeof window !== 'undefined') {
|
||||
window.speechSynthesis.cancel();
|
||||
}
|
||||
};
|
||||
}, [text, autoPlay, speak]);
|
||||
|
||||
// Cleanup on unmount
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (typeof window !== 'undefined') {
|
||||
window.speechSynthesis.cancel();
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
if (!text) return null;
|
||||
|
||||
return (
|
||||
<div className="flex items-center gap-3 px-3 py-2 rounded-lg bg-gray-50 dark:bg-gray-800 border border-gray-200 dark:border-gray-700">
|
||||
{/* Play/Pause button */}
|
||||
<button
|
||||
onClick={togglePlay}
|
||||
className="w-8 h-8 flex items-center justify-center rounded-full bg-indigo-500 text-white hover:bg-indigo-600 transition-colors"
|
||||
aria-label={isPlaying && !isPaused ? '暂停' : '播放'}
|
||||
>
|
||||
{isPlaying && !isPaused ? (
|
||||
<Pause className="w-4 h-4" />
|
||||
) : (
|
||||
<Play className="w-4 h-4" />
|
||||
)}
|
||||
</button>
|
||||
|
||||
{/* Stop button */}
|
||||
{isPlaying && (
|
||||
<button
|
||||
onClick={stop}
|
||||
className="w-6 h-6 flex items-center justify-center rounded text-gray-500 hover:text-gray-700 dark:hover:text-gray-300"
|
||||
aria-label="停止"
|
||||
>
|
||||
<VolumeX className="w-3.5 h-3.5" />
|
||||
</button>
|
||||
)}
|
||||
|
||||
{/* Speed control */}
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Gauge className="w-3.5 h-3.5 text-gray-400" />
|
||||
<select
|
||||
value={rate}
|
||||
onChange={(e) => setRate(Number(e.target.value))}
|
||||
className="text-xs bg-transparent border-none text-gray-600 dark:text-gray-400 cursor-pointer"
|
||||
>
|
||||
<option value={0.5}>0.5x</option>
|
||||
<option value={0.75}>0.75x</option>
|
||||
<option value={1}>1x</option>
|
||||
<option value={1.25}>1.25x</option>
|
||||
<option value={1.5}>1.5x</option>
|
||||
<option value={2}>2x</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
{/* Mute toggle */}
|
||||
<button
|
||||
onClick={() => setIsMuted(!isMuted)}
|
||||
className="text-gray-400 hover:text-gray-600 dark:hover:text-gray-300"
|
||||
aria-label={isMuted ? '取消静音' : '静音'}
|
||||
>
|
||||
{isMuted ? (
|
||||
<VolumeX className="w-4 h-4" />
|
||||
) : (
|
||||
<Volume2 className="w-4 h-4" />
|
||||
)}
|
||||
</button>
|
||||
|
||||
{/* Status indicator */}
|
||||
{isPlaying && (
|
||||
<span className="text-xs text-gray-400">
|
||||
{isPaused ? '已暂停' : '朗读中...'}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
295
desktop/src/components/classroom_player/WhiteboardCanvas.tsx
Normal file
295
desktop/src/components/classroom_player/WhiteboardCanvas.tsx
Normal file
@@ -0,0 +1,295 @@
|
||||
/**
|
||||
* WhiteboardCanvas — SVG-based whiteboard for classroom scene rendering.
|
||||
*
|
||||
* Supports incremental drawing operations:
|
||||
* - Text (positioned labels)
|
||||
* - Shapes (rectangles, circles, arrows)
|
||||
* - Charts (bar/line/pie via simple SVG)
|
||||
* - LaTeX (rendered as styled text blocks)
|
||||
*/
|
||||
|
||||
import { useCallback } from 'react';
|
||||
import type { SceneAction } from '../../types/classroom';
|
||||
|
||||
interface WhiteboardCanvasProps {
|
||||
items: WhiteboardItem[];
|
||||
width?: number;
|
||||
height?: number;
|
||||
}
|
||||
|
||||
export interface WhiteboardItem {
|
||||
type: string;
|
||||
data: SceneAction;
|
||||
}
|
||||
|
||||
export function WhiteboardCanvas({
|
||||
items,
|
||||
width = 800,
|
||||
height = 600,
|
||||
}: WhiteboardCanvasProps) {
|
||||
const renderItem = useCallback((item: WhiteboardItem, index: number) => {
|
||||
switch (item.type) {
|
||||
case 'whiteboard_draw_text':
|
||||
return <TextItem key={index} data={item.data as TextDrawData} />;
|
||||
case 'whiteboard_draw_shape':
|
||||
return <ShapeItem key={index} data={item.data as ShapeDrawData} />;
|
||||
case 'whiteboard_draw_chart':
|
||||
return <ChartItem key={index} data={item.data as ChartDrawData} />;
|
||||
case 'whiteboard_draw_latex':
|
||||
return <LatexItem key={index} data={item.data as LatexDrawData} />;
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="w-full h-full border border-gray-200 dark:border-gray-700 rounded-lg bg-white dark:bg-gray-900 overflow-auto">
|
||||
<svg
|
||||
viewBox={`0 0 ${width} ${height}`}
|
||||
className="w-full h-full"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
{/* Grid background */}
|
||||
<defs>
|
||||
<pattern id="grid" width="40" height="40" patternUnits="userSpaceOnUse">
|
||||
<path d="M 40 0 L 0 0 0 40" fill="none" stroke="#f0f0f0" strokeWidth="0.5" />
|
||||
</pattern>
|
||||
</defs>
|
||||
<rect width={width} height={height} fill="url(#grid)" />
|
||||
|
||||
{/* Rendered items */}
|
||||
{items.map((item, i) => renderItem(item, i))}
|
||||
</svg>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sub-components
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface TextDrawData {
|
||||
type: 'whiteboard_draw_text';
|
||||
x: number;
|
||||
y: number;
|
||||
text: string;
|
||||
fontSize?: number;
|
||||
color?: string;
|
||||
}
|
||||
|
||||
function TextItem({ data }: { data: TextDrawData }) {
|
||||
return (
|
||||
<text
|
||||
x={data.x}
|
||||
y={data.y}
|
||||
fontSize={data.fontSize ?? 16}
|
||||
fill={data.color ?? '#333333'}
|
||||
fontFamily="system-ui, sans-serif"
|
||||
>
|
||||
{data.text}
|
||||
</text>
|
||||
);
|
||||
}
|
||||
|
||||
interface ShapeDrawData {
|
||||
type: 'whiteboard_draw_shape';
|
||||
shape: string;
|
||||
x: number;
|
||||
y: number;
|
||||
width: number;
|
||||
height: number;
|
||||
fill?: string;
|
||||
}
|
||||
|
||||
function ShapeItem({ data }: { data: ShapeDrawData }) {
|
||||
switch (data.shape) {
|
||||
case 'circle':
|
||||
return (
|
||||
<ellipse
|
||||
cx={data.x + data.width / 2}
|
||||
cy={data.y + data.height / 2}
|
||||
rx={data.width / 2}
|
||||
ry={data.height / 2}
|
||||
fill={data.fill ?? '#e5e7eb'}
|
||||
stroke="#9ca3af"
|
||||
strokeWidth={1}
|
||||
/>
|
||||
);
|
||||
case 'arrow':
|
||||
return (
|
||||
<g>
|
||||
<line
|
||||
x1={data.x}
|
||||
y1={data.y + data.height / 2}
|
||||
x2={data.x + data.width}
|
||||
y2={data.y + data.height / 2}
|
||||
stroke={data.fill ?? '#6b7280'}
|
||||
strokeWidth={2}
|
||||
markerEnd="url(#arrowhead)"
|
||||
/>
|
||||
<defs>
|
||||
<marker id="arrowhead" markerWidth="10" markerHeight="7" refX="10" refY="3.5" orient="auto">
|
||||
<polygon points="0 0, 10 3.5, 0 7" fill={data.fill ?? '#6b7280'} />
|
||||
</marker>
|
||||
</defs>
|
||||
</g>
|
||||
);
|
||||
default: // rectangle
|
||||
return (
|
||||
<rect
|
||||
x={data.x}
|
||||
y={data.y}
|
||||
width={data.width}
|
||||
height={data.height}
|
||||
fill={data.fill ?? '#e5e7eb'}
|
||||
stroke="#9ca3af"
|
||||
strokeWidth={1}
|
||||
rx={4}
|
||||
/>
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
interface ChartDrawData {
|
||||
type: 'whiteboard_draw_chart';
|
||||
chartType: string;
|
||||
data: Record<string, unknown>;
|
||||
x: number;
|
||||
y: number;
|
||||
width: number;
|
||||
height: number;
|
||||
}
|
||||
|
||||
function ChartItem({ data }: { data: ChartDrawData }) {
|
||||
const chartData = data.data;
|
||||
const labels = (chartData?.labels as string[]) ?? [];
|
||||
const values = (chartData?.values as number[]) ?? [];
|
||||
|
||||
if (labels.length === 0 || values.length === 0) return null;
|
||||
|
||||
switch (data.chartType) {
|
||||
case 'bar':
|
||||
return <BarChart data={data} labels={labels} values={values} />;
|
||||
case 'line':
|
||||
return <LineChart data={data} labels={labels} values={values} />;
|
||||
default:
|
||||
return <BarChart data={data} labels={labels} values={values} />;
|
||||
}
|
||||
}
|
||||
|
||||
function BarChart({ data, labels, values }: {
|
||||
data: ChartDrawData;
|
||||
labels: string[];
|
||||
values: number[];
|
||||
}) {
|
||||
const maxVal = Math.max(...values, 1);
|
||||
const barWidth = data.width / (labels.length * 2);
|
||||
const chartHeight = data.height - 30;
|
||||
|
||||
return (
|
||||
<g transform={`translate(${data.x}, ${data.y})`}>
|
||||
{values.map((val, i) => {
|
||||
const barHeight = (val / maxVal) * chartHeight;
|
||||
return (
|
||||
<g key={i}>
|
||||
<rect
|
||||
x={i * (barWidth * 2) + barWidth / 2}
|
||||
y={chartHeight - barHeight}
|
||||
width={barWidth}
|
||||
height={barHeight}
|
||||
fill="#6366f1"
|
||||
rx={2}
|
||||
/>
|
||||
<text
|
||||
x={i * (barWidth * 2) + barWidth}
|
||||
y={data.height - 5}
|
||||
textAnchor="middle"
|
||||
fontSize={10}
|
||||
fill="#666"
|
||||
>
|
||||
{labels[i]}
|
||||
</text>
|
||||
</g>
|
||||
);
|
||||
})}
|
||||
</g>
|
||||
);
|
||||
}
|
||||
|
||||
function LineChart({ data, labels, values }: {
|
||||
data: ChartDrawData;
|
||||
labels: string[];
|
||||
values: number[];
|
||||
}) {
|
||||
const maxVal = Math.max(...values, 1);
|
||||
const chartHeight = data.height - 30;
|
||||
const stepX = data.width / Math.max(labels.length - 1, 1);
|
||||
|
||||
const points = values.map((val, i) => {
|
||||
const x = i * stepX;
|
||||
const y = chartHeight - (val / maxVal) * chartHeight;
|
||||
return `${x},${y}`;
|
||||
}).join(' ');
|
||||
|
||||
return (
|
||||
<g transform={`translate(${data.x}, ${data.y})`}>
|
||||
<polyline
|
||||
points={points}
|
||||
fill="none"
|
||||
stroke="#6366f1"
|
||||
strokeWidth={2}
|
||||
/>
|
||||
{values.map((val, i) => {
|
||||
const x = i * stepX;
|
||||
const y = chartHeight - (val / maxVal) * chartHeight;
|
||||
return (
|
||||
<g key={i}>
|
||||
<circle cx={x} cy={y} r={3} fill="#6366f1" />
|
||||
<text
|
||||
x={x}
|
||||
y={data.height - 5}
|
||||
textAnchor="middle"
|
||||
fontSize={10}
|
||||
fill="#666"
|
||||
>
|
||||
{labels[i]}
|
||||
</text>
|
||||
</g>
|
||||
);
|
||||
})}
|
||||
</g>
|
||||
);
|
||||
}
|
||||
|
||||
interface LatexDrawData {
|
||||
type: 'whiteboard_draw_latex';
|
||||
latex: string;
|
||||
x: number;
|
||||
y: number;
|
||||
}
|
||||
|
||||
function LatexItem({ data }: { data: LatexDrawData }) {
|
||||
return (
|
||||
<g transform={`translate(${data.x}, ${data.y})`}>
|
||||
<rect
|
||||
x={-4}
|
||||
y={-20}
|
||||
width={data.latex.length * 10 + 8}
|
||||
height={28}
|
||||
fill="#fef3c7"
|
||||
stroke="#f59e0b"
|
||||
strokeWidth={1}
|
||||
rx={4}
|
||||
/>
|
||||
<text
|
||||
x={0}
|
||||
y={0}
|
||||
fontSize={14}
|
||||
fill="#92400e"
|
||||
fontFamily="'Courier New', monospace"
|
||||
>
|
||||
{data.latex}
|
||||
</text>
|
||||
</g>
|
||||
);
|
||||
}
|
||||
12
desktop/src/components/classroom_player/index.ts
Normal file
12
desktop/src/components/classroom_player/index.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
/**
|
||||
* Classroom Player Components
|
||||
*
|
||||
* Re-exports all classroom player components.
|
||||
*/
|
||||
|
||||
export { ClassroomPlayer } from './ClassroomPlayer';
|
||||
export { SceneRenderer } from './SceneRenderer';
|
||||
export { AgentChat } from './AgentChat';
|
||||
export { NotesSidebar } from './NotesSidebar';
|
||||
export { WhiteboardCanvas } from './WhiteboardCanvas';
|
||||
export { TtsPlayer } from './TtsPlayer';
|
||||
76
desktop/src/hooks/useClassroom.ts
Normal file
76
desktop/src/hooks/useClassroom.ts
Normal file
@@ -0,0 +1,76 @@
|
||||
/**
|
||||
* useClassroom — React hook wrapping the classroom store for component consumption.
|
||||
*
|
||||
* Provides a simplified interface for classroom generation and chat.
|
||||
*/
|
||||
|
||||
import { useCallback } from 'react';
|
||||
import {
|
||||
useClassroomStore,
|
||||
type GenerationRequest,
|
||||
} from '../store/classroomStore';
|
||||
import type {
|
||||
Classroom,
|
||||
ClassroomChatMessage,
|
||||
} from '../types/classroom';
|
||||
|
||||
export interface UseClassroomReturn {
|
||||
/** Is generation in progress */
|
||||
generating: boolean;
|
||||
/** Current generation stage name */
|
||||
progressStage: string | null;
|
||||
/** Progress percentage 0-100 */
|
||||
progressPercent: number;
|
||||
/** The active classroom */
|
||||
activeClassroom: Classroom | null;
|
||||
/** Chat messages for active classroom */
|
||||
chatMessages: ClassroomChatMessage[];
|
||||
/** Is a chat request loading */
|
||||
chatLoading: boolean;
|
||||
/** Error message, if any */
|
||||
error: string | null;
|
||||
/** Start classroom generation */
|
||||
startGeneration: (request: GenerationRequest) => Promise<string>;
|
||||
/** Cancel active generation */
|
||||
cancelGeneration: () => void;
|
||||
/** Send a chat message in the active classroom */
|
||||
sendChatMessage: (message: string, sceneContext?: string) => Promise<void>;
|
||||
/** Clear current error */
|
||||
clearError: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Hook for classroom generation and multi-agent chat.
|
||||
*
|
||||
* Components should use this hook rather than accessing the store directly,
|
||||
* to keep the rendering logic decoupled from state management.
|
||||
*/
|
||||
export function useClassroom(): UseClassroomReturn {
|
||||
const {
|
||||
generating,
|
||||
progressStage,
|
||||
progressPercent,
|
||||
activeClassroom,
|
||||
chatMessages,
|
||||
chatLoading,
|
||||
error,
|
||||
startGeneration,
|
||||
cancelGeneration,
|
||||
sendChatMessage,
|
||||
clearError,
|
||||
} = useClassroomStore();
|
||||
|
||||
return {
|
||||
generating,
|
||||
progressStage,
|
||||
progressPercent,
|
||||
activeClassroom,
|
||||
chatMessages,
|
||||
chatLoading,
|
||||
error,
|
||||
startGeneration: useCallback((req: GenerationRequest) => startGeneration(req), [startGeneration]),
|
||||
cancelGeneration: useCallback(() => cancelGeneration(), [cancelGeneration]),
|
||||
sendChatMessage: useCallback((msg, ctx) => sendChatMessage(msg, ctx), [sendChatMessage]),
|
||||
clearError: useCallback(() => clearError(), [clearError]),
|
||||
};
|
||||
}
|
||||
@@ -1,27 +1,5 @@
|
||||
@import "tailwindcss";
|
||||
|
||||
/* Aurora gradient animation for welcome title (DeerFlow-inspired) */
|
||||
@keyframes gradient-shift {
|
||||
0%, 100% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
}
|
||||
|
||||
.aurora-title {
|
||||
background: linear-gradient(
|
||||
135deg,
|
||||
#f97316 0%, /* orange-500 */
|
||||
#ef4444 25%, /* red-500 */
|
||||
#f97316 50%, /* orange-500 */
|
||||
#fb923c 75%, /* orange-400 */
|
||||
#f97316 100% /* orange-500 */
|
||||
);
|
||||
background-size: 200% 200%;
|
||||
-webkit-background-clip: text;
|
||||
background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
animation: gradient-shift 4s ease infinite;
|
||||
}
|
||||
|
||||
:root {
|
||||
/* Brand Colors - 中性灰色系 */
|
||||
--color-primary: #374151; /* gray-700 */
|
||||
@@ -154,3 +132,38 @@ textarea:focus-visible {
|
||||
outline: none !important;
|
||||
box-shadow: none !important;
|
||||
}
|
||||
|
||||
/* === Accessibility: reduced motion === */
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
*, *::before, *::after {
|
||||
animation-duration: 0.01ms !important;
|
||||
animation-iteration-count: 1 !important;
|
||||
transition-duration: 0.01ms !important;
|
||||
scroll-behavior: auto !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* === Responsive breakpoints for small windows/tablets === */
|
||||
@media (max-width: 768px) {
|
||||
/* Auto-collapse sidebar aside on narrow viewports */
|
||||
aside.w-64 {
|
||||
width: 0 !important;
|
||||
min-width: 0 !important;
|
||||
overflow: hidden;
|
||||
border-right: none !important;
|
||||
}
|
||||
aside.w-64.sidebar-open {
|
||||
width: 260px !important;
|
||||
min-width: 260px !important;
|
||||
position: fixed;
|
||||
z-index: 50;
|
||||
height: 100vh;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
.chat-bubble-assistant,
|
||||
.chat-bubble-user {
|
||||
max-width: 95% !important;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
*
|
||||
* 为 ZCLAW 前端操作提供统一的审计日志记录功能。
|
||||
* 记录关键操作(Hand 触发、Agent 创建等)到本地存储。
|
||||
*
|
||||
* @reserved This module is reserved for future audit logging integration.
|
||||
* It is not currently imported by any component. When audit logging is needed,
|
||||
* import { logAudit, logAuditSuccess, logAuditFailure } from this module.
|
||||
*/
|
||||
|
||||
import { createLogger } from './logger';
|
||||
|
||||
142
desktop/src/lib/classroom-adapter.ts
Normal file
142
desktop/src/lib/classroom-adapter.ts
Normal file
@@ -0,0 +1,142 @@
|
||||
/**
|
||||
* Classroom Adapter
|
||||
*
|
||||
* Bridges the old ClassroomData type (ClassroomPreviewer) with the new
|
||||
* Classroom type (ClassroomPlayer + Tauri backend).
|
||||
*/
|
||||
|
||||
import type { Classroom, GeneratedScene } from '../types/classroom';
|
||||
import { SceneType, TeachingStyle, DifficultyLevel } from '../types/classroom';
|
||||
import type { ClassroomData, ClassroomScene } from '../components/ClassroomPreviewer';
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Old → New (ClassroomData → Classroom)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Convert a legacy ClassroomData to the new Classroom format.
|
||||
* Used when opening ClassroomPlayer from Pipeline result previews.
|
||||
*/
|
||||
export function adaptToClassroom(data: ClassroomData): Classroom {
|
||||
const scenes: GeneratedScene[] = data.scenes.map((scene, index) => ({
|
||||
id: scene.id,
|
||||
outlineId: `outline-${index}`,
|
||||
content: {
|
||||
title: scene.title,
|
||||
sceneType: mapSceneType(scene.type),
|
||||
content: {
|
||||
heading: scene.content.heading ?? scene.title,
|
||||
key_points: scene.content.bullets ?? [],
|
||||
description: scene.content.explanation,
|
||||
quiz: scene.content.quiz ?? undefined,
|
||||
},
|
||||
actions: [],
|
||||
durationSeconds: scene.duration ?? 60,
|
||||
notes: scene.narration,
|
||||
},
|
||||
order: index,
|
||||
})) as GeneratedScene[];
|
||||
|
||||
return {
|
||||
id: data.id,
|
||||
title: data.title,
|
||||
description: data.subject,
|
||||
topic: data.subject,
|
||||
style: TeachingStyle.Lecture,
|
||||
level: mapDifficulty(data.difficulty),
|
||||
totalDuration: data.duration * 60,
|
||||
objectives: [],
|
||||
scenes,
|
||||
agents: [],
|
||||
metadata: {
|
||||
generatedAt: new Date(data.createdAt).getTime(),
|
||||
version: '1.0',
|
||||
custom: {},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// New → Old (Classroom → ClassroomData)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Convert a new Classroom to the legacy ClassroomData format.
|
||||
* Used when rendering ClassroomPreviewer from new pipeline results.
|
||||
*/
|
||||
export function adaptToClassroomData(classroom: Classroom): ClassroomData {
|
||||
const scenes: ClassroomScene[] = classroom.scenes.map((scene) => {
|
||||
const data = scene.content.content as Record<string, unknown>;
|
||||
return {
|
||||
id: scene.id,
|
||||
title: scene.content.title,
|
||||
type: mapToLegacySceneType(scene.content.sceneType),
|
||||
content: {
|
||||
heading: (data?.heading as string) ?? scene.content.title,
|
||||
bullets: (data?.key_points as string[]) ?? [],
|
||||
explanation: (data?.description as string) ?? '',
|
||||
quiz: (data?.quiz as ClassroomScene['content']['quiz']) ?? undefined,
|
||||
},
|
||||
narration: scene.content.notes,
|
||||
duration: scene.content.durationSeconds,
|
||||
};
|
||||
});
|
||||
|
||||
return {
|
||||
id: classroom.id,
|
||||
title: classroom.title,
|
||||
subject: classroom.topic,
|
||||
difficulty: mapToLegacyDifficulty(classroom.level),
|
||||
duration: Math.ceil(classroom.totalDuration / 60),
|
||||
scenes,
|
||||
outline: {
|
||||
sections: classroom.scenes.map((scene) => ({
|
||||
title: scene.content.title,
|
||||
scenes: [scene.id],
|
||||
})),
|
||||
},
|
||||
createdAt: new Date(classroom.metadata.generatedAt).toISOString(),
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function mapSceneType(type: ClassroomScene['type']): SceneType {
|
||||
switch (type) {
|
||||
case 'title': return SceneType.Slide;
|
||||
case 'content': return SceneType.Slide;
|
||||
case 'quiz': return SceneType.Quiz;
|
||||
case 'interactive': return SceneType.Interactive;
|
||||
case 'summary': return SceneType.Text;
|
||||
default: return SceneType.Slide;
|
||||
}
|
||||
}
|
||||
|
||||
function mapToLegacySceneType(sceneType: string): ClassroomScene['type'] {
|
||||
switch (sceneType) {
|
||||
case 'quiz': return 'quiz';
|
||||
case 'interactive': return 'interactive';
|
||||
case 'text': return 'summary';
|
||||
default: return 'content';
|
||||
}
|
||||
}
|
||||
|
||||
function mapDifficulty(difficulty: string): DifficultyLevel {
|
||||
switch (difficulty) {
|
||||
case '初级': return DifficultyLevel.Beginner;
|
||||
case '中级': return DifficultyLevel.Intermediate;
|
||||
case '高级': return DifficultyLevel.Advanced;
|
||||
default: return DifficultyLevel.Intermediate;
|
||||
}
|
||||
}
|
||||
|
||||
function mapToLegacyDifficulty(level: string): ClassroomData['difficulty'] {
|
||||
switch (level) {
|
||||
case 'beginner': return '初级';
|
||||
case 'advanced': return '高级';
|
||||
case 'expert': return '高级';
|
||||
default: return '中级';
|
||||
}
|
||||
}
|
||||
@@ -56,12 +56,19 @@ function initErrorStore(): void {
|
||||
errors: [],
|
||||
|
||||
addError: (error: AppError) => {
|
||||
// Dedup: skip if same title+message already exists and undismissed
|
||||
const isDuplicate = errorStore.errors.some(
|
||||
(e) => !e.dismissed && e.title === error.title && e.message === error.message
|
||||
);
|
||||
if (isDuplicate) return;
|
||||
|
||||
const storedError: StoredError = {
|
||||
...error,
|
||||
dismissed: false,
|
||||
reported: false,
|
||||
};
|
||||
errorStore.errors = [storedError, ...errorStore.errors];
|
||||
// Cap at 50 errors to prevent unbounded growth
|
||||
errorStore.errors = [storedError, ...errorStore.errors].slice(0, 50);
|
||||
// Notify listeners
|
||||
notifyErrorListeners(error);
|
||||
},
|
||||
|
||||
@@ -103,6 +103,12 @@ export function installChatMethods(ClientClass: { prototype: KernelClient }): vo
|
||||
callbacks.onDelta(streamEvent.delta);
|
||||
break;
|
||||
|
||||
case 'thinkingDelta':
|
||||
if (callbacks.onThinkingDelta) {
|
||||
callbacks.onThinkingDelta(streamEvent.delta);
|
||||
}
|
||||
break;
|
||||
|
||||
case 'tool_start':
|
||||
log.debug('Tool started:', streamEvent.name, streamEvent.input);
|
||||
if (callbacks.onTool) {
|
||||
|
||||
@@ -5,8 +5,20 @@
|
||||
*/
|
||||
|
||||
import { invoke } from '@tauri-apps/api/core';
|
||||
import { listen, type UnlistenFn } from '@tauri-apps/api/event';
|
||||
import { createLogger } from './logger';
|
||||
import type { KernelClient } from './kernel-client';
|
||||
|
||||
const log = createLogger('KernelHands');
|
||||
|
||||
/** Payload emitted by the Rust backend on `hand-execution-complete` events. */
|
||||
export interface HandExecutionCompletePayload {
|
||||
approvalId: string;
|
||||
handId: string;
|
||||
success: boolean;
|
||||
error?: string | null;
|
||||
}
|
||||
|
||||
export function installHandMethods(ClientClass: { prototype: KernelClient }): void {
|
||||
const proto = ClientClass.prototype as any;
|
||||
|
||||
@@ -92,7 +104,7 @@ export function installHandMethods(ClientClass: { prototype: KernelClient }): vo
|
||||
*/
|
||||
proto.getHandStatus = async function (this: KernelClient, name: string, runId: string): Promise<{ status: string; result?: unknown }> {
|
||||
try {
|
||||
return await invoke('hand_run_status', { handName: name, runId });
|
||||
return await invoke('hand_run_status', { runId });
|
||||
} catch (e) {
|
||||
const { createLogger } = await import('./logger');
|
||||
createLogger('KernelHands').debug('hand_run_status failed', { name, runId, error: e });
|
||||
@@ -171,4 +183,26 @@ export function installHandMethods(ClientClass: { prototype: KernelClient }): vo
|
||||
proto.respondToApproval = async function (this: KernelClient, approvalId: string, approved: boolean, reason?: string): Promise<void> {
|
||||
return invoke('approval_respond', { id: approvalId, approved, reason });
|
||||
};
|
||||
|
||||
// ─── Event Listeners ───
|
||||
|
||||
/**
|
||||
* Listen for `hand-execution-complete` events emitted by the Rust backend
|
||||
* after a hand finishes executing (both from direct trigger and approval flow).
|
||||
*
|
||||
* Returns an unlisten function for cleanup.
|
||||
*/
|
||||
proto.onHandExecutionComplete = async function (
|
||||
this: KernelClient,
|
||||
callback: (payload: HandExecutionCompletePayload) => void,
|
||||
): Promise<UnlistenFn> {
|
||||
const unlisten = await listen<HandExecutionCompletePayload>(
|
||||
'hand-execution-complete',
|
||||
(event) => {
|
||||
log.debug('hand-execution-complete', event.payload);
|
||||
callback(event.payload);
|
||||
},
|
||||
);
|
||||
return unlisten;
|
||||
};
|
||||
}
|
||||
|
||||
@@ -109,7 +109,11 @@ export function installSkillMethods(ClientClass: { prototype: KernelClient }): v
|
||||
}> {
|
||||
return invoke('skill_execute', {
|
||||
id,
|
||||
context: {},
|
||||
context: {
|
||||
agentId: '',
|
||||
sessionId: '',
|
||||
workingDir: '',
|
||||
},
|
||||
input: input || {},
|
||||
});
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user