Compare commits

...

25 Commits

Author SHA1 Message Date
iven
8898bb399e docs: audit reports + feature docs + skills + admin-v2 + config sync
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Update audit tracker, roadmap, architecture docs,
add admin-v2 Roles page + Billing tests,
sync CLAUDE.md, Cargo.toml, docker-compose.yml,
add deep-research / frontend-design / chart-visualization skills

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 19:25:00 +08:00
iven
28299807b6 fix(desktop): DeerFlow UI — ChatArea refactor + ai-elements + dead CSS cleanup
ChatArea retry button uses setInput instead of direct sendToGateway,
fix bootstrap spinner stuck for non-logged-in users,
remove dead CSS (aurora-title/sidebar-open/quick-action-chips),
add ai components (ReasoningBlock/StreamingText/ChatMode/ModelSelector/TaskProgress),
add ClassroomPlayer + ResizableChatLayout + artifact panel

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 19:24:44 +08:00
iven
d40c4605b2 fix(knowledge): verification audit — 3 medium issues
- create_item: wrap item + version INSERT in transaction for atomicity
- update_item handler: validate content length (100KB) before DB hit
- KnowledgeChunk: document missing embedding field, safe per explicit SELECT usage
2026-04-02 19:16:32 +08:00
iven
7e4b787d5c fix(knowledge): deep audit — 18 bugs fixed across backend + frontend
CRITICAL:
- Migration permission seed WHERE name → WHERE id (matched 0 rows, all KB APIs broken)

HIGH:
- analytics_quality SQL alias + missing comma fix
- search() duplicate else block compile error
- chunk_content duplicate var declarations + type mismatch
- SQL invalid escape sequences
- delete_category missing rows_affected check

MEDIUM:
- analytics_overview hit_rate vs positive_feedback_rate separation
- analytics_quality GROUP BY kc.id,kc.name (same-name category merge)
- update_category handler trim + empty name validation
- update_item duplicate VALID_STATUSES inside transaction
- page_size max(1) lower bound in list handlers
- batch_create title/content/length validation
- embedding dispatch silent error → tracing::warn
- Version modal close clears detailItem state
- Search empty state distinguishes not-searched vs no-results
- Create modal cancel resets form
2026-04-02 19:07:42 +08:00
iven
837abec48a feat(billing): add usage increment API + wire hand/pipeline execution tracking
Server side:
- POST /api/v1/billing/usage/increment endpoint with dimension whitelist
  (hand_executions, pipeline_runs, relay_requests) and count validation (1-100)
- Returns updated usage quota after increment

Desktop side:
- New saas-billing.ts mixin with incrementUsageDimension() and
  reportUsageFireAndForget() (non-blocking, safe for finally blocks)
- handStore.triggerHand: reports hand_executions after successful run
- PipelinesPanel.handleRunComplete: reports pipeline_runs on completion
- SaaSClient type declarations for new billing methods

Billing pipeline now covers all three dimensions:
  relay_requests  → relay handler (server-side, real-time)
  hand_executions → handStore (client-side, fire-and-forget)
  pipeline_runs   → PipelinesPanel (client-side, fire-and-forget)
2026-04-02 02:02:59 +08:00
iven
11e3d37468 feat(billing): activate real-time quota enforcement pipeline
- Wire relay handler to increment_usage() for JSON responses (tokens + relay_requests)
- Wire relay handler to increment_dimension("relay_requests") for SSE streams
- Add increment_dimension() function for hand_executions/pipeline_runs dimensions
- Schedule AggregateUsageWorker hourly for reconciliation (run_on_start=true)
- Mount mock payment routes in dev mode (ZCLAW_SAAS_DEV=true)

Previously the quota middleware always allowed requests because usage
counters were never incremented. Now relay requests update billing_usage_quotas
in real-time, with the aggregator providing hourly reconciliation.
2026-04-02 01:52:01 +08:00
iven
8263b236fd refactor(desktop): wire PipelineResultPreview into PipelinesPanel
Replace the inline ResultModal with the full-featured
PipelineResultPreview component. This gives users JSON/Markdown/
Classroom mode switching, file download cards, and classroom export
support instead of the previous basic PresentationContainer wrapper.

Remove unused ResultModal component and PresentationContainer import.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 01:32:58 +08:00
iven
08268b32b8 feat(memory): implement FactStore SQLite persistence
Add `facts` table to schema with columns for id, agent_id, content,
category, confidence, source_session, and created_at. Implement
store_facts() and get_top_facts() on MemoryStore using upsert-by-id
and confidence-desc ordering. Facts extracted from conversations are
now durable across sessions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 01:26:57 +08:00
iven
1bf0d3a73d fix(memory): CJK-aware short query threshold + Chinese synonym expansion
1. MemoryMiddleware: replace byte-length check (query.len() < 4) with
   char-count check (query.chars().count() < 2). Single CJK characters
   are 3 UTF-8 bytes but 1 meaningful character — the old threshold
   incorrectly skipped 1-2 char Chinese queries like "你好".

2. QueryAnalyzer: add Chinese synonym mappings for 13 common technical
   terms (错误→bug, 优化→improve, 配置→config, etc.) so CJK queries
   can find relevant English-keyword memories and vice versa.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 01:21:29 +08:00
iven
07099e3ef0 test(hands): expand Slideshow tests (4→34) and fix Clip invalid action test
Slideshow: add navigation edge cases, autoplay/pause/resume, spotlight/
laser/highlight defaults, content block deserialization, Hand trait
dispatch, and add_slide helper tests.

Clip: fix test_execute_invalid_action to expect Err (execute returns
HandError for unknown variants).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 01:13:15 +08:00
iven
dce9035584 test(hands): add 28 unit tests for Twitter Hand
Cover config defaults, 13 action types deserialization, serialization
roundtrip, credential management, and data type parsing. Also add
PartialEq derive to HandStatus for test assertions.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 01:01:37 +08:00
iven
c8dc654fd4 feat(admin-v2): add billing management page
- Plan cards with feature comparison and pricing
- Usage progress bars with quota visualization
- Alipay/WeChat Pay method selection modal
- Payment status polling with auto-refresh on success
- Navigation + route registration

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 00:48:35 +08:00
iven
b1e3a27043 feat(saas): add payment integration with Alipay/WeChat mock support
- payment.rs: create_payment, handle_payment_callback, query_payment_status
- Mock pay page for development mode with HTML confirm/cancel flow
- Payment callback handler with subscription auto-creation on success
- Alipay form-urlencoded and WeChat JSON callback parsing
- 7 new routes including callback and mock-pay endpoints

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 00:41:35 +08:00
iven
becfda3fbf feat(admin-v2): add Knowledge base management page
- 4 tabs: Items (CRUD + ProTable), Categories (tree management), Search, Analytics
- Knowledge service with full API integration
- Nav item + breadcrumb + route registration
- Analytics overview with 8 KPI statistics

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 00:34:17 +08:00
iven
830e9fa301 feat(saas): add GenerateEmbedding worker for knowledge chunking
- Markdown-aware content splitting (512 token chunks with 64 overlap)
- CJK keyword extraction from chunk content with stop-word filtering
- Full refresh strategy (delete old chunks → re-insert on update)
- Phase 2 placeholder for vector embedding API integration

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 00:23:38 +08:00
iven
ef60f9a183 feat(saas): add knowledge base module — categories, items, versions, search, analytics
- 5 knowledge tables (categories, items, chunks, versions, usage) with pgvector + HNSW + GIN indexes
- 23+ API routes covering full CRUD, tree-structured categories, version snapshots
- Keyword-based search with ILIKE + array match (placeholder for vector search)
- Analytics endpoints: overview, trends, top-items, quality, gaps
- Markdown-aware content chunking with overlap strategy
- Worker dispatch for async embedding generation

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 00:21:28 +08:00
iven
b66087de0e feat(saas): add quota middleware and usage aggregation worker
B1.3 Quota middleware:
- quota_check_middleware for relay route chain
- Checks monthly relay_requests quota before processing
- Gracefully degrades on billing service failure

B1.5 AggregateUsageWorker:
- Aggregates usage_records into billing_usage_quotas monthly
- Supports single-account and all-accounts modes
- Scheduled hourly via Worker dispatcher (6 workers total)
2026-04-02 00:06:39 +08:00
iven
d06ecded34 feat(saas): add quota check middleware for relay requests
Injects billing quota verification before relay chat completion requests.
Checks monthly relay_requests quota via billing::service::check_quota.
Gracefully degrades on quota service failure (logs warning, allows request).
2026-04-02 00:03:26 +08:00
iven
9487cd7f72 feat(saas): add billing infrastructure — tables, types, service, handlers
B1.1 Billing database:
- 5 tables: billing_plans, billing_subscriptions, billing_invoices,
  billing_payments, billing_usage_quotas
- Seed data: Free(¥0)/Pro(¥49)/Team(¥199) plans
- JSONB limits for flexible plan configuration

Billing module (crates/zclaw-saas/src/billing/):
- types.rs: BillingPlan, Subscription, Invoice, Payment, UsageQuota
- service.rs: plan CRUD, subscription lookup, usage tracking, quota check
- handlers.rs: REST API (plans list/detail, subscription, usage)
- mod.rs: routes registered at /api/v1/billing/*

Cargo.toml: added chrono feature to sqlx for DateTime<Utc> support
2026-04-01 23:59:46 +08:00
iven
c6bd4aea27 feat(pipelines): add 10 industry-specific pipeline templates
Education (3): research-to-quiz, student-analysis, lesson-plan
Healthcare (3): policy-compliance, meeting-minutes, data-report
Design Shantou (4): trend-to-design, competitor-research,
  client-communication, supply-chain-collect
2026-04-01 23:43:45 +08:00
iven
17a2501808 test(hands): add unit tests for BrowserHand + fix requires_approval config
Fix needs_approval field in BrowserHand::new() from false to true to
match the TOML config (hands/browser.HAND.toml says requires_approval = true).
Browser automation has security implications and should require approval.

Add 11 unit tests covering:
- Config id and enabled state
- needs_approval correctness (after fix)
- Action deserialization (Navigate, Click, Type, Scrape, Screenshot)
- Roundtrip serialization for all major action variants
- BrowserSequence builder with stop_on_error()
- Multi-step sequence execution
- FormField deserialization

Also add stop_on_error() builder method to BrowserSequence which was
referenced in the test plan but missing from the struct.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-01 23:22:18 +08:00
iven
cc7ee3189d test(hands): add unit tests for CollectorHand + fix HTML extraction position tracking
Fix extract_visible_text to use proper byte position tracking (pos += char_len)
instead of iterating chars without position context, which caused script/style
tag detection to fail on multi-byte content. Also adds script/style stripping
logic and raises truncation limit to 10000 chars.

Adds 9 unit tests covering:
- Config identity verification
- OutputFormat serialization round-trip
- HTML text extraction (basic, script stripping, style stripping, empty input)
- Aggregate action with empty URLs
- CollectorAction deserialization (Collect/Aggregate/Extract)
- CollectionTarget deserialization
2026-04-01 23:21:43 +08:00
iven
62df7feac1 docs(spec): switch payment integration from Stripe to Alipay/WeChat Pay direct
Target market is domestic China users only — integrate Alipay Face-to-Face
Payment and WeChat Native Pay directly instead of Stripe as intermediary.
Updated billing module structure, risk table, and verification criteria.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-01 23:21:22 +08:00
iven
a851a2854f feat(desktop): update quick action prompts for education/healthcare/design industries
Tailor first-conversation prompts to the three target user groups:
- Education: AI tool comparison, digital transformation research
- Healthcare: administrative optimization proposal
- Design/Shantou: toy industry export trend analysis

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-01 23:21:06 +08:00
iven
59fc7debd6 feat(hands): add 25 unit tests + fix summary + fix HTML extraction for ResearcherHand
- Add comprehensive test suite: config, types, action deserialization, URL encoding,
  HTML text extraction, hand trait methods
- Fix summary field: generate rule-based summary from top search results (was always None)
- Fix extract_text_from_html: correct position tracking for script/style tag detection

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-01 23:16:57 +08:00
161 changed files with 21989 additions and 870 deletions

View File

@@ -9,7 +9,7 @@
ZCLAW 是面向中文用户的 AI Agent 桌面端,核心能力包括:
- **智能对话** - 多模型支持、流式响应、上下文管理
- **自主能力** - 8 个 Hands浏览器、数据采集、研究、预测等
- **自主能力** - 11 个 Hands9 启用 + 2 禁用: Predictor, Lead
- **技能系统** - 可扩展的 SKILL.md 技能定义
- **工作流编排** - 多步骤自动化任务
- **安全审计** - 完整的操作日志和权限控制
@@ -69,7 +69,7 @@ ZCLAW/
| 桌面框架 | Tauri 2.x |
| 样式方案 | Tailwind CSS |
| 配置格式 | TOML |
| 后端核心 | Rust Workspace (9 crates) |
| 后端核心 | Rust Workspace (10 crates) |
| SaaS 后端 | Axum + PostgreSQL (zclaw-saas) |
| 管理后台 | Next.js (admin/) |

178
Cargo.lock generated
View File

@@ -1314,7 +1314,7 @@ dependencies = [
"serde_json",
"serde_yaml",
"sha2",
"sqlx",
"sqlx 0.7.4",
"tauri",
"tauri-build",
"tauri-plugin-opener",
@@ -2262,6 +2262,8 @@ version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash 0.1.5",
]
@@ -2285,6 +2287,15 @@ dependencies = [
"hashbrown 0.14.5",
]
[[package]]
name = "hashlink"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1"
dependencies = [
"hashbrown 0.15.5",
]
[[package]]
name = "headers"
version = "0.4.1"
@@ -3716,6 +3727,15 @@ version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220"
[[package]]
name = "pgvector"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc58e2d255979a31caa7cabfa7aac654af0354220719ab7a68520ae7a91e8c0b"
dependencies = [
"sqlx 0.8.6",
]
[[package]]
name = "phf"
version = "0.8.0"
@@ -4571,6 +4591,7 @@ dependencies = [
"pkcs1",
"pkcs8",
"rand_core 0.6.4",
"sha2",
"signature",
"spki",
"subtle",
@@ -5271,13 +5292,24 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9a2ccff1a000a5a59cd33da541d9f2fdcd9e6e8229cc200565942bff36d0aaa"
dependencies = [
"sqlx-core",
"sqlx-macros",
"sqlx-core 0.7.4",
"sqlx-macros 0.7.4",
"sqlx-mysql",
"sqlx-postgres",
"sqlx-postgres 0.7.4",
"sqlx-sqlite",
]
[[package]]
name = "sqlx"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fefb893899429669dcdd979aff487bd78f4064e5e7907e4269081e0ef7d97dc"
dependencies = [
"sqlx-core 0.8.6",
"sqlx-macros 0.8.6",
"sqlx-postgres 0.8.6",
]
[[package]]
name = "sqlx-core"
version = "0.7.4"
@@ -5288,6 +5320,7 @@ dependencies = [
"atoi",
"byteorder",
"bytes",
"chrono",
"crc",
"crossbeam-queue",
"either",
@@ -5297,7 +5330,7 @@ dependencies = [
"futures-intrusive",
"futures-io",
"futures-util",
"hashlink",
"hashlink 0.8.4",
"hex",
"indexmap 2.13.0",
"log",
@@ -5317,6 +5350,38 @@ dependencies = [
"url",
]
[[package]]
name = "sqlx-core"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee6798b1838b6a0f69c007c133b8df5866302197e404e8b6ee8ed3e3a5e68dc6"
dependencies = [
"base64 0.22.1",
"bytes",
"crc",
"crossbeam-queue",
"either",
"event-listener 5.4.1",
"futures-core",
"futures-intrusive",
"futures-io",
"futures-util",
"hashbrown 0.15.5",
"hashlink 0.10.0",
"indexmap 2.13.0",
"log",
"memchr",
"once_cell",
"percent-encoding",
"serde",
"serde_json",
"sha2",
"smallvec",
"thiserror 2.0.18",
"tracing",
"url",
]
[[package]]
name = "sqlx-macros"
version = "0.7.4"
@@ -5325,11 +5390,24 @@ checksum = "4ea40e2345eb2faa9e1e5e326db8c34711317d2b5e08d0d5741619048a803127"
dependencies = [
"proc-macro2",
"quote",
"sqlx-core",
"sqlx-macros-core",
"sqlx-core 0.7.4",
"sqlx-macros-core 0.7.4",
"syn 1.0.109",
]
[[package]]
name = "sqlx-macros"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2d452988ccaacfbf5e0bdbc348fb91d7c8af5bee192173ac3636b5fb6e6715d"
dependencies = [
"proc-macro2",
"quote",
"sqlx-core 0.8.6",
"sqlx-macros-core 0.8.6",
"syn 2.0.117",
]
[[package]]
name = "sqlx-macros-core"
version = "0.7.4"
@@ -5346,9 +5424,9 @@ dependencies = [
"serde",
"serde_json",
"sha2",
"sqlx-core",
"sqlx-core 0.7.4",
"sqlx-mysql",
"sqlx-postgres",
"sqlx-postgres 0.7.4",
"sqlx-sqlite",
"syn 1.0.109",
"tempfile",
@@ -5356,6 +5434,28 @@ dependencies = [
"url",
]
[[package]]
name = "sqlx-macros-core"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19a9c1841124ac5a61741f96e1d9e2ec77424bf323962dd894bdb93f37d5219b"
dependencies = [
"dotenvy",
"either",
"heck 0.5.0",
"hex",
"once_cell",
"proc-macro2",
"quote",
"serde",
"serde_json",
"sha2",
"sqlx-core 0.8.6",
"sqlx-postgres 0.8.6",
"syn 2.0.117",
"url",
]
[[package]]
name = "sqlx-mysql"
version = "0.7.4"
@@ -5367,6 +5467,7 @@ dependencies = [
"bitflags 2.11.0",
"byteorder",
"bytes",
"chrono",
"crc",
"digest",
"dotenvy",
@@ -5391,7 +5492,7 @@ dependencies = [
"sha1",
"sha2",
"smallvec",
"sqlx-core",
"sqlx-core 0.7.4",
"stringprep",
"thiserror 1.0.69",
"tracing",
@@ -5408,6 +5509,7 @@ dependencies = [
"base64 0.21.7",
"bitflags 2.11.0",
"byteorder",
"chrono",
"crc",
"dotenvy",
"etcetera",
@@ -5429,13 +5531,50 @@ dependencies = [
"serde_json",
"sha2",
"smallvec",
"sqlx-core",
"sqlx-core 0.7.4",
"stringprep",
"thiserror 1.0.69",
"tracing",
"whoami",
]
[[package]]
name = "sqlx-postgres"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "db58fcd5a53cf07c184b154801ff91347e4c30d17a3562a635ff028ad5deda46"
dependencies = [
"atoi",
"base64 0.22.1",
"bitflags 2.11.0",
"byteorder",
"crc",
"dotenvy",
"etcetera",
"futures-channel",
"futures-core",
"futures-util",
"hex",
"hkdf",
"hmac",
"home",
"itoa",
"log",
"md-5",
"memchr",
"once_cell",
"rand 0.8.5",
"serde",
"serde_json",
"sha2",
"smallvec",
"sqlx-core 0.8.6",
"stringprep",
"thiserror 2.0.18",
"tracing",
"whoami",
]
[[package]]
name = "sqlx-sqlite"
version = "0.7.4"
@@ -5443,6 +5582,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b244ef0a8414da0bed4bb1910426e890b19e5e9bccc27ada6b797d05c55ae0aa"
dependencies = [
"atoi",
"chrono",
"flume",
"futures-channel",
"futures-core",
@@ -5453,7 +5593,7 @@ dependencies = [
"log",
"percent-encoding",
"serde",
"sqlx-core",
"sqlx-core 0.7.4",
"tracing",
"url",
"urlencoding",
@@ -8211,7 +8351,7 @@ dependencies = [
"libsqlite3-sys",
"serde",
"serde_json",
"sqlx",
"sqlx 0.7.4",
"thiserror 2.0.18",
"tokio",
"tokio-test",
@@ -8227,11 +8367,9 @@ dependencies = [
"async-trait",
"base64 0.22.1",
"chrono",
"hmac",
"reqwest 0.12.28",
"serde",
"serde_json",
"sha1",
"thiserror 2.0.18",
"tokio",
"tracing",
@@ -8272,12 +8410,14 @@ dependencies = [
name = "zclaw-memory"
version = "0.1.0"
dependencies = [
"anyhow",
"async-trait",
"chrono",
"futures",
"libsqlite3-sys",
"serde",
"serde_json",
"sqlx",
"sqlx 0.7.4",
"thiserror 2.0.18",
"tokio",
"tracing",
@@ -8362,9 +8502,11 @@ dependencies = [
"aes-gcm",
"anyhow",
"argon2",
"async-stream",
"async-trait",
"axum",
"axum-extra",
"base64 0.22.1",
"bytes",
"chrono",
"dashmap",
@@ -8372,15 +8514,17 @@ dependencies = [
"futures",
"hex",
"jsonwebtoken",
"pgvector",
"rand 0.8.5",
"regex",
"reqwest 0.12.28",
"rsa",
"secrecy",
"serde",
"serde_json",
"sha2",
"socket2 0.5.10",
"sqlx",
"sqlx 0.7.4",
"tempfile",
"thiserror 2.0.18",
"tokio",

View File

@@ -57,7 +57,7 @@ chrono = { version = "0.4", features = ["serde"] }
uuid = { version = "1", features = ["v4", "v5", "serde"] }
# Database
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite", "postgres"] }
sqlx = { version = "0.7", features = ["runtime-tokio", "sqlite", "postgres", "chrono"] }
libsqlite3-sys = { version = "0.27", features = ["bundled"] }
# HTTP client (for LLM drivers)
@@ -84,6 +84,7 @@ rand = "0.8"
# Crypto
sha2 = "0.10"
aes-gcm = "0.10"
rsa = { version = "0.9", features = ["pem"] }
# Home directory
dirs = "6"

View File

@@ -16,6 +16,9 @@ import {
SunOutlined,
MoonOutlined,
ApiOutlined,
BookOutlined,
CrownOutlined,
SafetyOutlined,
} from '@ant-design/icons'
import { Avatar, Dropdown, Tooltip, Drawer } from 'antd'
import { useAuthStore } from '@/stores/authStore'
@@ -37,11 +40,14 @@ interface NavItem {
const navItems: NavItem[] = [
{ path: '/', name: '仪表盘', icon: <DashboardOutlined />, group: '核心' },
{ path: '/accounts', name: '账号管理', icon: <TeamOutlined />, permission: 'account:admin', group: '资源管理' },
{ path: '/roles', name: '角色与权限', icon: <SafetyOutlined />, permission: 'account:admin', group: '资源管理' },
{ path: '/model-services', name: '模型服务', icon: <CloudServerOutlined />, permission: 'provider:manage', group: '资源管理' },
{ path: '/agent-templates', name: 'Agent 模板', icon: <RobotOutlined />, permission: 'model:read', group: '资源管理' },
{ path: '/api-keys', name: 'API 密钥', icon: <ApiOutlined />, permission: 'provider:manage', group: '资源管理' },
{ path: '/usage', name: '用量统计', icon: <BarChartOutlined />, permission: 'admin:full', group: '运维' },
{ path: '/relay', name: '中转任务', icon: <SwapOutlined />, permission: 'relay:use', group: '运维' },
{ path: '/knowledge', name: '知识库', icon: <BookOutlined />, permission: 'knowledge:read', group: '资源管理' },
{ path: '/billing', name: '计费管理', icon: <CrownOutlined />, permission: 'billing:read', group: '核心' },
{ path: '/logs', name: '操作日志', icon: <FileTextOutlined />, permission: 'admin:full', group: '运维' },
{ path: '/config', name: '系统配置', icon: <SettingOutlined />, permission: 'config:read', group: '系统' },
{ path: '/prompts', name: '提示词管理', icon: <MessageOutlined />, permission: 'prompt:read', group: '系统' },
@@ -197,6 +203,7 @@ function MobileDrawer({
const breadcrumbMap: Record<string, string> = {
'/': '仪表盘',
'/accounts': '账号管理',
'/roles': '角色与权限',
'/model-services': '模型服务',
'/providers': '模型服务',
'/models': '模型服务',
@@ -204,6 +211,8 @@ const breadcrumbMap: Record<string, string> = {
'/agent-templates': 'Agent 模板',
'/usage': '用量统计',
'/relay': '中转任务',
'/knowledge': '知识库',
'/billing': '计费管理',
'/config': '系统配置',
'/prompts': '提示词管理',
'/logs': '操作日志',

View File

@@ -0,0 +1,352 @@
// ============================================================
// 计费管理 — 计划/订阅/用量/支付
// ============================================================
import { useState } from 'react'
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import {
Button, message, Tag, Modal, Card, Row, Col, Statistic, Typography,
Progress, Space, Radio, Spin, Empty, Divider,
} from 'antd'
import {
CrownOutlined, CheckCircleOutlined, ThunderboltOutlined,
RocketOutlined, TeamOutlined, AlipayCircleOutlined,
WechatOutlined, LoadingOutlined,
} from '@ant-design/icons'
import { PageHeader } from '@/components/PageHeader'
import { ErrorState } from '@/components/ErrorState'
import { billingService } from '@/services/billing'
import type { BillingPlan, SubscriptionInfo, PaymentResult } from '@/services/billing'
const { Text, Title } = Typography
// === 计划卡片 ===
const planIcons: Record<string, React.ReactNode> = {
free: <RocketOutlined style={{ fontSize: 24 }} />,
pro: <ThunderboltOutlined style={{ fontSize: 24 }} />,
team: <TeamOutlined style={{ fontSize: 24 }} />,
}
const planColors: Record<string, string> = {
free: '#8c8c8c',
pro: '#863bff',
team: '#47bfff',
}
function PlanCard({
plan,
isCurrent,
onSelect,
}: {
plan: BillingPlan
isCurrent: boolean
onSelect: (plan: BillingPlan) => void
}) {
const color = planColors[plan.name] || '#666'
const limits = plan.limits as Record<string, unknown> | undefined
const maxRelay = (limits?.max_relay_requests_monthly as number) ?? '∞'
const maxHand = (limits?.max_hand_executions_monthly as number) ?? '∞'
const maxPipeline = (limits?.max_pipeline_runs_monthly as number) ?? '∞'
return (
<Card
className={`relative overflow-hidden transition-all duration-200 hover:shadow-lg ${
isCurrent ? 'ring-2 ring-offset-2' : ''
}`}
style={isCurrent ? { borderColor: color, '--tw-ring-color': color } as React.CSSProperties : {}}
>
{isCurrent && (
<div
className="absolute top-0 right-0 px-3 py-1 text-xs font-medium text-white rounded-bl-lg"
style={{ background: color }}
>
</div>
)}
<div className="text-center mb-4">
<div style={{ color }} className="mb-2">
{planIcons[plan.name] || <CrownOutlined style={{ fontSize: 24 }} />}
</div>
<Title level={4} style={{ margin: 0 }}>{plan.display_name}</Title>
{plan.description && (
<Text type="secondary" className="text-sm">{plan.description}</Text>
)}
</div>
<div className="text-center mb-4">
<span className="text-3xl font-bold" style={{ color }}>
¥{plan.price_cents === 0 ? '0' : (plan.price_cents / 100).toFixed(0)}
</span>
<Text type="secondary"> /{plan.interval === 'month' ? '月' : '年'}</Text>
</div>
<div className="space-y-2 text-sm">
<div className="flex items-center gap-2">
<CheckCircleOutlined style={{ color: '#52c41a' }} />
<span>: {maxRelay === Infinity ? '无限' : `${maxRelay} 次/月`}</span>
</div>
<div className="flex items-center gap-2">
<CheckCircleOutlined style={{ color: '#52c41a' }} />
<span>Hand : {maxHand === Infinity ? '无限' : `${maxHand} 次/月`}</span>
</div>
<div className="flex items-center gap-2">
<CheckCircleOutlined style={{ color: '#52c41a' }} />
<span>Pipeline : {maxPipeline === Infinity ? '无限' : `${maxPipeline} 次/月`}</span>
</div>
<div className="flex items-center gap-2">
<CheckCircleOutlined style={{ color: '#52c41a' }} />
<span>: {plan.name === 'free' ? '基础' : '高级'}</span>
</div>
<div className="flex items-center gap-2">
<CheckCircleOutlined style={{ color: '#52c41a' }} />
<span>: {plan.name === 'team' ? '最高' : plan.name === 'pro' ? '高' : '标准'}</span>
</div>
</div>
<Divider />
<Button
block
type={isCurrent ? 'default' : 'primary'}
disabled={isCurrent}
onClick={() => onSelect(plan)}
style={!isCurrent ? { background: color, borderColor: color } : {}}
>
{isCurrent ? '当前计划' : '升级'}
</Button>
</Card>
)
}
// === 用量进度条 ===
function UsageBar({ label, current, max }: { label: string; current: number; max: number | null }) {
const pct = max ? Math.min((current / max) * 100, 100) : 0
const displayMax = max ? max.toLocaleString() : '∞'
return (
<div className="mb-3">
<div className="flex justify-between text-xs text-neutral-500 dark:text-neutral-400 mb-1">
<span>{label}</span>
<span>{current.toLocaleString()} / {displayMax}</span>
</div>
<Progress
percent={pct}
showInfo={false}
strokeColor={pct >= 90 ? '#ff4d4f' : pct >= 70 ? '#faad14' : '#863bff'}
size="small"
/>
</div>
)
}
// === 主页面 ===
export default function Billing() {
const queryClient = useQueryClient()
const [payModalOpen, setPayModalOpen] = useState(false)
const [selectedPlan, setSelectedPlan] = useState<BillingPlan | null>(null)
const [payMethod, setPayMethod] = useState<'alipay' | 'wechat'>('alipay')
const [payResult, setPayResult] = useState<PaymentResult | null>(null)
const [pollingPayment, setPollingPayment] = useState<string | null>(null)
const { data: plans = [], isLoading: plansLoading, error: plansError, refetch } = useQuery({
queryKey: ['billing-plans'],
queryFn: ({ signal }) => billingService.listPlans(signal),
})
const { data: subInfo, isLoading: subLoading } = useQuery({
queryKey: ['billing-subscription'],
queryFn: ({ signal }) => billingService.getSubscription(signal),
})
// 支付状态轮询
const { data: paymentStatus } = useQuery({
queryKey: ['payment-status', pollingPayment],
queryFn: ({ signal }) => billingService.getPaymentStatus(pollingPayment!, signal),
enabled: !!pollingPayment,
refetchInterval: pollingPayment ? 3000 : false,
})
// 支付成功后刷新
if (paymentStatus?.status === 'succeeded' && pollingPayment) {
setPollingPayment(null)
setPayModalOpen(false)
setPayResult(null)
message.success('支付成功!计划已更新')
queryClient.invalidateQueries({ queryKey: ['billing-subscription'] })
}
const createPaymentMutation = useMutation({
mutationFn: (data: { plan_id: string; payment_method: 'alipay' | 'wechat' }) =>
billingService.createPayment(data),
onSuccess: (result) => {
setPayResult(result)
setPollingPayment(result.payment_id)
// 打开支付链接
window.open(result.pay_url, '_blank', 'width=480,height=640')
},
onError: (err: Error) => message.error(err.message || '创建支付失败'),
})
const handleSelectPlan = (plan: BillingPlan) => {
if (plan.price_cents === 0) return
setSelectedPlan(plan)
setPayResult(null)
setPayModalOpen(true)
}
const handleConfirmPay = () => {
if (!selectedPlan) return
createPaymentMutation.mutate({
plan_id: selectedPlan.id,
payment_method: payMethod,
})
}
if (plansError) {
return (
<>
<PageHeader title="计费管理" description="管理订阅计划和用量配额" />
<ErrorState message={(plansError as Error).message} onRetry={() => refetch()} />
</>
)
}
const currentPlanName = subInfo?.plan?.name || 'free'
const usage = subInfo?.usage
return (
<div>
<PageHeader title="计费管理" description="管理订阅计划和用量配额" />
{/* 当前计划 + 用量 */}
{subInfo && usage && (
<Card className="mb-6" title={<span className="text-sm font-semibold"></span>}>
<Row gutter={[24, 16]}>
<Col xs={24} md={8}>
<UsageBar
label="中转请求"
current={usage.relay_requests}
max={usage.max_relay_requests}
/>
</Col>
<Col xs={24} md={8}>
<UsageBar
label="Hand 执行"
current={usage.hand_executions}
max={usage.max_hand_executions}
/>
</Col>
<Col xs={24} md={8}>
<UsageBar
label="Pipeline 运行"
current={usage.pipeline_runs}
max={usage.max_pipeline_runs}
/>
</Col>
</Row>
{subInfo.subscription && (
<div className="mt-4 text-xs text-neutral-400">
: {new Date(subInfo.subscription.current_period_start).toLocaleDateString()} {new Date(subInfo.subscription.current_period_end).toLocaleDateString()}
</div>
)}
</Card>
)}
{/* 计划选择 */}
<Title level={5} className="mb-4"></Title>
{plansLoading ? (
<div className="flex justify-center py-8"><Spin /></div>
) : (
<Row gutter={[16, 16]}>
{plans.map((plan) => (
<Col key={plan.id} xs={24} sm={12} lg={8}>
<PlanCard
plan={plan}
isCurrent={plan.name === currentPlanName}
onSelect={handleSelectPlan}
/>
</Col>
))}
</Row>
)}
{/* 支付弹窗 */}
<Modal
title={selectedPlan ? `升级到 ${selectedPlan.display_name}` : '支付'}
open={payModalOpen}
onCancel={() => {
setPayModalOpen(false)
setPollingPayment(null)
setPayResult(null)
}}
footer={payResult ? null : undefined}
onOk={handleConfirmPay}
okText={createPaymentMutation.isPending ? '处理中...' : '确认支付'}
confirmLoading={createPaymentMutation.isPending}
>
{payResult ? (
<div className="text-center py-4">
<LoadingOutlined style={{ fontSize: 32, color: '#863bff' }} className="mb-4" />
<Title level={5}>...</Title>
<Text type="secondary">
<br />
: ¥{(payResult.amount_cents / 100).toFixed(2)}
</Text>
<div className="mt-4">
<Button onClick={() => { setPollingPayment(null); setPayModalOpen(false); setPayResult(null) }}>
</Button>
</div>
</div>
) : (
<div>
{selectedPlan && (
<div className="text-center mb-6">
<div className="text-2xl font-bold" style={{ color: planColors[selectedPlan.name] || '#666' }}>
¥{(selectedPlan.price_cents / 100).toFixed(0)}
</div>
<Text type="secondary">/{selectedPlan.interval === 'month' ? '月' : '年'}</Text>
</div>
)}
<Title level={5} className="text-center mb-4"></Title>
<Radio.Group
value={payMethod}
onChange={(e) => setPayMethod(e.target.value)}
className="w-full"
>
<Space direction="vertical" className="w-full" size={12}>
<Radio value="alipay" className="w-full">
<div className="flex items-center gap-3 p-3 border rounded-lg w-full hover:border-blue-400 transition-colors">
<AlipayCircleOutlined style={{ fontSize: 28, color: '#1677ff' }} />
<div>
<div className="font-medium"></div>
<div className="text-xs text-neutral-400"></div>
</div>
</div>
</Radio>
<Radio value="wechat" className="w-full">
<div className="flex items-center gap-3 p-3 border rounded-lg w-full hover:border-green-400 transition-colors">
<WechatOutlined style={{ fontSize: 28, color: '#07c160' }} />
<div>
<div className="font-medium"></div>
<div className="text-xs text-neutral-400"></div>
</div>
</div>
</Radio>
</Space>
</Radio.Group>
</div>
)}
</Modal>
</div>
)
}

View File

@@ -0,0 +1,750 @@
// ============================================================
// 知识库管理
// ============================================================
import { useState, useMemo, useEffect } from 'react'
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import {
Button, message, Tag, Modal, Form, Input, Select, Space, Popconfirm,
Card, Statistic, Row, Col, Tabs, Tree, Typography, Empty, Spin, InputNumber,
Table, Tooltip,
} from 'antd'
import {
PlusOutlined, SearchOutlined, BookOutlined, FolderOutlined,
DeleteOutlined, EditOutlined, EyeOutlined, BarChartOutlined,
HistoryOutlined, RollbackOutlined,
WarningOutlined,
} from '@ant-design/icons'
import type { ProColumns } from '@ant-design/pro-components'
import { ProTable } from '@ant-design/pro-components'
import { knowledgeService } from '@/services/knowledge'
import type { CategoryResponse, KnowledgeItem, SearchResult } from '@/services/knowledge'
const { TextArea } = Input
const { Text, Title } = Typography
// === 分类树 + 条目列表 Tab ===
function CategoriesPanel() {
const queryClient = useQueryClient()
const [createOpen, setCreateOpen] = useState(false)
const [editItem, setEditItem] = useState<CategoryResponse | null>(null)
const [createForm] = Form.useForm()
const [editForm] = Form.useForm()
const { data: categories = [], isLoading } = useQuery({
queryKey: ['knowledge-categories'],
queryFn: ({ signal }) => knowledgeService.listCategories(signal),
})
const createMutation = useMutation({
mutationFn: (data: Parameters<typeof knowledgeService.createCategory>[0]) =>
knowledgeService.createCategory(data),
onSuccess: () => {
message.success('分类已创建')
queryClient.invalidateQueries({ queryKey: ['knowledge-categories'] })
setCreateOpen(false)
createForm.resetFields()
},
onError: (err: Error) => message.error(err.message || '创建失败'),
})
const deleteMutation = useMutation({
mutationFn: (id: string) => knowledgeService.deleteCategory(id),
onSuccess: () => {
message.success('分类已删除')
queryClient.invalidateQueries({ queryKey: ['knowledge-categories'] })
},
onError: (err: Error) => message.error(err.message || '删除失败'),
})
const updateMutation = useMutation({
mutationFn: ({ id, ...data }: { id: string } & Record<string, unknown>) =>
knowledgeService.updateCategory(id, data),
onSuccess: () => {
message.success('分类已更新')
queryClient.invalidateQueries({ queryKey: ['knowledge-categories'] })
setEditItem(null)
},
onError: (err: Error) => message.error(err.message || '更新失败'),
})
// 编辑弹窗打开时同步表单值Ant Design Form initialValues 仅首次挂载生效)
useEffect(() => {
if (editItem) {
editForm.setFieldsValue({
name: editItem.name,
description: editItem.description,
parent_id: editItem.parent_id,
icon: editItem.icon,
})
}
}, [editItem, editForm])
// 获取当前编辑分类及其所有后代的 ID防止循环引用
const getDescendantIds = (id: string, cats: CategoryResponse[]): string[] => {
const ids: string[] = [id]
for (const c of cats) {
if (c.parent_id === id) {
ids.push(...getDescendantIds(c.id, cats))
}
}
return ids
}
const treeData = useMemo(
() => buildTreeData(categories, (id) => {
Modal.confirm({
title: '确认删除',
content: '删除后无法恢复,请确保分类下没有子分类和条目。',
okType: 'danger',
onOk: () => deleteMutation.mutate(id),
})
}, (id) => {
setEditItem(categories.find((c) => c.id === id) || null)
}),
[categories, deleteMutation],
)
return (
<div>
<div className="flex justify-between items-center mb-4">
<Title level={5} style={{ margin: 0 }}></Title>
<Button type="primary" icon={<PlusOutlined />} onClick={() => setCreateOpen(true)}>
</Button>
</div>
{isLoading ? (
<div className="flex justify-center py-8"><Spin /></div>
) : categories.length === 0 ? (
<Empty description="暂无分类,请新建一个" />
) : (
<Tree
treeData={treeData}
defaultExpandAll
showLine={{ showLeafIcon: false }}
showIcon
/>
)}
{/* 新建分类弹窗 */}
<Modal
title="新建分类"
open={createOpen}
onCancel={() => { setCreateOpen(false); createForm.resetFields() }}
onOk={() => createForm.submit()}
confirmLoading={createMutation.isPending}
>
<Form form={createForm} layout="vertical" onFinish={(v) => createMutation.mutate(v)}>
<Form.Item name="name" label="分类名称" rules={[{ required: true, message: '请输入分类名称' }]}>
<Input placeholder="例如:产品知识、技术文档" />
</Form.Item>
<Form.Item name="description" label="描述">
<TextArea rows={2} placeholder="可选描述" />
</Form.Item>
<Form.Item name="parent_id" label="父分类">
<Select placeholder="无(顶级分类)" allowClear>
{flattenCategories(categories).map((c) => (
<Select.Option key={c.id} value={c.id}>{c.name}</Select.Option>
))}
</Select>
</Form.Item>
<Form.Item name="icon" label="图标">
<Input placeholder="可选,如 📚" />
</Form.Item>
</Form>
</Modal>
{/* 编辑分类弹窗 */}
<Modal
title="编辑分类"
open={!!editItem}
onCancel={() => { setEditItem(null); editForm.resetFields() }}
onOk={() => editForm.submit()}
confirmLoading={updateMutation.isPending}
>
<Form
form={editForm}
layout="vertical"
initialValues={editItem ? { name: editItem.name, description: editItem.description, parent_id: editItem.parent_id, icon: editItem.icon } : undefined}
onFinish={(v) => editItem && updateMutation.mutate({ id: editItem.id, ...v })}
>
<Form.Item name="name" label="分类名称" rules={[{ required: true }]}>
<Input />
</Form.Item>
<Form.Item name="description" label="描述">
<TextArea rows={2} />
</Form.Item>
<Form.Item name="parent_id" label="父分类">
<Select placeholder="无(顶级分类)" allowClear>
{editItem && flattenCategories(categories)
.filter((c) => !getDescendantIds(editItem.id, categories).includes(c.id))
.map((c) => (
<Select.Option key={c.id} value={c.id}>{c.name}</Select.Option>
))}
</Form.Item>
<Form.Item name="icon" label="图标">
<Input placeholder="如 📚" />
</Form.Item>
</Form>
</Modal>
</div>
)
}
// === 条目列表 ===
function ItemsPanel() {
const queryClient = useQueryClient()
const [createOpen, setCreateOpen] = useState(false)
const [detailItem, setDetailItem] = useState<string | null>(null)
const [versionModalOpen, setVersionModalOpen] = useState(false)
const [rollingBackVersion, setRollingBackVersion] = useState<number | null>(null)
const [page, setPage] = useState(1)
const [pageSize, setPageSize] = useState(20)
const [filters, setFilters] = useState<{ category_id?: string; status?: string; keyword?: string }>({})
const [form] = Form.useForm()
const { data: categories = [] } = useQuery({
queryKey: ['knowledge-categories'],
queryFn: ({ signal }) => knowledgeService.listCategories(signal),
})
const { data: detailData, isLoading: detailLoading } = useQuery({
queryKey: ['knowledge-item-detail', detailItem],
queryFn: ({ signal }) => knowledgeService.getItem(detailItem!, signal),
enabled: !!detailItem,
})
const { data: versions } = useQuery({
queryKey: ['knowledge-item-versions', detailItem],
queryFn: ({ signal }) => knowledgeService.getVersions(detailItem!, signal),
enabled: !!detailItem,
})
const { data, isLoading } = useQuery({
queryKey: ['knowledge-items', page, pageSize, filters],
queryFn: ({ signal }) =>
knowledgeService.listItems({ page, page_size: pageSize, ...filters }, signal),
})
const createMutation = useMutation({
mutationFn: (data: Parameters<typeof knowledgeService.createItem>[0]) =>
knowledgeService.createItem(data),
onSuccess: () => {
message.success('条目已创建')
queryClient.invalidateQueries({ queryKey: ['knowledge-items'] })
setCreateOpen(false)
form.resetFields()
},
onError: (err: Error) => message.error(err.message || '创建失败'),
})
const deleteMutation = useMutation({
mutationFn: (id: string) => knowledgeService.deleteItem(id),
onSuccess: () => {
message.success('已删除')
queryClient.invalidateQueries({ queryKey: ['knowledge-items'] })
},
onError: (err: Error) => message.error(err.message || '删除失败'),
})
const rollbackMutation = useMutation({
mutationFn: ({ itemId, version }: { itemId: string; version: number }) =>
knowledgeService.rollbackVersion(itemId, version),
onSuccess: () => {
message.success('已回滚')
queryClient.invalidateQueries({ queryKey: ['knowledge-items'] })
queryClient.invalidateQueries({ queryKey: ['knowledge-item-detail'] })
queryClient.invalidateQueries({ queryKey: ['knowledge-item-versions'] })
setVersionModalOpen(false)
setRollingBackVersion(null)
},
onError: (err: Error) => {
message.error(err.message || '回滚失败')
setRollingBackVersion(null)
},
})
const statusColors: Record<string, string> = { active: 'green', draft: 'orange', archived: 'default' }
const statusLabels: Record<string, string> = { active: '活跃', draft: '草稿', archived: '已归档' }
const columns: ProColumns<KnowledgeItem>[] = [
{
title: '标题',
dataIndex: 'keyword',
width: 250,
render: (_, r) => (
<Button type="link" size="small" onClick={() => setDetailItem(r.id)}>
{r.title}
</Button>
),
},
{
title: '状态',
dataIndex: 'status',
width: 80,
valueEnum: Object.fromEntries(
Object.entries(statusLabels).map(([k, v]) => [k, { text: v, status: statusColors[k] === 'green' ? 'Success' : statusColors[k] === 'orange' ? 'Warning' : 'Default' }]),
),
},
{ title: '版本', dataIndex: 'version', width: 60, search: false },
{ title: '优先级', dataIndex: 'priority', width: 70, search: false },
{
title: '标签',
dataIndex: 'tags',
width: 200,
search: false,
render: (_, r) => (
<Space size={[4, 4]} wrap>
{r.tags?.map((t) => <Tag key={t}>{t}</Tag>)}
</Space>
),
},
{ title: '更新时间', dataIndex: 'updated_at', width: 160, valueType: 'dateTime', search: false },
{
title: '操作',
width: 150,
search: false,
render: (_, r) => (
<Space>
<Button type="link" size="small" icon={<EyeOutlined />} onClick={() => setDetailItem(r.id)} />
<Tooltip title="版本历史">
<Button type="link" size="small" icon={<HistoryOutlined />} onClick={() => { setDetailItem(r.id); setVersionModalOpen(true) }} />
</Tooltip>
<Popconfirm title="确认删除?" onConfirm={() => deleteMutation.mutate(r.id)}>
<Button type="link" size="small" danger icon={<DeleteOutlined />} />
</Popconfirm>
</Space>
),
},
]
return (
<div>
<ProTable<KnowledgeItem>
columns={columns}
dataSource={data?.items || []}
loading={isLoading}
rowKey="id"
search={{
onReset: () => { setFilters({}); setPage(1) },
onSearch: (values) => { setFilters(values); setPage(1) },
}}
toolBarRender={() => [
<Button key="create" type="primary" icon={<PlusOutlined />} onClick={() => setCreateOpen(true)}>
</Button>,
]}
pagination={{
current: page,
pageSize,
total: data?.total || 0,
showSizeChanger: true,
onChange: (p, ps) => { setPage(p); setPageSize(ps) },
}}
options={{ density: false, fullScreen: false, reload: () => queryClient.invalidateQueries({ queryKey: ['knowledge-items'] }) }}
/>
{/* 创建弹窗 */}
<Modal
title="新建知识条目"
open={createOpen}
onCancel={() => { setCreateOpen(false); form.resetFields() }}
onOk={() => form.submit()}
confirmLoading={createMutation.isPending}
width={640}
>
<Form form={form} layout="vertical" onFinish={(v) => createMutation.mutate(v)}>
<Form.Item name="category_id" label="分类" rules={[{ required: true, message: '请选择分类' }]}>
<Select placeholder="选择分类">
{flattenCategories(categories).map((c) => (
<Select.Option key={c.id} value={c.id}>{c.name}</Select.Option>
))}
</Select>
</Form.Item>
<Form.Item name="title" label="标题" rules={[{ required: true, message: '请输入标题' }]}>
<Input placeholder="知识条目标题" />
</Form.Item>
<Form.Item name="content" label="内容" rules={[{ required: true, message: '请输入内容' }]}>
<TextArea rows={8} placeholder="支持 Markdown 格式" />
</Form.Item>
<Row gutter={16}>
<Col span={12}>
<Form.Item name="keywords" label="关键词">
<Select mode="tags" placeholder="输入后回车添加" />
</Form.Item>
</Col>
<Col span={12}>
<Form.Item name="tags" label="标签">
<Select mode="tags" placeholder="输入后回车添加" />
</Form.Item>
</Col>
</Row>
<Form.Item name="priority" label="优先级" initialValue={0}>
<InputNumber min={0} max={100} />
</Form.Item>
</Form>
</Modal>
{/* 详情弹窗 */}
<Modal
title={detailData?.title || '条目详情'}
open={!!detailItem && !versionModalOpen}
onCancel={() => setDetailItem(null)}
footer={null}
width={720}
>
{detailData && (
<div>
<div className="mb-4 flex gap-2">
<Tag color={statusColors[detailData.status]}>{statusLabels[detailData.status] || detailData.status}</Tag>
<Tag> {detailData.version}</Tag>
<Tag> {detailData.priority}</Tag>
</div>
<div className="mb-4 whitespace-pre-wrap bg-neutral-50 dark:bg-neutral-900 p-4 rounded-lg max-h-96 overflow-y-auto text-sm">
{detailData.content}
</div>
<div className="flex gap-2 flex-wrap">
{detailData.tags?.map((t) => <Tag key={t} color="blue">{t}</Tag>)}
{detailData.keywords?.map((k) => <Tag key={k} color="cyan">{k}</Tag>)}
</div>
</div>
)}
</Modal>
{/* 版本历史弹窗 */}
<Modal
title={`版本历史 - ${detailData?.title || ''}`}
open={versionModalOpen}
onCancel={() => { setVersionModalOpen(false); setDetailItem(null) }}
footer={null}
width={720}
>
<Table
dataSource={versions?.versions || []}
rowKey="id"
loading={!versions}
size="small"
pagination={{ pageSize: 10 }}
columns={[
{ title: '版本', dataIndex: 'version', width: 70 },
{ title: '标题', dataIndex: 'title', ellipsis: true },
{ title: '摘要', dataIndex: 'change_summary', width: 200, ellipsis: true },
{ title: '创建者', dataIndex: 'created_by', width: 100 },
{ title: '创建时间', dataIndex: 'created_at', width: 160 },
{
title: '操作',
width: 80,
render: (_, r) => (
<Popconfirm
title={`确认回滚到版本 ${r.version}?`}
description="回滚将创建新版本,当前版本内容会被替换。"
onConfirm={() => {
setRollingBackVersion(r.version)
rollbackMutation.mutate({ itemId: detailItem!, version: r.version })
}}
>
<Button type="link" size="small" icon={<RollbackOutlined />} loading={rollingBackVersion === r.version}>
</Button>
</Popconfirm>
),
},
]}
/>
</div>
)
}
// === 搜索面板 ===
function SearchPanel() {
const [query, setQuery] = useState('')
const [results, setResults] = useState<SearchResult[]>([])
const [searching, setSearching] = useState(false)
const [hasSearched, setHasSearched] = useState(false)
const handleSearch = async () => {
if (!query.trim()) return
setSearching(true)
try {
const data = await knowledgeService.search({ query: query.trim(), limit: 10 })
setResults(data)
setHasSearched(true)
} catch {
message.error('搜索失败')
} finally {
setSearching(false)
}
}
return (
<div>
<Title level={5}></Title>
<Space.Compact className="w-full mb-4">
<Input
size="large"
placeholder="输入搜索关键词..."
value={query}
onChange={(e) => setQuery(e.target.value)}
onPressEnter={handleSearch}
prefix={<SearchOutlined />}
/>
<Button size="large" type="primary" loading={searching} onClick={handleSearch}>
</Button>
</Space.Compact>
{results.length === 0 && !searching && !hasSearched && (
<Empty description="输入关键词搜索知识库" />
)}
{results.length === 0 && !searching && hasSearched && (
<Empty description="未找到匹配的知识条目" />
)}
<div className="space-y-3">
{results.map((r) => (
<Card key={r.chunk_id} size="small" hoverable>
<div className="flex justify-between items-start mb-2">
<Text strong>{r.item_title}</Text>
<Tag>{r.category_name}</Tag>
</div>
<div className="text-sm text-neutral-600 dark:text-neutral-400 line-clamp-3 mb-2">
{r.content}
</div>
<div className="flex gap-1 flex-wrap">
{r.keywords?.slice(0, 5).map((k) => (
<Tag key={k} color="cyan" style={{ fontSize: 11 }}>{k}</Tag>
))}
</div>
</Card>
))}
</div>
</div>
)
}
// === 分析看板 ===
function AnalyticsPanel() {
const { data: overview, isLoading: overviewLoading } = useQuery({
queryKey: ['knowledge-analytics'],
queryFn: ({ signal }) => knowledgeService.getOverview(signal),
})
const { data: trends } = useQuery({
queryKey: ['knowledge-trends'],
queryFn: ({ signal }) => knowledgeService.getTrends(signal),
})
const { data: topItems } = useQuery({
queryKey: ['knowledge-top-items'],
queryFn: ({ signal }) => knowledgeService.getTopItems(signal),
})
const { data: quality } = useQuery({
queryKey: ['knowledge-quality'],
queryFn: ({ signal }) => knowledgeService.getQuality(signal),
})
const { data: gaps } = useQuery({
queryKey: ['knowledge-gaps'],
queryFn: ({ signal }) => knowledgeService.getGaps(signal),
})
if (overviewLoading) return <div className="flex justify-center py-8"><Spin /></div>
return (
<div>
<Title level={5} className="mb-4"></Title>
<Row gutter={[16, 16]}>
<Col span={6}>
<Card><Statistic title="总条目数" value={overview?.total_items || 0} /></Card>
</Col>
<Col span={6}>
<Card><Statistic title="活跃条目" value={overview?.active_items || 0} valueStyle={{ color: '#52c41a' }} /></Card>
</Col>
<Col span={6}>
<Card><Statistic title="分类数" value={overview?.total_categories || 0} /></Card>
</Col>
<Col span={6}>
<Card><Statistic title="本周新增" value={overview?.weekly_new_items || 0} valueStyle={{ color: '#1890ff' }} /></Card>
</Col>
</Row>
<Row gutter={[16, 16]} className="mt-4">
<Col span={6}>
<Card><Statistic title="总引用次数" value={overview?.total_references || 0} /></Card>
</Col>
<Col span={6}>
<Card>
<Statistic title="注入率" value={((overview?.injection_rate || 0) * 100).toFixed(1)} suffix="%" />
</Card>
</Col>
<Col span={6}>
<Card>
<Statistic title="正面反馈率" value={((overview?.positive_feedback_rate || 0) * 100).toFixed(1)} suffix="%" valueStyle={{ color: '#52c41a' }} />
</Card>
</Col>
<Col span={6}>
<Card><Statistic title="过期条目" value={overview?.stale_items_count || 0} valueStyle={{ color: '#faad14' }} /></Card>
</Col>
</Row>
{/* 趋势数据表格 */}
<Card title="检索趋势近30天" className="mt-4" size="small">
<Table
dataSource={trends?.trends || []}
rowKey="date"
loading={!trends}
size="small"
pagination={{ pageSize: 10 }}
columns={[
{ title: '日期', dataIndex: 'date', width: 120 },
{ title: '检索次数', dataIndex: 'count', width: 100 },
{ title: '注入次数', dataIndex: 'injected_count', width: 100 },
]}
/>
</Card>
{/* Top Items 表格 */}
<Card title="高频引用 Top 20" className="mt-4" size="small">
<Table
dataSource={topItems?.items || []}
rowKey="id"
loading={!topItems}
size="small"
pagination={{ pageSize: 10 }}
columns={[
{ title: '标题', dataIndex: 'title', ellipsis: true },
{ title: '分类', dataIndex: 'category', width: 120 },
{ title: '引用次数', dataIndex: 'ref_count', width: 100 },
]}
/>
</Card>
{/* 质量指标 */}
{quality?.categories?.length > 0 && (
<Card title="分类质量指标" className="mt-4" size="small">
<Table
dataSource={quality.categories}
rowKey="category"
size="small"
pagination={false}
columns={[
{ title: '分类', dataIndex: 'category', width: 150 },
{ title: '总条目', dataIndex: 'total', width: 80 },
{ title: '活跃', dataIndex: 'active', width: 80 },
{ title: '有关键词', dataIndex: 'with_keywords', width: 100 },
{ title: '平均优先级', dataIndex: 'avg_priority', width: 100, render: (v: number) => v?.toFixed(1) },
]}
/>
</Card>
)}
{/* 知识缺口 */}
{gaps?.gaps?.length > 0 && (
<Card
title={
<Space>
<WarningOutlined style={{ color: '#faad14' }} />
<span></span>
</Space>
}
className="mt-4"
size="small"
>
<Table
dataSource={gaps.gaps}
rowKey="query"
size="small"
pagination={{ pageSize: 10 }}
columns={[
{ title: '查询', dataIndex: 'query', ellipsis: true },
{ title: '次数', dataIndex: 'count', width: 80 },
{ title: '平均分', dataIndex: 'avg_score', width: 100, render: (v: number) => v?.toFixed(2) },
]}
/>
</Card>
)}
</div>
)
}
// === 主页面 ===
export default function Knowledge() {
return (
<div className="p-6">
<Tabs
defaultActiveKey="items"
items={[
{
key: 'items',
label: '知识条目',
icon: <BookOutlined />,
children: <ItemsPanel />,
},
{
key: 'categories',
label: '分类管理',
icon: <FolderOutlined />,
children: <CategoriesPanel />,
},
{
key: 'search',
label: '搜索',
icon: <SearchOutlined />,
children: <SearchPanel />,
},
{
key: 'analytics',
label: '分析看板',
icon: <BarChartOutlined />,
children: <AnalyticsPanel />,
},
]}
/>
</div>
)
}
// === 辅助函数 ===
function flattenCategories(cats: CategoryResponse[]): { id: string; name: string }[] {
const result: { id: string; name: string }[] = []
for (const c of cats) {
result.push({ id: c.id, name: c.name })
if (c.children?.length) {
result.push(...flattenCategories(c.children))
}
}
return result
}
interface TreeNode {
key: string
title: React.ReactNode
icon?: React.ReactNode
children?: TreeNode[]
}
function buildTreeData(cats: CategoryResponse[], onDelete: (id: string) => void, onEdit: (id: string) => void): TreeNode[] {
return cats.map((c) => ({
key: c.id,
title: (
<div className="flex items-center gap-2">
<span>{c.icon || '📁'} {c.name}</span>
<Tag>{c.item_count}</Tag>
<Button type="link" size="small" icon={<EditOutlined />} onClick={() => onEdit(c.id)} />
<Button type="link" size="small" danger onClick={() => onDelete(c.id)}>
<DeleteOutlined />
</Button>
</div>
),
children: c.children?.length ? buildTreeData(c.children, onDelete, onEdit) : undefined,
}))
}

View File

@@ -0,0 +1,509 @@
// ============================================================
// 角色与权限模板管理
// ============================================================
import { useState } from 'react'
import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query'
import {
Button,
message,
Tag,
Modal,
Form,
Input,
Select,
Space,
Popconfirm,
Tabs,
Tooltip,
} from 'antd'
import { PlusOutlined, SafetyOutlined, CheckCircleOutlined } from '@ant-design/icons'
import type { ProColumns } from '@ant-design/pro-components'
import { ProTable } from '@ant-design/pro-components'
import { roleService } from '@/services/roles'
import { PageHeader } from '@/components/PageHeader'
import type {
Role,
PermissionTemplate,
CreateRoleRequest,
UpdateRoleRequest,
CreateTemplateRequest,
} from '@/types'
// ============================================================
// 常见权限选项
// ============================================================
const permissionOptions = [
{ value: 'account:admin', label: 'account:admin' },
{ value: 'provider:manage', label: 'provider:manage' },
{ value: 'model:read', label: 'model:read' },
{ value: 'model:write', label: 'model:write' },
{ value: 'relay:use', label: 'relay:use' },
{ value: 'knowledge:read', label: 'knowledge:read' },
{ value: 'knowledge:write', label: 'knowledge:write' },
{ value: 'billing:read', label: 'billing:read' },
{ value: 'billing:write', label: 'billing:write' },
{ value: 'config:read', label: 'config:read' },
{ value: 'config:write', label: 'config:write' },
{ value: 'prompt:read', label: 'prompt:read' },
{ value: 'prompt:write', label: 'prompt:write' },
{ value: 'admin:full', label: 'admin:full' },
]
// ============================================================
// Roles Tab
// ============================================================
function RolesTab() {
const queryClient = useQueryClient()
const [form] = Form.useForm()
const [modalOpen, setModalOpen] = useState(false)
const [editingId, setEditingId] = useState<string | null>(null)
const { data, isLoading } = useQuery({
queryKey: ['roles'],
queryFn: ({ signal }) => roleService.list(signal),
})
const createMutation = useMutation({
mutationFn: (data: CreateRoleRequest) => roleService.create(data),
onSuccess: () => {
message.success('角色已创建')
queryClient.invalidateQueries({ queryKey: ['roles'] })
setModalOpen(false)
form.resetFields()
},
onError: (err: Error) => message.error(err.message || '创建失败'),
})
const updateMutation = useMutation({
mutationFn: ({ id, data }: { id: string; data: UpdateRoleRequest }) =>
roleService.update(id, data),
onSuccess: () => {
message.success('角色已更新')
queryClient.invalidateQueries({ queryKey: ['roles'] })
setModalOpen(false)
},
onError: (err: Error) => message.error(err.message || '更新失败'),
})
const deleteMutation = useMutation({
mutationFn: (id: string) => roleService.delete(id),
onSuccess: () => {
message.success('角色已删除')
queryClient.invalidateQueries({ queryKey: ['roles'] })
},
onError: (err: Error) => message.error(err.message || '删除失败'),
})
const handleSave = async () => {
const values = await form.validateFields()
if (editingId) {
updateMutation.mutate({ id: editingId, data: values })
} else {
createMutation.mutate(values)
}
}
const openEdit = async (record: Role) => {
setEditingId(record.id)
const permissions = await roleService.getPermissions(record.id).catch(() => record.permissions)
form.setFieldsValue({ ...record, permissions })
setModalOpen(true)
}
const openCreate = () => {
setEditingId(null)
form.resetFields()
setModalOpen(true)
}
const closeModal = () => {
setModalOpen(false)
setEditingId(null)
form.resetFields()
}
const columns: ProColumns<Role>[] = [
{
title: '角色名称',
dataIndex: 'name',
width: 160,
render: (_, record) => (
<span className="font-medium text-neutral-900 dark:text-neutral-100">
{record.name}
</span>
),
},
{
title: '描述',
dataIndex: 'description',
width: 240,
ellipsis: true,
render: (_, record) => record.description || '-',
},
{
title: '权限数',
dataIndex: 'permissions',
width: 100,
render: (_, record) => (
<Tooltip title={record.permissions?.join(', ') || '无权限'}>
<Tag>{record.permissions?.length ?? 0} </Tag>
</Tooltip>
),
},
{
title: '关联账号',
dataIndex: 'account_count',
width: 100,
render: (_, record) => record.account_count ?? 0,
},
{
title: '创建时间',
dataIndex: 'created_at',
width: 180,
render: (_, record) =>
record.created_at ? new Date(record.created_at).toLocaleString('zh-CN') : '-',
},
{
title: '操作',
width: 160,
render: (_, record) => (
<Space>
<Button size="small" onClick={() => openEdit(record)}>
</Button>
<Popconfirm
title="确定删除此角色?"
description="删除后关联的账号将失去此角色权限"
onConfirm={() => deleteMutation.mutate(record.id)}
>
<Button size="small" danger>
</Button>
</Popconfirm>
</Space>
),
},
]
return (
<div>
<ProTable<Role>
columns={columns}
dataSource={data ?? []}
loading={isLoading}
rowKey="id"
search={false}
toolBarRender={() => [
<Button key="add" type="primary" icon={<PlusOutlined />} onClick={openCreate}>
</Button>,
]}
pagination={{ showSizeChanger: false }}
/>
<Modal
title={editingId ? '编辑角色' : '新建角色'}
open={modalOpen}
onOk={handleSave}
onCancel={closeModal}
confirmLoading={createMutation.isPending || updateMutation.isPending}
width={560}
>
<Form form={form} layout="vertical" className="mt-4">
<Form.Item
name="name"
label="角色名称"
rules={[{ required: true, message: '请输入角色名称' }]}
>
<Input placeholder="如 editor, viewer" />
</Form.Item>
<Form.Item name="description" label="描述">
<Input.TextArea rows={2} placeholder="角色用途说明" />
</Form.Item>
<Form.Item name="permissions" label="权限">
<Select
mode="multiple"
placeholder="选择权限"
options={permissionOptions}
maxTagCount={5}
allowClear
filterOption={(input, option) =>
(option?.label as string)?.toLowerCase().includes(input.toLowerCase())
}
/>
</Form.Item>
</Form>
</Modal>
</div>
)
}
// ============================================================
// Permission Templates Tab
// ============================================================
function TemplatesTab() {
const queryClient = useQueryClient()
const [form] = Form.useForm()
const [modalOpen, setModalOpen] = useState(false)
const [applyOpen, setApplyOpen] = useState(false)
const [applyForm] = Form.useForm()
const [selectedTemplate, setSelectedTemplate] = useState<PermissionTemplate | null>(null)
const { data, isLoading } = useQuery({
queryKey: ['permission-templates'],
queryFn: ({ signal }) => roleService.listTemplates(signal),
})
const createMutation = useMutation({
mutationFn: (data: CreateTemplateRequest) => roleService.createTemplate(data),
onSuccess: () => {
message.success('模板已创建')
queryClient.invalidateQueries({ queryKey: ['permission-templates'] })
setModalOpen(false)
form.resetFields()
},
onError: (err: Error) => message.error(err.message || '创建失败'),
})
const deleteMutation = useMutation({
mutationFn: (id: string) => roleService.deleteTemplate(id),
onSuccess: () => {
message.success('模板已删除')
queryClient.invalidateQueries({ queryKey: ['permission-templates'] })
},
onError: (err: Error) => message.error(err.message || '删除失败'),
})
const applyMutation = useMutation({
mutationFn: ({ templateId, accountIds }: { templateId: string; accountIds: string[] }) =>
roleService.applyTemplate(templateId, accountIds),
onSuccess: () => {
message.success('模板已应用到所选账号')
queryClient.invalidateQueries({ queryKey: ['permission-templates'] })
setApplyOpen(false)
applyForm.resetFields()
setSelectedTemplate(null)
},
onError: (err: Error) => message.error(err.message || '应用失败'),
})
const openApply = (record: PermissionTemplate) => {
setSelectedTemplate(record)
applyForm.resetFields()
setApplyOpen(true)
}
const handleApply = async () => {
const values = await applyForm.validateFields()
if (!selectedTemplate) return
const accountIds = values.account_ids
?.split(',')
.map((s: string) => s.trim())
.filter(Boolean)
if (!accountIds?.length) {
message.warning('请输入至少一个账号 ID')
return
}
applyMutation.mutate({ templateId: selectedTemplate.id, accountIds })
}
const columns: ProColumns<PermissionTemplate>[] = [
{
title: '模板名称',
dataIndex: 'name',
width: 180,
render: (_, record) => (
<span className="font-medium text-neutral-900 dark:text-neutral-100">
{record.name}
</span>
),
},
{
title: '描述',
dataIndex: 'description',
width: 240,
ellipsis: true,
render: (_, record) => record.description || '-',
},
{
title: '权限数',
dataIndex: 'permissions',
width: 100,
render: (_, record) => (
<Tooltip title={record.permissions?.join(', ') || '无权限'}>
<Tag>{record.permissions?.length ?? 0} </Tag>
</Tooltip>
),
},
{
title: '创建时间',
dataIndex: 'created_at',
width: 180,
render: (_, record) =>
record.created_at ? new Date(record.created_at).toLocaleString('zh-CN') : '-',
},
{
title: '操作',
width: 180,
render: (_, record) => (
<Space>
<Button
size="small"
icon={<CheckCircleOutlined />}
onClick={() => openApply(record)}
>
</Button>
<Popconfirm
title="确定删除此模板?"
description="删除后已应用的账号不受影响"
onConfirm={() => deleteMutation.mutate(record.id)}
>
<Button size="small" danger>
</Button>
</Popconfirm>
</Space>
),
},
]
return (
<div>
<ProTable<PermissionTemplate>
columns={columns}
dataSource={data ?? []}
loading={isLoading}
rowKey="id"
search={false}
toolBarRender={() => [
<Button
key="add"
type="primary"
icon={<PlusOutlined />}
onClick={() => {
form.resetFields()
setModalOpen(true)
}}
>
</Button>,
]}
pagination={{ showSizeChanger: false }}
/>
{/* Create Template Modal */}
<Modal
title="新建权限模板"
open={modalOpen}
onOk={async () => {
const values = await form.validateFields()
createMutation.mutate(values)
}}
onCancel={() => {
setModalOpen(false)
form.resetFields()
}}
confirmLoading={createMutation.isPending}
width={560}
>
<Form form={form} layout="vertical" className="mt-4">
<Form.Item
name="name"
label="模板名称"
rules={[{ required: true, message: '请输入模板名称' }]}
>
<Input placeholder="如 basic-user, power-user" />
</Form.Item>
<Form.Item name="description" label="描述">
<Input.TextArea rows={2} placeholder="模板用途说明" />
</Form.Item>
<Form.Item name="permissions" label="权限">
<Select
mode="multiple"
placeholder="选择权限"
options={permissionOptions}
maxTagCount={5}
allowClear
filterOption={(input, option) =>
(option?.label as string)?.toLowerCase().includes(input.toLowerCase())
}
/>
</Form.Item>
</Form>
</Modal>
{/* Apply Template Modal */}
<Modal
title={`应用模板: ${selectedTemplate?.name ?? ''}`}
open={applyOpen}
onOk={handleApply}
onCancel={() => {
setApplyOpen(false)
setSelectedTemplate(null)
applyForm.resetFields()
}}
confirmLoading={applyMutation.isPending}
width={480}
>
<Form form={applyForm} layout="vertical" className="mt-4">
<div className="mb-4 text-sm text-neutral-500 dark:text-neutral-400">
{selectedTemplate?.permissions?.length ?? 0}
ID ID
</div>
<Form.Item
name="account_ids"
label="账号 ID"
rules={[{ required: true, message: '请输入账号 ID' }]}
>
<Input.TextArea
rows={3}
placeholder="如: acc_abc123, acc_def456"
/>
</Form.Item>
</Form>
</Modal>
</div>
)
}
// ============================================================
// Main Page: Roles & Permissions
// ============================================================
export default function Roles() {
return (
<div>
<PageHeader
title="角色与权限"
description="管理角色、权限模板,并将权限批量应用到账号"
/>
<Tabs
defaultActiveKey="roles"
items={[
{
key: 'roles',
label: (
<span className="flex items-center gap-1.5">
<SafetyOutlined />
</span>
),
children: <RolesTab />,
},
{
key: 'templates',
label: (
<span className="flex items-center gap-1.5">
<CheckCircleOutlined />
</span>
),
children: <TemplatesTab />,
},
]}
/>
</div>
)
}

View File

@@ -21,13 +21,16 @@ export const router = createBrowserRouter([
children: [
{ index: true, lazy: () => import('@/pages/Dashboard').then((m) => ({ Component: m.default })) },
{ path: 'accounts', lazy: () => import('@/pages/Accounts').then((m) => ({ Component: m.default })) },
{ path: 'roles', lazy: () => import('@/pages/Roles').then((m) => ({ Component: m.default })) },
{ path: 'model-services', lazy: () => import('@/pages/ModelServices').then((m) => ({ Component: m.default })) },
{ path: 'providers', lazy: () => import('@/pages/ModelServices').then((m) => ({ Component: m.default })) },
{ path: 'models', lazy: () => import('@/pages/ModelServices').then((m) => ({ Component: m.default })) },
{ path: 'agent-templates', lazy: () => import('@/pages/AgentTemplates').then((m) => ({ Component: m.default })) },
{ path: 'api-keys', lazy: () => import('@/pages/ModelServices').then((m) => ({ Component: m.default })) },
{ path: 'usage', lazy: () => import('@/pages/Usage').then((m) => ({ Component: m.default })) },
{ path: 'billing', lazy: () => import('@/pages/Billing').then((m) => ({ Component: m.default })) },
{ path: 'relay', lazy: () => import('@/pages/Relay').then((m) => ({ Component: m.default })) },
{ path: 'knowledge', lazy: () => import('@/pages/Knowledge').then((m) => ({ Component: m.default })) },
{ path: 'config', lazy: () => import('@/pages/Config').then((m) => ({ Component: m.default })) },
{ path: 'prompts', lazy: () => import('@/pages/Prompts').then((m) => ({ Component: m.default })) },
{ path: 'logs', lazy: () => import('@/pages/Logs').then((m) => ({ Component: m.default })) },

View File

@@ -0,0 +1,101 @@
import request, { withSignal } from './request'
// === Types ===
export interface BillingPlan {
id: string
name: string
display_name: string
description: string | null
price_cents: number
currency: string
interval: string
features: Record<string, unknown>
limits: Record<string, unknown>
is_default: boolean
sort_order: number
status: string
created_at: string
updated_at: string
}
export interface Subscription {
id: string
account_id: string
plan_id: string
status: string
current_period_start: string
current_period_end: string
trial_end: string | null
canceled_at: string | null
cancel_at_period_end: boolean
created_at: string
updated_at: string
}
export interface UsageQuota {
id: string
account_id: string
period_start: string
period_end: string
input_tokens: number
output_tokens: number
relay_requests: number
hand_executions: number
pipeline_runs: number
max_input_tokens: number | null
max_output_tokens: number | null
max_relay_requests: number | null
max_hand_executions: number | null
max_pipeline_runs: number | null
created_at: string
updated_at: string
}
export interface SubscriptionInfo {
plan: BillingPlan
subscription: Subscription | null
usage: UsageQuota
}
export interface PaymentResult {
payment_id: string
trade_no: string
pay_url: string
amount_cents: number
}
export interface PaymentStatus {
id: string
method: string
amount_cents: number
currency: string
status: string
}
// === Service ===
export const billingService = {
listPlans: (signal?: AbortSignal) =>
request.get<BillingPlan[]>('/billing/plans', withSignal({}, signal))
.then((r) => r.data),
getPlan: (id: string, signal?: AbortSignal) =>
request.get<BillingPlan>(`/billing/plans/${id}`, withSignal({}, signal))
.then((r) => r.data),
getSubscription: (signal?: AbortSignal) =>
request.get<SubscriptionInfo>('/billing/subscription', withSignal({}, signal))
.then((r) => r.data),
getUsage: (signal?: AbortSignal) =>
request.get<UsageQuota>('/billing/usage', withSignal({}, signal))
.then((r) => r.data),
createPayment: (data: { plan_id: string; payment_method: 'alipay' | 'wechat' }) =>
request.post<PaymentResult>('/billing/payments', data).then((r) => r.data),
getPaymentStatus: (id: string, signal?: AbortSignal) =>
request.get<PaymentStatus>(`/billing/payments/${id}`, withSignal({}, signal))
.then((r) => r.data),
}

View File

@@ -0,0 +1,162 @@
import request, { withSignal } from './request'
// === Types ===
export interface CategoryResponse {
id: string
name: string
description: string | null
parent_id: string | null
icon: string | null
sort_order: number
item_count: number
children: CategoryResponse[]
created_at: string
updated_at: string
}
export interface KnowledgeItem {
id: string
category_id: string
title: string
content: string
keywords: string[]
related_questions: string[]
priority: number
status: string
version: number
source: string
tags: string[]
created_by: string
created_at: string
updated_at: string
}
export interface SearchResult {
chunk_id: string
item_id: string
item_title: string
category_name: string
content: string
score: number
keywords: string[]
}
export interface AnalyticsOverview {
total_items: number
active_items: number
total_categories: number
weekly_new_items: number
total_references: number
avg_reference_per_item: number
hit_rate: number
injection_rate: number
positive_feedback_rate: number
stale_items_count: number
}
export interface ListItemsResponse {
items: KnowledgeItem[]
total: number
page: number
page_size: number
}
// === Service ===
export const knowledgeService = {
// 分类
listCategories: (signal?: AbortSignal) =>
request.get<CategoryResponse[]>('/knowledge/categories', withSignal({}, signal))
.then((r) => r.data),
createCategory: (data: { name: string; description?: string; parent_id?: string; icon?: string }) =>
request.post('/knowledge/categories', data).then((r) => r.data),
deleteCategory: (id: string) =>
request.delete(`/knowledge/categories/${id}`).then((r) => r.data),
updateCategory: (id: string, data: { name?: string; description?: string; parent_id?: string; icon?: string }) =>
request.put(`/knowledge/categories/${id}`, data).then((r) => r.data),
reorderCategories: (items: Array<{ id: string; sort_order: number }>) =>
request.patch('/knowledge/categories/reorder', { items }).then((r) => r.data),
getCategoryItems: (id: string, params?: { page?: number; page_size?: number; status?: string }, signal?: AbortSignal) =>
request.get<ListItemsResponse>(`/knowledge/categories/${id}/items`, withSignal({ params }, signal))
.then((r) => r.data),
// 条目
listItems: (params: { page?: number; page_size?: number; category_id?: string; status?: string; keyword?: string }, signal?: AbortSignal) =>
request.get<ListItemsResponse>('/knowledge/items', withSignal({ params }, signal))
.then((r) => r.data),
getItem: (id: string, signal?: AbortSignal) =>
request.get<KnowledgeItem>(`/knowledge/items/${id}`, withSignal({}, signal))
.then((r) => r.data),
createItem: (data: {
category_id: string
title: string
content: string
keywords?: string[]
related_questions?: string[]
priority?: number
tags?: string[]
}) => request.post('/knowledge/items', data).then((r) => r.data),
updateItem: (id: string, data: Record<string, unknown>) =>
request.put(`/knowledge/items/${id}`, data).then((r) => r.data),
deleteItem: (id: string) =>
request.delete(`/knowledge/items/${id}`).then((r) => r.data),
batchCreate: (items: Array<{
category_id: string
title: string
content: string
keywords?: string[]
tags?: string[]
}>) => request.post('/knowledge/items/batch', items).then((r) => r.data),
// 搜索
search: (data: { query: string; category_id?: string; limit?: number }) =>
request.post<SearchResult[]>('/knowledge/search', data).then((r) => r.data),
// 分析
getOverview: (signal?: AbortSignal) =>
request.get<AnalyticsOverview>('/knowledge/analytics/overview', withSignal({}, signal))
.then((r) => r.data),
getTrends: (signal?: AbortSignal) =>
request.get('/knowledge/analytics/trends', withSignal({}, signal))
.then((r) => r.data),
getTopItems: (signal?: AbortSignal) =>
request.get('/knowledge/analytics/top-items', withSignal({}, signal))
.then((r) => r.data),
getQuality: (signal?: AbortSignal) =>
request.get('/knowledge/analytics/quality', withSignal({}, signal))
.then((r) => r.data),
getGaps: (signal?: AbortSignal) =>
request.get('/knowledge/analytics/gaps', withSignal({}, signal))
.then((r) => r.data),
// 版本
getVersions: (itemId: string, signal?: AbortSignal) =>
request.get(`/knowledge/items/${itemId}/versions`, withSignal({}, signal))
.then((r) => r.data),
rollbackVersion: (itemId: string, version: number) =>
request.post(`/knowledge/items/${itemId}/rollback/${version}`).then((r) => r.data),
// 推荐搜索
recommend: (data: { query: string; category_id?: string; limit?: number }) =>
request.post<SearchResult[]>('/knowledge/recommend', data).then((r) => r.data),
// 导入
importItems: (data: { category_id: string; files: Array<{ content: string; title?: string; keywords?: string[]; tags?: string[] }> }) =>
request.post('/knowledge/items/import', data).then((r) => r.data),
}

View File

@@ -0,0 +1,50 @@
// ============================================================
// 角色与权限模板 服务层
// ============================================================
import request, { withSignal } from './request'
import type {
Role,
PermissionTemplate,
CreateRoleRequest,
UpdateRoleRequest,
CreateTemplateRequest,
} from '@/types'
export const roleService = {
// ── Roles ─────────────────────────────────────────────────
list: (signal?: AbortSignal) =>
request.get<Role[]>('/roles', withSignal({}, signal)).then((r) => r.data),
get: (id: string, signal?: AbortSignal) =>
request.get<Role>(`/roles/${id}`, withSignal({}, signal)).then((r) => r.data),
create: (data: CreateRoleRequest, signal?: AbortSignal) =>
request.post<Role>('/roles', data, withSignal({}, signal)).then((r) => r.data),
update: (id: string, data: UpdateRoleRequest, signal?: AbortSignal) =>
request.put<Role>(`/roles/${id}`, data, withSignal({}, signal)).then((r) => r.data),
delete: (id: string, signal?: AbortSignal) =>
request.delete(`/roles/${id}`, withSignal({}, signal)).then((r) => r.data),
// ── Role Permissions ──────────────────────────────────────
getPermissions: (roleId: string, signal?: AbortSignal) =>
request.get<string[]>(`/roles/${roleId}/permissions`, withSignal({}, signal)).then((r) => r.data),
// ── Permission Templates ──────────────────────────────────
listTemplates: (signal?: AbortSignal) =>
request.get<PermissionTemplate[]>('/permission-templates', withSignal({}, signal)).then((r) => r.data),
getTemplate: (id: string, signal?: AbortSignal) =>
request.get<PermissionTemplate>(`/permission-templates/${id}`, withSignal({}, signal)).then((r) => r.data),
createTemplate: (data: CreateTemplateRequest, signal?: AbortSignal) =>
request.post<PermissionTemplate>('/permission-templates', data, withSignal({}, signal)).then((r) => r.data),
deleteTemplate: (id: string, signal?: AbortSignal) =>
request.delete(`/permission-templates/${id}`, withSignal({}, signal)).then((r) => r.data),
applyTemplate: (templateId: string, accountIds: string[], signal?: AbortSignal) =>
request.post(`/permission-templates/${templateId}/apply`, { account_ids: accountIds }, withSignal({}, signal)).then((r) => r.data),
}

View File

@@ -282,3 +282,45 @@ export interface DailyUsageStat {
output_tokens: number
unique_devices: number
}
/** 角色 */
export interface Role {
id: string
name: string
description: string
permissions: string[]
account_count?: number
created_at: string
updated_at: string
}
/** 权限模板 */
export interface PermissionTemplate {
id: string
name: string
description: string
permissions: string[]
created_at: string
updated_at: string
}
/** 创建角色请求 */
export interface CreateRoleRequest {
name: string
description?: string
permissions?: string[]
}
/** 更新角色请求 */
export interface UpdateRoleRequest {
name?: string
description?: string
permissions?: string[]
}
/** 创建权限模板请求 */
export interface CreateTemplateRequest {
name: string
description?: string
permissions?: string[]
}

View File

@@ -0,0 +1,219 @@
// ============================================================
// Config 页面测试
// ============================================================
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
import { render, screen, waitFor } from '@testing-library/react'
import { http, HttpResponse } from 'msw'
import { setupServer } from 'msw/node'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import Config from '@/pages/Config'
// ── Mock data ────────────────────────────────────────────────
const mockConfigItems = [
{
id: 'cfg-001',
category: 'general',
key_path: 'general.app_name',
value_type: 'string',
current_value: 'ZCLAW',
default_value: 'ZCLAW',
source: 'database',
description: '应用程序名称',
requires_restart: false,
created_at: '2026-01-01T00:00:00Z',
updated_at: '2026-01-01T00:00:00Z',
},
{
id: 'cfg-002',
category: 'general',
key_path: 'general.debug_mode',
value_type: 'boolean',
current_value: 'false',
default_value: 'false',
source: 'default',
description: '调试模式开关',
requires_restart: true,
created_at: '2026-01-01T00:00:00Z',
updated_at: '2026-01-01T00:00:00Z',
},
{
id: 'cfg-003',
category: 'general',
key_path: 'general.max_connections',
value_type: 'integer',
current_value: null,
default_value: '100',
source: 'default',
description: '最大连接数',
requires_restart: false,
created_at: '2026-01-01T00:00:00Z',
updated_at: '2026-01-01T00:00:00Z',
},
]
const mockResponse = {
items: mockConfigItems,
total: 3,
page: 1,
page_size: 50,
}
// ── MSW server ───────────────────────────────────────────────
const server = setupServer()
beforeEach(() => {
server.listen({ onUnhandledRequest: 'bypass' })
})
afterEach(() => {
server.close()
})
// ── Helper: render with QueryClient ──────────────────────────
function renderWithProviders(ui: React.ReactElement) {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
})
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>,
)
}
// ── Tests ────────────────────────────────────────────────────
describe('Config page', () => {
it('renders page header', async () => {
server.use(
http.get('*/api/v1/config/items', () => {
return HttpResponse.json(mockResponse)
}),
)
renderWithProviders(<Config />)
expect(screen.getByText('系统配置')).toBeInTheDocument()
expect(screen.getByText('管理系统运行参数和功能开关')).toBeInTheDocument()
})
it('fetches and displays config items', async () => {
server.use(
http.get('*/api/v1/config/items', () => {
return HttpResponse.json(mockResponse)
}),
)
renderWithProviders(<Config />)
await waitFor(() => {
expect(screen.getByText('general.app_name')).toBeInTheDocument()
})
expect(screen.getByText('general.debug_mode')).toBeInTheDocument()
})
it('shows loading spinner while fetching', async () => {
server.use(
http.get('*/api/v1/config/items', async () => {
await new Promise((resolve) => setTimeout(resolve, 500))
return HttpResponse.json(mockResponse)
}),
)
renderWithProviders(<Config />)
// Ant Design Spin component renders a .ant-spin element
const spinner = document.querySelector('.ant-spin')
expect(spinner).toBeTruthy()
// Wait for loading to complete so afterEach cleanup is clean
await waitFor(() => {
expect(screen.getByText('general.app_name')).toBeInTheDocument()
})
})
it('shows error state on API failure', async () => {
server.use(
http.get('*/api/v1/config/items', () => {
return HttpResponse.json(
{ error: 'internal_error', message: '服务器内部错误' },
{ status: 500 },
)
}),
)
renderWithProviders(<Config />)
// Config page does not have a dedicated ErrorState; the ProTable simply
// renders empty when the query fails. We verify the page header is still
// rendered and the table body has no data rows (shows "暂无数据").
await waitFor(() => {
const emptyElements = screen.queryAllByText('暂无数据')
expect(emptyElements.length).toBeGreaterThanOrEqual(1)
})
// Page header is still present even on error
expect(screen.getByText('系统配置')).toBeInTheDocument()
})
it('renders config key_path and current_value columns', async () => {
server.use(
http.get('*/api/v1/config/items', () => {
return HttpResponse.json(mockResponse)
}),
)
renderWithProviders(<Config />)
// key_path values are rendered in <code> elements
await waitFor(() => {
expect(screen.getByText('general.app_name')).toBeInTheDocument()
})
expect(screen.getByText('general.debug_mode')).toBeInTheDocument()
// current_value "ZCLAW" appears in both the current_value column and default_value column
const zclawElements = screen.getAllByText('ZCLAW')
expect(zclawElements.length).toBeGreaterThanOrEqual(1)
})
it('renders requires_restart column with tags', async () => {
server.use(
http.get('*/api/v1/config/items', () => {
return HttpResponse.json(mockResponse)
}),
)
renderWithProviders(<Config />)
await waitFor(() => {
expect(screen.getByText('general.app_name')).toBeInTheDocument()
})
// requires_restart=true renders "是" (orange tag)
expect(screen.getByText('是')).toBeInTheDocument()
// requires_restart=false renders "否" (may appear multiple times for two items)
const noTags = screen.getAllByText('否')
expect(noTags.length).toBeGreaterThanOrEqual(1)
})
it('renders category tabs', async () => {
server.use(
http.get('*/api/v1/config/items', () => {
return HttpResponse.json(mockResponse)
}),
)
renderWithProviders(<Config />)
expect(screen.getByText('通用')).toBeInTheDocument()
expect(screen.getByText('认证')).toBeInTheDocument()
expect(screen.getByText('中转')).toBeInTheDocument()
expect(screen.getByText('模型')).toBeInTheDocument()
})
})

View File

@@ -0,0 +1,242 @@
// ============================================================
// Dashboard 页面测试
// ============================================================
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
import { render, screen, waitFor } from '@testing-library/react'
import { http, HttpResponse } from 'msw'
import { setupServer } from 'msw/node'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import Dashboard from '@/pages/Dashboard'
// ── Mock data ────────────────────────────────────────────────
const mockStats = {
total_accounts: 12,
active_accounts: 8,
tasks_today: 156,
active_providers: 3,
active_models: 7,
tokens_today_input: 24000,
tokens_today_output: 8500,
}
const mockLogs = {
items: [
{
id: 1,
account_id: 'acc-001',
action: 'login',
target_type: 'account',
target_id: 'acc-001',
details: null,
ip_address: '192.168.1.1',
created_at: '2026-03-30T10:00:00Z',
},
{
id: 2,
account_id: 'acc-002',
action: 'create_provider',
target_type: 'provider',
target_id: 'prov-001',
details: { name: 'OpenAI' },
ip_address: '10.0.0.1',
created_at: '2026-03-30T09:30:00Z',
},
],
total: 2,
page: 1,
page_size: 10,
}
// ── MSW server ───────────────────────────────────────────────
const server = setupServer()
beforeEach(() => {
server.listen({ onUnhandledRequest: 'bypass' })
})
afterEach(() => {
server.close()
})
// ── Helper: render with QueryClient ──────────────────────────
function renderWithProviders(ui: React.ReactElement) {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
})
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>,
)
}
// ── Tests ────────────────────────────────────────────────────
describe('Dashboard page', () => {
it('renders page header', async () => {
server.use(
http.get('*/api/v1/stats/dashboard', () => {
return HttpResponse.json(mockStats)
}),
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Dashboard />)
expect(screen.getByText('仪表盘')).toBeInTheDocument()
expect(screen.getByText('系统概览与最近活动')).toBeInTheDocument()
})
it('renders stat cards with correct values', async () => {
server.use(
http.get('*/api/v1/stats/dashboard', () => {
return HttpResponse.json(mockStats)
}),
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Dashboard />)
await waitFor(() => {
expect(screen.getByText('12')).toBeInTheDocument()
})
// Stat titles
expect(screen.getByText('总账号')).toBeInTheDocument()
expect(screen.getByText('活跃服务商')).toBeInTheDocument()
expect(screen.getByText('活跃模型')).toBeInTheDocument()
expect(screen.getByText('今日请求')).toBeInTheDocument()
expect(screen.getByText('今日 Token')).toBeInTheDocument()
// Token total: 24000 + 8500 = 32500
expect(screen.getByText('32,500')).toBeInTheDocument()
})
it('renders recent logs table with action labels', async () => {
server.use(
http.get('*/api/v1/stats/dashboard', () => {
return HttpResponse.json(mockStats)
}),
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Dashboard />)
// Wait for action labels from constants/status.ts
await waitFor(() => {
expect(screen.getByText('登录')).toBeInTheDocument()
})
expect(screen.getByText('创建服务商')).toBeInTheDocument()
})
it('renders target types in logs table', async () => {
server.use(
http.get('*/api/v1/stats/dashboard', () => {
return HttpResponse.json(mockStats)
}),
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Dashboard />)
await waitFor(() => {
expect(screen.getByText('登录')).toBeInTheDocument()
})
expect(screen.getByText('account')).toBeInTheDocument()
expect(screen.getByText('provider')).toBeInTheDocument()
})
it('shows loading spinner before stats load', async () => {
server.use(
http.get('*/api/v1/stats/dashboard', async () => {
await new Promise((resolve) => setTimeout(resolve, 500))
return HttpResponse.json(mockStats)
}),
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Dashboard />)
// Ant Design Spin component renders a .ant-spin element
const spinner = document.querySelector('.ant-spin')
expect(spinner).toBeTruthy()
// Wait for loading to complete so afterEach cleanup is clean
await waitFor(() => {
expect(screen.getByText('总账号')).toBeInTheDocument()
})
})
it('shows error state when stats request fails', async () => {
server.use(
http.get('*/api/v1/stats/dashboard', () => {
return HttpResponse.json(
{ error: 'internal_error', message: '服务器内部错误' },
{ status: 500 },
)
}),
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Dashboard />)
await waitFor(() => {
expect(screen.getByText('服务器内部错误')).toBeInTheDocument()
})
})
it('renders stat cards with zero values when stats are null', async () => {
server.use(
http.get('*/api/v1/stats/dashboard', () => {
return HttpResponse.json({})
}),
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json({ items: [], total: 0, page: 1, page_size: 10 })
}),
)
renderWithProviders(<Dashboard />)
// All stats should fallback to 0
await waitFor(() => {
const zeros = screen.getAllByText('0')
expect(zeros.length).toBeGreaterThanOrEqual(2)
})
})
it('renders recent logs section header', async () => {
server.use(
http.get('*/api/v1/stats/dashboard', () => {
return HttpResponse.json(mockStats)
}),
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Dashboard />)
await waitFor(() => {
expect(screen.getByText('最近操作日志')).toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,219 @@
// ============================================================
// Login 页面测试
// ============================================================
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest'
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import { MemoryRouter } from 'react-router-dom'
import Login from '@/pages/Login'
// ── Mock data ────────────────────────────────────────────────
const mockLoginResponse = {
token: 'jwt-token-123',
refresh_token: 'refresh-token-456',
account: {
id: 'acc-001',
username: 'testadmin',
email: 'admin@zclaw.ai',
display_name: 'Admin',
role: 'super_admin',
status: 'active',
totp_enabled: false,
last_login_at: null,
created_at: '2026-01-01T00:00:00Z',
llm_routing: 'relay',
},
}
const mockAccount = {
id: 'acc-001',
username: 'testadmin',
email: 'admin@zclaw.ai',
display_name: 'Admin',
role: 'super_admin',
status: 'active',
totp_enabled: false,
last_login_at: null,
created_at: '2026-01-01T00:00:00Z',
llm_routing: 'relay',
}
// ── Hoisted mocks ────────────────────────────────────────────
const { mockLogin, mockNavigate, mockAuthServiceLogin } = vi.hoisted(() => ({
mockLogin: vi.fn(),
mockNavigate: vi.fn(),
mockAuthServiceLogin: vi.fn(),
}))
vi.mock('@/stores/authStore', () => ({
useAuthStore: Object.assign(
vi.fn((selector: (s: Record<string, unknown>) => unknown) =>
selector({ login: mockLogin }),
),
{ getState: () => ({ token: null, refreshToken: null, logout: vi.fn() }) },
),
}))
vi.mock('@/services/auth', () => ({
authService: {
login: mockAuthServiceLogin,
},
}))
vi.mock('react-router-dom', async () => {
const actual = await vi.importActual<typeof import('react-router-dom')>('react-router-dom')
return {
...actual,
useNavigate: () => mockNavigate,
}
})
beforeEach(() => {
mockLogin.mockClear()
mockNavigate.mockClear()
mockAuthServiceLogin.mockClear()
})
// ── Helper: render with providers ────────────────────────────
function renderLogin(initialEntries = ['/login']) {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
})
return render(
<QueryClientProvider client={queryClient}>
<MemoryRouter initialEntries={initialEntries}>
<Login />
</MemoryRouter>
</QueryClientProvider>,
)
}
/** Click the LoginForm submit button (Ant Design renders "登 录" with a space) */
function getSubmitButton(): HTMLElement {
const btn = document.querySelector<HTMLButtonElement>(
'button.ant-btn-primary[type="button"]',
)
if (!btn) throw new Error('Submit button not found')
return btn
}
// ── Tests ────────────────────────────────────────────────────
describe('Login page', () => {
it('renders the login form with username and password fields', () => {
renderLogin()
expect(screen.getByText('登录到 ZCLAW')).toBeInTheDocument()
expect(screen.getByPlaceholderText('请输入用户名')).toBeInTheDocument()
expect(screen.getByPlaceholderText('请输入密码')).toBeInTheDocument()
const submitButton = getSubmitButton()
expect(submitButton).toBeTruthy()
})
it('shows the ZCLAW brand logo', () => {
renderLogin()
expect(screen.getByText('Z')).toBeInTheDocument()
expect(screen.getByText(/ZCLAW Admin/)).toBeInTheDocument()
})
it('successful login calls authStore.login and navigates to /', async () => {
const user = userEvent.setup()
mockAuthServiceLogin.mockResolvedValue(mockLoginResponse)
renderLogin()
await user.type(screen.getByPlaceholderText('请输入用户名'), 'testadmin')
await user.type(screen.getByPlaceholderText('请输入密码'), 'password123')
await user.click(getSubmitButton())
await waitFor(() => {
expect(mockLogin).toHaveBeenCalledWith(
'jwt-token-123',
'refresh-token-456',
mockAccount,
)
})
expect(mockNavigate).toHaveBeenCalledWith('/', { replace: true })
})
it('navigates to redirect path after login', async () => {
const user = userEvent.setup()
mockAuthServiceLogin.mockResolvedValue(mockLoginResponse)
renderLogin(['/login?from=/settings'])
await user.type(screen.getByPlaceholderText('请输入用户名'), 'testadmin')
await user.type(screen.getByPlaceholderText('请输入密码'), 'password123')
await user.click(getSubmitButton())
await waitFor(() => {
expect(mockNavigate).toHaveBeenCalledWith('/settings', { replace: true })
})
})
it('shows TOTP field when server returns TOTP-related error', async () => {
const user = userEvent.setup()
const error = new Error('请输入两步验证码 (TOTP)')
Object.assign(error, { status: 403 })
mockAuthServiceLogin.mockRejectedValue(error)
renderLogin()
// Initially no TOTP field
expect(screen.queryByPlaceholderText('请输入 6 位验证码')).not.toBeInTheDocument()
await user.type(screen.getByPlaceholderText('请输入用户名'), 'testadmin')
await user.type(screen.getByPlaceholderText('请输入密码'), 'password123')
await user.click(getSubmitButton())
// After TOTP error, TOTP field appears
await waitFor(() => {
expect(screen.getByPlaceholderText('请输入 6 位验证码')).toBeInTheDocument()
})
})
it('shows error message on invalid credentials', async () => {
const user = userEvent.setup()
const error = new Error('用户名或密码错误')
mockAuthServiceLogin.mockRejectedValue(error)
renderLogin()
await user.type(screen.getByPlaceholderText('请输入用户名'), 'wrong')
await user.type(screen.getByPlaceholderText('请输入密码'), 'wrong')
await user.click(getSubmitButton())
await waitFor(() => {
expect(screen.getByText('用户名或密码错误')).toBeInTheDocument()
})
})
it('does not call authStore.login on failed login', async () => {
const user = userEvent.setup()
const error = new Error('用户名或密码错误')
mockAuthServiceLogin.mockRejectedValue(error)
renderLogin()
await user.type(screen.getByPlaceholderText('请输入用户名'), 'wrong')
await user.type(screen.getByPlaceholderText('请输入密码'), 'wrong')
await user.click(getSubmitButton())
await waitFor(() => {
expect(screen.getByText('用户名或密码错误')).toBeInTheDocument()
})
expect(mockLogin).not.toHaveBeenCalled()
expect(mockNavigate).not.toHaveBeenCalled()
})
})

View File

@@ -0,0 +1,210 @@
// ============================================================
// Logs 页面测试
// ============================================================
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
import { render, screen, waitFor } from '@testing-library/react'
import { http, HttpResponse } from 'msw'
import { setupServer } from 'msw/node'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import Logs from '@/pages/Logs'
// ── Mock data ────────────────────────────────────────────────
const mockLogs = {
items: [
{
id: 1,
account_id: 'acc-001',
action: 'login',
target_type: 'account',
target_id: 'acc-001',
details: null,
ip_address: '192.168.1.1',
created_at: '2026-03-30T10:00:00Z',
},
{
id: 2,
account_id: 'acc-002',
action: 'create_provider',
target_type: 'provider',
target_id: 'prov-001',
details: { name: 'OpenAI' },
ip_address: '10.0.0.1',
created_at: '2026-03-30T09:30:00Z',
},
{
id: 3,
account_id: 'acc-001',
action: 'delete_model',
target_type: 'model',
target_id: 'mdl-001',
details: null,
ip_address: '192.168.1.1',
created_at: '2026-03-29T14:00:00Z',
},
],
total: 3,
page: 1,
page_size: 20,
}
// ── MSW server ───────────────────────────────────────────────
const server = setupServer()
beforeEach(() => {
server.listen({ onUnhandledRequest: 'bypass' })
})
afterEach(() => {
server.close()
})
// ── Helper: render with QueryClient ──────────────────────────
function renderWithProviders(ui: React.ReactElement) {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
})
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>,
)
}
// ── Tests ────────────────────────────────────────────────────
describe('Logs page', () => {
it('renders page header', async () => {
server.use(
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Logs />)
expect(screen.getByText('操作日志')).toBeInTheDocument()
expect(screen.getByText('系统审计与操作记录')).toBeInTheDocument()
})
it('fetches and displays log entries', async () => {
server.use(
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Logs />)
// Wait for action labels rendered from constants/status.ts
await waitFor(() => {
expect(screen.getByText('登录')).toBeInTheDocument()
})
expect(screen.getByText('创建服务商')).toBeInTheDocument()
expect(screen.getByText('删除模型')).toBeInTheDocument()
})
it('shows loading spinner while fetching', async () => {
server.use(
http.get('*/api/v1/logs/operations', async () => {
await new Promise((resolve) => setTimeout(resolve, 500))
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Logs />)
// Ant Design Spin component renders a .ant-spin element
const spinner = document.querySelector('.ant-spin')
expect(spinner).toBeTruthy()
// Wait for loading to complete so afterEach cleanup is clean
await waitFor(() => {
expect(screen.getByText('登录')).toBeInTheDocument()
})
})
it('shows ErrorState on API failure with retry button', async () => {
server.use(
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(
{ error: 'internal_error', message: '服务器内部错误' },
{ status: 500 },
)
}),
)
renderWithProviders(<Logs />)
// ErrorState renders the error message
await waitFor(() => {
expect(screen.getByText('服务器内部错误')).toBeInTheDocument()
})
// Ant Design Button splits two-character text with a space: "重 试"
const retryButton = screen.getByRole('button', { name: /重.?试/ })
expect(retryButton).toBeInTheDocument()
})
it('renders action as a colored tag', async () => {
server.use(
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Logs />)
await waitFor(() => {
expect(screen.getByText('登录')).toBeInTheDocument()
})
// Verify the action tags have the correct Ant Design color classes
const loginTag = screen.getByText('登录').closest('.ant-tag')
expect(loginTag).toBeTruthy()
// actionColors.login = 'green' → Ant Design renders ant-tag-green or ant-tag-color-green
expect(loginTag?.className).toMatch(/green/)
})
it('renders IP address column', async () => {
server.use(
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Logs />)
await waitFor(() => {
expect(screen.getByText('登录')).toBeInTheDocument()
})
// 192.168.1.1 appears twice (two log entries from the same IP)
const ip1Elements = screen.getAllByText('192.168.1.1')
expect(ip1Elements.length).toBeGreaterThanOrEqual(1)
expect(screen.getByText('10.0.0.1')).toBeInTheDocument()
})
it('renders target_type column', async () => {
server.use(
http.get('*/api/v1/logs/operations', () => {
return HttpResponse.json(mockLogs)
}),
)
renderWithProviders(<Logs />)
await waitFor(() => {
expect(screen.getByText('登录')).toBeInTheDocument()
})
expect(screen.getByText('account')).toBeInTheDocument()
expect(screen.getByText('provider')).toBeInTheDocument()
expect(screen.getByText('model')).toBeInTheDocument()
})
})

View File

@@ -0,0 +1,184 @@
// ============================================================
// ModelServices 页面测试
// ============================================================
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
import { render, screen, waitFor } from '@testing-library/react'
import { http, HttpResponse } from 'msw'
import { setupServer } from 'msw/node'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import ModelServices from '@/pages/ModelServices'
// ── Mock data ────────────────────────────────────────────────
const mockProviders = {
items: [
{
id: 'prov-001',
name: 'openai',
display_name: 'OpenAI',
base_url: 'https://api.openai.com/v1',
api_protocol: 'openai',
enabled: true,
rate_limit_rpm: 500,
rate_limit_tpm: null,
created_at: '2026-01-01T00:00:00Z',
updated_at: '2026-03-15T10:00:00Z',
},
{
id: 'prov-002',
name: 'anthropic',
display_name: 'Anthropic',
base_url: 'https://api.anthropic.com',
api_protocol: 'anthropic',
enabled: false,
rate_limit_rpm: 200,
rate_limit_tpm: null,
created_at: '2026-02-01T00:00:00Z',
updated_at: '2026-03-01T00:00:00Z',
},
{
id: 'prov-003',
name: 'deepseek',
display_name: 'DeepSeek',
base_url: 'https://api.deepseek.com/v1',
api_protocol: 'openai',
enabled: true,
rate_limit_rpm: null,
rate_limit_tpm: null,
created_at: '2026-03-01T00:00:00Z',
updated_at: '2026-03-01T00:00:00Z',
},
],
total: 3,
page: 1,
page_size: 20,
}
// ── MSW server ───────────────────────────────────────────────
const server = setupServer()
beforeEach(() => {
server.listen({ onUnhandledRequest: 'bypass' })
})
afterEach(() => {
server.close()
})
// ── Helper: render with QueryClient ──────────────────────────
function renderWithProviders(ui: React.ReactElement) {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
})
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>,
)
}
// ── Tests ────────────────────────────────────────────────────
describe('ModelServices page', () => {
it('renders page header', async () => {
server.use(
http.get('*/api/v1/providers', () => {
return HttpResponse.json(mockProviders)
}),
)
renderWithProviders(<ModelServices />)
expect(screen.getByText('模型服务')).toBeInTheDocument()
expect(screen.getByText('管理 AI 服务商、模型配置和 Key 池')).toBeInTheDocument()
})
it('fetches and displays providers', async () => {
server.use(
http.get('*/api/v1/providers', () => {
return HttpResponse.json(mockProviders)
}),
)
renderWithProviders(<ModelServices />)
await waitFor(() => {
expect(screen.getByText('OpenAI')).toBeInTheDocument()
})
expect(screen.getByText('Anthropic')).toBeInTheDocument()
expect(screen.getByText('DeepSeek')).toBeInTheDocument()
// Provider identifiers rendered as code
// openai also appears in base_url, so use getAllByText
expect(screen.getAllByText('openai').length).toBeGreaterThanOrEqual(1)
expect(screen.getAllByText('anthropic').length).toBeGreaterThanOrEqual(1)
expect(screen.getAllByText('deepseek').length).toBeGreaterThanOrEqual(1)
})
it('shows loading spinner before data arrives', async () => {
server.use(
http.get('*/api/v1/providers', async () => {
await new Promise((resolve) => setTimeout(resolve, 500))
return HttpResponse.json(mockProviders)
}),
)
renderWithProviders(<ModelServices />)
const spinner = document.querySelector('.ant-spin')
expect(spinner).toBeTruthy()
// Wait for loading to complete so afterEach cleanup is clean
await waitFor(() => {
expect(screen.getByText('OpenAI')).toBeInTheDocument()
})
})
it('renders provider status as tag', async () => {
server.use(
http.get('*/api/v1/providers', () => {
return HttpResponse.json(mockProviders)
}),
)
renderWithProviders(<ModelServices />)
await waitFor(() => {
expect(screen.getByText('OpenAI')).toBeInTheDocument()
})
// enabled: true -> "启用" tag, enabled: false -> "禁用" tag
const enabledTags = screen.getAllByText('启用')
expect(enabledTags.length).toBe(2) // openai + deepseek
expect(screen.getByText('禁用')).toBeInTheDocument() // anthropic
})
it('shows empty table on API failure', async () => {
server.use(
http.get('*/api/v1/providers', () => {
return HttpResponse.json(
{ error: 'internal_error', message: '获取服务商列表失败' },
{ status: 500 },
)
}),
)
renderWithProviders(<ModelServices />)
// Page header should still render
expect(screen.getByText('模型服务')).toBeInTheDocument()
// Provider names should NOT be rendered
await waitFor(() => {
expect(screen.queryByText('OpenAI')).not.toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,178 @@
// ============================================================
// Prompts 页面测试
// ============================================================
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
import { render, screen, waitFor } from '@testing-library/react'
import { http, HttpResponse } from 'msw'
import { setupServer } from 'msw/node'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import Prompts from '@/pages/Prompts'
// ── Mock data ────────────────────────────────────────────────
const mockPrompts = {
items: [
{
id: 'pt-001',
name: 'system-default',
category: 'system',
description: 'Default system prompt for all agents',
source: 'builtin' as const,
current_version: 3,
status: 'active' as const,
created_at: '2026-01-15T08:00:00Z',
updated_at: '2026-03-20T12:00:00Z',
},
{
id: 'pt-002',
name: 'custom-research',
category: 'tool',
description: 'Custom research prompt template',
source: 'custom' as const,
current_version: 1,
status: 'active' as const,
created_at: '2026-03-01T10:00:00Z',
updated_at: '2026-03-01T10:00:00Z',
},
{
id: 'pt-003',
name: 'legacy-summary',
category: 'system',
description: 'Legacy summary prompt',
source: 'builtin' as const,
current_version: 5,
status: 'archived' as const,
created_at: '2025-06-01T00:00:00Z',
updated_at: '2026-02-28T00:00:00Z',
},
],
total: 3,
page: 1,
page_size: 20,
}
// ── MSW server ───────────────────────────────────────────────
const server = setupServer()
beforeEach(() => {
server.listen({ onUnhandledRequest: 'bypass' })
})
afterEach(() => {
server.close()
})
// ── Helper: render with QueryClient ──────────────────────────
function renderWithProviders(ui: React.ReactElement) {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
})
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>,
)
}
// ── Tests ────────────────────────────────────────────────────
describe('Prompts page', () => {
it('renders page title and create button', async () => {
server.use(
http.get('*/api/v1/prompts', () => {
return HttpResponse.json(mockPrompts)
}),
)
renderWithProviders(<Prompts />)
expect(screen.getByText('提示词管理')).toBeInTheDocument()
expect(screen.getByText('管理系统提示词模板和版本历史')).toBeInTheDocument()
expect(screen.getByText('新建提示词')).toBeInTheDocument()
})
it('fetches and displays prompt templates', async () => {
server.use(
http.get('*/api/v1/prompts', () => {
return HttpResponse.json(mockPrompts)
}),
)
renderWithProviders(<Prompts />)
await waitFor(() => {
expect(screen.getByText('system-default')).toBeInTheDocument()
})
expect(screen.getByText('custom-research')).toBeInTheDocument()
expect(screen.getByText('legacy-summary')).toBeInTheDocument()
// Category "tool" appears once in data
expect(screen.getByText('tool')).toBeInTheDocument()
})
it('shows loading spinner before data arrives', async () => {
server.use(
http.get('*/api/v1/prompts', async () => {
await new Promise((resolve) => setTimeout(resolve, 500))
return HttpResponse.json(mockPrompts)
}),
)
renderWithProviders(<Prompts />)
const spinner = document.querySelector('.ant-spin')
expect(spinner).toBeTruthy()
// Wait for loading to complete so afterEach cleanup is clean
await waitFor(() => {
expect(screen.getByText('system-default')).toBeInTheDocument()
})
})
it('renders source as tag with correct labels', async () => {
server.use(
http.get('*/api/v1/prompts', () => {
return HttpResponse.json(mockPrompts)
}),
)
renderWithProviders(<Prompts />)
await waitFor(() => {
expect(screen.getByText('system-default')).toBeInTheDocument()
})
// sourceLabels: { builtin: '内置', custom: '自定义' }
// '内置' appears twice (2 builtin items), '自定义' appears once
const builtinTags = screen.getAllByText('内置')
expect(builtinTags.length).toBe(2)
expect(screen.getByText('自定义')).toBeInTheDocument()
})
it('shows error state on API failure', async () => {
server.use(
http.get('*/api/v1/prompts', () => {
return HttpResponse.json(
{ error: 'internal_error', message: '获取提示词列表失败' },
{ status: 500 },
)
}),
)
renderWithProviders(<Prompts />)
// React Query error propagation: ProTable receives empty data
// but the query error should be visible via the table state
// Check that no prompt names are rendered
await waitFor(() => {
expect(screen.queryByText('system-default')).not.toBeInTheDocument()
})
})
})

View File

@@ -0,0 +1,234 @@
// ============================================================
// Relay 页面测试
// ============================================================
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
import { render, screen, waitFor } from '@testing-library/react'
import { http, HttpResponse } from 'msw'
import { setupServer } from 'msw/node'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import Relay from '@/pages/Relay'
// ── Mock data ────────────────────────────────────────────────
const mockRelayTasks = {
items: [
{
id: 'task-001-abcdef',
account_id: 'acc-001',
provider_id: 'prov-001',
model_id: 'gpt-4o',
status: 'completed',
priority: 0,
attempt_count: 1,
max_attempts: 3,
input_tokens: 1500,
output_tokens: 800,
error_message: null,
queued_at: '2026-03-30T10:00:00Z',
started_at: '2026-03-30T10:00:01Z',
completed_at: '2026-03-30T10:00:05Z',
created_at: '2026-03-30T10:00:00Z',
},
{
id: 'task-002-ghijkl',
account_id: 'acc-002',
provider_id: 'prov-002',
model_id: 'claude-3.5-sonnet',
status: 'failed',
priority: 0,
attempt_count: 3,
max_attempts: 3,
input_tokens: 2000,
output_tokens: 0,
error_message: 'Rate limit exceeded',
queued_at: '2026-03-30T09:00:00Z',
started_at: '2026-03-30T09:00:01Z',
completed_at: '2026-03-30T09:01:00Z',
created_at: '2026-03-30T09:00:00Z',
},
{
id: 'task-003-mnopqr',
account_id: 'acc-001',
provider_id: 'prov-001',
model_id: 'gpt-4o-mini',
status: 'queued',
priority: 1,
attempt_count: 0,
max_attempts: 3,
input_tokens: 0,
output_tokens: 0,
error_message: null,
queued_at: '2026-03-30T11:00:00Z',
started_at: null,
completed_at: null,
created_at: '2026-03-30T11:00:00Z',
},
],
total: 3,
page: 1,
page_size: 20,
}
// ── MSW server ───────────────────────────────────────────────
const server = setupServer()
beforeEach(() => {
server.listen({ onUnhandledRequest: 'bypass' })
})
afterEach(() => {
server.close()
})
// ── Helper: render with QueryClient ──────────────────────────
function renderWithProviders(ui: React.ReactElement) {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
})
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>,
)
}
// ── Tests ────────────────────────────────────────────────────
describe('Relay page', () => {
it('renders page header', async () => {
server.use(
http.get('*/api/v1/relay/tasks', () => {
return HttpResponse.json(mockRelayTasks)
}),
)
renderWithProviders(<Relay />)
expect(screen.getByText('中转任务')).toBeInTheDocument()
expect(screen.getByText('查看和管理 AI 模型中转请求')).toBeInTheDocument()
})
it('fetches and displays relay tasks', async () => {
server.use(
http.get('*/api/v1/relay/tasks', () => {
return HttpResponse.json(mockRelayTasks)
}),
)
renderWithProviders(<Relay />)
await waitFor(() => {
expect(screen.getByText('已完成')).toBeInTheDocument()
})
expect(screen.getByText('失败')).toBeInTheDocument()
expect(screen.getByText('排队中')).toBeInTheDocument()
})
it('shows loading spinner while fetching', async () => {
server.use(
http.get('*/api/v1/relay/tasks', async () => {
await new Promise((resolve) => setTimeout(resolve, 500))
return HttpResponse.json(mockRelayTasks)
}),
)
renderWithProviders(<Relay />)
// Ant Design Spin component renders a .ant-spin element
const spinner = document.querySelector('.ant-spin')
expect(spinner).toBeTruthy()
// Wait for loading to complete so afterEach cleanup is clean
await waitFor(() => {
expect(screen.getByText('已完成')).toBeInTheDocument()
})
})
it('shows ErrorState on API failure with retry button', async () => {
server.use(
http.get('*/api/v1/relay/tasks', () => {
return HttpResponse.json(
{ error: 'internal_error', message: '服务器内部错误' },
{ status: 500 },
)
}),
)
renderWithProviders(<Relay />)
await waitFor(() => {
expect(screen.getByText('服务器内部错误')).toBeInTheDocument()
})
// Ant Design Button splits two-character text with a space: "重 试"
const retryButton = screen.getByRole('button', { name: /重.?试/ })
expect(retryButton).toBeInTheDocument()
})
it('renders status as colored tag', async () => {
server.use(
http.get('*/api/v1/relay/tasks', () => {
return HttpResponse.json(mockRelayTasks)
}),
)
renderWithProviders(<Relay />)
await waitFor(() => {
expect(screen.getByText('已完成')).toBeInTheDocument()
})
// Verify the status tags have correct Ant Design color classes
const completedTag = screen.getByText('已完成').closest('.ant-tag')
expect(completedTag).toBeTruthy()
// statusColors.completed = 'green'
expect(completedTag?.className).toMatch(/green/)
const failedTag = screen.getByText('失败').closest('.ant-tag')
expect(failedTag).toBeTruthy()
// statusColors.failed = 'red'
expect(failedTag?.className).toMatch(/red/)
})
it('renders model_id column', async () => {
server.use(
http.get('*/api/v1/relay/tasks', () => {
return HttpResponse.json(mockRelayTasks)
}),
)
renderWithProviders(<Relay />)
await waitFor(() => {
expect(screen.getByText('已完成')).toBeInTheDocument()
})
expect(screen.getByText('gpt-4o')).toBeInTheDocument()
expect(screen.getByText('claude-3.5-sonnet')).toBeInTheDocument()
expect(screen.getByText('gpt-4o-mini')).toBeInTheDocument()
})
it('renders token count column', async () => {
server.use(
http.get('*/api/v1/relay/tasks', () => {
return HttpResponse.json(mockRelayTasks)
}),
)
renderWithProviders(<Relay />)
await waitFor(() => {
expect(screen.getByText('已完成')).toBeInTheDocument()
})
// Token (入/出): 1,500 / 800
expect(screen.getByText(/1,500 \/ 800/)).toBeInTheDocument()
// 2,000 / 0
expect(screen.getByText(/2,000 \/ 0/)).toBeInTheDocument()
})
})

View File

@@ -0,0 +1,248 @@
// ============================================================
// Usage 页面测试
// ============================================================
import { describe, it, expect, beforeEach, afterEach } from 'vitest'
import { render, screen, waitFor } from '@testing-library/react'
import { http, HttpResponse } from 'msw'
import { setupServer } from 'msw/node'
import { QueryClient, QueryClientProvider } from '@tanstack/react-query'
import Usage from '@/pages/Usage'
// ── Mock data ────────────────────────────────────────────────
const mockDailyStats = [
{
day: '2026-03-28',
request_count: 120,
input_tokens: 24000,
output_tokens: 8000,
unique_devices: 5,
},
{
day: '2026-03-29',
request_count: 80,
input_tokens: 16000,
output_tokens: 5000,
unique_devices: 3,
},
{
day: '2026-03-30',
request_count: 200,
input_tokens: 40000,
output_tokens: 12000,
unique_devices: 7,
},
]
const mockModelStats = [
{
model_id: 'gpt-4o',
request_count: 300,
input_tokens: 60000,
output_tokens: 18000,
avg_latency_ms: 450.3,
success_rate: 0.98,
},
{
model_id: 'claude-sonnet-4-20250514',
request_count: 100,
input_tokens: 20000,
output_tokens: 7000,
avg_latency_ms: 620.7,
success_rate: 0.95,
},
]
// ── MSW server ───────────────────────────────────────────────
const server = setupServer()
beforeEach(() => {
server.listen({ onUnhandledRequest: 'bypass' })
})
afterEach(() => {
server.close()
})
// ── Helper: render with QueryClient ──────────────────────────
function renderWithProviders(ui: React.ReactElement) {
const queryClient = new QueryClient({
defaultOptions: {
queries: { retry: false },
},
})
return render(
<QueryClientProvider client={queryClient}>
{ui}
</QueryClientProvider>,
)
}
// ── Tests ────────────────────────────────────────────────────
describe('Usage page', () => {
it('renders page title and summary cards', async () => {
server.use(
http.get('*/api/v1/telemetry/daily', () => {
return HttpResponse.json(mockDailyStats)
}),
http.get('*/api/v1/telemetry/stats', () => {
return HttpResponse.json(mockModelStats)
}),
)
renderWithProviders(<Usage />)
expect(screen.getByText('用量统计')).toBeInTheDocument()
expect(screen.getByText('查看模型使用情况和 Token 消耗')).toBeInTheDocument()
// Summary card titles
expect(screen.getByText('总请求数')).toBeInTheDocument()
expect(screen.getByText('总 Token 数')).toBeInTheDocument()
// Total requests: 120 + 80 + 200 = 400
await waitFor(() => {
expect(screen.getByText('400')).toBeInTheDocument()
})
// Total tokens: (24000+8000) + (16000+5000) + (40000+12000) = 105,000
expect(screen.getByText('105,000')).toBeInTheDocument()
})
it('fetches and displays daily stats table', async () => {
server.use(
http.get('*/api/v1/telemetry/daily', () => {
return HttpResponse.json(mockDailyStats)
}),
http.get('*/api/v1/telemetry/stats', () => {
return HttpResponse.json(mockModelStats)
}),
)
renderWithProviders(<Usage />)
// Table column headers
expect(screen.getByText('每日统计')).toBeInTheDocument()
// Wait for data rows to render
await waitFor(() => {
expect(screen.getByText('2026-03-28')).toBeInTheDocument()
})
// Formatted request counts
expect(screen.getByText('120')).toBeInTheDocument()
expect(screen.getByText('80')).toBeInTheDocument()
expect(screen.getByText('200')).toBeInTheDocument()
// Device counts
expect(screen.getByText('5')).toBeInTheDocument()
})
it('fetches and displays model stats table', async () => {
server.use(
http.get('*/api/v1/telemetry/daily', () => {
return HttpResponse.json(mockDailyStats)
}),
http.get('*/api/v1/telemetry/stats', () => {
return HttpResponse.json(mockModelStats)
}),
)
renderWithProviders(<Usage />)
expect(screen.getByText('按模型统计')).toBeInTheDocument()
await waitFor(() => {
expect(screen.getByText('gpt-4o')).toBeInTheDocument()
})
expect(screen.getByText('claude-sonnet-4-20250514')).toBeInTheDocument()
// Success rate: 0.98 -> "98.0%"
expect(screen.getByText('98.0%')).toBeInTheDocument()
// Avg latency: 450.3 -> "450ms"
expect(screen.getByText('450ms')).toBeInTheDocument()
})
it('shows loading spinner before data loads', async () => {
server.use(
http.get('*/api/v1/telemetry/daily', async () => {
await new Promise((resolve) => setTimeout(resolve, 500))
return HttpResponse.json(mockDailyStats)
}),
http.get('*/api/v1/telemetry/stats', async () => {
await new Promise((resolve) => setTimeout(resolve, 500))
return HttpResponse.json(mockModelStats)
}),
)
renderWithProviders(<Usage />)
// Ant Design Spin component renders a .ant-spin element
const spinner = document.querySelector('.ant-spin')
expect(spinner).toBeTruthy()
// Wait for loading to complete so afterEach cleanup is clean
await waitFor(() => {
expect(screen.getByText('用量统计')).toBeInTheDocument()
})
})
it('shows ErrorState when daily stats request fails', async () => {
server.use(
http.get('*/api/v1/telemetry/daily', () => {
return HttpResponse.json(
{ error: 'internal_error', message: '服务器内部错误' },
{ status: 500 },
)
}),
http.get('*/api/v1/telemetry/stats', () => {
return HttpResponse.json(mockModelStats)
}),
)
renderWithProviders(<Usage />)
await waitFor(() => {
expect(screen.getByText('服务器内部错误')).toBeInTheDocument()
})
// ErrorState renders a retry button (antd v6 may split Chinese characters)
expect(screen.getByRole('button', { name: /重.*试/ })).toBeInTheDocument()
})
it('calculates totals correctly from daily data', async () => {
server.use(
http.get('*/api/v1/telemetry/daily', () => {
return HttpResponse.json([
{
day: '2026-03-30',
request_count: 1500,
input_tokens: 10000,
output_tokens: 3000,
unique_devices: 2,
},
])
}),
http.get('*/api/v1/telemetry/stats', () => {
return HttpResponse.json([])
}),
)
renderWithProviders(<Usage />)
// Total requests: 1500 (formatted as "1,500" by Statistic)
await waitFor(() => {
const elements = screen.getAllByText('1,500')
expect(elements.length).toBeGreaterThanOrEqual(1)
})
// Total tokens: 10000 + 3000 = 13,000
expect(screen.getAllByText('13,000').length).toBeGreaterThanOrEqual(1)
})
})

View File

@@ -31,4 +31,5 @@ jobs = [
{ name = "cleanup_rate_limit", interval = "5m", task = "cleanup_rate_limit", run_on_start = false },
{ name = "cleanup_refresh_tokens", interval = "1h", task = "cleanup_refresh_tokens", run_on_start = false },
{ name = "cleanup_devices", interval = "24h", task = "cleanup_devices", run_on_start = false },
{ name = "aggregate_usage", interval = "1h", task = "aggregate_usage", run_on_start = true, args = { account_id = null } },
]

View File

@@ -216,9 +216,10 @@ impl QueryAnalyzer {
expansions
}
/// Get synonyms for a keyword (simplified)
/// Get synonyms for a keyword (simplified, English + Chinese)
fn get_synonyms(&self, keyword: &str) -> Option<Vec<String>> {
let synonyms: &[&str] = match keyword {
// English synonyms
"code" => &["program", "script", "source"],
"error" => &["bug", "issue", "problem", "exception"],
"fix" => &["solve", "resolve", "repair", "patch"],
@@ -226,6 +227,20 @@ impl QueryAnalyzer {
"slow" => &["performance", "optimize", "speed"],
"help" => &["assist", "support", "guide", "aid"],
"learn" => &["study", "understand", "know", "grasp"],
// Chinese synonyms — critical for Chinese-language queries
"错误" => &["问题", "bug", "异常", "故障"],
"修复" => &["解决", "修正", "处理", "fix"],
"优化" => &["改进", "提升", "加速", "improve"],
"配置" => &["设置", "参数", "选项", "config"],
"性能" => &["速度", "效率", "performance"],
"问题" => &["错误", "故障", "issue", "problem"],
"帮助" => &["协助", "支持", "help"],
"学习" => &["了解", "掌握", "learn"],
"代码" => &["程序", "脚本", "code"],
"数据库" => &["DB", "database", "存储"],
"部署" => &["发布", "上线", "deploy"],
"测试" => &["验证", "检验", "test"],
"安全" => &["防护", "加密", "security"],
_ => return None,
};

View File

@@ -111,7 +111,7 @@ impl HandResult {
}
/// Hand execution status
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HandStatus {
Idle,

View File

@@ -134,7 +134,7 @@ impl BrowserHand {
id: "browser".to_string(),
name: "浏览器".to_string(),
description: "网页浏览器自动化,支持导航、交互和数据采集".to_string(),
needs_approval: false,
needs_approval: true,
dependencies: vec!["webdriver".to_string()],
input_schema: Some(serde_json::json!({
"type": "object",
@@ -420,8 +420,211 @@ impl BrowserSequence {
self
}
/// Set whether to stop on error
pub fn stop_on_error(mut self, stop: bool) -> Self {
self.stop_on_error = stop;
self
}
/// Build the sequence
pub fn build(self) -> Vec<BrowserAction> {
self.steps
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Hand;
use std::collections::HashMap;
fn fresh_context() -> HandContext {
HandContext {
agent_id: zclaw_types::AgentId::new(),
working_dir: None,
env: HashMap::new(),
timeout_secs: 30,
callback_url: None,
}
}
#[test]
fn test_browser_config() {
let hand = BrowserHand::new();
let config = hand.config();
assert_eq!(config.id, "browser");
assert!(config.enabled);
}
#[tokio::test]
async fn test_browser_config_needs_approval() {
let hand = BrowserHand::new();
assert!(hand.config().needs_approval, "Browser hand should require approval per TOML config");
}
#[test]
fn test_action_deserialize_navigate() {
let json = serde_json::json!({
"action": "navigate",
"url": "https://example.com",
"wait_for": "body"
});
let action: BrowserAction = serde_json::from_value(json).expect("deserialize navigate");
match action {
BrowserAction::Navigate { url, wait_for } => {
assert_eq!(url, "https://example.com");
assert_eq!(wait_for, Some("body".to_string()));
}
_ => panic!("Expected Navigate action, got {:?}", action),
}
}
#[test]
fn test_action_deserialize_click() {
let json = serde_json::json!({
"action": "click",
"selector": "#submit-btn",
"wait_ms": 500
});
let action: BrowserAction = serde_json::from_value(json).expect("deserialize click");
match action {
BrowserAction::Click { selector, wait_ms } => {
assert_eq!(selector, "#submit-btn");
assert_eq!(wait_ms, Some(500));
}
_ => panic!("Expected Click action, got {:?}", action),
}
}
#[test]
fn test_action_deserialize_type() {
let json = serde_json::json!({
"action": "type",
"selector": "#search",
"text": "hello world",
"clear_first": true
});
let action: BrowserAction = serde_json::from_value(json).expect("deserialize type");
match action {
BrowserAction::Type { selector, text, clear_first } => {
assert_eq!(selector, "#search");
assert_eq!(text, "hello world");
assert!(clear_first);
}
_ => panic!("Expected Type action, got {:?}", action),
}
}
#[test]
fn test_action_deserialize_scrape() {
let json = serde_json::json!({
"action": "scrape",
"selectors": ["h1", ".content", "#price"]
});
let action: BrowserAction = serde_json::from_value(json).expect("deserialize scrape");
match action {
BrowserAction::Scrape { selectors, wait_for } => {
assert_eq!(selectors, vec!["h1", ".content", "#price"]);
assert!(wait_for.is_none());
}
_ => panic!("Expected Scrape action, got {:?}", action),
}
}
#[test]
fn test_action_deserialize_screenshot() {
let json = serde_json::json!({
"action": "screenshot",
"full_page": true
});
let action: BrowserAction = serde_json::from_value(json).expect("deserialize screenshot");
match action {
BrowserAction::Screenshot { selector, full_page } => {
assert!(selector.is_none());
assert!(full_page);
}
_ => panic!("Expected Screenshot action, got {:?}", action),
}
}
#[test]
fn test_all_major_actions_roundtrip() {
let actions = vec![
BrowserAction::Navigate { url: "https://example.com".into(), wait_for: None },
BrowserAction::Click { selector: "#btn".into(), wait_ms: None },
BrowserAction::Type { selector: "#input".into(), text: "test".into(), clear_first: false },
BrowserAction::Scrape { selectors: vec!["h1".into()], wait_for: None },
BrowserAction::Screenshot { selector: None, full_page: false },
BrowserAction::Wait { selector: "#loaded".into(), timeout_ms: 5000 },
BrowserAction::Execute { script: "return 1".into(), args: vec![] },
BrowserAction::FillForm {
fields: vec![FormField { selector: "#name".into(), value: "Alice".into() }],
submit_selector: Some("#submit".into()),
},
];
for original in actions {
let json = serde_json::to_value(&original).expect("serialize action");
let roundtripped: BrowserAction = serde_json::from_value(json).expect("deserialize action");
assert_eq!(
serde_json::to_value(&original).unwrap(),
serde_json::to_value(&roundtripped).unwrap(),
"Roundtrip failed for {:?}",
original
);
}
}
#[tokio::test]
async fn test_browser_sequence_builder() {
let ctx = fresh_context();
let hand = BrowserHand::new();
let sequence = BrowserSequence::new("test_sequence")
.navigate("https://example.com")
.stop_on_error(false);
assert_eq!(sequence.name, "test_sequence");
assert!(!sequence.stop_on_error);
assert_eq!(sequence.steps.len(), 1);
// Execute the navigate step
let action_json = serde_json::to_value(&sequence.steps[0]).expect("serialize step");
let result = hand.execute(&ctx, action_json).await.expect("execute");
assert!(result.success);
assert_eq!(result.output["action"], "navigate");
assert_eq!(result.output["url"], "https://example.com");
}
#[tokio::test]
async fn test_browser_sequence_multiple_steps() {
let ctx = fresh_context();
let hand = BrowserHand::new();
let sequence = BrowserSequence::new("multi_step")
.navigate("https://example.com")
.click("#login-btn")
.type_text("#username", "admin")
.screenshot();
assert_eq!(sequence.steps.len(), 4);
// Verify each step can execute
for (i, step) in sequence.steps.iter().enumerate() {
let action_json = serde_json::to_value(step).expect("serialize step");
let result = hand.execute(&ctx, action_json).await.expect("execute step");
assert!(result.success, "Step {} failed: {:?}", i, result.error);
}
}
#[test]
fn test_form_field_deserialize() {
let json = serde_json::json!({
"selector": "#email",
"value": "user@example.com"
});
let field: FormField = serde_json::from_value(json).expect("deserialize form field");
assert_eq!(field.selector, "#email");
assert_eq!(field.value, "user@example.com");
}
}

View File

@@ -640,3 +640,390 @@ impl Hand for ClipHand {
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
// === Config & Defaults ===
#[test]
fn test_hand_config() {
let hand = ClipHand::new();
assert_eq!(hand.config().id, "clip");
assert_eq!(hand.config().name, "视频剪辑");
assert!(!hand.config().needs_approval);
assert!(hand.config().enabled);
assert!(hand.config().tags.contains(&"video".to_string()));
assert!(hand.config().input_schema.is_some());
}
#[test]
fn test_default_impl() {
let hand = ClipHand::default();
assert_eq!(hand.config().id, "clip");
}
#[test]
fn test_needs_approval() {
let hand = ClipHand::new();
assert!(!hand.needs_approval());
}
#[test]
fn test_check_dependencies() {
let hand = ClipHand::new();
let deps = hand.check_dependencies().unwrap();
// May or may not find ffmpeg depending on test environment
// Just verify it doesn't panic
let _ = deps;
}
// === VideoFormat ===
#[test]
fn test_video_format_default() {
assert!(matches!(VideoFormat::default(), VideoFormat::Mp4));
}
#[test]
fn test_video_format_deserialize() {
let fmt: VideoFormat = serde_json::from_value(json!("mp4")).unwrap();
assert!(matches!(fmt, VideoFormat::Mp4));
let fmt: VideoFormat = serde_json::from_value(json!("webm")).unwrap();
assert!(matches!(fmt, VideoFormat::Webm));
let fmt: VideoFormat = serde_json::from_value(json!("gif")).unwrap();
assert!(matches!(fmt, VideoFormat::Gif));
}
#[test]
fn test_video_format_serialize() {
assert_eq!(serde_json::to_value(&VideoFormat::Mp4).unwrap(), "mp4");
assert_eq!(serde_json::to_value(&VideoFormat::Webm).unwrap(), "webm");
}
// === Resolution ===
#[test]
fn test_resolution_default() {
assert!(matches!(Resolution::default(), Resolution::Original));
}
#[test]
fn test_resolution_presets() {
let r: Resolution = serde_json::from_value(json!("p720")).unwrap();
assert!(matches!(r, Resolution::P720));
let r: Resolution = serde_json::from_value(json!("p1080")).unwrap();
assert!(matches!(r, Resolution::P1080));
let r: Resolution = serde_json::from_value(json!("p4k")).unwrap();
assert!(matches!(r, Resolution::P4k));
}
#[test]
fn test_resolution_custom() {
let r: Resolution = serde_json::from_value(json!({"custom": {"width": 800, "height": 600}})).unwrap();
match r {
Resolution::Custom { width, height } => {
assert_eq!(width, 800);
assert_eq!(height, 600);
}
_ => panic!("Expected Custom"),
}
}
#[test]
fn test_resolution_serialize() {
assert_eq!(serde_json::to_value(&Resolution::P720).unwrap(), "p720");
assert_eq!(serde_json::to_value(&Resolution::Original).unwrap(), "original");
}
// === TrimConfig ===
#[test]
fn test_trim_config_deserialize() {
let config: TrimConfig = serde_json::from_value(json!({
"inputPath": "/input.mp4",
"outputPath": "/output.mp4",
"startTime": 5.0,
"duration": 10.0
})).unwrap();
assert_eq!(config.input_path, "/input.mp4");
assert_eq!(config.output_path, "/output.mp4");
assert_eq!(config.start_time, Some(5.0));
assert_eq!(config.duration, Some(10.0));
assert!(config.end_time.is_none());
}
#[test]
fn test_trim_config_minimal() {
let config: TrimConfig = serde_json::from_value(json!({
"inputPath": "/in.mp4",
"outputPath": "/out.mp4"
})).unwrap();
assert!(config.start_time.is_none());
assert!(config.end_time.is_none());
assert!(config.duration.is_none());
}
// === ConvertConfig ===
#[test]
fn test_convert_config_deserialize() {
let config: ConvertConfig = serde_json::from_value(json!({
"inputPath": "/input.avi",
"outputPath": "/output.mp4",
"format": "mp4",
"resolution": "p1080",
"videoBitrate": "4M",
"audioBitrate": "192k"
})).unwrap();
assert_eq!(config.input_path, "/input.avi");
assert!(matches!(config.format, VideoFormat::Mp4));
assert!(matches!(config.resolution, Resolution::P1080));
assert_eq!(config.video_bitrate, Some("4M".to_string()));
assert_eq!(config.audio_bitrate, Some("192k".to_string()));
}
#[test]
fn test_convert_config_defaults() {
let config: ConvertConfig = serde_json::from_value(json!({
"inputPath": "/in.mp4",
"outputPath": "/out.mp4"
})).unwrap();
assert!(matches!(config.format, VideoFormat::Mp4));
assert!(matches!(config.resolution, Resolution::Original));
assert!(config.video_bitrate.is_none());
assert!(config.audio_bitrate.is_none());
}
// === ThumbnailConfig ===
#[test]
fn test_thumbnail_config_deserialize() {
let config: ThumbnailConfig = serde_json::from_value(json!({
"inputPath": "/video.mp4",
"outputPath": "/thumb.jpg",
"time": 5.0,
"width": 320,
"height": 240
})).unwrap();
assert_eq!(config.input_path, "/video.mp4");
assert_eq!(config.time, 5.0);
assert_eq!(config.width, Some(320));
assert_eq!(config.height, Some(240));
}
#[test]
fn test_thumbnail_config_defaults() {
let config: ThumbnailConfig = serde_json::from_value(json!({
"inputPath": "/v.mp4",
"outputPath": "/t.jpg"
})).unwrap();
assert_eq!(config.time, 0.0);
assert!(config.width.is_none());
assert!(config.height.is_none());
}
// === ConcatConfig ===
#[test]
fn test_concat_config_deserialize() {
let config: ConcatConfig = serde_json::from_value(json!({
"inputPaths": ["/a.mp4", "/b.mp4"],
"outputPath": "/merged.mp4"
})).unwrap();
assert_eq!(config.input_paths.len(), 2);
assert_eq!(config.output_path, "/merged.mp4");
}
// === VideoInfo ===
#[test]
fn test_video_info_deserialize() {
let info: VideoInfo = serde_json::from_value(json!({
"path": "/test.mp4",
"durationSecs": 120.5,
"width": 1920,
"height": 1080,
"fps": 30.0,
"format": "mp4",
"videoCodec": "h264",
"audioCodec": "aac",
"bitrateKbps": 5000,
"fileSizeBytes": 75_000_000
})).unwrap();
assert_eq!(info.path, "/test.mp4");
assert_eq!(info.duration_secs, 120.5);
assert_eq!(info.width, 1920);
assert_eq!(info.fps, 30.0);
assert_eq!(info.video_codec, "h264");
assert_eq!(info.audio_codec, Some("aac".to_string()));
assert_eq!(info.bitrate_kbps, Some(5000));
assert_eq!(info.file_size_bytes, 75_000_000);
}
// === ClipAction Deserialization ===
#[test]
fn test_action_trim() {
let action: ClipAction = serde_json::from_value(json!({
"action": "trim",
"config": {
"inputPath": "/in.mp4",
"outputPath": "/out.mp4",
"startTime": 1.0,
"endTime": 5.0
}
})).unwrap();
match action {
ClipAction::Trim { config } => {
assert_eq!(config.input_path, "/in.mp4");
assert_eq!(config.start_time, Some(1.0));
}
_ => panic!("Expected Trim"),
}
}
#[test]
fn test_action_convert() {
let action: ClipAction = serde_json::from_value(json!({
"action": "convert",
"config": {
"inputPath": "/in.avi",
"outputPath": "/out.mp4"
}
})).unwrap();
assert!(matches!(action, ClipAction::Convert { .. }));
}
#[test]
fn test_action_resize() {
let action: ClipAction = serde_json::from_value(json!({
"action": "resize",
"input_path": "/in.mp4",
"output_path": "/out.mp4",
"resolution": "p720"
})).unwrap();
match action {
ClipAction::Resize { input_path, resolution, .. } => {
assert_eq!(input_path, "/in.mp4");
assert!(matches!(resolution, Resolution::P720));
}
_ => panic!("Expected Resize"),
}
}
#[test]
fn test_action_thumbnail() {
let action: ClipAction = serde_json::from_value(json!({
"action": "thumbnail",
"config": {
"inputPath": "/in.mp4",
"outputPath": "/thumb.jpg"
}
})).unwrap();
assert!(matches!(action, ClipAction::Thumbnail { .. }));
}
#[test]
fn test_action_concat() {
let action: ClipAction = serde_json::from_value(json!({
"action": "concat",
"config": {
"inputPaths": ["/a.mp4", "/b.mp4"],
"outputPath": "/out.mp4"
}
})).unwrap();
assert!(matches!(action, ClipAction::Concat { .. }));
}
#[test]
fn test_action_info() {
let action: ClipAction = serde_json::from_value(json!({
"action": "info",
"path": "/video.mp4"
})).unwrap();
match action {
ClipAction::Info { path } => assert_eq!(path, "/video.mp4"),
_ => panic!("Expected Info"),
}
}
#[test]
fn test_action_check_ffmpeg() {
let action: ClipAction = serde_json::from_value(json!({"action": "check_ffmpeg"})).unwrap();
assert!(matches!(action, ClipAction::CheckFfmpeg));
}
#[test]
fn test_action_invalid() {
let result = serde_json::from_value::<ClipAction>(json!({"action": "nonexistent"}));
assert!(result.is_err());
}
// === Hand execute dispatch ===
#[tokio::test]
async fn test_execute_check_ffmpeg() {
let hand = ClipHand::new();
let ctx = HandContext::default();
let result = hand.execute(&ctx, json!({"action": "check_ffmpeg"})).await.unwrap();
// Just verify it doesn't crash and returns a valid result
assert!(result.output.is_object());
// "available" field should exist
assert!(result.output["available"].is_boolean());
}
#[tokio::test]
async fn test_execute_invalid_action() {
let hand = ClipHand::new();
let ctx = HandContext::default();
let result = hand.execute(&ctx, json!({"action": "bogus"})).await;
assert!(result.is_err());
}
// === Status ===
#[test]
fn test_status() {
let hand = ClipHand::new();
let status = hand.status();
// Either Idle (ffmpeg found) or Failed (not found) — just verify it doesn't panic
assert!(matches!(status, crate::HandStatus::Idle | crate::HandStatus::Failed));
}
// === Roundtrip ===
#[test]
fn test_trim_action_roundtrip() {
let json = json!({
"action": "trim",
"config": {
"inputPath": "/in.mp4",
"outputPath": "/out.mp4",
"startTime": 2.0,
"duration": 5.0
}
});
let action: ClipAction = serde_json::from_value(json).unwrap();
let serialized = serde_json::to_value(&action).unwrap();
assert_eq!(serialized["action"], "trim");
assert_eq!(serialized["config"]["inputPath"], "/in.mp4");
assert_eq!(serialized["config"]["startTime"], 2.0);
assert_eq!(serialized["config"]["duration"], 5.0);
}
#[test]
fn test_info_action_roundtrip() {
let json = json!({"action": "info", "path": "/video.mp4"});
let action: ClipAction = serde_json::from_value(json).unwrap();
let serialized = serde_json::to_value(&action).unwrap();
assert_eq!(serialized["action"], "info");
assert_eq!(serialized["path"], "/video.mp4");
}
}

View File

@@ -13,7 +13,7 @@ use zclaw_types::Result;
use crate::{Hand, HandConfig, HandContext, HandResult};
/// Output format options
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum OutputFormat {
Json,
@@ -234,16 +234,37 @@ impl CollectorHand {
self.extract_visible_text(html)
}
/// Extract visible text from HTML
/// Extract visible text from HTML, stripping scripts and styles
fn extract_visible_text(&self, html: &str) -> String {
let html_lower = html.to_lowercase();
let mut text = String::new();
let mut in_tag = false;
let mut in_script = false;
let mut in_style = false;
let mut pos: usize = 0;
for c in html.chars() {
let char_len = c.len_utf8();
match c {
'<' => in_tag = true,
'>' => in_tag = false,
'<' => {
let remaining = &html_lower[pos..];
if remaining.starts_with("</script") {
in_script = false;
} else if remaining.starts_with("</style") {
in_style = false;
}
if remaining.starts_with("<script") {
in_script = true;
} else if remaining.starts_with("<style") {
in_style = true;
}
in_tag = true;
}
'>' => {
in_tag = false;
}
_ if in_tag => {}
_ if in_script || in_style => {}
' ' | '\n' | '\t' | '\r' => {
if !text.ends_with(' ') && !text.is_empty() {
text.push(' ');
@@ -251,11 +272,11 @@ impl CollectorHand {
}
_ => text.push(c),
}
pos += char_len;
}
// Limit length
if text.len() > 500 {
text.truncate(500);
if text.len() > 10000 {
text.truncate(10000);
text.push_str("...");
}
@@ -407,3 +428,166 @@ impl Hand for CollectorHand {
crate::HandStatus::Idle
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_collector_config() {
let hand = CollectorHand::new();
assert_eq!(hand.config().id, "collector");
assert_eq!(hand.config().name, "数据采集器");
assert!(hand.config().enabled);
assert!(!hand.config().needs_approval);
}
#[test]
fn test_output_format_serialize() {
let formats = vec![
(OutputFormat::Csv, "\"csv\""),
(OutputFormat::Markdown, "\"markdown\""),
(OutputFormat::Json, "\"json\""),
(OutputFormat::Text, "\"text\""),
];
for (fmt, expected) in formats {
let serialized = serde_json::to_string(&fmt).unwrap();
assert_eq!(serialized, expected);
}
// Verify round-trip deserialization
for json_str in &["\"csv\"", "\"markdown\"", "\"json\"", "\"text\""] {
let deserialized: OutputFormat = serde_json::from_str(json_str).unwrap();
let re_serialized = serde_json::to_string(&deserialized).unwrap();
assert_eq!(&re_serialized, json_str);
}
}
#[test]
fn test_extract_visible_text_basic() {
let hand = CollectorHand::new();
let html = "<html><body><h1>Title</h1><p>Content here</p></body></html>";
let text = hand.extract_visible_text(html);
assert!(text.contains("Title"), "should contain 'Title', got: {}", text);
assert!(text.contains("Content here"), "should contain 'Content here', got: {}", text);
}
#[test]
fn test_extract_visible_text_strips_scripts() {
let hand = CollectorHand::new();
let html = "<html><body><script>alert('xss')</script><p>Safe content</p></body></html>";
let text = hand.extract_visible_text(html);
assert!(!text.contains("alert"), "script content should be removed, got: {}", text);
assert!(!text.contains("xss"), "script content should be removed, got: {}", text);
assert!(text.contains("Safe content"), "visible content should remain, got: {}", text);
}
#[test]
fn test_extract_visible_text_strips_styles() {
let hand = CollectorHand::new();
let html = "<html><head><style>body { color: red; }</style></head><body><p>Text</p></body></html>";
let text = hand.extract_visible_text(html);
assert!(!text.contains("color"), "style content should be removed, got: {}", text);
assert!(!text.contains("red"), "style content should be removed, got: {}", text);
assert!(text.contains("Text"), "visible content should remain, got: {}", text);
}
#[test]
fn test_extract_visible_text_empty() {
let hand = CollectorHand::new();
let text = hand.extract_visible_text("");
assert!(text.is_empty(), "empty HTML should produce empty text, got: '{}'", text);
}
#[tokio::test]
async fn test_aggregate_action_empty_urls() {
let hand = CollectorHand::new();
let config = AggregationConfig {
urls: vec![],
aggregate_fields: vec![],
};
let result = hand.execute_aggregate(&config).await.unwrap();
let results = result.get("results").unwrap().as_array().unwrap();
assert_eq!(results.len(), 0, "empty URLs should produce empty results");
assert_eq!(result.get("source_count").unwrap().as_u64().unwrap(), 0);
}
#[test]
fn test_collector_action_deserialize() {
// Collect action
let collect_json = json!({
"action": "collect",
"target": {
"url": "https://example.com",
"selector": ".article",
"fields": { "title": "h1" },
"maxItems": 10
},
"format": "markdown"
});
let action: CollectorAction = serde_json::from_value(collect_json).unwrap();
match action {
CollectorAction::Collect { target, format } => {
assert_eq!(target.url, "https://example.com");
assert_eq!(target.selector.as_deref(), Some(".article"));
assert_eq!(target.max_items, 10);
assert!(format.is_some());
assert_eq!(format.unwrap(), OutputFormat::Markdown);
}
_ => panic!("Expected Collect action"),
}
// Aggregate action
let aggregate_json = json!({
"action": "aggregate",
"config": {
"urls": ["https://a.com", "https://b.com"],
"aggregateFields": ["title", "content"]
}
});
let action: CollectorAction = serde_json::from_value(aggregate_json).unwrap();
match action {
CollectorAction::Aggregate { config } => {
assert_eq!(config.urls.len(), 2);
assert_eq!(config.aggregate_fields.len(), 2);
}
_ => panic!("Expected Aggregate action"),
}
// Extract action
let extract_json = json!({
"action": "extract",
"url": "https://example.com",
"selectors": { "title": "h1", "body": "p" }
});
let action: CollectorAction = serde_json::from_value(extract_json).unwrap();
match action {
CollectorAction::Extract { url, selectors } => {
assert_eq!(url, "https://example.com");
assert_eq!(selectors.len(), 2);
}
_ => panic!("Expected Extract action"),
}
}
#[test]
fn test_collection_target_deserialize() {
let json = json!({
"url": "https://example.com/page",
"selector": ".content",
"fields": {
"title": "h1",
"author": ".author-name"
},
"maxItems": 50
});
let target: CollectionTarget = serde_json::from_value(json).unwrap();
assert_eq!(target.url, "https://example.com/page");
assert_eq!(target.selector.as_deref(), Some(".content"));
assert_eq!(target.fields.len(), 2);
assert_eq!(target.max_items, 50);
}
}

View File

@@ -344,31 +344,34 @@ impl ResearcherHand {
/// Extract readable text from HTML
fn extract_text_from_html(&self, html: &str) -> String {
// Simple text extraction - remove HTML tags
let html_lower = html.to_lowercase();
let mut text = String::new();
let mut in_tag = false;
let mut in_script = false;
let mut in_style = false;
let mut pos: usize = 0;
for c in html.chars() {
let char_len = c.len_utf8();
match c {
'<' => {
in_tag = true;
let remaining = html[text.len()..].to_lowercase();
// Check for closing tags before entering tag mode
let remaining = &html_lower[pos..];
if remaining.starts_with("</script") {
in_script = false;
} else if remaining.starts_with("</style") {
in_style = false;
}
// Check for opening tags
if remaining.starts_with("<script") {
in_script = true;
} else if remaining.starts_with("<style") {
in_style = true;
}
in_tag = true;
}
'>' => {
in_tag = false;
let remaining = html[text.len()..].to_lowercase();
if remaining.starts_with("</script>") {
in_script = false;
} else if remaining.starts_with("</style>") {
in_style = false;
}
}
_ if in_tag => {}
_ if in_script || in_style => {}
@@ -379,9 +382,9 @@ impl ResearcherHand {
}
_ => text.push(c),
}
pos += char_len;
}
// Limit length
if text.len() > 10000 {
text.truncate(10000);
text.push_str("...");
@@ -445,10 +448,33 @@ impl ResearcherHand {
let duration = start.elapsed().as_millis() as u64;
// Generate summary from top results
let summary = if results.is_empty() {
"未找到相关结果,建议调整搜索关键词后重试".to_string()
} else {
let top_snippets: Vec<&str> = results
.iter()
.take(3)
.filter_map(|r| {
let s = r.snippet.trim();
if s.is_empty() { None } else { Some(s) }
})
.collect();
if top_snippets.is_empty() {
format!("找到 {} 条相关结果,但无摘要信息", results.len())
} else {
format!(
"基于 {} 条搜索结果:{}",
results.len(),
top_snippets.join("")
)
}
};
Ok(ResearchReport {
query: query.query.clone(),
results,
summary: None, // Would require LLM integration
summary: Some(summary),
key_findings,
related_topics,
researched_at: chrono::Utc::now().to_rfc3339(),
@@ -543,3 +569,276 @@ fn url_encode(s: &str) -> String {
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_hand() -> ResearcherHand {
ResearcherHand::new()
}
fn test_context() -> HandContext {
HandContext::default()
}
// --- Config & Type Tests ---
#[test]
fn test_config_id() {
let hand = create_test_hand();
assert_eq!(hand.config().id, "researcher");
assert_eq!(hand.config().name, "研究员");
assert!(hand.config().enabled);
assert!(!hand.config().needs_approval);
}
#[test]
fn test_search_engine_default_is_auto() {
let engine = SearchEngine::default();
assert!(matches!(engine, SearchEngine::Auto));
}
#[test]
fn test_research_depth_default_is_standard() {
let depth = ResearchDepth::default();
assert!(matches!(depth, ResearchDepth::Standard));
}
#[test]
fn test_research_depth_serialize() {
let json = serde_json::to_string(&ResearchDepth::Deep).unwrap();
assert_eq!(json, "\"deep\"");
}
#[test]
fn test_research_depth_deserialize() {
let depth: ResearchDepth = serde_json::from_str("\"quick\"").unwrap();
assert!(matches!(depth, ResearchDepth::Quick));
}
#[test]
fn test_search_engine_serialize_roundtrip() {
for engine in [SearchEngine::Google, SearchEngine::Bing, SearchEngine::DuckDuckGo, SearchEngine::Auto] {
let json = serde_json::to_string(&engine).unwrap();
let back: SearchEngine = serde_json::from_str(&json).unwrap();
assert_eq!(json, serde_json::to_string(&back).unwrap());
}
}
// --- Action Deserialization Tests ---
#[test]
fn test_action_search_deserialize() {
let json = json!({
"action": "search",
"query": {
"query": "Rust programming",
"engine": "duckduckgo",
"depth": "quick",
"maxResults": 5
}
});
let action: ResearcherAction = serde_json::from_value(json).unwrap();
match action {
ResearcherAction::Search { query } => {
assert_eq!(query.query, "Rust programming");
assert!(matches!(query.engine, SearchEngine::DuckDuckGo));
assert!(matches!(query.depth, ResearchDepth::Quick));
assert_eq!(query.max_results, 5);
}
_ => panic!("Expected Search action"),
}
}
#[test]
fn test_action_fetch_deserialize() {
let json = json!({
"action": "fetch",
"url": "https://example.com/page"
});
let action: ResearcherAction = serde_json::from_value(json).unwrap();
match action {
ResearcherAction::Fetch { url } => {
assert_eq!(url, "https://example.com/page");
}
_ => panic!("Expected Fetch action"),
}
}
#[test]
fn test_action_report_deserialize() {
let json = json!({
"action": "report",
"query": {
"query": "AI trends 2026",
"depth": "deep"
}
});
let action: ResearcherAction = serde_json::from_value(json).unwrap();
match action {
ResearcherAction::Report { query } => {
assert_eq!(query.query, "AI trends 2026");
assert!(matches!(query.depth, ResearchDepth::Deep));
}
_ => panic!("Expected Report action"),
}
}
#[test]
fn test_action_invalid_rejected() {
let json = json!({
"action": "unknown_action",
"data": "whatever"
});
let result: std::result::Result<ResearcherAction, _> = serde_json::from_value(json);
assert!(result.is_err());
}
// --- URL Encoding Tests ---
#[test]
fn test_url_encode_ascii() {
assert_eq!(url_encode("hello world"), "hello%20world");
}
#[test]
fn test_url_encode_chinese() {
let encoded = url_encode("中文搜索");
assert!(encoded.contains("%"));
// Chinese chars should be percent-encoded
assert!(!encoded.contains("中文"));
}
#[test]
fn test_url_encode_safe_chars() {
assert_eq!(url_encode("abc123-_."), "abc123-_.".to_string());
}
#[test]
fn test_url_encode_empty() {
assert_eq!(url_encode(""), "");
}
// --- HTML Text Extraction Tests ---
#[test]
fn test_extract_text_basic() {
let hand = create_test_hand();
let html = "<html><body><h1>Title</h1><p>Content here</p></body></html>";
let text = hand.extract_text_from_html(html);
assert!(text.contains("Title"));
assert!(text.contains("Content here"));
}
#[test]
fn test_extract_text_strips_scripts() {
let hand = create_test_hand();
let html = "<html><body><script>alert('xss')</script><p>Safe text</p></body></html>";
let text = hand.extract_text_from_html(html);
assert!(!text.contains("alert"));
assert!(text.contains("Safe text"));
}
#[test]
fn test_extract_text_strips_styles() {
let hand = create_test_hand();
let html = "<html><body><style>.class{color:red}</style><p>Visible</p></body></html>";
let text = hand.extract_text_from_html(html);
assert!(!text.contains("color"));
assert!(text.contains("Visible"));
}
#[test]
fn test_extract_text_truncates_long_content() {
let hand = create_test_hand();
let long_body: String = "x".repeat(20000);
let html = format!("<html><body><p>{}</p></body></html>", long_body);
let text = hand.extract_text_from_html(&html);
assert!(text.len() <= 10003); // 10000 + "..."
}
#[test]
fn test_extract_text_empty_body() {
let hand = create_test_hand();
let html = "<html><body></body></html>";
let text = hand.extract_text_from_html(html);
assert!(text.is_empty());
}
// --- Hand Trait Tests ---
#[tokio::test]
async fn test_needs_approval_is_false() {
let hand = create_test_hand();
assert!(!hand.needs_approval());
}
#[tokio::test]
async fn test_status_is_idle() {
let hand = create_test_hand();
assert!(matches!(hand.status(), crate::HandStatus::Idle));
}
#[tokio::test]
async fn test_check_dependencies_ok() {
let hand = create_test_hand();
let missing = hand.check_dependencies().unwrap();
// Default is_dependency_available returns true for all
assert!(missing.is_empty());
}
// --- Default Values Tests ---
#[test]
fn test_research_query_defaults() {
let json = json!({ "query": "test" });
let query: ResearchQuery = serde_json::from_value(json).unwrap();
assert_eq!(query.query, "test");
assert!(matches!(query.engine, SearchEngine::Auto));
assert!(matches!(query.depth, ResearchDepth::Standard));
assert_eq!(query.max_results, 10);
assert_eq!(query.time_limit_secs, 60);
assert!(!query.include_related);
}
#[test]
fn test_search_result_serialization() {
let result = SearchResult {
title: "Test".to_string(),
url: "https://example.com".to_string(),
snippet: "A snippet".to_string(),
source: "TestSource".to_string(),
relevance: 90,
content: None,
fetched_at: None,
};
let json = serde_json::to_string(&result).unwrap();
assert!(json.contains("Test"));
assert!(json.contains("https://example.com"));
}
#[test]
fn test_research_report_summary_is_some_when_results() {
// Verify the struct allows Some value
let report = ResearchReport {
query: "test".to_string(),
results: vec![SearchResult {
title: "R".to_string(),
url: "https://r.co".to_string(),
snippet: "snippet text".to_string(),
source: "S".to_string(),
relevance: 80,
content: None,
fetched_at: None,
}],
summary: Some("基于 1 条搜索结果snippet text".to_string()),
key_findings: vec![],
related_topics: vec![],
researched_at: "2026-01-01T00:00:00Z".to_string(),
duration_ms: 100,
};
assert!(report.summary.is_some());
assert!(report.summary.unwrap().contains("snippet text"));
}
}

View File

@@ -346,13 +346,50 @@ impl Hand for SlideshowHand {
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
// === Config & Defaults ===
#[tokio::test]
async fn test_slideshow_creation() {
let hand = SlideshowHand::new();
assert_eq!(hand.config().id, "slideshow");
assert_eq!(hand.config().name, "幻灯片");
assert!(!hand.config().needs_approval);
assert!(hand.config().enabled);
assert!(hand.config().tags.contains(&"presentation".to_string()));
}
#[test]
fn test_default_impl() {
let hand = SlideshowHand::default();
assert_eq!(hand.config().id, "slideshow");
}
#[test]
fn test_needs_approval() {
let hand = SlideshowHand::new();
assert!(!hand.needs_approval());
}
#[test]
fn test_status() {
let hand = SlideshowHand::new();
assert_eq!(hand.status(), HandStatus::Idle);
}
#[test]
fn test_default_state() {
let state = SlideshowState::default();
assert_eq!(state.current_slide, 0);
assert_eq!(state.total_slides, 0);
assert!(!state.is_playing);
assert_eq!(state.auto_play_interval_ms, 5000);
assert!(state.slides.is_empty());
}
// === Navigation ===
#[tokio::test]
async fn test_navigation() {
let hand = SlideshowHand::with_slides_async(vec![
@@ -374,6 +411,53 @@ mod tests {
assert_eq!(hand.get_state().await.current_slide, 1);
}
#[tokio::test]
async fn test_next_slide_at_end() {
let hand = SlideshowHand::with_slides_async(vec![
SlideContent { title: "Only Slide".to_string(), subtitle: None, content: vec![], notes: None, background: None },
]).await;
// At slide 0, should not advance past last slide
hand.execute_action(SlideshowAction::NextSlide).await.unwrap();
assert_eq!(hand.get_state().await.current_slide, 0);
}
#[tokio::test]
async fn test_prev_slide_at_beginning() {
let hand = SlideshowHand::with_slides_async(vec![
SlideContent { title: "Slide 1".to_string(), subtitle: None, content: vec![], notes: None, background: None },
SlideContent { title: "Slide 2".to_string(), subtitle: None, content: vec![], notes: None, background: None },
]).await;
// At slide 0, should not go below 0
hand.execute_action(SlideshowAction::PrevSlide).await.unwrap();
assert_eq!(hand.get_state().await.current_slide, 0);
}
#[tokio::test]
async fn test_goto_slide_out_of_range() {
let hand = SlideshowHand::with_slides_async(vec![
SlideContent { title: "Slide 1".to_string(), subtitle: None, content: vec![], notes: None, background: None },
]).await;
let result = hand.execute_action(SlideshowAction::GotoSlide { slide_number: 5 }).await.unwrap();
assert!(!result.success);
}
#[tokio::test]
async fn test_goto_slide_returns_content() {
let hand = SlideshowHand::with_slides_async(vec![
SlideContent { title: "First".to_string(), subtitle: None, content: vec![], notes: None, background: None },
SlideContent { title: "Second".to_string(), subtitle: None, content: vec![], notes: None, background: None },
]).await;
let result = hand.execute_action(SlideshowAction::GotoSlide { slide_number: 1 }).await.unwrap();
assert!(result.success);
assert_eq!(result.output["slide_content"]["title"], "Second");
}
// === Spotlight & Laser & Highlight ===
#[tokio::test]
async fn test_spotlight() {
let hand = SlideshowHand::new();
@@ -384,6 +468,20 @@ mod tests {
let result = hand.execute_action(action).await.unwrap();
assert!(result.success);
assert_eq!(result.output["element_id"], "title");
assert_eq!(result.output["duration_ms"], 2000);
}
#[tokio::test]
async fn test_spotlight_default_duration() {
let hand = SlideshowHand::new();
let action = SlideshowAction::Spotlight {
element_id: "elem".to_string(),
duration_ms: default_spotlight_duration(),
};
let result = hand.execute_action(action).await.unwrap();
assert_eq!(result.output["duration_ms"], 2000);
}
#[tokio::test]
@@ -397,8 +495,96 @@ mod tests {
let result = hand.execute_action(action).await.unwrap();
assert!(result.success);
assert_eq!(result.output["x"], 100.0);
assert_eq!(result.output["y"], 200.0);
}
#[tokio::test]
async fn test_highlight_default_color() {
let hand = SlideshowHand::new();
let action = SlideshowAction::Highlight {
x: 10.0, y: 20.0, width: 100.0, height: 50.0,
color: None, duration_ms: 2000,
};
let result = hand.execute_action(action).await.unwrap();
assert!(result.success);
assert_eq!(result.output["color"], "#ffcc00");
}
#[tokio::test]
async fn test_highlight_custom_color() {
let hand = SlideshowHand::new();
let action = SlideshowAction::Highlight {
x: 0.0, y: 0.0, width: 50.0, height: 50.0,
color: Some("#ff0000".to_string()), duration_ms: 1000,
};
let result = hand.execute_action(action).await.unwrap();
assert_eq!(result.output["color"], "#ff0000");
}
// === AutoPlay / Pause / Resume ===
#[tokio::test]
async fn test_autoplay_pause_resume() {
let hand = SlideshowHand::new();
// AutoPlay
let result = hand.execute_action(SlideshowAction::AutoPlay { interval_ms: 3000 }).await.unwrap();
assert!(result.success);
assert!(hand.get_state().await.is_playing);
assert_eq!(hand.get_state().await.auto_play_interval_ms, 3000);
// Pause
hand.execute_action(SlideshowAction::Pause).await.unwrap();
assert!(!hand.get_state().await.is_playing);
// Resume
hand.execute_action(SlideshowAction::Resume).await.unwrap();
assert!(hand.get_state().await.is_playing);
// Stop
hand.execute_action(SlideshowAction::StopAutoPlay).await.unwrap();
assert!(!hand.get_state().await.is_playing);
}
#[tokio::test]
async fn test_autoplay_default_interval() {
let hand = SlideshowHand::new();
hand.execute_action(SlideshowAction::AutoPlay { interval_ms: default_interval() }).await.unwrap();
assert_eq!(hand.get_state().await.auto_play_interval_ms, 5000);
}
// === PlayAnimation ===
#[tokio::test]
async fn test_play_animation() {
let hand = SlideshowHand::new();
let result = hand.execute_action(SlideshowAction::PlayAnimation {
animation_id: "fade_in".to_string(),
}).await.unwrap();
assert!(result.success);
assert_eq!(result.output["animation_id"], "fade_in");
}
// === GetState ===
#[tokio::test]
async fn test_get_state() {
let hand = SlideshowHand::with_slides_async(vec![
SlideContent { title: "A".to_string(), subtitle: None, content: vec![], notes: None, background: None },
]).await;
let result = hand.execute_action(SlideshowAction::GetState).await.unwrap();
assert!(result.success);
assert_eq!(result.output["total_slides"], 1);
assert_eq!(result.output["current_slide"], 0);
}
// === SetContent ===
#[tokio::test]
async fn test_set_content() {
let hand = SlideshowHand::new();
@@ -421,5 +607,188 @@ mod tests {
assert!(result.success);
assert_eq!(hand.get_state().await.total_slides, 1);
assert_eq!(hand.get_state().await.slides[0].title, "Test Slide");
}
#[tokio::test]
async fn test_set_content_append() {
let hand = SlideshowHand::with_slides_async(vec![
SlideContent { title: "First".to_string(), subtitle: None, content: vec![], notes: None, background: None },
]).await;
let content = SlideContent {
title: "Appended".to_string(), subtitle: None, content: vec![], notes: None, background: None,
};
let result = hand.execute_action(SlideshowAction::SetContent {
slide_number: 1,
content,
}).await.unwrap();
assert!(result.success);
assert_eq!(result.output["status"], "slide_added");
assert_eq!(hand.get_state().await.total_slides, 2);
}
#[tokio::test]
async fn test_set_content_invalid_index() {
let hand = SlideshowHand::new();
let content = SlideContent {
title: "Gap".to_string(), subtitle: None, content: vec![], notes: None, background: None,
};
let result = hand.execute_action(SlideshowAction::SetContent {
slide_number: 5,
content,
}).await.unwrap();
assert!(!result.success);
}
// === Action Deserialization ===
#[test]
fn test_deserialize_next_slide() {
let action: SlideshowAction = serde_json::from_value(json!({"action": "next_slide"})).unwrap();
assert!(matches!(action, SlideshowAction::NextSlide));
}
#[test]
fn test_deserialize_goto_slide() {
let action: SlideshowAction = serde_json::from_value(json!({"action": "goto_slide", "slide_number": 3})).unwrap();
match action {
SlideshowAction::GotoSlide { slide_number } => assert_eq!(slide_number, 3),
_ => panic!("Expected GotoSlide"),
}
}
#[test]
fn test_deserialize_laser() {
let action: SlideshowAction = serde_json::from_value(json!({
"action": "laser", "x": 50.0, "y": 75.0
})).unwrap();
match action {
SlideshowAction::Laser { x, y, .. } => {
assert_eq!(x, 50.0);
assert_eq!(y, 75.0);
}
_ => panic!("Expected Laser"),
}
}
#[test]
fn test_deserialize_autoplay() {
let action: SlideshowAction = serde_json::from_value(json!({"action": "auto_play"})).unwrap();
match action {
SlideshowAction::AutoPlay { interval_ms } => assert_eq!(interval_ms, 5000),
_ => panic!("Expected AutoPlay"),
}
}
#[test]
fn test_deserialize_invalid_action() {
let result = serde_json::from_value::<SlideshowAction>(json!({"action": "nonexistent"}));
assert!(result.is_err());
}
// === ContentBlock Deserialization ===
#[test]
fn test_content_block_text() {
let block: ContentBlock = serde_json::from_value(json!({
"type": "text", "text": "Hello"
})).unwrap();
match block {
ContentBlock::Text { text, style } => {
assert_eq!(text, "Hello");
assert!(style.is_none());
}
_ => panic!("Expected Text"),
}
}
#[test]
fn test_content_block_list() {
let block: ContentBlock = serde_json::from_value(json!({
"type": "list", "items": ["A", "B"], "ordered": true
})).unwrap();
match block {
ContentBlock::List { items, ordered } => {
assert_eq!(items, vec!["A", "B"]);
assert!(ordered);
}
_ => panic!("Expected List"),
}
}
#[test]
fn test_content_block_code() {
let block: ContentBlock = serde_json::from_value(json!({
"type": "code", "code": "fn main() {}", "language": "rust"
})).unwrap();
match block {
ContentBlock::Code { code, language } => {
assert_eq!(code, "fn main() {}");
assert_eq!(language, Some("rust".to_string()));
}
_ => panic!("Expected Code"),
}
}
#[test]
fn test_content_block_table() {
let block: ContentBlock = serde_json::from_value(json!({
"type": "table",
"headers": ["Name", "Age"],
"rows": [["Alice", "30"]]
})).unwrap();
match block {
ContentBlock::Table { headers, rows } => {
assert_eq!(headers, vec!["Name", "Age"]);
assert_eq!(rows, vec![vec!["Alice", "30"]]);
}
_ => panic!("Expected Table"),
}
}
// === Hand trait via execute ===
#[tokio::test]
async fn test_hand_execute_dispatch() {
let hand = SlideshowHand::with_slides_async(vec![
SlideContent { title: "S1".to_string(), subtitle: None, content: vec![], notes: None, background: None },
SlideContent { title: "S2".to_string(), subtitle: None, content: vec![], notes: None, background: None },
]).await;
let ctx = HandContext::default();
let result = hand.execute(&ctx, json!({"action": "next_slide"})).await.unwrap();
assert!(result.success);
assert_eq!(result.output["current_slide"], 1);
}
#[tokio::test]
async fn test_hand_execute_invalid_action() {
let hand = SlideshowHand::new();
let ctx = HandContext::default();
let result = hand.execute(&ctx, json!({"action": "invalid"})).await.unwrap();
assert!(!result.success);
}
// === add_slide helper ===
#[tokio::test]
async fn test_add_slide() {
let hand = SlideshowHand::new();
hand.add_slide(SlideContent {
title: "Dynamic".to_string(), subtitle: None, content: vec![], notes: None, background: None,
}).await;
hand.add_slide(SlideContent {
title: "Dynamic 2".to_string(), subtitle: None, content: vec![], notes: None, background: None,
}).await;
let state = hand.get_state().await;
assert_eq!(state.total_slides, 2);
assert_eq!(state.slides.len(), 2);
}
}

View File

@@ -823,3 +823,417 @@ impl Hand for TwitterHand {
crate::HandStatus::Idle
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Hand;
use zclaw_types::id::AgentId;
fn make_context() -> HandContext {
HandContext {
agent_id: AgentId::new(),
working_dir: None,
env: std::collections::HashMap::new(),
timeout_secs: 30,
callback_url: None,
}
}
// === Config & Defaults ===
#[test]
fn test_hand_config() {
let hand = TwitterHand::new();
assert_eq!(hand.config().id, "twitter");
assert_eq!(hand.config().name, "Twitter 自动化");
assert!(hand.config().needs_approval);
assert!(hand.config().enabled);
assert!(hand.config().tags.contains(&"twitter".to_string()));
assert!(hand.config().input_schema.is_some());
}
#[test]
fn test_default_impl() {
let hand = TwitterHand::default();
assert_eq!(hand.config().id, "twitter");
}
#[test]
fn test_needs_approval() {
let hand = TwitterHand::new();
assert!(hand.needs_approval());
}
#[test]
fn test_status() {
let hand = TwitterHand::new();
assert_eq!(hand.status(), crate::HandStatus::Idle);
}
#[test]
fn test_check_dependencies() {
let hand = TwitterHand::new();
let deps = hand.check_dependencies().unwrap();
assert!(!deps.is_empty());
}
// === Action Deserialization ===
#[test]
fn test_tweet_action_deserialize() {
let json = json!({
"action": "tweet",
"config": {
"text": "Hello world!"
}
});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::Tweet { config } => {
assert_eq!(config.text, "Hello world!");
assert!(config.media_urls.is_empty());
assert!(config.reply_to.is_none());
assert!(config.quote_tweet.is_none());
assert!(config.poll.is_none());
}
_ => panic!("Expected Tweet action"),
}
}
#[test]
fn test_tweet_action_with_reply() {
let json = json!({
"action": "tweet",
"config": {
"text": "@user reply",
"replyTo": "123456"
}
});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::Tweet { config } => {
assert_eq!(config.reply_to.as_deref(), Some("123456"));
}
_ => panic!("Expected Tweet action"),
}
}
#[test]
fn test_tweet_action_with_poll() {
let json = json!({
"action": "tweet",
"config": {
"text": "Vote!",
"poll": {
"options": ["A", "B", "C"],
"durationMinutes": 60
}
}
});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::Tweet { config } => {
let poll = config.poll.unwrap();
assert_eq!(poll.options, vec!["A", "B", "C"]);
assert_eq!(poll.duration_minutes, 60);
}
_ => panic!("Expected Tweet action"),
}
}
#[test]
fn test_delete_tweet_action() {
let json = json!({"action": "delete_tweet", "tweet_id": "789"});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::DeleteTweet { tweet_id } => assert_eq!(tweet_id, "789"),
_ => panic!("Expected DeleteTweet"),
}
}
#[test]
fn test_like_unlike_actions() {
let like: TwitterAction = serde_json::from_value(json!({"action": "like", "tweet_id": "111"})).unwrap();
match like {
TwitterAction::Like { tweet_id } => assert_eq!(tweet_id, "111"),
_ => panic!("Expected Like"),
}
let unlike: TwitterAction = serde_json::from_value(json!({"action": "unlike", "tweet_id": "111"})).unwrap();
match unlike {
TwitterAction::Unlike { tweet_id } => assert_eq!(tweet_id, "111"),
_ => panic!("Expected Unlike"),
}
}
#[test]
fn test_retweet_unretweet_actions() {
let rt: TwitterAction = serde_json::from_value(json!({"action": "retweet", "tweet_id": "222"})).unwrap();
match rt {
TwitterAction::Retweet { tweet_id } => assert_eq!(tweet_id, "222"),
_ => panic!("Expected Retweet"),
}
let unrt: TwitterAction = serde_json::from_value(json!({"action": "unretweet", "tweet_id": "222"})).unwrap();
match unrt {
TwitterAction::Unretweet { tweet_id } => assert_eq!(tweet_id, "222"),
_ => panic!("Expected Unretweet"),
}
}
#[test]
fn test_search_action_defaults() {
let json = json!({"action": "search", "config": {"query": "rust lang"}});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::Search { config } => {
assert_eq!(config.query, "rust lang");
assert_eq!(config.max_results, 10); // default
assert!(config.next_token.is_none());
}
_ => panic!("Expected Search"),
}
}
#[test]
fn test_search_action_custom_max() {
let json = json!({"action": "search", "config": {"query": "test", "maxResults": 50}});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::Search { config } => assert_eq!(config.max_results, 50),
_ => panic!("Expected Search"),
}
}
#[test]
fn test_timeline_action_defaults() {
let json = json!({"action": "timeline", "config": {}});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::Timeline { config } => {
assert!(config.user_id.is_none());
assert_eq!(config.max_results, 10); // default
assert!(!config.exclude_replies);
assert!(config.include_retweets);
}
_ => panic!("Expected Timeline"),
}
}
#[test]
fn test_get_tweet_action() {
let json = json!({"action": "get_tweet", "tweet_id": "999"});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::GetTweet { tweet_id } => assert_eq!(tweet_id, "999"),
_ => panic!("Expected GetTweet"),
}
}
#[test]
fn test_get_user_action() {
let json = json!({"action": "get_user", "username": "elonmusk"});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::GetUser { username } => assert_eq!(username, "elonmusk"),
_ => panic!("Expected GetUser"),
}
}
#[test]
fn test_followers_action() {
let json = json!({"action": "followers", "user_id": "u1", "max_results": 50});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::Followers { user_id, max_results } => {
assert_eq!(user_id, "u1");
assert_eq!(max_results, Some(50));
}
_ => panic!("Expected Followers"),
}
}
#[test]
fn test_following_action_no_max() {
let json = json!({"action": "following", "user_id": "u2"});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::Following { user_id, max_results } => {
assert_eq!(user_id, "u2");
assert!(max_results.is_none());
}
_ => panic!("Expected Following"),
}
}
#[test]
fn test_check_credentials_action() {
let json = json!({"action": "check_credentials"});
let action: TwitterAction = serde_json::from_value(json).unwrap();
match action {
TwitterAction::CheckCredentials => {}
_ => panic!("Expected CheckCredentials"),
}
}
#[test]
fn test_invalid_action() {
let json = json!({"action": "invalid_action"});
let result = serde_json::from_value::<TwitterAction>(json);
assert!(result.is_err());
}
// === Serialization Roundtrip ===
#[test]
fn test_tweet_action_roundtrip() {
let json = json!({
"action": "tweet",
"config": {
"text": "Test tweet",
"mediaUrls": ["https://example.com/img.jpg"],
"replyTo": "123",
"quoteTweet": "456"
}
});
let action: TwitterAction = serde_json::from_value(json).unwrap();
let serialized = serde_json::to_value(&action).unwrap();
// Verify core fields survive roundtrip (camelCase via serde rename)
assert_eq!(serialized["action"], "tweet");
assert_eq!(serialized["config"]["text"], "Test tweet");
assert_eq!(serialized["config"]["mediaUrls"][0], "https://example.com/img.jpg");
assert_eq!(serialized["config"]["replyTo"], "123");
assert_eq!(serialized["config"]["quoteTweet"], "456");
}
#[test]
fn test_search_action_roundtrip() {
let json = json!({
"action": "search",
"config": {
"query": "hello world",
"maxResults": 25
}
});
let action: TwitterAction = serde_json::from_value(json).unwrap();
let serialized = serde_json::to_value(&action).unwrap();
assert_eq!(serialized["action"], "search");
assert_eq!(serialized["config"]["query"], "hello world");
assert_eq!(serialized["config"]["maxResults"], 25);
}
// === Credentials ===
#[tokio::test]
async fn test_set_and_get_credentials() {
let hand = TwitterHand::new();
// Initially no credentials
assert!(hand.get_credentials().await.is_none());
// Set credentials
hand.set_credentials(TwitterCredentials {
api_key: "key".into(),
api_secret: "secret".into(),
access_token: "token".into(),
access_token_secret: "token_secret".into(),
bearer_token: Some("bearer".into()),
}).await;
let creds = hand.get_credentials().await.unwrap();
assert_eq!(creds.api_key, "key");
assert_eq!(creds.bearer_token.as_deref(), Some("bearer"));
}
#[tokio::test]
async fn test_check_credentials_without_config() {
let hand = TwitterHand::new();
let ctx = make_context();
let result = hand.execute(&ctx, json!({"action": "check_credentials"})).await.unwrap();
// No "success" field in output → HandResult.success defaults to false
assert!(!result.success);
assert_eq!(result.output["configured"], false);
}
#[tokio::test]
async fn test_check_credentials_with_config() {
let hand = TwitterHand::new();
hand.set_credentials(TwitterCredentials {
api_key: "key".into(),
api_secret: "secret".into(),
access_token: "token".into(),
access_token_secret: "token_secret".into(),
bearer_token: Some("bearer".into()),
}).await;
let ctx = make_context();
let result = hand.execute(&ctx, json!({"action": "check_credentials"})).await.unwrap();
// execute_check_credentials returns {"configured": true, ...} without "success" field
// HandResult.success = result["success"].as_bool().unwrap_or(false) = false
// But the actual data is in output
assert_eq!(result.output["configured"], true);
assert_eq!(result.output["has_bearer_token"], true);
}
// === Tweet Data Types ===
#[test]
fn test_tweet_deserialize() {
let json = json!({
"id": "t123",
"text": "Hello!",
"authorId": "a456",
"authorName": "Test User",
"authorUsername": "testuser",
"createdAt": "2026-01-01T00:00:00Z",
"publicMetrics": {
"retweetCount": 5,
"replyCount": 2,
"likeCount": 10,
"quoteCount": 1,
"impressionCount": 1000
}
});
let tweet: Tweet = serde_json::from_value(json).unwrap();
assert_eq!(tweet.id, "t123");
assert_eq!(tweet.public_metrics.like_count, 10);
assert!(tweet.media.is_empty());
}
#[test]
fn test_twitter_user_deserialize() {
let json = json!({
"id": "u1",
"name": "Alice",
"username": "alice",
"verified": true,
"publicMetrics": {
"followersCount": 100,
"followingCount": 50,
"tweetCount": 1000,
"listedCount": 5
}
});
let user: TwitterUser = serde_json::from_value(json).unwrap();
assert_eq!(user.username, "alice");
assert!(user.verified);
assert_eq!(user.public_metrics.followers_count, 100);
}
#[test]
fn test_media_info_deserialize() {
let json = json!({
"mediaKey": "mk1",
"mediaType": "photo",
"url": "https://pbs.example.com/photo.jpg",
"width": 1200,
"height": 800
});
let media: MediaInfo = serde_json::from_value(json).unwrap();
assert_eq!(media.media_type, "photo");
assert_eq!(media.width, 1200);
}
}

View File

@@ -71,4 +71,19 @@ CREATE INDEX IF NOT EXISTS idx_kv_agent ON kv_store(agent_id);
CREATE INDEX IF NOT EXISTS idx_hand_runs_hand ON hand_runs(hand_name);
CREATE INDEX IF NOT EXISTS idx_hand_runs_status ON hand_runs(status);
CREATE INDEX IF NOT EXISTS idx_hand_runs_created ON hand_runs(created_at);
-- Structured facts table (extracted from conversations)
CREATE TABLE IF NOT EXISTS facts (
id TEXT PRIMARY KEY,
agent_id TEXT NOT NULL,
content TEXT NOT NULL,
category TEXT NOT NULL,
confidence REAL NOT NULL,
source_session TEXT,
created_at INTEGER NOT NULL,
FOREIGN KEY (agent_id) REFERENCES agents(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_facts_agent ON facts(agent_id);
CREATE INDEX IF NOT EXISTS idx_facts_category ON facts(agent_id, category);
CREATE INDEX IF NOT EXISTS idx_facts_confidence ON facts(agent_id, confidence DESC);
"#;

View File

@@ -482,6 +482,76 @@ impl MemoryStore {
Ok(count as u32)
}
// === Fact CRUD ===
/// Store extracted facts for an agent (upsert by id).
pub async fn store_facts(&self, agent_id: &str, facts: &[crate::fact::Fact]) -> Result<()> {
for fact in facts {
let category_str = serde_json::to_string(&fact.category)
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
// Trim the JSON quotes from serialized enum variant
let category_clean = category_str.trim_matches('"');
sqlx::query(
r#"
INSERT INTO facts (id, agent_id, content, category, confidence, source_session, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
content = excluded.content,
category = excluded.category,
confidence = excluded.confidence,
source_session = excluded.source_session
"#,
)
.bind(&fact.id)
.bind(agent_id)
.bind(&fact.content)
.bind(category_clean)
.bind(fact.confidence)
.bind(&fact.source)
.bind(fact.created_at as i64)
.execute(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
}
Ok(())
}
/// Get top facts for an agent, ordered by confidence descending.
pub async fn get_top_facts(&self, agent_id: &str, limit: usize) -> Result<Vec<crate::fact::Fact>> {
let rows = sqlx::query_as::<_, (String, String, String, f64, Option<String>, i64)>(
r#"
SELECT id, content, category, confidence, source_session, created_at
FROM facts
WHERE agent_id = ?
ORDER BY confidence DESC
LIMIT ?
"#,
)
.bind(agent_id)
.bind(limit as i64)
.fetch_all(&self.pool)
.await
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
let mut facts = Vec::with_capacity(rows.len());
for (id, content, category_str, confidence, source, created_at) in rows {
let category: crate::fact::FactCategory = serde_json::from_value(
serde_json::Value::String(category_str)
).map_err(|e| ZclawError::StorageError(format!("Invalid category: {}", e)))?;
facts.push(crate::fact::Fact {
id,
content,
category,
confidence,
created_at: created_at as u64,
source,
});
}
Ok(facts)
}
fn row_to_hand_run(
row: (String, String, String, String, String, Option<String>, Option<String>, Option<i64>, String, Option<String>, Option<String>),
) -> Result<HandRun> {
@@ -527,10 +597,13 @@ mod tests {
description: None,
model: ModelConfig::default(),
system_prompt: None,
soul: None,
capabilities: vec![],
tools: vec![],
max_tokens: None,
temperature: None,
workspace: None,
compaction_threshold: None,
enabled: true,
}
}

View File

@@ -60,6 +60,22 @@ impl AgentMiddleware for MemoryMiddleware {
fn priority(&self) -> i32 { 150 }
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
// Skip memory injection for very short queries.
// Short queries (e.g., "1+6", "hi", "好") don't benefit from memory context.
// Worse, the retriever's scope-based fallback may return high-importance but
// irrelevant old memories, causing the model to think about past conversations
// instead of answering the current question.
// Use char count (not byte count) so CJK queries are handled correctly:
// a single Chinese char is 3 UTF-8 bytes but 1 meaningful character.
let query = ctx.user_input.trim();
if query.chars().count() < 2 {
tracing::debug!(
"[MemoryMiddleware] Skipping enhancement for short query ({:?}): no memory context needed",
query
);
return Ok(MiddlewareDecision::Continue);
}
match self.growth.enhance_prompt(
&ctx.agent_id,
&ctx.system_prompt,
@@ -92,21 +108,27 @@ impl AgentMiddleware for MemoryMiddleware {
return Ok(());
}
match self.growth.process_conversation(
// Combined extraction: single LLM call produces both memories and structured facts.
// Avoids double LLM extraction ( process_conversation + extract_structured_facts).
match self.growth.extract_combined(
&ctx.agent_id,
&ctx.messages,
ctx.session_id.clone(),
&ctx.session_id,
).await {
Ok(count) => {
Ok(Some((mem_count, facts))) => {
tracing::info!(
"[MemoryMiddleware] Extracted {} memories for agent {}",
count,
"[MemoryMiddleware] Extracted {} memories + {} structured facts for agent {}",
mem_count,
facts.len(),
agent_key
);
}
Ok(None) => {
tracing::debug!("[MemoryMiddleware] No memories or facts extracted");
}
Err(e) => {
// Non-fatal: extraction failure should not affect the response
tracing::warn!("[MemoryMiddleware] Memory extraction failed: {}", e);
tracing::warn!("[MemoryMiddleware] Combined extraction failed: {}", e);
}
}

View File

@@ -26,14 +26,17 @@ chrono = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
sqlx = { workspace = true }
pgvector = { version = "0.4", features = ["sqlx"] }
reqwest = { workspace = true }
secrecy = { workspace = true }
sha2 = { workspace = true }
rand = { workspace = true }
dashmap = { workspace = true }
hex = { workspace = true }
rsa = { workspace = true, features = ["sha2"] }
base64 = { workspace = true }
socket2 = { workspace = true }
url = "2"
url = { workspace = true }
axum = { workspace = true }
axum-extra = { workspace = true }
@@ -47,6 +50,7 @@ data-encoding = "2"
regex = { workspace = true }
aes-gcm = { workspace = true }
bytes = { workspace = true }
async-stream = { workspace = true }
[dev-dependencies]
tempfile = { workspace = true }

View File

@@ -0,0 +1,10 @@
-- Add is_embedding column to models table
-- Distinguishes embedding models from chat/completion models
ALTER TABLE models ADD COLUMN IF NOT EXISTS is_embedding BOOLEAN NOT NULL DEFAULT FALSE;
-- Add model_type column for future extensibility (chat, embedding, image, audio, etc.)
ALTER TABLE models ADD COLUMN IF NOT EXISTS model_type TEXT NOT NULL DEFAULT 'chat';
-- Index for quick filtering of embedding models
CREATE INDEX IF NOT EXISTS idx_models_is_embedding ON models(is_embedding) WHERE is_embedding = TRUE;
CREATE INDEX IF NOT EXISTS idx_models_model_type ON models(model_type);

View File

@@ -0,0 +1,133 @@
-- Migration: Billing tables for subscription management
-- Supports: Free/Pro/Team plans, Alipay + WeChat Pay, usage quotas
-- Plan definitions (Free/Pro/Team)
CREATE TABLE IF NOT EXISTS billing_plans (
id TEXT PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
display_name TEXT NOT NULL,
description TEXT,
price_cents INTEGER NOT NULL DEFAULT 0,
currency TEXT NOT NULL DEFAULT 'CNY',
interval TEXT NOT NULL DEFAULT 'month',
features JSONB NOT NULL DEFAULT '{}',
limits JSONB NOT NULL DEFAULT '{}',
is_default BOOLEAN NOT NULL DEFAULT FALSE,
sort_order INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'active',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_billing_plans_status ON billing_plans(status);
-- Account subscriptions
CREATE TABLE IF NOT EXISTS billing_subscriptions (
id TEXT PRIMARY KEY,
account_id TEXT NOT NULL,
plan_id TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'active',
current_period_start TIMESTAMPTZ NOT NULL DEFAULT NOW(),
current_period_end TIMESTAMPTZ NOT NULL,
trial_end TIMESTAMPTZ,
canceled_at TIMESTAMPTZ,
cancel_at_period_end BOOLEAN NOT NULL DEFAULT FALSE,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
FOREIGN KEY (plan_id) REFERENCES billing_plans(id)
);
CREATE INDEX IF NOT EXISTS idx_billing_sub_account ON billing_subscriptions(account_id);
CREATE UNIQUE INDEX IF NOT EXISTS idx_billing_sub_active
ON billing_subscriptions(account_id)
WHERE status IN ('trial', 'active', 'past_due');
-- Invoices
CREATE TABLE IF NOT EXISTS billing_invoices (
id TEXT PRIMARY KEY,
account_id TEXT NOT NULL,
subscription_id TEXT,
plan_id TEXT,
amount_cents INTEGER NOT NULL,
currency TEXT NOT NULL DEFAULT 'CNY',
description TEXT,
status TEXT NOT NULL DEFAULT 'pending',
due_at TIMESTAMPTZ,
paid_at TIMESTAMPTZ,
voided_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
FOREIGN KEY (subscription_id) REFERENCES billing_subscriptions(id) ON DELETE SET NULL,
FOREIGN KEY (plan_id) REFERENCES billing_plans(id)
);
CREATE INDEX IF NOT EXISTS idx_billing_inv_account ON billing_invoices(account_id);
CREATE INDEX IF NOT EXISTS idx_billing_inv_status ON billing_invoices(status);
CREATE INDEX IF NOT EXISTS idx_billing_inv_time ON billing_invoices(created_at);
-- Payment records (Alipay / WeChat Pay)
CREATE TABLE IF NOT EXISTS billing_payments (
id TEXT PRIMARY KEY,
invoice_id TEXT NOT NULL,
account_id TEXT NOT NULL,
amount_cents INTEGER NOT NULL,
currency TEXT NOT NULL DEFAULT 'CNY',
method TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
external_trade_no TEXT,
paid_at TIMESTAMPTZ,
refunded_at TIMESTAMPTZ,
failure_reason TEXT,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
FOREIGN KEY (invoice_id) REFERENCES billing_invoices(id) ON DELETE CASCADE,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS idx_billing_pay_invoice ON billing_payments(invoice_id);
CREATE INDEX IF NOT EXISTS idx_billing_pay_account ON billing_payments(account_id);
CREATE INDEX IF NOT EXISTS idx_billing_pay_trade_no ON billing_payments(external_trade_no);
CREATE INDEX IF NOT EXISTS idx_billing_pay_status ON billing_payments(status);
-- Monthly usage quotas (per account per billing period)
CREATE TABLE IF NOT EXISTS billing_usage_quotas (
id TEXT PRIMARY KEY,
account_id TEXT NOT NULL,
period_start TIMESTAMPTZ NOT NULL,
period_end TIMESTAMPTZ NOT NULL,
input_tokens BIGINT NOT NULL DEFAULT 0,
output_tokens BIGINT NOT NULL DEFAULT 0,
relay_requests INTEGER NOT NULL DEFAULT 0,
hand_executions INTEGER NOT NULL DEFAULT 0,
pipeline_runs INTEGER NOT NULL DEFAULT 0,
max_input_tokens BIGINT,
max_output_tokens BIGINT,
max_relay_requests INTEGER,
max_hand_executions INTEGER,
max_pipeline_runs INTEGER,
metadata JSONB NOT NULL DEFAULT '{}',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE CASCADE,
UNIQUE(account_id, period_start)
);
CREATE INDEX IF NOT EXISTS idx_billing_usage_account ON billing_usage_quotas(account_id);
CREATE INDEX IF NOT EXISTS idx_billing_usage_period ON billing_usage_quotas(period_start, period_end);
-- Seed: default plans
INSERT INTO billing_plans (id, name, display_name, description, price_cents, interval, features, limits, is_default, sort_order)
VALUES
('plan-free', 'free', '免费版', '基础功能,适合个人体验', 0, 'month',
'{"hands": ["browser", "collector", "researcher"], "chat_modes": ["flash", "thinking"], "pipelines": 3, "support": "community"}'::jsonb,
'{"max_input_tokens_monthly": 500000, "max_output_tokens_monthly": 500000, "max_relay_requests_monthly": 100, "max_hand_executions_monthly": 20, "max_pipeline_runs_monthly": 5}'::jsonb,
TRUE, 0),
('plan-pro', 'pro', '专业版', '全功能解锁,适合知识工作者', 4900, 'month',
'{"hands": "all", "chat_modes": "all", "pipelines": -1, "support": "priority", "memory": true, "export": true}'::jsonb,
'{"max_input_tokens_monthly": 5000000, "max_output_tokens_monthly": 5000000, "max_relay_requests_monthly": 2000, "max_hand_executions_monthly": 200, "max_pipeline_runs_monthly": 100}'::jsonb,
FALSE, 1),
('plan-team', 'team', '团队版', '多席位协作,适合企业团队', 19900, 'month',
'{"hands": "all", "chat_modes": "all", "pipelines": -1, "support": "dedicated", "memory": true, "export": true, "sharing": true, "admin": true}'::jsonb,
'{"max_input_tokens_monthly": 50000000, "max_output_tokens_monthly": 50000000, "max_relay_requests_monthly": 20000, "max_hand_executions_monthly": 1000, "max_pipeline_runs_monthly": 500}'::jsonb,
FALSE, 2)
ON CONFLICT (name) DO NOTHING;

View File

@@ -0,0 +1,123 @@
-- Migration: Knowledge Base tables with pgvector support
-- 5 tables: knowledge_categories, knowledge_items, knowledge_chunks,
-- knowledge_versions, knowledge_usage
-- Enable pgvector extension
CREATE EXTENSION IF NOT EXISTS vector;
-- 行业分类树
CREATE TABLE IF NOT EXISTS knowledge_categories (
id TEXT PRIMARY KEY,
name VARCHAR(100) NOT NULL,
description TEXT,
parent_id TEXT REFERENCES knowledge_categories(id) ON DELETE RESTRICT,
icon VARCHAR(50),
sort_order INT DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
CHECK (id != parent_id)
);
CREATE INDEX IF NOT EXISTS idx_kc_parent ON knowledge_categories(parent_id);
-- 知识条目
CREATE TABLE IF NOT EXISTS knowledge_items (
id TEXT PRIMARY KEY,
category_id TEXT NOT NULL REFERENCES knowledge_categories(id) ON DELETE RESTRICT,
title VARCHAR(255) NOT NULL,
content TEXT NOT NULL,
keywords TEXT[] DEFAULT '{}',
related_questions TEXT[] DEFAULT '{}',
priority INT DEFAULT 0,
status VARCHAR(20) DEFAULT 'active' CHECK (status IN ('active', 'archived', 'deprecated', 'draft')),
version INT DEFAULT 1,
source VARCHAR(50) DEFAULT 'manual',
tags TEXT[] DEFAULT '{}',
created_by TEXT NOT NULL REFERENCES accounts(id),
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
CHECK (length(content) <= 100000)
);
CREATE INDEX IF NOT EXISTS idx_ki_category ON knowledge_items(category_id);
CREATE INDEX IF NOT EXISTS idx_ki_status_updated ON knowledge_items(status, updated_at DESC);
CREATE INDEX IF NOT EXISTS idx_ki_keywords ON knowledge_items USING GIN(keywords);
-- 知识分块RAG 检索核心)
CREATE TABLE IF NOT EXISTS knowledge_chunks (
id TEXT PRIMARY KEY,
item_id TEXT NOT NULL REFERENCES knowledge_items(id) ON DELETE CASCADE,
chunk_index INT NOT NULL,
content TEXT NOT NULL,
embedding vector(1536),
keywords TEXT[] DEFAULT '{}',
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_kchunks_item_idx ON knowledge_chunks(item_id, chunk_index);
CREATE INDEX IF NOT EXISTS idx_kchunks_item ON knowledge_chunks(item_id);
CREATE INDEX IF NOT EXISTS idx_kchunks_keywords ON knowledge_chunks USING GIN(keywords);
-- 向量相似度索引HNSW无需预填充数据
-- 仅在有数据后创建此索引可提升性能,这里预创建
CREATE INDEX IF NOT EXISTS idx_kchunks_embedding ON knowledge_chunks
USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 128);
-- 版本快照
CREATE TABLE IF NOT EXISTS knowledge_versions (
id TEXT PRIMARY KEY,
item_id TEXT NOT NULL REFERENCES knowledge_items(id) ON DELETE CASCADE,
version INT NOT NULL,
title VARCHAR(255) NOT NULL,
content TEXT NOT NULL,
keywords TEXT[] DEFAULT '{}',
related_questions TEXT[] DEFAULT '{}',
change_summary TEXT,
created_by TEXT NOT NULL REFERENCES accounts(id),
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_kv_item ON knowledge_versions(item_id);
-- 使用追踪
CREATE TABLE IF NOT EXISTS knowledge_usage (
id TEXT PRIMARY KEY,
item_id TEXT REFERENCES knowledge_items(id) ON DELETE SET NULL,
chunk_id TEXT REFERENCES knowledge_chunks(id) ON DELETE SET NULL,
session_id VARCHAR(100),
query_text TEXT,
relevance_score FLOAT,
was_injected BOOLEAN DEFAULT FALSE,
agent_feedback VARCHAR(20) CHECK (agent_feedback IN ('positive', 'negative')),
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_ku_item ON knowledge_usage(item_id) WHERE item_id IS NOT NULL;
-- BRIN 索引:追加写入的时间序列数据比 B-tree 更高效
CREATE INDEX IF NOT EXISTS idx_ku_created_brin ON knowledge_usage USING brin(created_at);
-- 权限种子数据(使用 jsonb 操作避免 REPLACE 脆弱性)
UPDATE roles
SET permissions = (
SELECT '[' || string_agg('"' || elem || '"', ', ') || ']'
FROM (
SELECT DISTINCT elem
FROM json_array_elements_text(permissions::json) AS elem
UNION ALL SELECT 'knowledge:read'
UNION ALL SELECT 'knowledge:write'
UNION ALL SELECT 'knowledge:admin'
UNION ALL SELECT 'knowledge:search'
) sub
)
WHERE id = 'super_admin'
AND permissions NOT LIKE '%knowledge:read%';
UPDATE roles
SET permissions = (
SELECT '[' || string_agg('"' || elem || '"', ', ') || ']'
FROM (
SELECT DISTINCT elem
FROM json_array_elements_text(permissions::json) AS elem
UNION ALL SELECT 'knowledge:read'
UNION ALL SELECT 'knowledge:write'
UNION ALL SELECT 'knowledge:search'
) sub
)
WHERE id = 'admin'
AND permissions NOT LIKE '%knowledge:read%';

View File

@@ -0,0 +1,5 @@
-- Add execution result columns to scheduled_tasks
-- Tracks the output and duration of each task execution for observability
ALTER TABLE scheduled_tasks ADD COLUMN IF NOT EXISTS last_result TEXT;
ALTER TABLE scheduled_tasks ADD COLUMN IF NOT EXISTS last_duration_ms INTEGER;

View File

@@ -67,14 +67,17 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
}
}
// 异步更新 last_used_at(不阻塞请求)
let db = state.db.clone();
tokio::spawn(async move {
let now = chrono::Utc::now().to_rfc3339();
let _ = sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
.bind(&now).bind(&token_hash)
.execute(&db).await;
});
// 异步更新 last_used_at — 通过 Worker 通道派发,受 SpawnLimiter 门控
// 替换原来的 tokio::spawn(DB UPDATE),消除每请求无限制 spawn
{
use crate::workers::update_last_used::UpdateLastUsedArgs;
let args = UpdateLastUsedArgs {
token_hash: token_hash.to_string(),
};
if let Err(e) = state.worker_dispatcher.dispatch("update_last_used", args).await {
tracing::debug!("Failed to dispatch update_last_used: {}", e);
}
}
Ok(AuthContext {
account_id,
@@ -84,23 +87,43 @@ async fn verify_api_token(state: &AppState, raw_token: &str, client_ip: Option<S
})
}
/// 从请求中提取客户端 IP
fn extract_client_ip(req: &Request) -> Option<String> {
// 优先从 ConnectInfo 获取
if let Some(ConnectInfo(addr)) = req.extensions().get::<ConnectInfo<SocketAddr>>() {
return Some(addr.ip().to_string());
/// 从请求中提取客户端 IP(安全版:仅对 trusted_proxies 解析 XFF
fn extract_client_ip(req: &Request, trusted_proxies: &[String]) -> Option<String> {
// 优先从 ConnectInfo 获取直接连接 IP
let connect_ip = req.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ConnectInfo(addr)| addr.ip().to_string());
// 仅当直接连接 IP 在 trusted_proxies 中时,才信任 XFF/X-Real-IP
if let Some(ref ip) = connect_ip {
if trusted_proxies.iter().any(|p| p == ip) {
// 受信代理 → 从 XFF 取真实客户端 IP
if let Some(forwarded) = req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
{
if let Some(client) = forwarded.split(',').next() {
let trimmed = client.trim();
if !trimmed.is_empty() {
return Some(trimmed.to_string());
}
}
}
// 尝试 X-Real-IP
if let Some(real_ip) = req.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
{
let trimmed = real_ip.trim();
if !trimmed.is_empty() {
return Some(trimmed.to_string());
}
}
}
}
// 回退到 X-Forwarded-For / X-Real-IP
if let Some(forwarded) = req.headers()
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
{
return Some(forwarded.split(',').next()?.trim().to_string());
}
req.headers()
.get("x-real-ip")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
// 非受信来源或无代理头 → 返回直接连接 IP
connect_ip
}
/// 认证中间件: 从 JWT Cookie / Authorization Header / API Token 提取身份
@@ -110,7 +133,10 @@ pub async fn auth_middleware(
mut req: Request,
next: Next,
) -> Response {
let client_ip = extract_client_ip(&req);
let client_ip = {
let config = state.config.read().await;
extract_client_ip(&req, &config.server.trusted_proxies)
};
let auth_header = req.headers()
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok());

View File

@@ -0,0 +1,395 @@
//! 计费 HTTP 处理器
use axum::{
extract::{Extension, Form, Path, Query, State},
Json,
};
use serde::Deserialize;
use crate::auth::types::AuthContext;
use crate::error::{SaasError, SaasResult};
use crate::state::AppState;
use super::service;
use super::types::*;
/// GET /api/v1/billing/plans — 列出所有活跃计划
pub async fn list_plans(
State(state): State<AppState>,
) -> SaasResult<Json<Vec<BillingPlan>>> {
let plans = service::list_plans(&state.db).await?;
Ok(Json(plans))
}
/// GET /api/v1/billing/plans/:id — 获取单个计划详情
pub async fn get_plan(
State(state): State<AppState>,
Path(plan_id): Path<String>,
) -> SaasResult<Json<BillingPlan>> {
let plan = service::get_plan(&state.db, &plan_id).await?
.ok_or_else(|| crate::error::SaasError::NotFound("计划不存在".into()))?;
Ok(Json(plan))
}
/// GET /api/v1/billing/subscription — 获取当前订阅
pub async fn get_subscription(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
let plan = service::get_account_plan(&state.db, &ctx.account_id).await?;
let sub = service::get_active_subscription(&state.db, &ctx.account_id).await?;
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
Ok(Json(serde_json::json!({
"plan": plan,
"subscription": sub,
"usage": usage,
})))
}
/// GET /api/v1/billing/usage — 获取当月用量
pub async fn get_usage(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<UsageQuota>> {
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
Ok(Json(usage))
}
/// POST /api/v1/billing/usage/increment — 客户端上报用量Hand/Pipeline 执行后调用)
///
/// 请求体: `{ "dimension": "hand_executions" | "pipeline_runs" | "relay_requests", "count": 1 }`
/// 需要认证 — account_id 从 JWT 提取。
#[derive(Debug, Deserialize)]
pub struct IncrementUsageRequest {
/// 用量维度hand_executions / pipeline_runs / relay_requests
pub dimension: String,
/// 递增数量,默认 1
#[serde(default = "default_count")]
pub count: i32,
}
fn default_count() -> i32 { 1 }
pub async fn increment_usage_dimension(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<IncrementUsageRequest>,
) -> SaasResult<Json<serde_json::Value>> {
// 验证维度白名单
if !["hand_executions", "pipeline_runs", "relay_requests"].contains(&req.dimension.as_str()) {
return Err(SaasError::InvalidInput(
format!("无效的用量维度: {},支持: hand_executions / pipeline_runs / relay_requests", req.dimension)
));
}
// 限制单次递增上限(防滥用)
if req.count < 1 || req.count > 100 {
return Err(SaasError::InvalidInput(
format!("count 必须在 1~100 范围内,得到: {}", req.count)
));
}
// 单次原子更新,避免循环 N 次数据库查询
service::increment_dimension_by(&state.db, &ctx.account_id, &req.dimension, req.count).await?;
// 返回更新后的用量
let usage = service::get_or_create_usage(&state.db, &ctx.account_id).await?;
Ok(Json(serde_json::json!({
"dimension": req.dimension,
"incremented": req.count,
"usage": usage,
})))
}
/// POST /api/v1/billing/payments — 创建支付订单
pub async fn create_payment(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreatePaymentRequest>,
) -> SaasResult<Json<PaymentResult>> {
let config = state.config.read().await;
let result = super::payment::create_payment(
&state.db,
&ctx.account_id,
&req,
&config.payment,
).await?;
Ok(Json(result))
}
/// GET /api/v1/billing/payments/:id — 查询支付状态
pub async fn get_payment_status(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(payment_id): Path<String>,
) -> SaasResult<Json<serde_json::Value>> {
let status = super::payment::query_payment_status(
&state.db,
&payment_id,
&ctx.account_id,
).await?;
Ok(Json(status))
}
/// POST /api/v1/billing/callback/:method — 支付回调(支付宝/微信异步通知)
pub async fn payment_callback(
State(state): State<AppState>,
Path(method): Path<String>,
body: axum::body::Bytes,
) -> SaasResult<String> {
tracing::info!("Payment callback received: method={}, body_len={}", method, body.len());
let body_str = String::from_utf8_lossy(&body);
let config = state.config.read().await;
let (trade_no, status, callback_amount) = if method == "alipay" {
parse_alipay_callback(&body_str, &config.payment)?
} else if method == "wechat" {
parse_wechat_callback(&body_str, &config.payment)?
} else {
tracing::warn!("Unknown payment callback method: {}", method);
return Ok("fail".into());
};
// trade_no 是必填字段,缺失说明回调格式异常
let trade_no = trade_no.ok_or_else(|| {
tracing::warn!("Payment callback missing out_trade_no: method={}", method);
SaasError::InvalidInput("回调缺少交易号".into())
})?;
if let Err(e) = super::payment::handle_payment_callback(&state.db, &trade_no, &status, callback_amount).await {
// 对外返回通用错误,不泄露内部细节
tracing::error!("Payment callback processing failed: method={}, error={}", method, e);
return Ok("fail".into());
}
// 支付宝期望 "success",微信期望 JSON
if method == "alipay" {
Ok("success".into())
} else {
Ok(r#"{"code":"SUCCESS","message":"OK"}"#.into())
}
}
// === Mock 支付(开发模式) ===
#[derive(Debug, Deserialize)]
pub struct MockPayQuery {
trade_no: String,
amount: i32,
subject: String,
}
/// GET /api/v1/billing/mock-pay — 开发模式 Mock 支付页面
pub async fn mock_pay_page(
Query(params): Query<MockPayQuery>,
) -> axum::response::Html<String> {
// HTML 转义防止 XSS
let safe_subject = html_escape(&params.subject);
let safe_trade_no = html_escape(&params.trade_no);
let amount_yuan = params.amount as f64 / 100.0;
axum::response::Html(format!(r#"
<!DOCTYPE html>
<html lang="zh">
<head><meta charset="utf-8"><title>Mock 支付</title>
<style>
body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20px; }}
.card {{ background: #fff; border-radius: 12px; padding: 24px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }}
.amount {{ font-size: 32px; font-weight: 700; color: #333; text-align: center; margin: 20px 0; }}
.btn {{ display: block; width: 100%; padding: 12px; border: none; border-radius: 8px; font-size: 16px; cursor: pointer; margin-top: 12px; }}
.btn-pay {{ background: #1677ff; color: #fff; }}
.btn-pay:hover {{ background: #0958d9; }}
.btn-fail {{ background: #f5f5f5; color: #999; }}
.subject {{ text-align: center; color: #666; font-size: 14px; }}
</style></head>
<body>
<div class="card">
<div class="subject">{safe_subject}</div>
<div class="amount">¥{amount_yuan}</div>
<div style="text-align:center;color:#999;font-size:12px;margin-bottom:16px;">
订单号: {safe_trade_no}
</div>
<form action="/api/v1/billing/mock-pay/confirm" method="POST">
<input type="hidden" name="trade_no" value="{safe_trade_no}" />
<button type="submit" name="action" value="success" class="btn btn-pay">确认支付 ¥{amount_yuan}</button>
<button type="submit" name="action" value="fail" class="btn btn-fail">模拟失败</button>
</form>
</div>
</body></html>
"#))
}
#[derive(Debug, Deserialize)]
pub struct MockPayConfirm {
trade_no: String,
action: String,
}
/// POST /api/v1/billing/mock-pay/confirm — Mock 支付确认
pub async fn mock_pay_confirm(
State(state): State<AppState>,
Form(form): Form<MockPayConfirm>,
) -> SaasResult<axum::response::Html<String>> {
let status = if form.action == "success" { "success" } else { "failed" };
if let Err(e) = super::payment::handle_payment_callback(&state.db, &form.trade_no, status, None).await {
tracing::error!("Mock payment callback failed: {}", e);
}
let msg = if status == "success" {
"支付成功!您可以关闭此页面。"
} else {
"支付已取消。"
};
Ok(axum::response::Html(format!(r#"
<!DOCTYPE html>
<html lang="zh">
<head><meta charset="utf-8"><title>支付结果</title>
<style>
body {{ font-family: system-ui; max-width: 480px; margin: 40px auto; padding: 20px; text-align: center; }}
.msg {{ font-size: 18px; color: #333; margin: 40px 0; }}
</style></head>
<body><div class="msg">{msg}</div></body>
</html>
"#)))
}
// === 回调解析 ===
/// 解析支付宝回调并验签,返回 (trade_no, status, callback_amount_cents)
fn parse_alipay_callback(
body: &str,
config: &crate::config::PaymentConfig,
) -> SaasResult<(Option<String>, String, Option<i32>)> {
// form-urlencoded → key=value 对
let mut params: Vec<(String, String)> = Vec::new();
for pair in body.split('&') {
if let Some((k, v)) = pair.split_once('=') {
params.push((
k.to_string(),
urlencoding::decode(v).unwrap_or_default().to_string(),
));
}
}
let mut trade_no = None;
let mut callback_amount: Option<i32> = None;
// 验签:生产环境强制,开发环境允许跳过
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if let Some(ref public_key) = config.alipay_public_key {
match super::payment::verify_alipay_callback(&params, public_key) {
Ok(true) => {}
Ok(false) => {
tracing::warn!("Alipay callback signature verification FAILED");
return Err(SaasError::InvalidInput("支付宝回调验签失败".into()));
}
Err(e) => {
tracing::error!("Alipay callback verification error: {}", e);
return Err(SaasError::InvalidInput("支付宝回调验签异常".into()));
}
}
} else if !is_dev {
tracing::error!("Alipay public key not configured in production — rejecting callback");
return Err(SaasError::InvalidInput("支付宝公钥未配置,无法验签".into()));
} else {
tracing::warn!("Alipay public key not configured (dev mode), skipping signature verification");
}
// 提取 trade_no、trade_status 和 total_amount
let mut trade_status = "unknown".to_string();
for (k, v) in &params {
match k.as_str() {
"out_trade_no" => trade_no = Some(v.clone()),
"trade_status" => trade_status = v.clone(),
"total_amount" => {
// 支付宝金额为元(字符串),转为分(整数)
if let Ok(yuan) = v.parse::<f64>() {
callback_amount = Some((yuan * 100.0).round() as i32);
}
}
_ => {}
}
}
// 支付宝成功状态映射
let status = if trade_status == "TRADE_SUCCESS" || trade_status == "TRADE_FINISHED" {
"TRADE_SUCCESS"
} else {
&trade_status
};
Ok((trade_no, status.to_string(), callback_amount))
}
/// 解析微信支付回调,解密 resource 字段,返回 (trade_no, status, callback_amount_cents)
fn parse_wechat_callback(
body: &str,
config: &crate::config::PaymentConfig,
) -> SaasResult<(Option<String>, String, Option<i32>)> {
let v: serde_json::Value = serde_json::from_str(body)
.map_err(|e| SaasError::InvalidInput(format!("微信回调 JSON 解析失败: {}", e)))?;
let event_type = v.get("event_type")
.and_then(|t| t.as_str())
.unwrap_or("");
if event_type != "TRANSACTION.SUCCESS" {
// 非支付成功事件(如退款等),忽略
return Ok((None, event_type.to_string(), None));
}
// 解密 resource 字段
let resource = v.get("resource")
.ok_or_else(|| SaasError::InvalidInput("微信回调缺少 resource 字段".into()))?;
let ciphertext = resource.get("ciphertext")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 ciphertext".into()))?;
let nonce = resource.get("nonce")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::InvalidInput("微信回调 resource 缺少 nonce".into()))?;
let associated_data = resource.get("associated_data")
.and_then(|v| v.as_str())
.unwrap_or("");
let api_v3_key = config.wechat_api_v3_key.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信 API v3 密钥未配置,无法解密回调".into()))?;
let plaintext = super::payment::decrypt_wechat_resource(
ciphertext, nonce, associated_data, api_v3_key,
)?;
let decrypted: serde_json::Value = serde_json::from_str(&plaintext)
.map_err(|e| SaasError::Internal(format!("微信回调解密内容 JSON 解析失败: {}", e)))?;
let trade_no = decrypted.get("out_trade_no")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let trade_state = decrypted.get("trade_state")
.and_then(|v| v.as_str())
.unwrap_or("UNKNOWN");
// 微信金额已为分(整数)
let callback_amount = decrypted.get("amount")
.and_then(|a| a.get("total"))
.and_then(|v| v.as_i64())
.map(|v| v as i32);
Ok((trade_no, trade_state.to_string(), callback_amount))
}
/// HTML 转义,防止 XSS 注入
fn html_escape(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&#x27;")
}

View File

@@ -0,0 +1,33 @@
//! 计费模块 — 计划管理、订阅、用量配额、支付
pub mod types;
pub mod service;
pub mod handlers;
pub mod payment;
use axum::routing::{get, post};
/// 需要认证的计费路由
pub fn routes() -> axum::Router<crate::state::AppState> {
axum::Router::new()
.route("/api/v1/billing/plans", get(handlers::list_plans))
.route("/api/v1/billing/plans/{id}", get(handlers::get_plan))
.route("/api/v1/billing/subscription", get(handlers::get_subscription))
.route("/api/v1/billing/usage", get(handlers::get_usage))
.route("/api/v1/billing/usage/increment", post(handlers::increment_usage_dimension))
.route("/api/v1/billing/payments", post(handlers::create_payment))
.route("/api/v1/billing/payments/{id}", get(handlers::get_payment_status))
}
/// 支付回调路由(无需 auth — 支付宝/微信服务器回调)
pub fn callback_routes() -> axum::Router<crate::state::AppState> {
axum::Router::new()
.route("/api/v1/billing/callback/{method}", post(handlers::payment_callback))
}
/// mock 支付页面路由(开发模式)
pub fn mock_routes() -> axum::Router<crate::state::AppState> {
axum::Router::new()
.route("/api/v1/billing/mock-pay", get(handlers::mock_pay_page))
.route("/api/v1/billing/mock-pay/confirm", post(handlers::mock_pay_confirm))
}

View File

@@ -0,0 +1,647 @@
//! 支付集成 — 支付宝/微信支付(直连 HTTP 实现)
//!
//! 不依赖第三方 SDK使用 `rsa` crate 做 RSA2 签名,`reqwest` 做 HTTP 调用。
//! 开发模式(`ZCLAW_SAAS_DEV=true`)使用 mock 支付。
use sqlx::PgPool;
use crate::config::PaymentConfig;
use crate::error::{SaasError, SaasResult};
use super::types::*;
// ────────────────────────────────────────────────────────────────
// 公开 API
// ────────────────────────────────────────────────────────────────
/// 创建支付订单,返回支付链接/二维码 URL
///
/// 发票和支付记录在事务中创建,确保原子性。
pub async fn create_payment(
pool: &PgPool,
account_id: &str,
req: &CreatePaymentRequest,
config: &PaymentConfig,
) -> SaasResult<PaymentResult> {
// 1. 获取计划信息
let plan = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'"
)
.bind(&req.plan_id)
.fetch_optional(pool)
.await?
.ok_or_else(|| SaasError::NotFound("计划不存在或已下架".into()))?;
// 检查是否已有活跃订阅
let existing = sqlx::query_scalar::<_, i64>(
"SELECT COUNT(*) FROM billing_subscriptions \
WHERE account_id = $1 AND status IN ('trial', 'active') AND plan_id = $2"
)
.bind(account_id)
.bind(&req.plan_id)
.fetch_one(pool)
.await?;
if existing > 0 {
return Err(SaasError::InvalidInput("已订阅该计划".into()));
}
// 2. 在事务中创建发票和支付记录
let mut tx = pool.begin().await
.map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?;
let invoice_id = uuid::Uuid::new_v4().to_string();
let now = chrono::Utc::now();
let due = now + chrono::Duration::days(1);
sqlx::query(
"INSERT INTO billing_invoices \
(id, account_id, plan_id, amount_cents, currency, description, status, due_at, created_at, updated_at) \
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7, $8, $8)"
)
.bind(&invoice_id)
.bind(account_id)
.bind(&req.plan_id)
.bind(plan.price_cents)
.bind(&plan.currency)
.bind(format!("{} - {} ({})", plan.display_name, plan.interval, now.format("%Y-%m")))
.bind(due.to_rfc3339())
.bind(now.to_rfc3339())
.execute(&mut *tx)
.await?;
let payment_id = uuid::Uuid::new_v4().to_string();
let trade_no = format!("ZCLAW-{}-{}", chrono::Utc::now().format("%Y%m%d%H%M%S"), &payment_id[..8]);
sqlx::query(
"INSERT INTO billing_payments \
(id, invoice_id, account_id, amount_cents, currency, method, status, external_trade_no, metadata, created_at, updated_at) \
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7, '{}', $8, $8)"
)
.bind(&payment_id)
.bind(&invoice_id)
.bind(account_id)
.bind(plan.price_cents)
.bind(&plan.currency)
.bind(req.payment_method.to_string())
.bind(&trade_no)
.bind(now.to_rfc3339())
.execute(&mut *tx)
.await?;
tx.commit().await
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
// 3. 生成支付链接
let pay_url = generate_pay_url(
req.payment_method,
&trade_no,
plan.price_cents,
&plan.display_name,
config,
).await?;
Ok(PaymentResult {
payment_id,
trade_no,
pay_url,
amount_cents: plan.price_cents,
})
}
/// 处理支付回调(支付宝/微信异步通知)
///
/// `callback_amount_cents` 来自回调报文的金额(分),用于与 DB 金额交叉验证。
/// 整个操作在数据库事务中执行,使用 SELECT FOR UPDATE 防止并发竞态。
pub async fn handle_payment_callback(
pool: &PgPool,
trade_no: &str,
status: &str,
callback_amount_cents: Option<i32>,
) -> SaasResult<()> {
// 1. 在事务中锁定支付记录,防止 TOCTOU 竞态
let mut tx = pool.begin().await
.map_err(|e| SaasError::Internal(format!("开启事务失败: {}", e)))?;
let payment: Option<(String, String, String, i32, String)> = sqlx::query_as::<_, (String, String, String, i32, String)>(
"SELECT id, invoice_id, account_id, amount_cents, status \
FROM billing_payments WHERE external_trade_no = $1 FOR UPDATE"
)
.bind(trade_no)
.fetch_optional(&mut *tx)
.await?;
let (payment_id, invoice_id, account_id, db_amount, current_status) = match payment {
Some(p) => p,
None => {
tracing::error!("Payment callback for unknown trade: {}", sanitize_log(trade_no));
tx.rollback().await?;
return Ok(());
}
};
// 幂等性:已处理过直接返回
if current_status != "pending" {
tracing::info!("Payment already processed (idempotent): trade={}, status={}", sanitize_log(trade_no), current_status);
tx.rollback().await?;
return Ok(());
}
// 2. 金额交叉验证(防篡改)
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if let Some(callback_amount) = callback_amount_cents {
if callback_amount != db_amount {
tracing::error!(
"Amount mismatch: trade={}, db_amount={}, callback_amount={}. Rejecting.",
sanitize_log(trade_no), db_amount, callback_amount
);
tx.rollback().await?;
return Err(SaasError::InvalidInput("回调验证失败".into()));
}
} else if !is_dev {
// 非开发环境必须有金额
tracing::error!("Callback without amount in non-dev mode: trade={}", sanitize_log(trade_no));
tx.rollback().await?;
return Err(SaasError::InvalidInput("回调缺少金额验证".into()));
} else {
tracing::warn!("DEV: Skipping amount verification for trade={}", sanitize_log(trade_no));
}
let now = chrono::Utc::now().to_rfc3339();
if status == "success" || status == "TRADE_SUCCESS" || status == "SUCCESS" {
// 3. 更新支付状态
sqlx::query(
"UPDATE billing_payments SET status = 'succeeded', paid_at = $1, updated_at = $1 WHERE id = $2"
)
.bind(&now)
.bind(&payment_id)
.execute(&mut *tx)
.await?;
// 4. 更新发票状态
sqlx::query(
"UPDATE billing_invoices SET status = 'paid', paid_at = $1, updated_at = $1 WHERE id = $2"
)
.bind(&now)
.bind(&invoice_id)
.execute(&mut *tx)
.await?;
// 5. 获取发票关联的计划
let plan_id: Option<String> = sqlx::query_scalar(
"SELECT plan_id FROM billing_invoices WHERE id = $1"
)
.bind(&invoice_id)
.fetch_optional(&mut *tx)
.await?
.flatten();
if let Some(plan_id) = plan_id {
// 6. 取消旧订阅
sqlx::query(
"UPDATE billing_subscriptions SET status = 'canceled', canceled_at = $1, updated_at = $1 \
WHERE account_id = $2 AND status IN ('trial', 'active')"
)
.bind(&now)
.bind(&account_id)
.execute(&mut *tx)
.await?;
// 7. 创建新订阅30 天周期)
let sub_id = uuid::Uuid::new_v4().to_string();
let period_end = (chrono::Utc::now() + chrono::Duration::days(30)).to_rfc3339();
let period_start = chrono::Utc::now().to_rfc3339();
sqlx::query(
"INSERT INTO billing_subscriptions \
(id, account_id, plan_id, status, current_period_start, current_period_end, created_at, updated_at) \
VALUES ($1, $2, $3, 'active', $4, $5, $6, $6)"
)
.bind(&sub_id)
.bind(&account_id)
.bind(&plan_id)
.bind(&period_start)
.bind(&period_end)
.bind(&now)
.execute(&mut *tx)
.await?;
tracing::info!(
"Payment succeeded: account={}, plan={}, subscription={}",
account_id, plan_id, sub_id
);
}
tx.commit().await
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
} else {
// 支付失败:截断 status 防止注入,更新发票为 void
let safe_reason = truncate_str(status, 200);
sqlx::query(
"UPDATE billing_payments SET status = 'failed', failure_reason = $1, updated_at = $2 WHERE id = $3"
)
.bind(&safe_reason)
.bind(&now)
.bind(&payment_id)
.execute(&mut *tx)
.await?;
// 同时将发票标记为 void
sqlx::query(
"UPDATE billing_invoices SET status = 'void', voided_at = $1, updated_at = $1 WHERE id = $2"
)
.bind(&now)
.bind(&invoice_id)
.execute(&mut *tx)
.await?;
tx.commit().await
.map_err(|e| SaasError::Internal(format!("事务提交失败: {}", e)))?;
tracing::warn!("Payment failed: trade={}, status={}", sanitize_log(trade_no), safe_reason);
}
Ok(())
}
/// 查询支付状态
pub async fn query_payment_status(
pool: &PgPool,
payment_id: &str,
account_id: &str,
) -> SaasResult<serde_json::Value> {
let payment: (String, String, i32, String, String) = sqlx::query_as::<_, (String, String, i32, String, String)>(
"SELECT id, method, amount_cents, currency, status \
FROM billing_payments WHERE id = $1 AND account_id = $2"
)
.bind(payment_id)
.bind(account_id)
.fetch_optional(pool)
.await?
.ok_or_else(|| SaasError::NotFound("支付记录不存在".into()))?;
let (id, method, amount, currency, status) = payment;
Ok(serde_json::json!({
"id": id,
"method": method,
"amount_cents": amount,
"currency": currency,
"status": status,
}))
}
// ────────────────────────────────────────────────────────────────
// 支付 URL 生成
// ────────────────────────────────────────────────────────────────
/// 生成支付 URL根据配置决定 mock 或真实支付
async fn generate_pay_url(
method: PaymentMethod,
trade_no: &str,
amount_cents: i32,
subject: &str,
config: &PaymentConfig,
) -> SaasResult<String> {
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if is_dev {
return Ok(mock_pay_url(trade_no, amount_cents, subject));
}
match method {
PaymentMethod::Alipay => generate_alipay_url(trade_no, amount_cents, subject, config),
PaymentMethod::Wechat => generate_wechat_url(trade_no, amount_cents, subject, config).await,
}
}
fn mock_pay_url(trade_no: &str, amount_cents: i32, subject: &str) -> String {
let base = std::env::var("ZCLAW_SAAS_URL")
.unwrap_or_else(|_| "http://localhost:8080".into());
format!(
"{}/api/v1/billing/mock-pay?trade_no={}&amount={}&subject={}",
base,
urlencoding::encode(trade_no),
amount_cents,
urlencoding::encode(subject),
)
}
// ────────────────────────────────────────────────────────────────
// 支付宝 — alipay.trade.page.payRSA2 签名 + 证书模式)
// ────────────────────────────────────────────────────────────────
fn generate_alipay_url(
trade_no: &str,
amount_cents: i32,
subject: &str,
config: &PaymentConfig,
) -> SaasResult<String> {
let app_id = config.alipay_app_id.as_deref()
.ok_or_else(|| SaasError::InvalidInput("支付宝 app_id 未配置".into()))?;
let private_key_pem = config.alipay_private_key.as_deref()
.ok_or_else(|| SaasError::InvalidInput("支付宝商户私钥未配置".into()))?;
let notify_url = config.alipay_notify_url.as_deref()
.ok_or_else(|| SaasError::InvalidInput("支付宝回调 URL 未配置".into()))?;
// 金额:分 → 元(整数运算避免浮点精度问题)
let yuan_part = amount_cents / 100;
let cent_part = amount_cents % 100;
let amount_yuan = format!("{}.{:02}", yuan_part, cent_part);
// 构建请求参数(字典序)
let mut params: Vec<(&str, String)> = vec![
("app_id", app_id.to_string()),
("method", "alipay.trade.page.pay".to_string()),
("charset", "utf-8".to_string()),
("sign_type", "RSA2".to_string()),
("timestamp", chrono::Local::now().format("%Y-%m-%d %H:%M:%S").to_string()),
("version", "1.0".to_string()),
("notify_url", notify_url.to_string()),
("biz_content", serde_json::json!({
"out_trade_no": trade_no,
"total_amount": amount_yuan,
"subject": subject,
"product_code": "FAST_INSTANT_TRADE_PAY",
}).to_string()),
];
// 按 key 字典序排列并拼接
params.sort_by(|a, b| a.0.cmp(b.0));
let sign_str: String = params.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join("&");
// RSA2 签名
let sign = rsa_sign_sha256_base64(private_key_pem, sign_str.as_bytes())?;
// 构建 gateway URL
params.push(("sign", sign));
let query: String = params.iter()
.map(|(k, v)| format!("{}={}", k, urlencoding::encode(v)))
.collect::<Vec<_>>()
.join("&");
Ok(format!("https://openapi.alipay.com/gateway.do?{}", query))
}
// ────────────────────────────────────────────────────────────────
// 微信支付 — V3 Native PayQR 码模式)
// ────────────────────────────────────────────────────────────────
async fn generate_wechat_url(
trade_no: &str,
amount_cents: i32,
subject: &str,
config: &PaymentConfig,
) -> SaasResult<String> {
let mch_id = config.wechat_mch_id.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付商户号未配置".into()))?;
let serial_no = config.wechat_serial_no.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付证书序列号未配置".into()))?;
let private_key_pem = config.wechat_private_key_path.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付私钥路径未配置".into()))?;
let notify_url = config.wechat_notify_url.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付回调 URL 未配置".into()))?;
let app_id = config.wechat_app_id.as_deref()
.ok_or_else(|| SaasError::InvalidInput("微信支付 App ID 未配置".into()))?;
// 读取私钥文件
let private_key = std::fs::read_to_string(private_key_pem)
.map_err(|e| SaasError::InvalidInput(format!("微信支付私钥文件读取失败: {}", e)))?;
let body = serde_json::json!({
"appid": app_id,
"mchid": mch_id,
"description": subject,
"out_trade_no": trade_no,
"notify_url": notify_url,
"amount": {
"total": amount_cents,
"currency": "CNY",
},
});
let body_str = body.to_string();
// 构建签名字符串
let timestamp = chrono::Utc::now().timestamp().to_string();
let nonce_str = uuid::Uuid::new_v4().to_string().replace("-", "");
let sign_message = format!(
"POST\n/v3/pay/transactions/native\n{}\n{}\n{}\n",
timestamp, nonce_str, body_str
);
let signature = rsa_sign_sha256_base64(&private_key, sign_message.as_bytes())?;
// 构建 Authorization 头
let auth_header = format!(
"WECHATPAY2-SHA256-RSA2048 mchid=\"{}\",nonce_str=\"{}\",timestamp=\"{}\",serial_no=\"{}\",signature=\"{}\"",
mch_id, nonce_str, timestamp, serial_no, signature
);
// 发送请求
let client = reqwest::Client::new();
let resp = client
.post("https://api.mch.weixin.qq.com/v3/pay/transactions/native")
.header("Content-Type", "application/json")
.header("Authorization", auth_header)
.header("Accept", "application/json")
.body(body_str)
.send()
.await
.map_err(|e| SaasError::Internal(format!("微信支付请求失败: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
tracing::error!("WeChat Pay API error: status={}, body={}", status, text);
return Err(SaasError::InvalidInput(format!(
"微信支付创建订单失败 (HTTP {})", status
)));
}
let resp_json: serde_json::Value = resp.json().await
.map_err(|e| SaasError::Internal(format!("微信支付响应解析失败: {}", e)))?;
let code_url = resp_json.get("code_url")
.and_then(|v| v.as_str())
.ok_or_else(|| SaasError::Internal("微信支付响应缺少 code_url".into()))?
.to_string();
Ok(code_url)
}
// ────────────────────────────────────────────────────────────────
// 回调验签
// ────────────────────────────────────────────────────────────────
/// 验证支付宝回调签名
pub fn verify_alipay_callback(
params: &[(String, String)],
alipay_public_key_pem: &str,
) -> SaasResult<bool> {
// 1. 提取 sign 和 sign_type剩余参数字典序拼接
let mut sign = None;
let mut filtered: Vec<(&str, &str)> = Vec::new();
for (k, v) in params {
match k.as_str() {
"sign" => sign = Some(v.clone()),
"sign_type" => {} // 跳过
_ => {
if !v.is_empty() {
filtered.push((k.as_str(), v.as_str()));
}
}
}
}
let sign = match sign {
Some(s) => s,
None => return Ok(false),
};
filtered.sort_by(|a, b| a.0.cmp(b.0));
let sign_str: String = filtered.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join("&");
// 2. 用支付宝公钥验签
rsa_verify_sha256(alipay_public_key_pem, sign_str.as_bytes(), &sign)
}
/// 解密微信支付回调 resource 字段AES-256-GCM
pub fn decrypt_wechat_resource(
ciphertext_b64: &str,
nonce: &str,
associated_data: &str,
api_v3_key: &str,
) -> SaasResult<String> {
use aes_gcm::{Aes256Gcm, KeyInit, Nonce};
use aes_gcm::aead::Aead;
use base64::Engine;
let key_bytes = api_v3_key.as_bytes();
if key_bytes.len() != 32 {
return Err(SaasError::Internal("微信 API v3 密钥必须为 32 字节".into()));
}
let nonce_bytes = nonce.as_bytes();
if nonce_bytes.len() != 12 {
return Err(SaasError::InvalidInput("微信回调 nonce 长度必须为 12 字节".into()));
}
let ciphertext = base64::engine::general_purpose::STANDARD
.decode(ciphertext_b64)
.map_err(|e| SaasError::Internal(format!("base64 解码失败: {}", e)))?;
let cipher = Aes256Gcm::new_from_slice(key_bytes)
.map_err(|e| SaasError::Internal(format!("AES 密钥初始化失败: {}", e)))?;
let nonce = Nonce::from_slice(nonce_bytes);
let plaintext = cipher
.decrypt(nonce, aes_gcm::aead::Payload {
msg: &ciphertext,
aad: associated_data.as_bytes(),
})
.map_err(|e| SaasError::Internal(format!("AES-GCM 解密失败: {}", e)))?;
String::from_utf8(plaintext)
.map_err(|e| SaasError::Internal(format!("解密结果 UTF-8 转换失败: {}", e)))
}
// ────────────────────────────────────────────────────────────────
// RSA 工具函数
// ────────────────────────────────────────────────────────────────
/// SHA256WithRSA 签名 + Base64 编码PKCS#1 v1.5
fn rsa_sign_sha256_base64(
private_key_pem: &str,
message: &[u8],
) -> SaasResult<String> {
use rsa::pkcs8::DecodePrivateKey;
use rsa::signature::{Signer, SignatureEncoding};
use sha2::Sha256;
use rsa::pkcs1v15::SigningKey;
use base64::Engine;
let private_key = rsa::RsaPrivateKey::from_pkcs8_pem(private_key_pem)
.map_err(|e| SaasError::Internal(format!("RSA 私钥解析失败: {}", e)))?;
let signing_key = SigningKey::<Sha256>::new(private_key);
let signature = signing_key.sign(message);
Ok(base64::engine::general_purpose::STANDARD.encode(signature.to_bytes()))
}
/// SHA256WithRSA 验签
fn rsa_verify_sha256(
public_key_pem: &str,
message: &[u8],
signature_b64: &str,
) -> SaasResult<bool> {
use rsa::pkcs8::DecodePublicKey;
use rsa::signature::Verifier;
use sha2::Sha256;
use rsa::pkcs1v15::VerifyingKey;
use base64::Engine;
let public_key = match rsa::RsaPublicKey::from_public_key_pem(public_key_pem) {
Ok(k) => k,
Err(e) => {
tracing::error!("RSA 公钥解析失败: {}", e);
return Ok(false);
}
};
let signature_bytes = match base64::engine::general_purpose::STANDARD.decode(signature_b64) {
Ok(b) => b,
Err(e) => {
tracing::error!("签名 base64 解码失败: {}", e);
return Ok(false);
}
};
let verifying_key = VerifyingKey::<Sha256>::new(public_key);
let signature = match rsa::pkcs1v15::Signature::try_from(signature_bytes.as_slice()) {
Ok(s) => s,
Err(_) => return Ok(false),
};
Ok(verifying_key.verify(message, &signature).is_ok())
}
// ────────────────────────────────────────────────────────────────
// 辅助函数
// ────────────────────────────────────────────────────────────────
/// 日志安全:只保留字母数字和 `-` `_`,防止日志注入
fn sanitize_log(s: &str) -> String {
s.chars()
.filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
.collect()
}
/// 截断字符串到指定长度(按字符而非字节)
fn truncate_str(s: &str, max_len: usize) -> String {
let chars: Vec<char> = s.chars().collect();
if chars.len() <= max_len {
s.to_string()
} else {
chars.into_iter().take(max_len).collect()
}
}
impl std::fmt::Display for PaymentMethod {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Alipay => write!(f, "alipay"),
Self::Wechat => write!(f, "wechat"),
}
}
}

View File

@@ -0,0 +1,303 @@
//! 计费服务层 — 计划查询、订阅管理、用量检查
use chrono::{Datelike, Timelike};
use sqlx::PgPool;
use crate::error::SaasResult;
use super::types::*;
/// 获取所有活跃计划
pub async fn list_plans(pool: &PgPool) -> SaasResult<Vec<BillingPlan>> {
let plans = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE status = 'active' ORDER BY sort_order"
)
.fetch_all(pool)
.await?;
Ok(plans)
}
/// 获取单个计划(公开 API 只返回 active 计划)
pub async fn get_plan(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
let plan = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE id = $1 AND status = 'active'"
)
.bind(plan_id)
.fetch_optional(pool)
.await?;
Ok(plan)
}
/// 获取单个计划(内部使用,不过滤 status用于已订阅用户查看旧计划
pub async fn get_plan_any_status(pool: &PgPool, plan_id: &str) -> SaasResult<Option<BillingPlan>> {
let plan = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE id = $1"
)
.bind(plan_id)
.fetch_optional(pool)
.await?;
Ok(plan)
}
/// 获取账户当前有效订阅
pub async fn get_active_subscription(
pool: &PgPool,
account_id: &str,
) -> SaasResult<Option<Subscription>> {
let sub = sqlx::query_as::<_, Subscription>(
"SELECT * FROM billing_subscriptions \
WHERE account_id = $1 AND status IN ('trial', 'active', 'past_due') \
ORDER BY created_at DESC LIMIT 1"
)
.bind(account_id)
.fetch_optional(pool)
.await?;
Ok(sub)
}
/// 获取账户当前计划(有订阅返回订阅计划,否则返回 Free
pub async fn get_account_plan(pool: &PgPool, account_id: &str) -> SaasResult<BillingPlan> {
if let Some(sub) = get_active_subscription(pool, account_id).await? {
if let Some(plan) = get_plan_any_status(pool, &sub.plan_id).await? {
return Ok(plan);
}
}
// 回退到 Free 计划
let free = sqlx::query_as::<_, BillingPlan>(
"SELECT * FROM billing_plans WHERE name = 'free' AND status = 'active' LIMIT 1"
)
.fetch_optional(pool)
.await?;
Ok(free.unwrap_or_else(|| BillingPlan {
id: "plan-free".into(),
name: "free".into(),
display_name: "免费版".into(),
description: Some("基础功能".into()),
price_cents: 0,
currency: "CNY".into(),
interval: "month".into(),
features: serde_json::json!({}),
limits: serde_json::json!({
"max_input_tokens_monthly": 500000,
"max_output_tokens_monthly": 500000,
"max_relay_requests_monthly": 100,
"max_hand_executions_monthly": 20,
"max_pipeline_runs_monthly": 5,
}),
is_default: true,
sort_order: 0,
status: "active".into(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
}))
}
/// 获取或创建当月用量记录(原子操作,使用 INSERT ON CONFLICT 防止 TOCTOU 竞态)
pub async fn get_or_create_usage(pool: &PgPool, account_id: &str) -> SaasResult<UsageQuota> {
let now = chrono::Utc::now();
let period_start = now
.with_day(1).unwrap_or(now)
.with_hour(0).unwrap_or(now)
.with_minute(0).unwrap_or(now)
.with_second(0).unwrap_or(now)
.with_nanosecond(0).unwrap_or(now);
// 先尝试获取已有记录
let existing = sqlx::query_as::<_, UsageQuota>(
"SELECT * FROM billing_usage_quotas \
WHERE account_id = $1 AND period_start = $2"
)
.bind(account_id)
.bind(period_start)
.fetch_optional(pool)
.await?;
if let Some(usage) = existing {
return Ok(usage);
}
// 获取当前计划限额
let plan = get_account_plan(pool, account_id).await?;
let limits: PlanLimits = serde_json::from_value(plan.limits.clone())
.unwrap_or_else(|_| PlanLimits::free());
// 计算月末
let period_end = if now.month() == 12 {
now.with_year(now.year() + 1).and_then(|d| d.with_month(1))
} else {
now.with_month(now.month() + 1)
}.unwrap_or(now)
.with_day(1).unwrap_or(now)
.with_hour(0).unwrap_or(now)
.with_minute(0).unwrap_or(now)
.with_second(0).unwrap_or(now)
.with_nanosecond(0).unwrap_or(now);
// 使用 INSERT ON CONFLICT 原子创建(防止并发重复插入)
let id = uuid::Uuid::new_v4().to_string();
let inserted = sqlx::query_as::<_, UsageQuota>(
"INSERT INTO billing_usage_quotas \
(id, account_id, period_start, period_end, \
max_input_tokens, max_output_tokens, max_relay_requests, \
max_hand_executions, max_pipeline_runs) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \
ON CONFLICT (account_id, period_start) DO NOTHING \
RETURNING *"
)
.bind(&id)
.bind(account_id)
.bind(period_start)
.bind(period_end)
.bind(limits.max_input_tokens_monthly)
.bind(limits.max_output_tokens_monthly)
.bind(limits.max_relay_requests_monthly)
.bind(limits.max_hand_executions_monthly)
.bind(limits.max_pipeline_runs_monthly)
.fetch_optional(pool)
.await?;
if let Some(usage) = inserted {
return Ok(usage);
}
// ON CONFLICT 说明另一个并发请求已经创建了,直接查询返回
let usage = sqlx::query_as::<_, UsageQuota>(
"SELECT * FROM billing_usage_quotas \
WHERE account_id = $1 AND period_start = $2"
)
.bind(account_id)
.bind(period_start)
.fetch_one(pool)
.await?;
Ok(usage)
}
/// 增加用量计数Relay 请求tokens + relay_requests +1
///
/// 在 relay handler 响应成功后直接调用,实现实时配额更新。
/// 聚合器 `AggregateUsageWorker` 每小时做一次对账修正。
pub async fn increment_usage(
pool: &PgPool,
account_id: &str,
input_tokens: i64,
output_tokens: i64,
) -> SaasResult<()> {
let usage = get_or_create_usage(pool, account_id).await?;
sqlx::query(
"UPDATE billing_usage_quotas \
SET input_tokens = input_tokens + $1, \
output_tokens = output_tokens + $2, \
relay_requests = relay_requests + 1, \
updated_at = NOW() \
WHERE id = $3"
)
.bind(input_tokens)
.bind(output_tokens)
.bind(&usage.id)
.execute(pool)
.await?;
Ok(())
}
/// 增加单一维度用量计数(单次 +1
///
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
pub async fn increment_dimension(
pool: &PgPool,
account_id: &str,
dimension: &str,
) -> SaasResult<()> {
let usage = get_or_create_usage(pool, account_id).await?;
match dimension {
"relay_requests" => {
sqlx::query(
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + 1, updated_at = NOW() WHERE id = $1"
).bind(&usage.id).execute(pool).await?;
}
"hand_executions" => {
sqlx::query(
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + 1, updated_at = NOW() WHERE id = $1"
).bind(&usage.id).execute(pool).await?;
}
"pipeline_runs" => {
sqlx::query(
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + 1, updated_at = NOW() WHERE id = $1"
).bind(&usage.id).execute(pool).await?;
}
_ => return Err(crate::error::SaasError::InvalidInput(
format!("Unknown usage dimension: {}", dimension)
)),
}
Ok(())
}
/// 增加单一维度用量计数(批量 +N原子操作替代循环调用
///
/// 使用静态 SQL 分支(白名单),避免动态列名注入风险。
pub async fn increment_dimension_by(
pool: &PgPool,
account_id: &str,
dimension: &str,
count: i32,
) -> SaasResult<()> {
let usage = get_or_create_usage(pool, account_id).await?;
match dimension {
"relay_requests" => {
sqlx::query(
"UPDATE billing_usage_quotas SET relay_requests = relay_requests + $1, updated_at = NOW() WHERE id = $2"
).bind(count).bind(&usage.id).execute(pool).await?;
}
"hand_executions" => {
sqlx::query(
"UPDATE billing_usage_quotas SET hand_executions = hand_executions + $1, updated_at = NOW() WHERE id = $2"
).bind(count).bind(&usage.id).execute(pool).await?;
}
"pipeline_runs" => {
sqlx::query(
"UPDATE billing_usage_quotas SET pipeline_runs = pipeline_runs + $1, updated_at = NOW() WHERE id = $2"
).bind(count).bind(&usage.id).execute(pool).await?;
}
_ => return Err(crate::error::SaasError::InvalidInput(
format!("Unknown usage dimension: {}", dimension)
)),
}
Ok(())
}
/// 检查用量配额
pub async fn check_quota(
pool: &PgPool,
account_id: &str,
quota_type: &str,
) -> SaasResult<QuotaCheck> {
let usage = get_or_create_usage(pool, account_id).await?;
let (current, limit) = match quota_type {
"input_tokens" => (usage.input_tokens, usage.max_input_tokens),
"output_tokens" => (usage.output_tokens, usage.max_output_tokens),
"relay_requests" => (usage.relay_requests as i64, usage.max_relay_requests.map(|v| v as i64)),
"hand_executions" => (usage.hand_executions as i64, usage.max_hand_executions.map(|v| v as i64)),
"pipeline_runs" => (usage.pipeline_runs as i64, usage.max_pipeline_runs.map(|v| v as i64)),
_ => return Ok(QuotaCheck {
allowed: true,
reason: None,
current: 0,
limit: None,
remaining: None,
}),
};
let allowed = limit.map_or(true, |lim| current < lim);
let remaining = limit.map(|lim| (lim - current).max(0));
Ok(QuotaCheck {
allowed,
reason: if !allowed { Some(format!("{} 配额已用尽", quota_type)) } else { None },
current,
limit,
remaining,
})
}

View File

@@ -0,0 +1,161 @@
//! 计费类型定义
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
/// 计费计划定义 — 对应 billing_plans 表
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct BillingPlan {
pub id: String,
pub name: String,
pub display_name: String,
pub description: Option<String>,
pub price_cents: i32,
pub currency: String,
pub interval: String,
pub features: serde_json::Value,
pub limits: serde_json::Value,
pub is_default: bool,
pub sort_order: i32,
pub status: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// 计划限额(从 limits JSON 反序列化)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PlanLimits {
#[serde(default)]
pub max_input_tokens_monthly: Option<i64>,
#[serde(default)]
pub max_output_tokens_monthly: Option<i64>,
#[serde(default)]
pub max_relay_requests_monthly: Option<i32>,
#[serde(default)]
pub max_hand_executions_monthly: Option<i32>,
#[serde(default)]
pub max_pipeline_runs_monthly: Option<i32>,
}
impl PlanLimits {
pub fn free() -> Self {
Self {
max_input_tokens_monthly: Some(500_000),
max_output_tokens_monthly: Some(500_000),
max_relay_requests_monthly: Some(100),
max_hand_executions_monthly: Some(20),
max_pipeline_runs_monthly: Some(5),
}
}
}
/// 账户订阅 — 对应 billing_subscriptions 表
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Subscription {
pub id: String,
pub account_id: String,
pub plan_id: String,
pub status: String,
pub current_period_start: DateTime<Utc>,
pub current_period_end: DateTime<Utc>,
pub trial_end: Option<DateTime<Utc>>,
pub canceled_at: Option<DateTime<Utc>>,
pub cancel_at_period_end: bool,
pub metadata: serde_json::Value,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// 发票 — 对应 billing_invoices 表
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Invoice {
pub id: String,
pub account_id: String,
pub subscription_id: Option<String>,
pub plan_id: Option<String>,
pub amount_cents: i32,
pub currency: String,
pub description: Option<String>,
pub status: String,
pub due_at: Option<DateTime<Utc>>,
pub paid_at: Option<DateTime<Utc>>,
pub voided_at: Option<DateTime<Utc>>,
pub metadata: serde_json::Value,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// 支付记录 — 对应 billing_payments 表
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct Payment {
pub id: String,
pub invoice_id: String,
pub account_id: String,
pub amount_cents: i32,
pub currency: String,
pub method: String,
pub status: String,
pub external_trade_no: Option<String>,
pub paid_at: Option<DateTime<Utc>>,
pub refunded_at: Option<DateTime<Utc>>,
pub failure_reason: Option<String>,
pub metadata: serde_json::Value,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// 月度用量配额 — 对应 billing_usage_quotas 表
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct UsageQuota {
pub id: String,
pub account_id: String,
pub period_start: DateTime<Utc>,
pub period_end: DateTime<Utc>,
pub input_tokens: i64,
pub output_tokens: i64,
pub relay_requests: i32,
pub hand_executions: i32,
pub pipeline_runs: i32,
pub max_input_tokens: Option<i64>,
pub max_output_tokens: Option<i64>,
pub max_relay_requests: Option<i32>,
pub max_hand_executions: Option<i32>,
pub max_pipeline_runs: Option<i32>,
pub metadata: serde_json::Value,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
/// 用量检查结果
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QuotaCheck {
pub allowed: bool,
pub reason: Option<String>,
pub current: i64,
pub limit: Option<i64>,
pub remaining: Option<i64>,
}
/// 支付方式
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PaymentMethod {
Alipay,
Wechat,
}
/// 创建支付请求
#[derive(Debug, Deserialize)]
pub struct CreatePaymentRequest {
pub plan_id: String,
pub payment_method: PaymentMethod,
}
/// 支付结果
#[derive(Debug, Serialize)]
pub struct PaymentResult {
pub payment_id: String,
pub trade_no: String,
pub pay_url: String,
pub amount_cents: i32,
}

View File

@@ -167,6 +167,22 @@ impl AppCache {
self.relay_queue_counts.retain(|k, _| db_keys.contains(k));
}
// ============ 快捷查找Phase 2: 减少关键路径 DB 查询) ============
/// 按 model_id 查找已启用的模型。O(1) DashMap 查找。
pub fn get_model(&self, model_id: &str) -> Option<CachedModel> {
self.models.get(model_id)
.filter(|m| m.enabled)
.map(|r| r.value().clone())
}
/// 按 provider id 查找已启用的 Provider。O(1) DashMap 查找。
pub fn get_provider(&self, provider_id: &str) -> Option<CachedProvider> {
self.providers.get(provider_id)
.filter(|p| p.enabled)
.map(|r| r.value().clone())
}
// ============ 缓存失效 ============
/// 清除 model 缓存中的指定条目Admin CRUD 后调用)

View File

@@ -4,9 +4,15 @@ use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use secrecy::SecretString;
/// 当前期望的配置版本
const CURRENT_CONFIG_VERSION: u32 = 1;
/// SaaS 服务器完整配置
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SaaSConfig {
/// Configuration schema version
#[serde(default = "default_config_version")]
pub config_version: u32,
pub server: ServerConfig,
pub database: DatabaseConfig,
pub auth: AuthConfig,
@@ -15,6 +21,8 @@ pub struct SaaSConfig {
pub rate_limit: RateLimitConfig,
#[serde(default)]
pub scheduler: SchedulerConfig,
#[serde(default)]
pub payment: PaymentConfig,
}
/// Scheduler 定时任务配置
@@ -66,6 +74,30 @@ pub struct ServerConfig {
pub struct DatabaseConfig {
#[serde(default = "default_db_url")]
pub url: String,
/// 连接池最大连接数
#[serde(default = "default_max_connections")]
pub max_connections: u32,
/// 连接池最小连接数
#[serde(default = "default_min_connections")]
pub min_connections: u32,
/// 获取连接超时 (秒)
#[serde(default = "default_acquire_timeout")]
pub acquire_timeout_secs: u64,
/// 空闲连接回收超时 (秒)
#[serde(default = "default_idle_timeout")]
pub idle_timeout_secs: u64,
/// 连接最大生命周期 (秒)
#[serde(default = "default_max_lifetime")]
pub max_lifetime_secs: u64,
/// Worker 并发上限 (Semaphore permits)
#[serde(default = "default_worker_concurrency")]
pub worker_concurrency: usize,
/// 限流事件批量 flush 间隔 (秒)
#[serde(default = "default_rate_limit_batch_interval")]
pub rate_limit_batch_interval_secs: u64,
/// 限流事件批量 flush 最大条目数
#[serde(default = "default_rate_limit_batch_max")]
pub rate_limit_batch_max_size: usize,
}
/// 认证配置
@@ -97,12 +129,21 @@ pub struct RelayConfig {
pub max_attempts: u32,
}
fn default_config_version() -> u32 { 1 }
fn default_host() -> String { "0.0.0.0".into() }
fn default_port() -> u16 { 8080 }
fn default_db_url() -> String { "postgres://localhost:5432/zclaw".into() }
fn default_jwt_hours() -> i64 { 24 }
fn default_totp_issuer() -> String { "ZCLAW SaaS".into() }
fn default_refresh_hours() -> i64 { 168 }
fn default_max_connections() -> u32 { 100 }
fn default_min_connections() -> u32 { 5 }
fn default_acquire_timeout() -> u64 { 8 }
fn default_idle_timeout() -> u64 { 180 }
fn default_max_lifetime() -> u64 { 900 }
fn default_worker_concurrency() -> usize { 20 }
fn default_rate_limit_batch_interval() -> u64 { 5 }
fn default_rate_limit_batch_max() -> usize { 500 }
fn default_max_queue() -> usize { 1000 }
fn default_max_concurrent() -> usize { 5 }
fn default_batch_window() -> u64 { 50 }
@@ -132,15 +173,115 @@ impl Default for RateLimitConfig {
}
}
/// 支付配置
///
/// 支付宝和微信支付商户配置。所有字段通过环境变量传入(不写入 TOML 文件)。
/// 字段缺失时自动降级为 mock 支付模式。
///
/// 注意:自定义 Debug 和 Serialize 实现会隐藏敏感字段。
#[derive(Clone, Serialize, Deserialize)]
pub struct PaymentConfig {
/// 支付宝 App ID来自支付宝开放平台
#[serde(default)]
pub alipay_app_id: Option<String>,
/// 支付宝商户私钥RSA2— 敏感,不序列化
#[serde(default, skip_serializing)]
pub alipay_private_key: Option<String>,
/// 支付宝公钥证书路径(用于验签)
#[serde(default)]
pub alipay_cert_path: Option<String>,
/// 支付宝回调通知 URL
#[serde(default)]
pub alipay_notify_url: Option<String>,
/// 支付宝公钥用于回调验签PEM 格式)— 敏感,不序列化
#[serde(default, skip_serializing)]
pub alipay_public_key: Option<String>,
/// 微信支付商户号
#[serde(default)]
pub wechat_mch_id: Option<String>,
/// 微信支付商户证书序列号
#[serde(default)]
pub wechat_serial_no: Option<String>,
/// 微信支付商户私钥路径
#[serde(default)]
pub wechat_private_key_path: Option<String>,
/// 微信支付 API v3 密钥 — 敏感,不序列化
#[serde(default, skip_serializing)]
pub wechat_api_v3_key: Option<String>,
/// 微信支付回调通知 URL
#[serde(default)]
pub wechat_notify_url: Option<String>,
/// 微信支付 App ID公众号/小程序)
#[serde(default)]
pub wechat_app_id: Option<String>,
}
impl std::fmt::Debug for PaymentConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PaymentConfig")
.field("alipay_app_id", &self.alipay_app_id)
.field("alipay_private_key", &self.alipay_private_key.as_ref().map(|_| "***REDACTED***"))
.field("alipay_cert_path", &self.alipay_cert_path)
.field("alipay_notify_url", &self.alipay_notify_url)
.field("alipay_public_key", &self.alipay_public_key.as_ref().map(|_| "***REDACTED***"))
.field("wechat_mch_id", &self.wechat_mch_id)
.field("wechat_serial_no", &self.wechat_serial_no)
.field("wechat_private_key_path", &self.wechat_private_key_path)
.field("wechat_api_v3_key", &self.wechat_api_v3_key.as_ref().map(|_| "***REDACTED***"))
.field("wechat_notify_url", &self.wechat_notify_url)
.field("wechat_app_id", &self.wechat_app_id)
.finish()
}
}
impl Default for PaymentConfig {
fn default() -> Self {
// 优先从环境变量读取,未配置则降级 mock
Self {
alipay_app_id: std::env::var("ALIPAY_APP_ID").ok(),
alipay_private_key: std::env::var("ALIPAY_PRIVATE_KEY").ok(),
alipay_cert_path: std::env::var("ALIPAY_CERT_PATH").ok(),
alipay_notify_url: std::env::var("ALIPAY_NOTIFY_URL").ok(),
alipay_public_key: std::env::var("ALIPAY_PUBLIC_KEY").ok(),
wechat_mch_id: std::env::var("WECHAT_PAY_MCH_ID").ok(),
wechat_serial_no: std::env::var("WECHAT_PAY_SERIAL_NO").ok(),
wechat_private_key_path: std::env::var("WECHAT_PAY_PRIVATE_KEY_PATH").ok(),
wechat_api_v3_key: std::env::var("WECHAT_PAY_API_V3_KEY").ok(),
wechat_notify_url: std::env::var("WECHAT_PAY_NOTIFY_URL").ok(),
wechat_app_id: std::env::var("WECHAT_PAY_APP_ID").ok(),
}
}
}
impl PaymentConfig {
/// 支付宝是否已完整配置
pub fn alipay_configured(&self) -> bool {
self.alipay_app_id.is_some()
&& self.alipay_private_key.is_some()
&& self.alipay_notify_url.is_some()
}
/// 微信支付是否已完整配置
pub fn wechat_configured(&self) -> bool {
self.wechat_mch_id.is_some()
&& self.wechat_serial_no.is_some()
&& self.wechat_private_key_path.is_some()
&& self.wechat_notify_url.is_some()
}
}
impl Default for SaaSConfig {
fn default() -> Self {
Self {
config_version: 1,
server: ServerConfig::default(),
database: DatabaseConfig::default(),
auth: AuthConfig::default(),
relay: RelayConfig::default(),
rate_limit: RateLimitConfig::default(),
scheduler: SchedulerConfig::default(),
payment: PaymentConfig::default(),
}
}
}
@@ -158,7 +299,17 @@ impl Default for ServerConfig {
impl Default for DatabaseConfig {
fn default() -> Self {
Self { url: default_db_url() }
Self {
url: default_db_url(),
max_connections: default_max_connections(),
min_connections: default_min_connections(),
acquire_timeout_secs: default_acquire_timeout(),
idle_timeout_secs: default_idle_timeout(),
max_lifetime_secs: default_max_lifetime(),
worker_concurrency: default_worker_concurrency(),
rate_limit_batch_interval_secs: default_rate_limit_batch_interval(),
rate_limit_batch_max_size: default_rate_limit_batch_max(),
}
}
}
@@ -220,6 +371,26 @@ impl SaaSConfig {
SaaSConfig::default()
};
// 配置版本兼容性检查
if config.config_version < CURRENT_CONFIG_VERSION {
tracing::warn!(
"[Config] config_version ({}) is below current version ({}). \
Some features may not work correctly. \
Please update your saas-config.toml. \
See docs for migration guide.",
config.config_version,
CURRENT_CONFIG_VERSION
);
} else if config.config_version > CURRENT_CONFIG_VERSION {
tracing::error!(
"[Config] config_version ({}) is ahead of supported version ({}). \
This server version may not support all configured features. \
Consider upgrading the server.",
config.config_version,
CURRENT_CONFIG_VERSION
);
}
// 环境变量覆盖数据库 URL (避免在配置文件中存储密码)
if let Ok(db_url) = std::env::var("ZCLAW_DATABASE_URL") {
config.database.url = db_url;

View File

@@ -2,34 +2,44 @@
use sqlx::postgres::PgPoolOptions;
use sqlx::PgPool;
use crate::config::DatabaseConfig;
use crate::error::SaasResult;
const SCHEMA_VERSION: i32 = 11;
const SCHEMA_VERSION: i32 = 13;
/// 初始化数据库
pub async fn init_db(database_url: &str) -> SaasResult<PgPool> {
// 连接池大小可通过环境变量配置,默认 100relay 请求每次 10+ 串行查询50 偏紧
pub async fn init_db(config: &DatabaseConfig) -> SaasResult<PgPool> {
// 环境变量覆盖 URL避免在配置文件中存储密码
let database_url = std::env::var("ZCLAW_DATABASE_URL")
.unwrap_or_else(|_| config.url.clone());
// 环境变量覆盖连接数(向后兼容)
let max_connections: u32 = std::env::var("ZCLAW_DB_MAX_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(100);
.unwrap_or(config.max_connections);
let min_connections: u32 = std::env::var("ZCLAW_DB_MIN_CONNECTIONS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5);
.unwrap_or(config.min_connections);
tracing::info!("Database pool: max={}, min={}", max_connections, min_connections);
tracing::info!(
"Database pool: max={}, min={}, acquire_timeout={}s, idle_timeout={}s, max_lifetime={}s",
max_connections, min_connections,
config.acquire_timeout_secs, config.idle_timeout_secs, config.max_lifetime_secs
);
let pool = PgPoolOptions::new()
.max_connections(max_connections)
.min_connections(min_connections)
.acquire_timeout(std::time::Duration::from_secs(8))
.idle_timeout(std::time::Duration::from_secs(180))
.max_lifetime(std::time::Duration::from_secs(900))
.connect(database_url)
.acquire_timeout(std::time::Duration::from_secs(config.acquire_timeout_secs))
.idle_timeout(std::time::Duration::from_secs(config.idle_timeout_secs))
.max_lifetime(std::time::Duration::from_secs(config.max_lifetime_secs))
.connect(&database_url)
.await?;
run_migrations(&pool).await?;
ensure_security_columns(&pool).await?;
seed_admin_account(&pool).await?;
seed_builtin_prompts(&pool).await?;
seed_demo_data(&pool).await?;
@@ -884,6 +894,56 @@ async fn fix_seed_data(pool: &PgPool) -> SaasResult<()> {
Ok(())
}
/// 防御性检查:确保安全审计新增的列存在(即使 schema_version 显示已是最新)
///
/// 场景:旧数据库的 schema_version 已被手动更新但迁移文件未实际执行,
/// 或者迁移文件在 version check 时被跳过。
async fn ensure_security_columns(pool: &PgPool) -> SaasResult<()> {
// 检查 password_version 列是否存在
let col_exists: bool = sqlx::query_scalar(
"SELECT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name = 'accounts' AND column_name = 'password_version')"
)
.fetch_one(pool)
.await
.unwrap_or(false);
if !col_exists {
tracing::warn!("[DB] 'password_version' column missing — applying security fix migration");
sqlx::query("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS password_version INTEGER NOT NULL DEFAULT 1")
.execute(pool).await?;
sqlx::query("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS failed_login_count INTEGER NOT NULL DEFAULT 0")
.execute(pool).await?;
sqlx::query("ALTER TABLE accounts ADD COLUMN IF NOT EXISTS locked_until TIMESTAMPTZ")
.execute(pool).await?;
tracing::info!("[DB] Security columns (password_version, failed_login_count, locked_until) applied");
}
// 检查 rate_limit_events 表是否存在
let table_exists: bool = sqlx::query_scalar(
"SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'rate_limit_events')"
)
.fetch_one(pool)
.await
.unwrap_or(false);
if !table_exists {
tracing::warn!("[DB] 'rate_limit_events' table missing — applying rate limit migration");
if let Err(e) = sqlx::query(
"CREATE TABLE IF NOT EXISTS rate_limit_events (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
key TEXT NOT NULL,
count BIGINT NOT NULL DEFAULT 1,
window_start TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)"
).execute(pool).await {
tracing::warn!("[DB] Failed to create rate_limit_events: {}", e);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
// PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容

View File

@@ -0,0 +1,591 @@
//! 知识库 HTTP 处理器
use axum::{
extract::{Extension, Path, Query, State},
Json,
};
use crate::auth::types::AuthContext;
use crate::error::{SaasError, SaasResult};
use crate::state::AppState;
use super::service;
use super::types::*;
// === 分类管理 ===
/// GET /api/v1/knowledge/categories — 树形分类列表
pub async fn list_categories(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<Vec<CategoryResponse>>> {
check_permission(&ctx, "knowledge:read")?;
let tree = service::list_categories_tree(&state.db).await?;
Ok(Json(tree))
}
/// POST /api/v1/knowledge/categories — 创建分类
pub async fn create_category(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateCategoryRequest>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:write")?;
if req.name.trim().is_empty() {
return Err(SaasError::InvalidInput("分类名称不能为空".into()));
}
let cat = service::create_category(
&state.db,
req.name.trim(),
req.description.as_deref(),
req.parent_id.as_deref(),
req.icon.as_deref(),
).await?;
Ok(Json(serde_json::json!({
"id": cat.id,
"name": cat.name,
})))
}
/// PUT /api/v1/knowledge/categories/:id — 更新分类
pub async fn update_category(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(id): Path<String>,
Json(req): Json<UpdateCategoryRequest>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:write")?;
if let Some(ref name) = req.name {
if name.trim().is_empty() {
return Err(SaasError::InvalidInput("分类名称不能为空".into()));
}
}
let cat = service::update_category(
&state.db,
&id,
req.name.as_deref().map(|n| n.trim()),
req.description.as_deref(),
req.parent_id.as_deref(),
req.icon.as_deref(),
).await?;
Ok(Json(serde_json::json!({
"id": cat.id,
"name": cat.name,
"updated": true,
})))
}
/// DELETE /api/v1/knowledge/categories/:id — 删除分类
pub async fn delete_category(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(id): Path<String>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:admin")?;
service::delete_category(&state.db, &id).await?;
Ok(Json(serde_json::json!({"deleted": true})))
}
/// GET /api/v1/knowledge/categories/:id/items — 分类下条目列表
pub async fn list_category_items(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(id): Path<String>,
Query(query): Query<ListItemsQuery>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:read")?;
let page = query.page.unwrap_or(1).max(1);
let page_size = query.page_size.unwrap_or(20).max(1).min(100);
let status_filter = query.status.as_deref().unwrap_or("active");
let (items, total) = service::list_items_by_category(
&state.db,
&id,
status_filter,
page,
page_size,
).await?;
Ok(Json(serde_json::json!({
"items": items,
"total": total,
"page": page,
"page_size": page_size,
})))
}
// === 知识条目 CRUD ===
/// GET /api/v1/knowledge/items — 分页列表
pub async fn list_items(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Query(query): Query<ListItemsQuery>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:read")?;
let page = query.page.unwrap_or(1).max(1).min(10000);
let page_size = query.page_size.unwrap_or(20).max(1).min(100);
let offset = (page - 1) * page_size;
// 转义 ILIKE 通配符,防止用户输入的 % 和 _ 被当作通配符
let keyword = query.keyword.as_ref().map(|k| {
k.replace('\\', "\\\\").replace('%', "\\%").replace('_', "\\_")
});
let items: Vec<KnowledgeItem> = sqlx::query_as(
"SELECT ki.* FROM knowledge_items ki \
JOIN knowledge_categories kc ON ki.category_id = kc.id \
WHERE ($1::text IS NULL OR ki.category_id = $1) \
AND ($2::text IS NULL OR ki.status = $2) \
AND ($3::text IS NULL OR ki.title ILIKE '%' || $3 || '%') \
ORDER BY ki.priority DESC, ki.updated_at DESC \
LIMIT $4 OFFSET $5"
)
.bind(&query.category_id)
.bind(&query.status)
.bind(&keyword)
.bind(page_size)
.bind(offset)
.fetch_all(&state.db)
.await?;
let total: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_items ki \
WHERE ($1::text IS NULL OR ki.category_id = $1) \
AND ($2::text IS NULL OR ki.status = $2) \
AND ($3::text IS NULL OR ki.title ILIKE '%' || $3 || '%')"
)
.bind(&query.category_id)
.bind(&query.status)
.bind(&keyword)
.fetch_one(&state.db)
.await?;
Ok(Json(serde_json::json!({
"items": items,
"total": total.0,
"page": page,
"page_size": page_size,
})))
}
/// POST /api/v1/knowledge/items — 创建条目
pub async fn create_item(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<CreateItemRequest>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:write")?;
if req.title.trim().is_empty() || req.content.trim().is_empty() {
return Err(SaasError::InvalidInput("标题和内容不能为空".into()));
}
if req.content.len() > 100_000 {
return Err(SaasError::InvalidInput("内容不能超过 100KB".into()));
}
let item = service::create_item(&state.db, &ctx.account_id, &req).await?;
// 异步触发 embedding 生成
if let Err(e) = state.worker_dispatcher.dispatch(
"generate_embedding",
serde_json::json!({ "item_id": item.id }),
).await {
tracing::warn!("Failed to dispatch embedding generation: {}", e);
}
Ok(Json(serde_json::json!({
"id": item.id,
"title": item.title,
"version": item.version,
})))
}
/// POST /api/v1/knowledge/items/batch — 批量创建
pub async fn batch_create_items(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(items): Json<Vec<CreateItemRequest>>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:write")?;
if items.len() > 50 {
return Err(SaasError::InvalidInput("单次批量创建不能超过 50 条".into()));
}
let mut created = Vec::new();
for req in &items {
if req.title.trim().is_empty() || req.content.trim().is_empty() {
tracing::warn!("Batch create: skipping item with empty title or content");
continue;
}
if req.content.len() > 100_000 {
tracing::warn!("Batch create: skipping item '{}' (content too long)", req.title);
continue;
}
match service::create_item(&state.db, &ctx.account_id, req).await {
Ok(item) => {
let _ = state.worker_dispatcher.dispatch(
"generate_embedding",
serde_json::json!({ "item_id": item.id }),
).await.map_err(|e| {
tracing::warn!("[Knowledge] Failed to dispatch embedding for item {}: {}", item.id, e);
e
});
created.push(item.id);
}
Err(e) => {
tracing::warn!("Batch create item failed: {}", e);
}
}
}
Ok(Json(serde_json::json!({
"created_count": created.len(),
"ids": created,
})))
}
/// GET /api/v1/knowledge/items/:id — 条目详情
pub async fn get_item(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(id): Path<String>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:read")?;
let item = service::get_item(&state.db, &id).await?
.ok_or_else(|| SaasError::NotFound("知识条目不存在".into()))?;
Ok(Json(serde_json::to_value(item).unwrap_or_default()))
}
/// PUT /api/v1/knowledge/items/:id — 更新条目
pub async fn update_item(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(id): Path<String>,
Json(req): Json<UpdateItemRequest>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:write")?;
if let Some(ref content) = req.content {
if content.len() > 100_000 {
return Err(SaasError::InvalidInput("内容不能超过 100KB".into()));
}
}
let updated = service::update_item(&state.db, &id, &ctx.account_id, &req).await?;
// 触发 re-embedding
if let Err(e) = state.worker_dispatcher.dispatch(
"generate_embedding",
serde_json::json!({ "item_id": id }),
).await {
tracing::warn!("[Knowledge] Failed to dispatch re-embedding for item {}: {}", id, e);
}
Ok(Json(serde_json::json!({
"id": updated.id,
"version": updated.version,
})))
}
/// DELETE /api/v1/knowledge/items/:id — 删除条目
pub async fn delete_item(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(id): Path<String>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:admin")?;
service::delete_item(&state.db, &id).await?;
Ok(Json(serde_json::json!({"deleted": true})))
}
// === 版本控制 ===
/// GET /api/v1/knowledge/items/:id/versions
pub async fn list_versions(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path(id): Path<String>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:read")?;
let versions: Vec<KnowledgeVersion> = sqlx::query_as(
"SELECT * FROM knowledge_versions WHERE item_id = $1 ORDER BY version DESC"
)
.bind(&id)
.fetch_all(&state.db)
.await?;
Ok(Json(serde_json::json!({"versions": versions})))
}
/// GET /api/v1/knowledge/items/:id/versions/:v
pub async fn get_version(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path((id, v)): Path<(String, i32)>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:read")?;
let version: KnowledgeVersion = sqlx::query_as(
"SELECT * FROM knowledge_versions WHERE item_id = $1 AND version = $2"
)
.bind(&id)
.bind(v)
.fetch_optional(&state.db)
.await?
.ok_or_else(|| SaasError::NotFound("版本不存在".into()))?;
Ok(Json(serde_json::to_value(version).unwrap_or_default()))
}
/// POST /api/v1/knowledge/items/:id/rollback/:v
pub async fn rollback_version(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Path((id, v)): Path<(String, i32)>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:admin")?;
let updated = service::rollback_version(&state.db, &id, v, &ctx.account_id).await?;
// 触发 re-embedding
if let Err(e) = state.worker_dispatcher.dispatch(
"generate_embedding",
serde_json::json!({ "item_id": id }),
).await {
tracing::warn!("[Knowledge] Failed to dispatch re-embedding after rollback for item {}: {}", id, e);
}
Ok(Json(serde_json::json!({
"id": updated.id,
"version": updated.version,
"rolled_back_to": v,
})))
}
// === 检索 ===
/// POST /api/v1/knowledge/search — 语义搜索
pub async fn search(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<SearchRequest>,
) -> SaasResult<Json<Vec<SearchResult>>> {
check_permission(&ctx, "knowledge:search")?;
let limit = req.limit.unwrap_or(5).min(10);
let min_score = req.min_score.unwrap_or(0.5);
let results = service::search(
&state.db,
&req.query,
req.category_id.as_deref(),
limit,
min_score,
).await?;
Ok(Json(results))
}
/// POST /api/v1/knowledge/recommend — 关联推荐
pub async fn recommend(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<SearchRequest>,
) -> SaasResult<Json<Vec<SearchResult>>> {
check_permission(&ctx, "knowledge:search")?;
let limit = req.limit.unwrap_or(5).min(10);
let results = service::search(
&state.db,
&req.query,
req.category_id.as_deref(),
limit,
0.3,
).await?;
Ok(Json(results))
}
// === 分析看板 ===
/// GET /api/v1/knowledge/analytics/overview
pub async fn analytics_overview(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<AnalyticsOverview>> {
check_permission(&ctx, "knowledge:read")?;
let overview = service::analytics_overview(&state.db).await?;
Ok(Json(overview))
}
/// GET /api/v1/knowledge/analytics/trends
pub async fn analytics_trends(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:read")?;
// 使用 serde_json::Value 行来避免 PgRow 序列化
let trends: Vec<(serde_json::Value,)> = sqlx::query_as(
"SELECT json_build_object(
'date', DATE(created_at),
'count', COUNT(*),
'injected_count', SUM(CASE WHEN was_injected THEN 1 ELSE 0 END)
) as row \
FROM knowledge_usage \
WHERE created_at >= NOW() - interval '30 days' \
GROUP BY DATE(created_at) ORDER BY DATE(created_at)"
)
.fetch_all(&state.db)
.await
.unwrap_or_default();
let trends: Vec<serde_json::Value> = trends.into_iter().map(|(v,)| v).collect();
Ok(Json(serde_json::json!({"trends": trends})))
}
/// GET /api/v1/knowledge/analytics/top-items
pub async fn analytics_top_items(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:read")?;
let items: Vec<(serde_json::Value,)> = sqlx::query_as(
"SELECT json_build_object(
'id', ki.id,
'title', ki.title,
'category', kc.name,
'ref_count', COUNT(ku.id)
) as row \
FROM knowledge_items ki \
JOIN knowledge_categories kc ON ki.category_id = kc.id \
LEFT JOIN knowledge_usage ku ON ku.item_id = ki.id \
WHERE ki.status = 'active' \
GROUP BY ki.id, ki.title, kc.name \
ORDER BY COUNT(ku.id) DESC LIMIT 20"
)
.fetch_all(&state.db)
.await
.unwrap_or_default();
let items: Vec<serde_json::Value> = items.into_iter().map(|(v,)| v).collect();
Ok(Json(serde_json::json!({"items": items})))
}
/// GET /api/v1/knowledge/analytics/quality
pub async fn analytics_quality(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:read")?;
let quality = service::analytics_quality(&state.db).await?;
Ok(Json(quality))
}
/// GET /api/v1/knowledge/analytics/gaps
pub async fn analytics_gaps(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:read")?;
let gaps = service::analytics_gaps(&state.db).await?;
Ok(Json(gaps))
}
// === 批量操作 ===
/// PATCH /api/v1/knowledge/categories/reorder — 批量排序
pub async fn reorder_categories(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(items): Json<Vec<ReorderItem>>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:write")?;
if items.is_empty() {
return Ok(Json(serde_json::json!({"reordered": false, "count": 0})));
}
if items.len() > 100 {
return Err(SaasError::InvalidInput("单次排序不能超过 100 个".into()));
}
// 使用事务保证原子性
let mut tx = state.db.begin().await?;
for item in &items {
sqlx::query("UPDATE knowledge_categories SET sort_order = $1, updated_at = NOW() WHERE id = $2")
.bind(item.sort_order)
.bind(&item.id)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
Ok(Json(serde_json::json!({"reordered": true, "count": items.len()})))
}
/// POST /api/v1/knowledge/items/import — Markdown 文件导入
pub async fn import_items(
State(state): State<AppState>,
Extension(ctx): Extension<AuthContext>,
Json(req): Json<ImportRequest>,
) -> SaasResult<Json<serde_json::Value>> {
check_permission(&ctx, "knowledge:write")?;
if req.files.len() > 20 {
return Err(SaasError::InvalidInput("单次导入不能超过 20 个文件".into()));
}
let mut created = Vec::new();
for file in &req.files {
// 内容长度检查(数据库限制 100KB
if file.content.len() > 100_000 {
tracing::warn!("跳过文件 '{}': 内容超长 ({} bytes)", file.title.as_deref().unwrap_or("未命名"), file.content.len());
continue;
}
// 空内容检查
if file.content.trim().is_empty() {
tracing::warn!("跳过空文件: '{}'", file.title.as_deref().unwrap_or("未命名"));
continue;
}
let title = file.title.clone().unwrap_or_else(|| {
file.content.lines().next()
.map(|l| l.trim_start_matches('#').trim().to_string())
.unwrap_or_else(|| format!("导入条目 {}", created.len() + 1))
});
let item_req = CreateItemRequest {
category_id: req.category_id.clone(),
title,
content: file.content.clone(),
keywords: file.keywords.clone(),
related_questions: None,
priority: None,
tags: file.tags.clone(),
};
match service::create_item(&state.db, &ctx.account_id, &item_req).await {
Ok(item) => {
let _ = state.worker_dispatcher.dispatch(
"generate_embedding",
serde_json::json!({ "item_id": item.id }),
).await.map_err(|e| {
tracing::warn!("[Knowledge] Failed to dispatch embedding for item {}: {}", item.id, e);
e
});
created.push(item.id);
}
Err(e) => {
tracing::warn!("Import item '{}' failed: {}", item_req.title, e);
}
}
}
Ok(Json(serde_json::json!({
"created_count": created.len(),
"ids": created,
})))
}
// === 辅助函数 ===
fn check_permission(ctx: &AuthContext, permission: &str) -> SaasResult<()> {
crate::auth::handlers::check_permission(ctx, permission)
}

View File

@@ -0,0 +1,39 @@
//! 知识库模块 — 行业知识管理、RAG 检索、版本控制
pub mod types;
pub mod service;
pub mod handlers;
use axum::routing::{delete, get, patch, post, put};
pub fn routes() -> axum::Router<crate::state::AppState> {
axum::Router::new()
// 分类管理
.route("/api/v1/knowledge/categories", get(handlers::list_categories))
.route("/api/v1/knowledge/categories", post(handlers::create_category))
.route("/api/v1/knowledge/categories/{id}", put(handlers::update_category))
.route("/api/v1/knowledge/categories/{id}", delete(handlers::delete_category))
.route("/api/v1/knowledge/categories/{id}/items", get(handlers::list_category_items))
.route("/api/v1/knowledge/categories/reorder", patch(handlers::reorder_categories))
// 知识条目 CRUD
.route("/api/v1/knowledge/items", get(handlers::list_items))
.route("/api/v1/knowledge/items", post(handlers::create_item))
.route("/api/v1/knowledge/items/batch", post(handlers::batch_create_items))
.route("/api/v1/knowledge/items/import", post(handlers::import_items))
.route("/api/v1/knowledge/items/{id}", get(handlers::get_item))
.route("/api/v1/knowledge/items/{id}", put(handlers::update_item))
.route("/api/v1/knowledge/items/{id}", delete(handlers::delete_item))
// 版本控制
.route("/api/v1/knowledge/items/{id}/versions", get(handlers::list_versions))
.route("/api/v1/knowledge/items/{id}/versions/{v}", get(handlers::get_version))
.route("/api/v1/knowledge/items/{id}/rollback/{v}", post(handlers::rollback_version))
// 检索
.route("/api/v1/knowledge/search", post(handlers::search))
.route("/api/v1/knowledge/recommend", post(handlers::recommend))
// 分析看板
.route("/api/v1/knowledge/analytics/overview", get(handlers::analytics_overview))
.route("/api/v1/knowledge/analytics/trends", get(handlers::analytics_trends))
.route("/api/v1/knowledge/analytics/top-items", get(handlers::analytics_top_items))
.route("/api/v1/knowledge/analytics/quality", get(handlers::analytics_quality))
.route("/api/v1/knowledge/analytics/gaps", get(handlers::analytics_gaps))
}

View File

@@ -0,0 +1,783 @@
//! 知识库服务层 — CRUD、检索、分析
use sqlx::PgPool;
use crate::error::SaasResult;
use super::types::*;
// === 分类管理 ===
/// 获取分类树(带条目计数)
pub async fn list_categories_tree(pool: &PgPool) -> SaasResult<Vec<CategoryResponse>> {
let categories: Vec<KnowledgeCategory> = sqlx::query_as(
"SELECT * FROM knowledge_categories ORDER BY sort_order, name"
)
.fetch_all(pool)
.await?;
// 获取每个分类的条目计数
let counts: Vec<(String, i64)> = sqlx::query_as(
"SELECT category_id, COUNT(*) FROM knowledge_items WHERE status = 'active' GROUP BY category_id"
)
.fetch_all(pool)
.await?;
let count_map: std::collections::HashMap<String, i64> = counts.into_iter().collect();
// 构建树形结构
let mut roots = Vec::new();
let mut all: Vec<CategoryResponse> = categories.into_iter().map(|c| {
let count = *count_map.get(&c.id).unwrap_or(&0);
CategoryResponse {
id: c.id,
name: c.name,
description: c.description,
parent_id: c.parent_id,
icon: c.icon,
sort_order: c.sort_order,
item_count: count,
children: Vec::new(),
created_at: c.created_at.to_rfc3339(),
updated_at: c.updated_at.to_rfc3339(),
}
}).collect();
// 构建子节点映射
let mut children_map: std::collections::HashMap<String, Vec<CategoryResponse>> =
std::collections::HashMap::new();
for cat in all.drain(..) {
if let Some(ref parent_id) = cat.parent_id {
children_map.entry(parent_id.clone()).or_default().push(cat);
} else {
roots.push(cat);
}
}
// 递归填充子节点
fn fill_children(
cats: &mut Vec<CategoryResponse>,
children_map: &mut std::collections::HashMap<String, Vec<CategoryResponse>>,
) {
for cat in cats.iter_mut() {
if let Some(children) = children_map.remove(&cat.id) {
cat.children = children;
fill_children(&mut cat.children, children_map);
}
// 累加子节点条目数到父节点
let child_count: i64 = cat.children.iter().map(|c| c.item_count).sum();
cat.item_count += child_count;
}
}
fill_children(&mut roots, &mut children_map);
Ok(roots)
}
/// 创建分类
pub async fn create_category(
pool: &PgPool,
name: &str,
description: Option<&str>,
parent_id: Option<&str>,
icon: Option<&str>,
) -> SaasResult<KnowledgeCategory> {
// 验证 parent_id 存在性
if let Some(pid) = parent_id {
let exists: bool = sqlx::query_scalar(
"SELECT EXISTS(SELECT 1 FROM knowledge_categories WHERE id = $1)"
)
.bind(pid)
.fetch_one(pool)
.await?;
if !exists {
return Err(crate::error::SaasError::InvalidInput(
format!("父分类 '{}' 不存在", pid),
));
}
}
let id = uuid::Uuid::new_v4().to_string();
let category = sqlx::query_as::<_, KnowledgeCategory>(
"INSERT INTO knowledge_categories (id, name, description, parent_id, icon) \
VALUES ($1, $2, $3, $4, $5) RETURNING *"
)
.bind(&id)
.bind(name)
.bind(description)
.bind(parent_id)
.bind(icon)
.fetch_one(pool)
.await?;
Ok(category)
}
/// 删除分类(有子分类或条目时拒绝)
pub async fn delete_category(pool: &PgPool, category_id: &str) -> SaasResult<()> {
// 检查子分类
let child_count: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_categories WHERE parent_id = $1"
)
.bind(category_id)
.fetch_one(pool)
.await?;
if child_count.0 > 0 {
return Err(crate::error::SaasError::InvalidInput(
"该分类下有子分类,无法删除".into(),
));
}
// 检查条目
let item_count: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_items WHERE category_id = $1"
)
.bind(category_id)
.fetch_one(pool)
.await?;
if item_count.0 > 0 {
return Err(crate::error::SaasError::InvalidInput(
"该分类下有知识条目,无法删除".into(),
));
}
let result = sqlx::query("DELETE FROM knowledge_categories WHERE id = $1")
.bind(category_id)
.execute(pool)
.await?;
if result.rows_affected() == 0 {
return Err(crate::error::SaasError::NotFound("分类不存在".into()));
}
Ok(())
}
/// 更新分类(含循环引用检测 + 深度限制)
pub async fn update_category(
pool: &PgPool,
category_id: &str,
name: Option<&str>,
description: Option<&str>,
parent_id: Option<&str>,
icon: Option<&str>,
) -> SaasResult<KnowledgeCategory> {
if let Some(pid) = parent_id {
if pid == category_id {
return Err(crate::error::SaasError::InvalidInput(
"分类不能成为自身的子分类".into(),
));
}
// 检查新的父级不是当前分类的后代(循环检测)
let mut check_id = pid.to_string();
let mut depth = 0;
loop {
if check_id == category_id {
return Err(crate::error::SaasError::InvalidInput(
"循环引用:父级分类不能是当前分类的后代".into(),
));
}
let parent: Option<(Option<String>,)> = sqlx::query_as(
"SELECT parent_id FROM knowledge_categories WHERE id = $1"
)
.bind(&check_id)
.fetch_optional(pool)
.await?;
match parent {
Some((Some(gp),)) => {
check_id = gp;
depth += 1;
if depth > 10 { break; }
}
_ => break,
}
}
// 检查深度限制(最多 3 层)
let mut current_depth = 0;
let mut check = pid.to_string();
while let Some((Some(p),)) = sqlx::query_as::<_, (Option<String>,)>(
"SELECT parent_id FROM knowledge_categories WHERE id = $1"
)
.bind(&check)
.fetch_optional(pool)
.await?
{
check = p;
current_depth += 1;
if current_depth > 10 { break; }
}
if current_depth >= 3 {
return Err(crate::error::SaasError::InvalidInput(
"分类层级不能超过 3 层".into(),
));
}
}
let category = sqlx::query_as::<_, KnowledgeCategory>(
"UPDATE knowledge_categories SET \
name = COALESCE($1, name), \
description = COALESCE($2, description), \
parent_id = COALESCE($3, parent_id), \
icon = COALESCE($4, icon), \
updated_at = NOW() \
WHERE id = $5 RETURNING *"
)
.bind(name)
.bind(description)
.bind(parent_id)
.bind(icon)
.bind(category_id)
.fetch_optional(pool)
.await?
.ok_or_else(|| crate::error::SaasError::NotFound("分类不存在".into()))?;
Ok(category)
}
// === 知识条目 CRUD ===
/// 按分类分页查询条目列表
pub async fn list_items_by_category(
pool: &PgPool,
category_id: &str,
status_filter: &str,
page: i64,
page_size: i64,
) -> SaasResult<(Vec<KnowledgeItem>, i64)> {
let offset = (page - 1) * page_size;
let items: Vec<KnowledgeItem> = sqlx::query_as(
"SELECT * FROM knowledge_items \
WHERE category_id = $1 AND status = $2 \
ORDER BY priority DESC, updated_at DESC \
LIMIT $3 OFFSET $4"
)
.bind(category_id)
.bind(status_filter)
.bind(page_size)
.bind(offset)
.fetch_all(pool)
.await?;
let total: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_items WHERE category_id = $1 AND status = $2"
)
.bind(category_id)
.bind(status_filter)
.fetch_one(pool)
.await?;
Ok((items, total.0))
}
/// 创建知识条目
pub async fn create_item(
pool: &PgPool,
account_id: &str,
req: &CreateItemRequest,
) -> SaasResult<KnowledgeItem> {
let id = uuid::Uuid::new_v4().to_string();
let keywords = req.keywords.as_deref().unwrap_or(&[]);
let related_questions = req.related_questions.as_deref().unwrap_or(&[]);
let priority = req.priority.unwrap_or(0);
let tags = req.tags.as_deref().unwrap_or(&[]);
// 验证 category_id 存在性
let cat_exists: bool = sqlx::query_scalar(
"SELECT EXISTS(SELECT 1 FROM knowledge_categories WHERE id = $1)"
)
.bind(&req.category_id)
.fetch_one(pool)
.await?;
if !cat_exists {
return Err(crate::error::SaasError::InvalidInput(
format!("分类 '{}' 不存在", req.category_id),
));
}
// 使用事务保证 item + version 原子性
let mut tx = pool.begin().await?;
let item = sqlx::query_as::<_, KnowledgeItem>(
"INSERT INTO knowledge_items \
(id, category_id, title, content, keywords, related_questions, priority, tags, created_by) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) \
RETURNING *"
)
.bind(&id)
.bind(&req.category_id)
.bind(&req.title)
.bind(&req.content)
.bind(keywords)
.bind(related_questions)
.bind(priority)
.bind(tags)
.bind(account_id)
.fetch_one(&mut *tx)
.await?;
// 创建初始版本快照
let version_id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO knowledge_versions \
(id, item_id, version, title, content, keywords, related_questions, created_by) \
VALUES ($1, $2, 1, $3, $4, $5, $6, $7)"
)
.bind(&version_id)
.bind(&id)
.bind(&req.title)
.bind(&req.content)
.bind(keywords)
.bind(related_questions)
.bind(account_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(item)
}
/// 获取条目详情
pub async fn get_item(pool: &PgPool, item_id: &str) -> SaasResult<Option<KnowledgeItem>> {
let item = sqlx::query_as::<_, KnowledgeItem>(
"SELECT * FROM knowledge_items WHERE id = $1"
)
.bind(item_id)
.fetch_optional(pool)
.await?;
Ok(item)
}
/// 更新条目(含版本快照)— 事务保护防止并发竞态
pub async fn update_item(
pool: &PgPool,
item_id: &str,
account_id: &str,
req: &UpdateItemRequest,
) -> SaasResult<KnowledgeItem> {
// status 验证在事务之前,避免无谓锁占用
const VALID_STATUSES: &[&str] = &["active", "draft", "archived", "deprecated"];
if let Some(ref status) = &req.status {
if !VALID_STATUSES.contains(&status.as_str()) {
return Err(crate::error::SaasError::InvalidInput(
format!("无效的状态值: {},有效值: {}", status, VALID_STATUSES.join(", "))
));
}
}
let mut tx = pool.begin().await?;
// 获取当前条目并锁定行防止并发修改
let current = sqlx::query_as::<_, KnowledgeItem>(
"SELECT * FROM knowledge_items WHERE id = $1 FOR UPDATE"
)
.bind(item_id)
.fetch_optional(&mut *tx)
.await?
.ok_or_else(|| crate::error::SaasError::NotFound("知识条目不存在".into()))?;
// 合并更新
let title = req.title.as_deref().unwrap_or(&current.title);
let content = req.content.as_deref().unwrap_or(&current.content);
let keywords: Vec<String> = req.keywords.as_ref()
.or(Some(&current.keywords))
.unwrap_or(&vec![])
.clone();
let related_questions: Vec<String> = req.related_questions.as_ref()
.or(Some(&current.related_questions))
.unwrap_or(&vec![])
.clone();
let priority = req.priority.unwrap_or(current.priority);
let tags: Vec<String> = req.tags.as_ref()
.or(Some(&current.tags))
.unwrap_or(&vec![])
.clone();
// 更新条目
let updated = sqlx::query_as::<_, KnowledgeItem>(
"UPDATE knowledge_items SET \
title = $1, content = $2, keywords = $3, related_questions = $4, \
priority = $5, tags = $6, status = COALESCE($7, status), \
version = version + 1, updated_at = NOW() \
WHERE id = $8 RETURNING *"
)
.bind(title)
.bind(content)
.bind(&keywords)
.bind(&related_questions)
.bind(priority)
.bind(&tags)
.bind(req.status.as_deref())
.bind(item_id)
.fetch_one(&mut *tx)
.await?;
// 创建版本快照
let version_id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO knowledge_versions \
(id, item_id, version, title, content, keywords, related_questions, \
change_summary, created_by) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
)
.bind(&version_id)
.bind(item_id)
.bind(updated.version)
.bind(title)
.bind(content)
.bind(&keywords)
.bind(&related_questions)
.bind(req.change_summary.as_deref())
.bind(account_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(updated)
}
/// 删除条目(级联删除 chunks + versions
pub async fn delete_item(pool: &PgPool, item_id: &str) -> SaasResult<()> {
let result = sqlx::query("DELETE FROM knowledge_items WHERE id = $1")
.bind(item_id)
.execute(pool)
.await?;
if result.rows_affected() == 0 {
return Err(crate::error::SaasError::NotFound("知识条目不存在".into()));
}
Ok(())
}
// === 分块 ===
/// 将内容按 Markdown 标题 + 固定长度分块
pub fn chunk_content(content: &str, max_tokens: usize, overlap: usize) -> Vec<String> {
// 先按 Markdown 标题分段
let sections: Vec<&str> = content.split("\n# ").collect();
let mut chunks = Vec::new();
for (i, section) in sections.iter().enumerate() {
// 第一个片段保留原始内容,其余片段重新添加标题标记
let section_content = if i == 0 {
section.to_string()
} else {
format!("# {}", section)
};
// 磁盘估算 token中文约 1.5 字符/token)
let estimated_tokens = section_content.len() / 2;
if estimated_tokens <= max_tokens {
if !section_content.trim().is_empty() {
chunks.push(section_content.trim().to_string());
}
} else {
// 超长段落按固定长度切分
let chars: Vec<char> = section_content.chars().collect();
let chunk_chars = max_tokens * 2; // 近似字符数
let overlap_chars = overlap * 2;
let mut pos = 0;
while pos < chars.len() {
let end = (pos + chunk_chars).min(chars.len());
let chunk_str: String = chars[pos..end].iter().collect();
if !chunk_str.trim().is_empty() {
chunks.push(chunk_str.trim().to_string());
}
pos = if end >= chars.len() { end} else { end.saturating_sub(overlap_chars) };
}
}
}
chunks}
// === 搜索 ===
/// 语义搜索(向量 + 关键词混合)
pub async fn search(
pool: &PgPool,
query: &str,
category_id: Option<&str>,
limit: i64,
min_score: f64,
) -> SaasResult<Vec<SearchResult>> {
// 暂时使用关键词匹配(向量搜索需要 embedding 生成)
let pattern = format!("%{}%", query.replace('\\', "\\\\").replace('%', "\\%").replace('_', "\\_"));
let results = if let Some(cat_id) = category_id {
sqlx::query_as::<_, (String, String, String, String, String, Vec<String>)>(
"SELECT kc.id, kc.item_id, ki.title, kcat.name, kc.content, kc.keywords \
FROM knowledge_chunks kc \
JOIN knowledge_items ki ON kc.item_id = ki.id \
JOIN knowledge_categories kcat ON ki.category_id = kcat.id \
WHERE ki.status = 'active' \
AND ki.category_id = $1 \
AND (kc.content ILIKE $2 OR $3 = ANY(kc.keywords)) \
ORDER BY ki.priority DESC \
LIMIT $4"
)
.bind(cat_id)
.bind(&pattern)
.bind(query)
.bind(limit)
.fetch_all(pool)
.await?
} else {
sqlx::query_as::<_, (String, String, String, String, String, Vec<String>)>(
"SELECT kc.id, kc.item_id, ki.title, kcat.name, kc.content, kc.keywords \
FROM knowledge_chunks kc \
JOIN knowledge_items ki ON kc.item_id = ki.id \
JOIN knowledge_categories kcat ON ki.category_id = kcat.id \
WHERE ki.status = 'active' \
AND (kc.content ILIKE $1 OR $2 = ANY(kc.keywords)) \
ORDER BY ki.priority DESC \
LIMIT $3"
)
.bind(&pattern)
.bind(query)
.bind(limit)
.fetch_all(pool)
.await?
};
Ok(results.into_iter().map(|(chunk_id, item_id, title, cat_name, content, keywords)| {
// 基于关键词匹配数计算分数:匹配数 / 总查询关键词数
let query_keywords: Vec<&str> = query.split_whitespace().collect();
let matched_count = keywords.iter()
.filter(|k| query_keywords.iter().any(|qk| k.to_lowercase().contains(&qk.to_lowercase())))
.count();
let score = if keywords.is_empty() || query_keywords.is_empty() {
0.5
} else {
(matched_count as f64 / keywords.len().max(query_keywords.len()) as f64).min(1.0)
};
SearchResult {
chunk_id,
item_id,
item_title: title,
category_name: cat_name,
content,
score,
keywords,
}
}).filter(|r| r.score >= min_score).collect())
}
// === 分析 ===
/// 分析总览
pub async fn analytics_overview(pool: &PgPool) -> SaasResult<AnalyticsOverview> {
let total_items: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_items"
)
.fetch_one(pool)
.await?;
let active_items: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_items WHERE status = 'active'"
)
.fetch_one(pool)
.await?;
let total_categories: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_categories"
)
.fetch_one(pool)
.await?;
let weekly_new: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_items WHERE created_at >= NOW() - interval '7 days'"
)
.fetch_one(pool)
.await?;
let total_refs: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_usage"
)
.fetch_one(pool)
.await?;
let injected: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_usage WHERE was_injected = true"
)
.fetch_one(pool)
.await?;
let positive: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_usage WHERE agent_feedback = 'positive'"
)
.fetch_one(pool)
.await?;
let with_feedback: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_usage WHERE agent_feedback IS NOT NULL"
)
.fetch_one(pool)
.await?;
let stale: (i64,) = sqlx::query_as(
"SELECT COUNT(*) FROM knowledge_items ki \
WHERE ki.status = 'active' \
AND NOT EXISTS (SELECT 1 FROM knowledge_usage ku WHERE ku.item_id = ki.id AND ku.created_at >= NOW() - interval '90 days')"
)
.fetch_one(pool)
.await?;
let hit_rate = if total_refs.0 > 0 { with_feedback.0 as f64 / total_refs.0 as f64 } else { 0.0 };
let injection_rate = if total_refs.0 > 0 { injected.0 as f64 / total_refs.0 as f64 } else { 0.0 };
let positive_rate = if total_refs.0 > 0 { positive.0 as f64 / total_refs.0 as f64 } else { 0.0 };
Ok(AnalyticsOverview {
total_items: total_items.0,
active_items: active_items.0,
total_categories: total_categories.0,
weekly_new_items: weekly_new.0,
total_references: total_refs.0,
avg_reference_per_item: if total_items.0 > 0 { total_refs.0 as f64 / total_items.0 as f64 } else { 0.0 },
hit_rate,
injection_rate,
positive_feedback_rate: positive_rate,
stale_items_count: stale.0,
})
}
/// 回滚到指定版本(创建新版本快照)
pub async fn rollback_version(
pool: &PgPool,
item_id: &str,
target_version: i32,
account_id: &str,
) -> SaasResult<KnowledgeItem> {
// 使用事务保证原子性,防止并发回滚冲突
let mut tx = pool.begin().await?;
// 获取目标版本
let version: KnowledgeVersion = sqlx::query_as(
"SELECT * FROM knowledge_versions WHERE item_id = $1 AND version = $2"
)
.bind(item_id)
.bind(target_version)
.fetch_optional(&mut *tx)
.await?
.ok_or_else(|| crate::error::SaasError::NotFound("版本不存在".into()))?;
// 锁定当前条目行防止并发修改SELECT FOR UPDATE
let current: Option<(i32,)> = sqlx::query_as(
"SELECT version FROM knowledge_items WHERE id = $1 FOR UPDATE"
)
.bind(item_id)
.fetch_optional(&mut *tx)
.await?;
let current_version = current
.ok_or_else(|| crate::error::SaasError::NotFound("知识条目不存在".into()))?
.0;
// 防止版本无限递增: 最多 100 个版本
if current_version >= 100 {
return Err(crate::error::SaasError::InvalidInput(
"版本数已达上限(100),请考虑合并历史版本".into(),
));
}
let new_version = current_version + 1;
// 更新条目为该版本内容
let updated = sqlx::query_as::<_, KnowledgeItem>(
"UPDATE knowledge_items SET \
title = $1, content = $2, keywords = $3, related_questions = $4, \
version = $5, updated_at = NOW() \
WHERE id = $6 RETURNING *"
)
.bind(&version.title)
.bind(&version.content)
.bind(&version.keywords)
.bind(&version.related_questions)
.bind(new_version)
.bind(item_id)
.fetch_one(&mut *tx)
.await?;
// 创建新版本快照(记录回滚来源)
let version_id = uuid::Uuid::new_v4().to_string();
let summary = format!("回滚到版本 {}(当前版本 {} → 新版本 {}", target_version, current_version, new_version);
sqlx::query(
"INSERT INTO knowledge_versions \
(id, item_id, version, title, content, keywords, related_questions, \
change_summary, created_by) \
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
)
.bind(&version_id)
.bind(item_id)
.bind(new_version)
.bind(&updated.title)
.bind(&updated.content)
.bind(&updated.keywords)
.bind(&updated.related_questions)
.bind(&summary)
.bind(account_id)
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(updated)
}
/// 质量指标(按分类分组)
pub async fn analytics_quality(pool: &PgPool) -> SaasResult<serde_json::Value> {
let quality: Vec<(serde_json::Value,)> = sqlx::query_as(
"SELECT json_build_object(
'category', kc.name,
'total', COUNT(ki.id),
'active', COUNT(CASE WHEN ki.status = 'active' THEN 1 END),
'with_keywords', COUNT(CASE WHEN array_length(ki.keywords, 1) > 0 THEN 1 END),
'avg_priority', COALESCE(AVG(ki.priority), 0)
) as row \
FROM knowledge_categories kc \
LEFT JOIN knowledge_items ki ON ki.category_id = kc.id \
GROUP BY kc.id, kc.name \
ORDER BY COUNT(ki.id) DESC"
)
.fetch_all(pool)
.await
.unwrap_or_else(|e| {
tracing::warn!("analytics_quality query failed: {}", e);
vec![]
});
Ok(serde_json::json!({
"categories": quality.into_iter().map(|(v,)| v).collect::<Vec<_>>()
}))
}
/// 知识缺口检测(低分查询聚类)
pub async fn analytics_gaps(pool: &PgPool) -> SaasResult<serde_json::Value> {
let gaps: Vec<(serde_json::Value,)> = sqlx::query_as(
"SELECT json_build_object(
'query', ku.query_text,
'count', COUNT(*),
'avg_score', COALESCE(AVG(ku.relevance_score), 0)
) as row \
FROM knowledge_usage ku \
WHERE ku.created_at >= NOW() - interval '30 days' \
AND (ku.relevance_score IS NULL OR ku.relevance_score < 0.5) \
AND ku.query_text IS NOT NULL \
GROUP BY ku.query_text \
ORDER BY COUNT(*) DESC \
LIMIT 20"
)
.fetch_all(pool)
.await
.unwrap_or_else(|e| {
tracing::warn!("analytics_gaps query failed: {}", e);
vec![]
});
Ok(serde_json::json!({
"gaps": gaps.into_iter().map(|(v,)| v).collect::<Vec<_>>()
}))
}

View File

@@ -0,0 +1,225 @@
//! 知识库类型定义
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
// === 分类 ===
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct KnowledgeCategory {
pub id: String,
pub name: String,
pub description: Option<String>,
pub parent_id: Option<String>,
pub icon: Option<String>,
pub sort_order: i32,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Deserialize)]
pub struct CreateCategoryRequest {
pub name: String,
pub description: Option<String>,
pub parent_id: Option<String>,
pub icon: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateCategoryRequest {
pub name: Option<String>,
pub description: Option<String>,
pub parent_id: Option<String>,
pub icon: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct CategoryResponse {
pub id: String,
pub name: String,
pub description: Option<String>,
pub parent_id: Option<String>,
pub icon: Option<String>,
pub sort_order: i32,
pub item_count: i64,
pub children: Vec<CategoryResponse>,
pub created_at: String,
pub updated_at: String,
}
// === 知识条目 ===
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct KnowledgeItem {
pub id: String,
pub category_id: String,
pub title: String,
pub content: String,
pub keywords: Vec<String>,
pub related_questions: Vec<String>,
pub priority: i32,
pub status: String,
pub version: i32,
pub source: String,
pub tags: Vec<String>,
pub created_by: String,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
#[derive(Debug, Deserialize)]
pub struct CreateItemRequest {
pub category_id: String,
pub title: String,
pub content: String,
pub keywords: Option<Vec<String>>,
pub related_questions: Option<Vec<String>>,
pub priority: Option<i32>,
pub tags: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateItemRequest {
pub category_id: Option<String>,
pub title: Option<String>,
pub content: Option<String>,
pub keywords: Option<Vec<String>>,
pub related_questions: Option<Vec<String>>,
pub priority: Option<i32>,
pub status: Option<String>,
pub tags: Option<Vec<String>>,
pub change_summary: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct ListItemsQuery {
pub page: Option<i64>,
pub page_size: Option<i64>,
pub category_id: Option<String>,
pub status: Option<String>,
pub keyword: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct ItemResponse {
pub id: String,
pub category_id: String,
pub category_name: String,
pub title: String,
pub content: String,
pub keywords: Vec<String>,
pub related_questions: Vec<String>,
pub priority: i32,
pub status: String,
pub version: i32,
pub source: String,
pub tags: Vec<String>,
pub created_by: String,
pub reference_count: i64,
pub created_at: String,
pub updated_at: String,
}
// === 知识分块 ===
// 注意DB 表含 embedding vector(1536) 列,但当前所有查询均显式指定列,
// 故 struct 暂不映射该字段。若未来使用 SELECT * 需添加 embedding: Option<pgvector::Vector>。
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct KnowledgeChunk {
pub id: String,
pub item_id: String,
pub chunk_index: i32,
pub content: String,
pub keywords: Vec<String>,
pub created_at: DateTime<Utc>,
}
// === 版本快照 ===
#[derive(Debug, Clone, Serialize, Deserialize, sqlx::FromRow)]
pub struct KnowledgeVersion {
pub id: String,
pub item_id: String,
pub version: i32,
pub title: String,
pub content: String,
pub keywords: Vec<String>,
pub related_questions: Vec<String>,
pub change_summary: Option<String>,
pub created_by: String,
pub created_at: DateTime<Utc>,
}
// === 使用追踪 ===
#[derive(Debug, Serialize, Deserialize, sqlx::FromRow)]
pub struct KnowledgeUsage {
pub id: String,
pub item_id: String,
pub chunk_id: Option<String>,
pub session_id: Option<String>,
pub query_text: Option<String>,
pub relevance_score: Option<f64>,
pub was_injected: bool,
pub agent_feedback: Option<String>,
pub created_at: DateTime<Utc>,
}
// === 搜索 ===
#[derive(Debug, Deserialize)]
pub struct SearchRequest {
pub query: String,
pub category_id: Option<String>,
pub limit: Option<i64>,
pub min_score: Option<f64>,
}
#[derive(Debug, Serialize)]
pub struct SearchResult {
pub chunk_id: String,
pub item_id: String,
pub item_title: String,
pub category_name: String,
pub content: String,
pub score: f64,
pub keywords: Vec<String>,
}
// === 分析 ===
#[derive(Debug, Serialize)]
pub struct AnalyticsOverview {
pub total_items: i64,
pub active_items: i64,
pub total_categories: i64,
pub weekly_new_items: i64,
pub total_references: i64,
pub avg_reference_per_item: f64,
pub hit_rate: f64,
pub injection_rate: f64,
pub positive_feedback_rate: f64,
pub stale_items_count: i64,
}
// === 批量操作 ===
#[derive(Debug, Deserialize)]
pub struct ReorderItem {
pub id: String,
pub sort_order: i32,
}
#[derive(Debug, Deserialize)]
pub struct ImportFile {
pub content: String,
pub title: Option<String>,
pub keywords: Option<Vec<String>>,
pub tags: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct ImportRequest {
pub category_id: String,
pub files: Vec<ImportFile>,
}

View File

@@ -25,3 +25,5 @@ pub mod prompt;
pub mod agent_template;
pub mod scheduled_task;
pub mod telemetry;
pub mod billing;
pub mod knowledge;

View File

@@ -11,9 +11,14 @@ use zclaw_saas::workers::cleanup_refresh_tokens::CleanupRefreshTokensWorker;
use zclaw_saas::workers::cleanup_rate_limit::CleanupRateLimitWorker;
use zclaw_saas::workers::record_usage::RecordUsageWorker;
use zclaw_saas::workers::update_last_used::UpdateLastUsedWorker;
use zclaw_saas::workers::aggregate_usage::AggregateUsageWorker;
use zclaw_saas::workers::generate_embedding::GenerateEmbeddingWorker;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Load .env file from project root (walk up from current dir)
load_dotenv();
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
@@ -24,26 +29,36 @@ async fn main() -> anyhow::Result<()> {
let config = SaaSConfig::load()?;
info!("SaaS config loaded: {}:{}", config.server.host, config.server.port);
let db = init_db(&config.database.url).await?;
let db = init_db(&config.database).await?;
info!("Database initialized");
// 创建 Worker spawn 限制器(门控并发 DB 操作数量)
let worker_limiter = zclaw_saas::state::SpawnLimiter::new(
"worker",
config.database.worker_concurrency,
);
info!("Worker spawn limiter: {} permits", config.database.worker_concurrency);
// 初始化 Worker 调度器 + 注册所有 Worker
let mut dispatcher = WorkerDispatcher::new(db.clone());
let mut dispatcher = WorkerDispatcher::new(db.clone(), worker_limiter.clone());
dispatcher.register(LogOperationWorker);
dispatcher.register(CleanupRefreshTokensWorker);
dispatcher.register(CleanupRateLimitWorker);
dispatcher.register(RecordUsageWorker);
dispatcher.register(UpdateLastUsedWorker);
info!("Worker dispatcher initialized (5 workers registered)");
dispatcher.register(AggregateUsageWorker);
dispatcher.register(GenerateEmbeddingWorker);
info!("Worker dispatcher initialized (7 workers registered)");
// 优雅停机令牌 — 取消后所有 SSE 流和长连接立即终止
let shutdown_token = CancellationToken::new();
let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone())?;
let state = AppState::new(db.clone(), config.clone(), dispatcher, shutdown_token.clone(), worker_limiter.clone())?;
// Restore rate limit counts from DB so limits survive server restarts
// 仅恢复最近 60s 的计数(与 middleware 的 60s 滑动窗口一致),避免过于保守的限流
{
let rows: Vec<(String, i64)> = sqlx::query_as(
"SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '1 hour' GROUP BY key"
"SELECT key, SUM(count) FROM rate_limit_events WHERE window_start > NOW() - interval '60 seconds' GROUP BY key"
)
.fetch_all(&db)
.await
@@ -51,18 +66,17 @@ async fn main() -> anyhow::Result<()> {
let mut restored_count = 0usize;
for (key, count) in rows {
let mut entries = Vec::new();
// Approximate: insert count timestamps at "now" — the DashMap will
// expire them naturally via the retain() call in the middleware.
// This is intentionally approximate; exact window alignment is not
// required for rate limiting correctness.
for _ in 0..count as usize {
// 限制恢复计数不超过 RPM 配额,避免重启后过于保守
let rpm = state.rate_limit_rpm() as usize;
let capped = (count as usize).min(rpm);
let mut entries = Vec::with_capacity(capped);
for _ in 0..capped {
entries.push(std::time::Instant::now());
}
state.rate_limit_entries.insert(key, entries);
restored_count += 1;
}
info!("Restored rate limit state from DB: {} keys", restored_count);
info!("Restored rate limit state from DB: {} keys (60s window, capped at RPM)", restored_count);
}
// 迁移旧格式 TOTP secret明文 → 加密 enc: 格式)
@@ -117,20 +131,64 @@ async fn main() -> anyhow::Result<()> {
});
}
let app = build_router(state).await;
// 限流事件批量 flush (可配置间隔,默认 5s)
{
let flush_state = state.clone();
let batch_interval = config.database.rate_limit_batch_interval_secs;
let batch_max = config.database.rate_limit_batch_max_size;
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(batch_interval));
loop {
interval.tick().await;
flush_state.flush_rate_limit_batch(batch_max).await;
}
});
}
// 连接池可观测性 (30s 指标日志)
{
let metrics_db = db.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
loop {
interval.tick().await;
let pool = &metrics_db;
let total = pool.options().get_max_connections() as usize;
let idle = pool.num_idle() as usize;
let used = total.saturating_sub(idle);
let usage_pct = if total > 0 { used * 100 / total } else { 0 };
tracing::info!(
"[PoolMetrics] total={} idle={} used={} usage_pct={}%",
total, idle, used, usage_pct,
);
if usage_pct >= 80 {
tracing::warn!(
"[PoolMetrics] HIGH USAGE: {}% of connections in use!",
usage_pct,
);
}
}
});
}
let app = build_router(state.clone()).await;
// 配置 TCP keepalive + 短 SO_LINGER防止 CLOSE_WAIT 累积
let listener = create_listener(&config.server.host, config.server.port)?;
info!("SaaS server listening on {}:{}", config.server.host, config.server.port);
// 优雅停机: Ctrl+C → 取消 CancellationToken → SSE 流终止 → 连接排空
// 优雅停机: Ctrl+C → 最终批量 flush → 取消 CancellationToken → SSE 流终止 → 连接排空
let token = shutdown_token.clone();
let flush_state = state;
let batch_max = config.database.rate_limit_batch_max_size;
axum::serve(listener, app.into_make_service_with_connect_info::<std::net::SocketAddr>())
.with_graceful_shutdown(async move {
tokio::signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
info!("Received shutdown signal, cancelling SSE streams and draining connections...");
info!("Received shutdown signal, flushing pending rate limit batch...");
flush_state.flush_rate_limit_batch(batch_max).await;
info!("Cancelling SSE streams and draining connections...");
token.cancel();
})
.await?;
@@ -265,6 +323,7 @@ async fn build_router(state: AppState) -> axum::Router {
let public_routes = zclaw_saas::auth::routes()
.route("/api/health", axum::routing::get(health_handler))
.merge(zclaw_saas::billing::callback_routes())
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::public_rate_limit_middleware,
@@ -280,6 +339,8 @@ async fn build_router(state: AppState) -> axum::Router {
.merge(zclaw_saas::agent_template::routes())
.merge(zclaw_saas::scheduled_task::routes())
.merge(zclaw_saas::telemetry::routes())
.merge(zclaw_saas::billing::routes())
.merge(zclaw_saas::knowledge::routes())
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::api_version_middleware,
@@ -313,6 +374,10 @@ async fn build_router(state: AppState) -> axum::Router {
state.clone(),
zclaw_saas::middleware::request_id_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::quota_check_middleware,
))
.layer(middleware::from_fn_with_state(
state.clone(),
zclaw_saas::middleware::rate_limit_middleware,
@@ -322,10 +387,55 @@ async fn build_router(state: AppState) -> axum::Router {
zclaw_saas::auth::auth_middleware,
));
axum::Router::new()
let mut router = axum::Router::new()
.merge(non_streaming_routes)
.merge(relay_routes)
.merge(relay_routes);
// 开发模式挂载 mock 支付页面
{
let is_dev = std::env::var("ZCLAW_SAAS_DEV")
.map(|v| v == "true" || v == "1")
.unwrap_or(false);
if is_dev {
router = router.merge(zclaw_saas::billing::mock_routes());
info!("Mock payment routes mounted (dev mode)");
}
}
router
.layer(TraceLayer::new_for_http())
.layer(cors)
.with_state(state)
}
/// Load `.env` file from project root by walking up from current directory.
/// Sets environment variables that are not already set (does not override).
fn load_dotenv() {
let mut dir = std::env::current_dir().unwrap_or_default();
loop {
let env_path = dir.join(".env");
if env_path.is_file() {
if let Ok(content) = std::fs::read_to_string(&env_path) {
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((key, value)) = line.split_once('=') {
let key = key.trim();
let value = value.trim();
// Only set if not already defined in environment
if std::env::var(key).is_err() {
std::env::set_var(key, value);
}
}
}
tracing::debug!("Loaded .env from {}", env_path.display());
}
return;
}
if !dir.pop() {
break;
}
}
}

View File

@@ -93,17 +93,56 @@ pub async fn rate_limit_middleware(
)).into_response();
}
// Write-through to DB for persistence across restarts (fire-and-forget)
// Write-through to batch accumulator (memory-only, flushed periodically by background task)
// 替换原来的 fire-and-forget tokio::spawn(DB INSERT),消除每请求 1 个 DB 连接消耗
if should_persist {
let db = state.db.clone();
tokio::spawn(async move {
let _ = sqlx::query(
"INSERT INTO rate_limit_events (key, window_start, count) VALUES ($1, NOW(), 1)"
)
.bind(&key)
.execute(&db)
.await;
});
let mut entry = state.rate_limit_batch.entry(key).or_insert(0);
*entry += 1;
}
next.run(req).await
}
/// 配额检查中间件
/// 在 Relay 请求前检查账户月度用量配额
/// 仅对 /api/v1/relay/chat/completions 生效
pub async fn quota_check_middleware(
State(state): State<AppState>,
req: Request<Body>,
next: Next,
) -> Response<Body> {
let path = req.uri().path();
// 仅对 relay 请求检查配额
if !path.starts_with("/api/v1/relay/") {
return next.run(req).await;
}
// 从扩展中获取认证上下文
let account_id = match req.extensions().get::<AuthContext>() {
Some(ctx) => ctx.account_id.clone(),
None => return next.run(req).await,
};
// 检查 relay_requests 配额
match crate::billing::service::check_quota(&state.db, &account_id, "relay_requests").await {
Ok(check) if !check.allowed => {
tracing::warn!(
"Quota exceeded for account {}: {} ({}/{})",
account_id,
check.reason.as_deref().unwrap_or("配额已用尽"),
check.current,
check.limit.map(|l| l.to_string()).unwrap_or_else(|| "".into()),
);
return SaasError::RateLimited(
check.reason.unwrap_or_else(|| "月度配额已用尽".into()),
).into_response();
}
Err(e) => {
// 配额检查失败不阻断请求(降级策略)
tracing::warn!("Quota check failed for account {}: {}", account_id, e);
}
_ => {}
}
next.run(req).await
@@ -192,17 +231,10 @@ pub async fn public_rate_limit_middleware(
return SaasError::RateLimited(error_msg.into()).into_response();
}
// Write-through to DB for persistence across restarts (fire-and-forget)
// Write-through to batch accumulator (memory-only, flushed periodically)
if should_persist {
let db = state.db.clone();
tokio::spawn(async move {
let _ = sqlx::query(
"INSERT INTO rate_limit_events (key, window_start, count) VALUES ($1, NOW(), 1)"
)
.bind(&key)
.execute(&db)
.await;
});
let mut entry = state.rate_limit_batch.entry(key).or_insert(0);
*entry += 1;
}
next.run(req).await

View File

@@ -82,6 +82,10 @@ pub async fn create_provider(
let provider = service::create_provider(&state.db, &req, &enc_key).await?;
log_operation(&state.db, &ctx.account_id, "provider.create", "provider", &provider.id,
Some(serde_json::json!({"name": &req.name})), ctx.client_ip.as_deref()).await?;
// Admin mutation 后立即刷新缓存,消除 60s 陈旧窗口
if let Err(e) = state.cache.load_from_db(&state.db).await {
tracing::warn!("Cache reload failed after provider.create: {}", e);
}
Ok((StatusCode::CREATED, Json(provider)))
}
@@ -102,6 +106,9 @@ pub async fn update_provider(
drop(config);
let provider = service::update_provider(&state.db, &id, &req, &enc_key).await?;
log_operation(&state.db, &ctx.account_id, "provider.update", "provider", &id, None, ctx.client_ip.as_deref()).await?;
if let Err(e) = state.cache.load_from_db(&state.db).await{
tracing::warn!("Cache reload failed after provider.update: {}", e);
}
Ok(Json(provider))
}
@@ -114,6 +121,9 @@ pub async fn delete_provider(
check_permission(&ctx, "provider:manage")?;
service::delete_provider(&state.db, &id).await?;
log_operation(&state.db, &ctx.account_id, "provider.delete", "provider", &id, None, ctx.client_ip.as_deref()).await?;
if let Err(e) = state.cache.load_from_db(&state.db).await{
tracing::warn!("Cache reload failed after provider.delete: {}", e);
}
Ok(Json(serde_json::json!({"ok": true})))
}
@@ -150,6 +160,9 @@ pub async fn create_model(
let model = service::create_model(&state.db, &req).await?;
log_operation(&state.db, &ctx.account_id, "model.create", "model", &model.id,
Some(serde_json::json!({"model_id": &req.model_id, "provider_id": &req.provider_id})), ctx.client_ip.as_deref()).await?;
if let Err(e) = state.cache.load_from_db(&state.db).await{
tracing::warn!("Cache reload failed after model.create: {}", e);
}
Ok((StatusCode::CREATED, Json(model)))
}
@@ -163,6 +176,9 @@ pub async fn update_model(
check_permission(&ctx, "model:manage")?;
let model = service::update_model(&state.db, &id, &req).await?;
log_operation(&state.db, &ctx.account_id, "model.update", "model", &id, None, ctx.client_ip.as_deref()).await?;
if let Err(e) = state.cache.load_from_db(&state.db).await{
tracing::warn!("Cache reload failed after model.update: {}", e);
}
Ok(Json(model))
}
@@ -175,6 +191,9 @@ pub async fn delete_model(
check_permission(&ctx, "model:manage")?;
service::delete_model(&state.db, &id).await?;
log_operation(&state.db, &ctx.account_id, "model.delete", "model", &id, None, ctx.client_ip.as_deref()).await?;
if let Err(e) = state.cache.load_from_db(&state.db).await{
tracing::warn!("Cache reload failed after model.delete: {}", e);
}
Ok(Json(serde_json::json!({"ok": true})))
}

View File

@@ -29,3 +29,12 @@ pub struct PromptVersionRow {
pub min_app_version: Option<String>,
pub created_at: String,
}
/// prompt_sync_status 表行
#[derive(Debug, FromRow)]
pub struct PromptSyncStatusRow {
pub device_id: String,
pub template_id: String,
pub synced_version: i32,
pub synced_at: String,
}

View File

@@ -2,6 +2,24 @@
use sqlx::FromRow;
/// telemetry_reports 表行
#[derive(Debug, FromRow)]
pub struct TelemetryReportRow {
pub id: String,
pub account_id: String,
pub device_id: String,
pub app_version: Option<String>,
pub model_id: String,
pub input_tokens: i64,
pub output_tokens: i64,
pub latency_ms: Option<i32>,
pub success: bool,
pub error_type: Option<String>,
pub connection_mode: Option<String>,
pub reported_at: String,
pub created_at: String,
}
/// telemetry 按 model 分组统计
#[derive(Debug, FromRow)]
pub struct TelemetryModelStatsRow {

View File

@@ -4,7 +4,7 @@ use sqlx::PgPool;
use crate::error::{SaasError, SaasResult};
use crate::common::PaginatedResponse;
use crate::common::normalize_pagination;
use crate::models::{PromptTemplateRow, PromptVersionRow};
use crate::models::{PromptTemplateRow, PromptVersionRow, PromptSyncStatusRow};
use super::types::*;
/// 创建提示词模板 + 初始版本
@@ -310,3 +310,21 @@ pub async fn check_updates(
server_time: chrono::Utc::now().to_rfc3339(),
})
}
/// 查询设备的提示词同步状态
pub async fn get_sync_status(
db: &PgPool,
device_id: &str,
) -> SaasResult<Vec<PromptSyncStatusRow>> {
let rows = sqlx::query_as::<_, PromptSyncStatusRow>(
"SELECT device_id, template_id, synced_version, synced_at \
FROM prompt_sync_status \
WHERE device_id = $1 \
ORDER BY synced_at DESC \
LIMIT 50"
)
.bind(device_id)
.fetch_all(db)
.await?;
Ok(rows)
}

View File

@@ -23,18 +23,12 @@ pub async fn chat_completions(
) -> SaasResult<Response> {
check_permission(&ctx, "relay:use")?;
// 队列容量检查:防止过载(立即释放读锁)
// 队列容量检查:使用内存 AtomicI64 计数器,消除 DB COUNT 查询
let max_queue_size = {
let config = state.config.read().await;
config.relay.max_queue_size
};
let queued_count: i64 = sqlx::query_scalar(
"SELECT COUNT(*) FROM relay_tasks WHERE account_id = $1 AND status IN ('queued', 'processing')"
)
.bind(&ctx.account_id)
.fetch_one(&state.db)
.await
.unwrap_or(0);
let queued_count = state.cache.relay_queue_count(&ctx.account_id);
if queued_count >= max_queue_size as i64 {
return Err(SaasError::RateLimited(
@@ -128,18 +122,8 @@ pub async fn chat_completions(
.and_then(|v| v.as_bool())
.unwrap_or(false);
// 查找 model 对应的 provider — 使用精准查询避免全量加载
let target_model: Option<crate::models::ModelRow> = sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
supports_streaming, supports_vision, enabled, pricing_input, pricing_output,
created_at, updated_at
FROM models WHERE model_id = $1 AND enabled = true LIMIT 1"
)
.bind(&model_name)
.fetch_optional(&state.db)
.await?;
let target_model = target_model
// 查找 model — 使用内存缓存O(1) DashMap消除关键路径 DB 查询
let target_model = state.cache.get_model(model_name)
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
// Stream compatibility check: reject stream requests for non-streaming models
@@ -149,8 +133,9 @@ pub async fn chat_completions(
));
}
// 获取 provider 信息
let provider = model_service::get_provider(&state.db, &target_model.provider_id).await?;
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
let provider = state.cache.get_provider(&target_model.provider_id)
.ok_or_else(|| SaasError::NotFound(format!("Provider {} 不存在", target_model.provider_id)))?;
if !provider.enabled {
return Err(SaasError::Forbidden(format!("Provider {} 已禁用", provider.name)));
}
@@ -171,6 +156,9 @@ pub async fn chat_completions(
max_attempts,
).await?;
// 递增内存队列计数器(替代 DB COUNT 查询)
state.cache.relay_enqueue(&ctx.account_id);
// 异步派发操作日志(非阻塞,不占用关键路径 DB 连接)
state.dispatch_log_operation(
&ctx.account_id, "relay.request", "relay_task", &task.id,
@@ -186,8 +174,7 @@ pub async fn chat_completions(
&enc_key,
).await;
// 克隆用于异步 usage 记录
let db_usage = state.db.clone();
// 克隆用于 Worker dispatch usage 记录(受 SpawnLimiter 门控,不再直接 spawn
let account_id_usage = ctx.account_id.clone();
let provider_id_usage = target_model.provider_id.clone();
let model_id_usage = target_model.model_id.clone();
@@ -195,30 +182,62 @@ pub async fn chat_completions(
match response {
Ok(service::RelayResponse::Json(body)) => {
let (input_tokens, output_tokens) = service::extract_token_usage_from_json(&body);
// 异步记录 usage不阻塞响应)
tokio::spawn(async move {
if let Err(e) = model_service::record_usage(
&db_usage, &account_id_usage, &provider_id_usage,
&model_id_usage, input_tokens, output_tokens,
None, "success", None,
).await {
tracing::warn!("Failed to record relay usage: {}", e);
// 通过 Worker dispatch 记录 usage受 SpawnLimiter 门控,不阻塞响应)
{
let args = crate::workers::record_usage::RecordUsageArgs {
account_id: account_id_usage.clone(),
provider_id: provider_id_usage.clone(),
model_id: model_id_usage.clone(),
input_tokens: input_tokens as i32,
output_tokens: output_tokens as i32,
latency_ms: None,
status: "success".to_string(),
error_message: None,
};
if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await {
tracing::warn!("Failed to dispatch record_usage: {}", e);
}
});
}
// 实时更新计费配额relay_requests + tokens 同步递增)
if let Err(e) = crate::billing::service::increment_usage(
&state.db, &account_id_usage, input_tokens as i64, output_tokens as i64,
).await {
tracing::warn!("Failed to increment billing usage for {}: {}", account_id_usage, e);
}
// 任务完成,递减队列计数器
state.cache.relay_dequeue(&account_id_usage);
Ok((StatusCode::OK, [(axum::http::header::CONTENT_TYPE, "application/json")], body).into_response())
}
Ok(service::RelayResponse::Sse(body)) => {
// 异步记录 SSE 占位 usage
tokio::spawn(async move {
if let Err(e) = model_service::record_usage(
&db_usage, &account_id_usage, &provider_id_usage,
&model_id_usage, 0, 0,
None, "streaming", None,
).await {
tracing::warn!("Failed to record SSE usage placeholder: {}", e);
// 通过 Worker dispatch 记录 SSE 占位 usage
{
let args = crate::workers::record_usage::RecordUsageArgs {
account_id: account_id_usage.clone(),
provider_id: provider_id_usage.clone(),
model_id: model_id_usage.clone(),
input_tokens: 0,
output_tokens: 0,
latency_ms: None,
status: "streaming".to_string(),
error_message: None,
};
if let Err(e) = state.worker_dispatcher.dispatch("record_usage", args).await {
tracing::warn!("Failed to dispatch SSE usage: {}", e);
}
});
}
// SSE: relay_requests 实时递增tokens 由 AggregateUsageWorker 对账修正)
if let Err(e) = crate::billing::service::increment_dimension(
&state.db, &account_id_usage, "relay_requests",
).await {
tracing::warn!("Failed to increment billing relay_requests for {}: {}", account_id_usage, e);
}
// SSE 流已返回,递减队列计数器(流式任务开始处理)
state.cache.relay_dequeue(&account_id_usage);
let response = axum::response::Response::builder()
.status(StatusCode::OK)
@@ -230,17 +249,25 @@ pub async fn chat_completions(
Ok(response)
}
Err(e) => {
// 异步记录失败 usage(不阻塞错误响应)
// 通过 Worker dispatch 记录失败 usage
let error_msg = e.to_string();
tokio::spawn(async move {
if let Err(e2) = model_service::record_usage(
&db_usage, &account_id_usage, &provider_id_usage,
&model_id_usage, 0, 0,
None, "failed", Some(&error_msg),
).await {
tracing::warn!("Failed to record relay failure usage: {}", e2);
{
let args = crate::workers::record_usage::RecordUsageArgs {
account_id: account_id_usage.clone(),
provider_id: provider_id_usage.clone(),
model_id: model_id_usage.clone(),
input_tokens: 0,
output_tokens: 0,
latency_ms: None,
status: "failed".to_string(),
error_message: Some(error_msg),
};
if let Err(e2) = state.worker_dispatcher.dispatch("record_usage", args).await {
tracing::warn!("Failed to dispatch failure usage: {}", e2);
}
});
}
// 任务失败,递减队列计数器(失败请求不计费)
state.cache.relay_dequeue(&account_id_usage);
Err(e)
}
}

View File

@@ -281,6 +281,39 @@ pub async fn delete_provider_key(
Ok(())
}
/// Key 使用窗口统计
#[derive(Debug, Clone)]
pub struct KeyUsageStats {
pub key_id: String,
pub window_minute: String,
pub request_count: i32,
pub token_count: i64,
}
/// 查询指定 Key 的最近使用窗口统计
pub async fn get_key_usage_stats(
db: &PgPool,
key_id: &str,
limit: i64,
) -> SaasResult<Vec<KeyUsageStats>> {
let limit = limit.min(60).max(1);
let rows: Vec<(String, String, i32, i64)> = sqlx::query_as(
"SELECT key_id, window_minute, request_count, token_count \
FROM key_usage_window \
WHERE key_id = $1 \
ORDER BY window_minute DESC \
LIMIT $2"
)
.bind(key_id)
.bind(limit)
.fetch_all(db)
.await?;
Ok(rows.into_iter().map(|(key_id, window_minute, request_count, token_count)| {
KeyUsageStats { key_id, window_minute, request_count, token_count }
}).collect())
}
/// 解析冷却剩余时间(秒)
fn parse_cooldown_remaining(cooldown_until: &str, now: &str) -> i64 {
let cooldown = chrono::DateTime::parse_from_rfc3339(cooldown_until);

View File

@@ -2,11 +2,23 @@
use sqlx::PgPool;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use crate::error::{SaasError, SaasResult};
use crate::models::RelayTaskRow;
use super::types::*;
// ============ StreamBridge 背压常量 ============
/// 上游无数据时,发送 SSE 心跳注释行的间隔
const STREAMBRIDGE_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(15);
/// 上游无数据时,丢弃连接的超时阈值
const STREAMBRIDGE_TIMEOUT: Duration = Duration::from_secs(30);
/// 流结束后延迟清理的时间窗口
const STREAMBRIDGE_CLEANUP_DELAY: Duration = Duration::from_secs(60);
/// 判断 HTTP 状态码是否为可重试的瞬态错误 (5xx + 429)
fn is_retryable_status(status: u16) -> bool {
status == 429 || (500..600).contains(&status)
@@ -33,15 +45,24 @@ pub async fn create_relay_task(
let request_hash = hash_request(request_body);
let max_attempts = max_attempts.max(1).min(5);
sqlx::query(
// INSERT ... RETURNING 合并两次 DB 往返为一次
let row: RelayTaskRow = sqlx::query_as(
"INSERT INTO relay_tasks (id, account_id, provider_id, model_id, request_hash, request_body, status, priority, attempt_count, max_attempts, queued_at, created_at)
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)"
VALUES ($1, $2, $3, $4, $5, $6, 'queued', $7, 0, $8, $9, $9)
RETURNING id, account_id, provider_id, model_id, status, priority, attempt_count, max_attempts, input_tokens, output_tokens, error_message, queued_at, started_at, completed_at, created_at"
)
.bind(&id).bind(account_id).bind(provider_id).bind(model_id)
.bind(&request_hash).bind(request_body).bind(priority).bind(max_attempts as i64).bind(&now)
.execute(db).await?;
.fetch_one(db)
.await?;
get_relay_task(db, &id).await
Ok(RelayTaskInfo {
id: row.id, account_id: row.account_id, provider_id: row.provider_id, model_id: row.model_id,
status: row.status, priority: row.priority, attempt_count: row.attempt_count,
max_attempts: row.max_attempts, input_tokens: row.input_tokens, output_tokens: row.output_tokens,
error_message: row.error_message, queued_at: row.queued_at, started_at: row.started_at,
completed_at: row.completed_at, created_at: row.created_at,
})
}
pub async fn get_relay_task(db: &PgPool, task_id: &str) -> SaasResult<RelayTaskInfo> {
@@ -295,9 +316,9 @@ pub async fn execute_relay(
}
});
// Convert mpsc::Receiver into a Body stream
let body_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let body = axum::body::Body::from_stream(body_stream);
// Build StreamBridge: wraps the bounded receiver with heartbeat,
// timeout, and delayed cleanup (DeerFlow-inspired backpressure).
let body = build_stream_bridge(rx, task_id.to_string());
// SSE 流结束后异步记录 usage + Key 使用量
// 使用全局 Arc<Semaphore> 限制并发 spawned tasks防止高并发时耗尽连接池
@@ -335,6 +356,14 @@ pub async fn execute_relay(
if tokio::time::timeout(std::time::Duration::from_secs(5), db_op).await.is_err() {
tracing::warn!("SSE usage recording timed out for task {}", task_id_clone);
}
// StreamBridge 延迟清理:流结束 60s 后释放残留资源
// (主要是 Arc<SseUsageCapture> 等,通过 drop(_permit) 归还信号量)
tokio::time::sleep(STREAMBRIDGE_CLEANUP_DELAY).await;
tracing::debug!(
"[StreamBridge] Cleanup delay elapsed for task {}",
task_id_clone
);
});
return Ok(RelayResponse::Sse(body));
@@ -346,7 +375,9 @@ pub async fn execute_relay(
// 记录 Key 使用量
let _ = super::key_pool::record_key_usage(
db, &key_id, Some(input_tokens + output_tokens),
).await;
).await.map_err(|e| {
tracing::warn!("[Relay] Failed to record key usage for billing: {}", e);
});
return Ok(RelayResponse::Json(body));
}
}
@@ -423,6 +454,98 @@ pub enum RelayResponse {
Sse(axum::body::Body),
}
// ============ StreamBridge ============
/// 构建 StreamBridge将 mpsc::Receiver 包装为带心跳、超时的 axum Body。
///
/// 借鉴 DeerFlow StreamBridge 背压机制:
/// - 15s 心跳:上游长时间无输出时,发送 SSE 注释行 `: heartbeat\n\n` 保持连接活跃
/// - 30s 超时:上游连续 30s 无真实数据时,发送超时事件并关闭流
/// - 60s 延迟清理:由调用方的 spawned task 在流结束后延迟释放资源
fn build_stream_bridge(
mut rx: tokio::sync::mpsc::Receiver<Result<bytes::Bytes, std::io::Error>>,
task_id: String,
) -> axum::body::Body {
// SSE heartbeat comment bytes: `: heartbeat\n\n`
// SSE spec: lines starting with `:` are comments and ignored by clients
const HEARTBEAT_BYTES: &[u8] = b": heartbeat\n\n";
// SSE timeout error event
const TIMEOUT_EVENT: &[u8] = b"data: {\"error\":\"stream_timeout\",\"message\":\"upstream timed out\"}\n\n";
let stream = async_stream::stream! {
// Track how many consecutive heartbeat-only cycles have elapsed.
// Real data resets this counter; after 2 heartbeats (30s) without
// real data, we terminate the stream.
let mut idle_heartbeats: u32 = 0;
loop {
// tokio::select! races the next data chunk against a heartbeat timer.
// The timer resets on every iteration, ensuring heartbeats only fire
// during genuine idle periods.
tokio::select! {
biased; // prioritize data over heartbeat
chunk = rx.recv() => {
match chunk {
Some(Ok(data)) => {
// Real data received — reset idle counter
idle_heartbeats = 0;
yield Ok::<bytes::Bytes, std::io::Error>(data);
}
Some(Err(e)) => {
tracing::warn!(
"[StreamBridge] Upstream error for task {}: {}",
task_id, e
);
yield Err(e);
break;
}
None => {
// Channel closed = upstream finished normally
tracing::debug!(
"[StreamBridge] Upstream completed for task {}",
task_id
);
break;
}
}
}
// Heartbeat: send SSE comment if no data for 15s
_ = tokio::time::sleep(STREAMBRIDGE_HEARTBEAT_INTERVAL) => {
idle_heartbeats += 1;
tracing::trace!(
"[StreamBridge] Heartbeat #{} for task {} (idle {}s)",
idle_heartbeats,
task_id,
idle_heartbeats as u64 * STREAMBRIDGE_HEARTBEAT_INTERVAL.as_secs(),
);
// After 2 consecutive heartbeats without real data (30s),
// terminate the stream to prevent connection leaks.
if idle_heartbeats >= 2 {
tracing::warn!(
"[StreamBridge] Timeout ({:?}) no real data, closing stream for task {}",
STREAMBRIDGE_TIMEOUT,
task_id,
);
yield Ok(bytes::Bytes::from_static(TIMEOUT_EVENT));
break;
}
yield Ok(bytes::Bytes::from_static(HEARTBEAT_BYTES));
}
}
}
};
// Pin the stream to a Box<dyn Stream + Send> to satisfy Body::from_stream
let boxed: std::pin::Pin<Box<dyn futures::Stream<Item = Result<bytes::Bytes, std::io::Error>> + Send>> =
Box::pin(stream);
axum::body::Body::from_stream(boxed)
}
// ============ Helpers ============
fn hash_request(body: &str) -> String {

View File

@@ -20,7 +20,9 @@ struct ScheduledTaskRow {
last_run_at: Option<String>,
next_run_at: Option<String>,
run_count: i32,
last_result: Option<String>,
last_error: Option<String>,
last_duration_ms: Option<i64>,
input_payload: Option<serde_json::Value>,
created_at: String,
}
@@ -41,7 +43,9 @@ impl ScheduledTaskRow {
last_run: self.last_run_at.clone(),
next_run: self.next_run_at.clone(),
run_count: self.run_count,
last_result: self.last_result.clone(),
last_error: self.last_error.clone(),
last_duration_ms: self.last_duration_ms,
created_at: self.created_at.clone(),
}
}
@@ -86,7 +90,9 @@ pub async fn create_task(
last_run: None,
next_run: None,
run_count: 0,
last_result: None,
last_error: None,
last_duration_ms: None,
created_at: now,
})
}
@@ -99,7 +105,7 @@ pub async fn list_tasks(
let rows: Vec<ScheduledTaskRow> = sqlx::query_as(
"SELECT id, account_id, name, description, schedule, schedule_type,
target_type, target_id, enabled, last_run_at, next_run_at,
run_count, last_error, input_payload, created_at
run_count, last_result, last_error, last_duration_ms, input_payload, created_at
FROM scheduled_tasks WHERE account_id = $1 ORDER BY created_at DESC"
)
.bind(account_id)
@@ -118,7 +124,7 @@ pub async fn get_task(
let row: Option<ScheduledTaskRow> = sqlx::query_as(
"SELECT id, account_id, name, description, schedule, schedule_type,
target_type, target_id, enabled, last_run_at, next_run_at,
run_count, last_error, input_payload, created_at
run_count, last_result, last_error, last_duration_ms, input_payload, created_at
FROM scheduled_tasks WHERE id = $1 AND account_id = $2"
)
.bind(task_id)

View File

@@ -58,6 +58,8 @@ pub struct ScheduledTaskResponse {
pub last_run: Option<String>,
pub next_run: Option<String>,
pub run_count: i32,
pub last_result: Option<String>,
pub last_error: Option<String>,
pub last_duration_ms: Option<i64>,
pub created_at: String,
}

View File

@@ -3,11 +3,18 @@
//! 通过 TOML 配置定时任务,无需改代码调整调度时间。
//! 配置格式在 config.rs 的 SchedulerConfig / JobConfig 中定义。
use std::time::Duration;
use std::time::{Duration, Instant};
use sqlx::PgPool;
use crate::config::SchedulerConfig;
use crate::workers::WorkerDispatcher;
/// 单次任务执行的产出
struct TaskExecution {
result: Option<String>,
error: Option<String>,
duration_ms: i64,
}
/// 解析时间间隔字符串为 Duration
pub fn parse_duration(s: &str) -> Result<Duration, String> {
let s = s.trim().to_lowercase();
@@ -143,23 +150,42 @@ pub fn start_user_task_scheduler(db: PgPool) {
});
}
/// 执行单个调度任务
/// 执行单个调度任务,返回执行产出(结果/错误/耗时)
async fn execute_scheduled_task(
db: &PgPool,
task_id: &str,
target_type: &str,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let task_info: Option<(String, Option<String>)> = sqlx::query_as(
) -> TaskExecution {
let start = Instant::now();
let task_info: Option<(String, Option<String>)> = match sqlx::query_as(
"SELECT name, config_json FROM scheduled_tasks WHERE id = $1"
)
.bind(task_id)
.fetch_optional(db)
.await
.map_err(|e| format!("Failed to fetch task {}: {}", task_id, e))?;
{
Ok(info) => info,
Err(e) => {
let elapsed = start.elapsed().as_millis() as i64;
return TaskExecution {
result: None,
error: Some(format!("Failed to fetch task {}: {}", task_id, e)),
duration_ms: elapsed,
};
}
};
let (task_name, _config_json) = match task_info {
Some(info) => info,
None => return Err(format!("Task {} not found", task_id).into()),
None => {
let elapsed = start.elapsed().as_millis() as i64;
return TaskExecution {
result: None,
error: Some(format!("Task {} not found", task_id)),
duration_ms: elapsed,
};
}
};
tracing::info!(
@@ -167,22 +193,39 @@ async fn execute_scheduled_task(
task_name, target_type
);
match target_type {
let exec_result = match target_type {
t if t == "agent" => {
tracing::info!("[UserScheduler] Agent task '{}' queued for execution", task_name);
Ok("agent_dispatched".to_string())
}
t if t == "hand" => {
tracing::info!("[UserScheduler] Hand task '{}' queued for execution", task_name);
Ok("hand_dispatched".to_string())
}
t if t == "workflow" => {
tracing::info!("[UserScheduler] Workflow task '{}' queued for execution", task_name);
Ok("workflow_dispatched".to_string())
}
other => {
tracing::warn!("[UserScheduler] Unknown target_type '{}' for task '{}'", other, task_name);
Err(format!("Unknown target_type: {}", other))
}
}
};
Ok(())
let elapsed = start.elapsed().as_millis() as i64;
match exec_result {
Ok(msg) => TaskExecution {
result: Some(msg),
error: None,
duration_ms: elapsed,
},
Err(err) => TaskExecution {
result: None,
error: Some(err),
duration_ms: elapsed,
},
}
}
async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
@@ -206,17 +249,19 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
task_id, target_type, schedule_type
);
// 执行任务
match execute_scheduled_task(db, &task_id, &target_type).await {
Ok(()) => {
tracing::info!("[UserScheduler] task {} executed successfully", task_id);
}
Err(e) => {
tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, e);
}
// 执行任务并收集产出
let exec = execute_scheduled_task(db, &task_id, &target_type).await;
if let Some(ref err) = exec.error {
tracing::error!("[UserScheduler] task {} execution failed: {}", task_id, err);
} else {
tracing::info!(
"[UserScheduler] task {} executed successfully ({}ms)",
task_id, exec.duration_ms
);
}
// 更新任务状态
// 更新任务状态(含执行产出)
let result = sqlx::query(
"UPDATE scheduled_tasks
SET last_run_at = NOW(),
@@ -228,10 +273,16 @@ async fn tick_user_tasks(db: &PgPool) -> Result<(), sqlx::Error> {
WHEN schedule_type = 'interval' AND interval_seconds IS NOT NULL
THEN NOW() + (interval_seconds || ' seconds')::INTERVAL
ELSE NULL
END
END,
last_result = $2,
last_error = $3,
last_duration_ms = $4
WHERE id = $1"
)
.bind(&task_id)
.bind(&exec.result)
.bind(&exec.error)
.bind(exec.duration_ms)
.execute(db)
.await;

View File

@@ -10,6 +10,44 @@ use crate::config::SaaSConfig;
use crate::workers::WorkerDispatcher;
use crate::cache::AppCache;
// ============ SpawnLimiter ============
/// 可复用的并发限制器,基于 Arc<Semaphore>。
/// 复用 SSE_SPAWN_SEMAPHORE 模式,为 Worker、中间件等场景提供统一门控。
#[derive(Clone)]
pub struct SpawnLimiter {
semaphore: Arc<tokio::sync::Semaphore>,
name: &'static str,
}
impl SpawnLimiter {
pub fn new(name: &'static str, max_permits: usize) -> Self {
Self {
semaphore: Arc::new(tokio::sync::Semaphore::new(max_permits)),
name,
}
}
/// 尝试获取 permit满时返回 None适用于可丢弃的操作如 usage 记录)
pub fn try_acquire(&self) -> Option<tokio::sync::OwnedSemaphorePermit> {
self.semaphore.clone().try_acquire_owned().ok()
}
/// 异步等待 permit适用于不可丢弃的操作如 Worker 任务)
pub async fn acquire(&self) -> tokio::sync::OwnedSemaphorePermit {
self.semaphore
.clone()
.acquire_owned()
.await
.expect("SpawnLimiter semaphore closed unexpectedly")
}
pub fn name(&self) -> &'static str { self.name }
pub fn available(&self) -> usize { self.semaphore.available_permits() }
}
// ============ AppState ============
/// 全局应用状态,通过 Axum State 共享
#[derive(Clone)]
pub struct AppState {
@@ -33,10 +71,20 @@ pub struct AppState {
pub shutdown_token: CancellationToken,
/// 应用缓存: Model/Provider/队列计数器
pub cache: AppCache,
/// Worker spawn 并发限制器
pub worker_limiter: SpawnLimiter,
/// 限流事件批量累加器: key → 待写入计数
pub rate_limit_batch: Arc<dashmap::DashMap<String, i64>>,
}
impl AppState {
pub fn new(db: PgPool, config: SaaSConfig, worker_dispatcher: WorkerDispatcher, shutdown_token: CancellationToken) -> anyhow::Result<Self> {
pub fn new(
db: PgPool,
config: SaaSConfig,
worker_dispatcher: WorkerDispatcher,
shutdown_token: CancellationToken,
worker_limiter: SpawnLimiter,
) -> anyhow::Result<Self> {
let jwt_secret = config.jwt_secret()?;
let rpm = config.rate_limit.requests_per_minute;
Ok(Self {
@@ -50,6 +98,8 @@ impl AppState {
worker_dispatcher,
shutdown_token,
cache: AppCache::new(),
worker_limiter,
rate_limit_batch: Arc::new(dashmap::DashMap::new()),
})
}
@@ -96,4 +146,60 @@ impl AppState {
tracing::warn!("Failed to dispatch log_operation: {}", e);
}
}
/// 限流事件批量 flush 到 DB
///
/// 使用 swap-to-zero 模式先将计数器原子归零DB 写入成功后删除条目。
/// 如果 DB 写入失败,归零的计数会在下次 flush 时重新累加(因 middleware 持续写入)。
pub async fn flush_rate_limit_batch(&self, max_batch: usize) {
// 阶段1: 收集非零 key将计数器原子归零而非删除
// 这样如果 DB 写入失败middleware 的新累加会在已有 key 上继续
let mut batch: Vec<(String, i64)> = Vec::with_capacity(max_batch.min(64));
let keys: Vec<String> = self.rate_limit_batch.iter()
.filter(|e| *e.value() > 0)
.take(max_batch)
.map(|e| e.key().clone())
.collect();
for key in &keys {
// 原子交换为 0取走当前值
if let Some(mut entry) = self.rate_limit_batch.get_mut(key) {
if *entry > 0 {
batch.push((key.clone(), *entry));
*entry = 0; // 归零而非删除
}
}
}
if batch.is_empty() { return; }
let keys_buf: Vec<String> = batch.iter().map(|(k, _)| k.clone()).collect();
let counts: Vec<i64> = batch.iter().map(|(_, c)| *c).collect();
let result = sqlx::query(
"INSERT INTO rate_limit_events (key, window_start, count)
SELECT u.key, NOW(), u.cnt FROM UNNEST($1::text[], $2::bigint[]) AS u(key, cnt)"
)
.bind(&keys_buf)
.bind(&counts)
.execute(&self.db)
.await;
if let Err(e) = result {
// DB 写入失败:将归零的计数加回去,避免数据丢失
tracing::warn!("[RateLimitBatch] flush failed ({} entries), restoring counts: {}", batch.len(), e);
for (key, count) in &batch {
if let Some(mut entry) = self.rate_limit_batch.get_mut(key) {
*entry += *count;
}
}
} else {
// DB 写入成功:删除已归零的条目
for (key, _) in &batch {
self.rate_limit_batch.remove_if(key, |_, v| *v == 0);
}
tracing::debug!("[RateLimitBatch] flushed {} entries", batch.len());
}
}
}

View File

@@ -2,7 +2,7 @@
use sqlx::PgPool;
use crate::error::SaasResult;
use crate::models::{TelemetryModelStatsRow, TelemetryDailyStatsRow};
use crate::models::{TelemetryModelStatsRow, TelemetryDailyStatsRow, TelemetryReportRow};
use super::types::*;
const CHUNK_SIZE: usize = 100;
@@ -270,3 +270,27 @@ pub async fn get_daily_stats(
Ok(stats)
}
/// 查询账号最近的遥测报告
pub async fn get_recent_reports(
db: &PgPool,
account_id: &str,
limit: i64,
) -> SaasResult<Vec<TelemetryReportRow>> {
let limit = limit.min(100).max(1);
let rows = sqlx::query_as::<_, TelemetryReportRow>(
"SELECT id, account_id, device_id, app_version, model_id, \
input_tokens, output_tokens, latency_ms, success, \
error_type, connection_mode, \
reported_at::text, created_at::text \
FROM telemetry_reports \
WHERE account_id = $1 \
ORDER BY reported_at DESC \
LIMIT $2"
)
.bind(account_id)
.bind(limit)
.fetch_all(db)
.await?;
Ok(rows)
}

View File

@@ -0,0 +1,123 @@
//! 计费用量聚合 Worker
//!
//! 从 usage_records 聚合当月用量到 billing_usage_quotas 表。
//! 由 Scheduler 每小时触发,或在 relay 请求完成时直接派发。
use async_trait::async_trait;
use chrono::{Datelike, Timelike};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use crate::error::SaasResult;
use super::Worker;
/// 用量聚合参数
#[derive(Debug, Serialize, Deserialize)]
pub struct AggregateUsageArgs {
/// 聚合的目标账户 IDNone = 聚合所有活跃账户)
pub account_id: Option<String>,
}
pub struct AggregateUsageWorker;
#[async_trait]
impl Worker for AggregateUsageWorker {
type Args = AggregateUsageArgs;
fn name(&self) -> &str {
"aggregate_usage"
}
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
match args.account_id {
Some(account_id) => {
aggregate_single_account(db, &account_id).await?;
}
None => {
aggregate_all_accounts(db).await?;
}
}
Ok(())
}
}
/// 聚合单个账户的当月用量
async fn aggregate_single_account(db: &PgPool, account_id: &str) -> SaasResult<()> {
// 获取或创建用量记录(确保存在)
let usage = crate::billing::service::get_or_create_usage(db, account_id).await?;
// 从 usage_records 聚合当月实际 token 用量
let now = chrono::Utc::now();
let period_start = now
.with_day(1).unwrap_or(now)
.with_hour(0).unwrap_or(now)
.with_minute(0).unwrap_or(now)
.with_second(0).unwrap_or(now)
.with_nanosecond(0).unwrap_or(now);
let aggregated: Option<(i64, i64, i64)> = sqlx::query_as(
"SELECT COALESCE(SUM(input_tokens), 0), \
COALESCE(SUM(output_tokens), 0), \
COUNT(*) \
FROM usage_records \
WHERE account_id = $1 AND created_at >= $2 AND status = 'success'"
)
.bind(account_id)
.bind(period_start)
.fetch_optional(db)
.await?;
if let Some((input_tokens, output_tokens, request_count)) = aggregated {
sqlx::query(
"UPDATE billing_usage_quotas \
SET input_tokens = $1, \
output_tokens = $2, \
relay_requests = GREATEST(relay_requests, $3::int), \
updated_at = NOW() \
WHERE id = $4"
)
.bind(input_tokens)
.bind(output_tokens)
.bind(request_count as i32)
.bind(&usage.id)
.execute(db)
.await?;
tracing::debug!(
"Aggregated usage for account {}: in={}, out={}, reqs={}",
account_id, input_tokens, output_tokens, request_count
);
}
Ok(())
}
/// 聚合所有活跃账户
async fn aggregate_all_accounts(db: &PgPool) -> SaasResult<()> {
let account_ids: Vec<String> = sqlx::query_scalar(
"SELECT DISTINCT account_id FROM billing_subscriptions \
WHERE status IN ('trial', 'active', 'past_due') \
UNION \
SELECT DISTINCT account_id FROM billing_usage_quotas \
WHERE period_start >= date_trunc('month', NOW())"
)
.fetch_all(db)
.await?;
let total = account_ids.len();
let mut errors = 0;
for account_id in &account_ids {
if let Err(e) = aggregate_single_account(db, account_id).await {
tracing::warn!("Failed to aggregate usage for {}: {}", account_id, e);
errors += 1;
}
}
tracing::info!(
"Usage aggregation complete: {} accounts, {} errors",
total, errors
);
Ok(())
}

View File

@@ -0,0 +1,168 @@
//! 知识条目分块 + Embedding 生成 Worker
//!
//! 当知识条目创建/更新时触发:
//! 1. 按 Markdown 标题 + 固定长度分块
//! 2. 提取关键词(从 item 的 keywords 字段继承 + 内容提取)
//! 3. 写入 knowledge_chunks 表
//! 4. 如果配置了 embedding provider生成向量 embeddingPhase 2
use async_trait::async_trait;
use sqlx::PgPool;
use serde::{Serialize, Deserialize};
use crate::error::SaasResult;
use super::Worker;
#[derive(Debug, Serialize, Deserialize)]
pub struct GenerateEmbeddingArgs {
pub item_id: String,
}
pub struct GenerateEmbeddingWorker;
#[async_trait]
impl Worker for GenerateEmbeddingWorker {
type Args = GenerateEmbeddingArgs;
fn name(&self) -> &str {
"generate_embedding"
}
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
// 1. 加载条目
let item: Option<(String, String, Vec<String>)> = sqlx::query_as(
"SELECT content, title, keywords FROM knowledge_items WHERE id = $1"
)
.bind(&args.item_id)
.fetch_optional(db)
.await?;
let (content, title, keywords) = match item {
Some(row) => row,
None => {
tracing::warn!("GenerateEmbedding: item {} not found, skipping", args.item_id);
return Ok(());
}
};
// 2. 分块
let chunks = crate::knowledge::service::chunk_content(&content, 512, 64);
if chunks.is_empty() {
tracing::debug!("GenerateEmbedding: item {} has no content to chunk", args.item_id);
return Ok(());
}
// 3. 在事务中删除旧分块 + 插入新分块(防止并发竞争条件)
let mut tx = db.begin().await?;
// 锁定条目行防止并发 worker 同时处理同一条目
let locked: Option<(String,)> = sqlx::query_as(
"SELECT id FROM knowledge_items WHERE id = $1 FOR UPDATE"
)
.bind(&args.item_id)
.fetch_optional(&mut *tx)
.await?;
if locked.is_none() {
tx.rollback().await?;
tracing::warn!("GenerateEmbedding: item {} was deleted during processing", args.item_id);
return Ok(());
}
sqlx::query("DELETE FROM knowledge_chunks WHERE item_id = $1")
.bind(&args.item_id)
.execute(&mut *tx)
.await?;
for (idx, chunk) in chunks.iter().enumerate() {
let chunk_id = uuid::Uuid::new_v4().to_string();
let mut chunk_keywords = keywords.clone();
extract_chunk_keywords(chunk, &mut chunk_keywords);
sqlx::query(
"INSERT INTO knowledge_chunks (id, item_id, chunk_index, content, keywords, created_at)
VALUES ($1, $2, $3, $4, $5, NOW())"
)
.bind(&chunk_id)
.bind(&args.item_id)
.bind(idx as i32)
.bind(chunk)
.bind(&chunk_keywords)
.execute(&mut *tx)
.await?;
}
tx.commit().await?;
tracing::info!(
"GenerateEmbedding: item '{}' → {} chunks (keywords: {})",
title,
chunks.len(),
keywords.len(),
);
// Phase 2: 如果配置了 embedding provider在此处调用 embedding API
// 并更新 chunks 的 embedding 列
// TODO: let _ = generate_vectors(db, &args.item_id, &chunks).await;
Ok(())
}
}
/// 从 chunk 内容中提取高频中文词组作为补充关键词
///
/// 简单策略:提取 2-4 字的连续中文字符段,取出现频率 > 1 的
fn extract_chunk_keywords(content: &str, keywords: &mut Vec<String>) {
let chars: Vec<char> = content.chars().collect();
let mut i = 0;
while i < chars.len() {
// 寻找连续中文字符段
if is_cjk(chars[i]) {
let start = i;
while i < chars.len() && is_cjk(chars[i]) {
i += 1;
}
let segment: String = chars[start..i].iter().collect();
// 提取 2-4 字的子串
let seg_chars: Vec<char> = segment.chars().collect();
if seg_chars.len() >= 2 {
// 只取前 2-4 字的短语(避免过长无意义词组)
for len in 2..=4.min(seg_chars.len()) {
let phrase: String = seg_chars[..len].iter().collect();
// 过滤常见停用词(简单版)
if !is_stop_word(&phrase) && !keywords.contains(&phrase) {
keywords.push(phrase);
}
}
}
} else {
i += 1;
}
}
// 限制关键词总数
keywords.truncate(50);
}
/// 判断是否为 CJK 字符
fn is_cjk(c: char) -> bool {
matches!(c,
'\u{4E00}'..='\u{9FFF}' | // CJK Unified Ideographs
'\u{3400}'..='\u{4DBF}' | // CJK Unified Ideographs Extension A
'\u{F900}'..='\u{FAFF}' // CJK Compatibility Ideographs
)
}
/// 简单停用词表
fn is_stop_word(s: &str) -> bool {
matches!(s,
"" | "" | "" | "" | "" | "" | "" | "" | "" | "" |
"" | "" | "一个" | "" | "" | "" | "" | "" | "" | "" |
"" | "" | "" | "没有" | "" | "" | "自己" | "" | "" | "" |
"" | "" | "" | "" | "什么" | "" | "所以" | "但是" | "因为" |
"如果" | "可以" | "能够" | "需要" | "应该" | "已经" | "还是" | "或者"
)
}

View File

@@ -42,10 +42,12 @@ struct TaskMessage {
/// Worker 调度器 — 管理所有 Worker 的注册和派发
///
/// 使用 Arc 包装,可安全跨任务共享。
/// 通过 SpawnLimiter 限制并发执行的任务数,防止连接池耗尽。
pub struct WorkerDispatcher {
db: PgPool,
sender: mpsc::Sender<TaskMessage>,
handlers: HashMap<String, Arc<dyn DynWorker>>,
spawn_limiter: crate::state::SpawnLimiter,
}
impl Clone for WorkerDispatcher {
@@ -54,6 +56,7 @@ impl Clone for WorkerDispatcher {
db: self.db.clone(),
sender: self.sender.clone(),
handlers: self.handlers.clone(),
spawn_limiter: self.spawn_limiter.clone(),
}
}
}
@@ -90,7 +93,7 @@ where
impl WorkerDispatcher {
/// 创建新的调度器
pub fn new(db: PgPool) -> Self {
pub fn new(db: PgPool, spawn_limiter: crate::state::SpawnLimiter) -> Self {
// channel 容量 1024足够缓冲高峰期任务
let (sender, receiver) = mpsc::channel(1024);
@@ -98,6 +101,7 @@ impl WorkerDispatcher {
db,
sender,
handlers: HashMap::new(),
spawn_limiter,
};
// 启动消费循环
@@ -152,10 +156,15 @@ impl WorkerDispatcher {
}
/// 启动消费循环
///
/// 通过 SpawnLimiter 门控并发:消费者循环在 spawn 之前获取 permit
/// 信号量满时阻塞消费者循环(而非 spawn 无限任务),提供真正的背压。
/// 重试时先 drop permit 再 sleep避免浪费 permit 在等待期间。
fn start_consumer(&self, mut receiver: mpsc::Receiver<TaskMessage>) {
let db = self.db.clone();
let handlers = self.handlers.clone();
let sender = self.sender.clone();
let limiter = self.spawn_limiter.clone();
tokio::spawn(async move {
while let Some(msg) = receiver.recv().await {
@@ -171,21 +180,34 @@ impl WorkerDispatcher {
let max_retries = handler.max_retries();
let db = db.clone();
let sender = sender.clone();
let limiter = limiter.clone();
// 关键:在 spawn 之前获取 permit
// 信号量满时阻塞消费者循环,限制 tokio::spawn 调用数量
let permit = limiter.acquire().await;
tracing::trace!(
"Worker '{}' acquired permit ({} available), spawning task",
worker_name, limiter.available()
);
tokio::spawn(async move {
// permit 已预获取,任务立即执行
let _permit = permit;
match handler.perform(&db, &msg.args_json).await {
Ok(()) => {
tracing::debug!("Worker {} completed successfully", worker_name);
}
Err(e) => {
if msg.attempt < max_retries {
// 先 drop permit不占用并发配额在 sleep 期间
drop(_permit);
let delay = std::time::Duration::from_secs(1 << msg.attempt.min(4));
tracing::warn!(
"Worker {} failed (attempt {}/{}): {}. Re-queuing after {:?}.",
worker_name, msg.attempt, max_retries, e, delay
);
tokio::time::sleep(delay).await;
// 重新入队(递增 attempt 计数)
let retry_msg = TaskMessage {
worker_name: msg.worker_name.clone(),
args_json: msg.args_json.clone(),
@@ -218,6 +240,8 @@ pub mod cleanup_rate_limit;
pub mod cleanup_refresh_tokens;
pub mod update_last_used;
pub mod record_usage;
pub mod aggregate_usage;
pub mod generate_embedding;
// 便捷导出
pub use log_operation::LogOperationWorker;
@@ -225,3 +249,4 @@ pub use cleanup_rate_limit::CleanupRateLimitWorker;
pub use cleanup_refresh_tokens::CleanupRefreshTokensWorker;
pub use update_last_used::UpdateLastUsedWorker;
pub use record_usage::RecordUsageWorker;
pub use aggregate_usage::AggregateUsageWorker;

View File

@@ -8,7 +8,8 @@ use super::Worker;
#[derive(Debug, Serialize, Deserialize)]
pub struct UpdateLastUsedArgs {
pub token_id: String,
/// token_hash 用于 WHERE 条件匹配
pub token_hash: String,
}
pub struct UpdateLastUsedWorker;
@@ -23,9 +24,9 @@ impl Worker for UpdateLastUsedWorker {
async fn perform(&self, db: &PgPool, args: Self::Args) -> SaasResult<()> {
let now = chrono::Utc::now().to_rfc3339();
sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE id = $2")
sqlx::query("UPDATE api_tokens SET last_used_at = $1 WHERE token_hash = $2")
.bind(&now)
.bind(&args.token_id)
.bind(&args.token_hash)
.execute(db)
.await?;
Ok(())

View File

@@ -0,0 +1,223 @@
//! Classroom multi-agent chat commands
//!
//! - `classroom_chat` — send a message and receive multi-agent responses
//! - `classroom_chat_history` — retrieve chat history for a classroom
use std::sync::Arc;
use tokio::sync::Mutex;
use serde::{Deserialize, Serialize};
use tauri::State;
use zclaw_kernel::generation::{
AgentProfile, AgentRole,
ClassroomChatMessage, ClassroomChatState,
ClassroomChatRequest,
build_chat_prompt, parse_chat_responses,
};
use zclaw_runtime::CompletionRequest;
use super::ClassroomStore;
use crate::kernel_commands::KernelState;
// ---------------------------------------------------------------------------
// State
// ---------------------------------------------------------------------------
/// Chat state store: classroom_id → chat state
pub type ChatStore = Arc<Mutex<std::collections::HashMap<String, ClassroomChatState>>>;
pub fn create_chat_state() -> ChatStore {
Arc::new(Mutex::new(std::collections::HashMap::new()))
}
// ---------------------------------------------------------------------------
// Request / Response
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClassroomChatCmdRequest {
pub classroom_id: String,
pub user_message: String,
pub scene_context: Option<String>,
}
// ---------------------------------------------------------------------------
// Commands
// ---------------------------------------------------------------------------
/// Send a message in the classroom chat and get multi-agent responses.
#[tauri::command]
pub async fn classroom_chat(
store: State<'_, ClassroomStore>,
chat_store: State<'_, ChatStore>,
kernel_state: State<'_, KernelState>,
request: ClassroomChatCmdRequest,
) -> Result<Vec<ClassroomChatMessage>, String> {
if request.user_message.trim().is_empty() {
return Err("Message cannot be empty".to_string());
}
// Get classroom data
let classroom = {
let s = store.lock().await;
s.get(&request.classroom_id)
.cloned()
.ok_or_else(|| format!("Classroom '{}' not found", request.classroom_id))?
};
// Create user message
let user_msg = ClassroomChatMessage::user_message(&request.user_message);
// Get chat history for context
let history: Vec<ClassroomChatMessage> = {
let cs = chat_store.lock().await;
cs.get(&request.classroom_id)
.map(|s| s.messages.clone())
.unwrap_or_default()
};
// Try LLM-powered multi-agent responses, fallback to placeholder
let agent_responses = match generate_llm_responses(&kernel_state, &classroom.agents, &request.user_message, request.scene_context.as_deref(), &history).await {
Ok(responses) => responses,
Err(e) => {
tracing::warn!("LLM chat generation failed, using placeholders: {}", e);
generate_placeholder_responses(
&classroom.agents,
&request.user_message,
request.scene_context.as_deref(),
)
}
};
// Store in chat state
{
let mut cs = chat_store.lock().await;
let state = cs.entry(request.classroom_id.clone())
.or_insert_with(|| ClassroomChatState {
messages: vec![],
active: true,
});
state.messages.push(user_msg);
state.messages.extend(agent_responses.clone());
}
Ok(agent_responses)
}
/// Retrieve chat history for a classroom
#[tauri::command]
pub async fn classroom_chat_history(
chat_store: State<'_, ChatStore>,
classroom_id: String,
) -> Result<Vec<ClassroomChatMessage>, String> {
let cs = chat_store.lock().await;
Ok(cs.get(&classroom_id)
.map(|s| s.messages.clone())
.unwrap_or_default())
}
// ---------------------------------------------------------------------------
// Placeholder response generation
// ---------------------------------------------------------------------------
fn generate_placeholder_responses(
agents: &[AgentProfile],
user_message: &str,
scene_context: Option<&str>,
) -> Vec<ClassroomChatMessage> {
let mut responses = Vec::new();
// Teacher always responds
if let Some(teacher) = agents.iter().find(|a| a.role == AgentRole::Teacher) {
let context_hint = scene_context
.map(|ctx| format!("关于「{}」,", ctx))
.unwrap_or_default();
responses.push(ClassroomChatMessage::agent_message(
teacher,
&format!("{}这是一个很好的问题!让我来详细解释一下「{}」的核心概念...", context_hint, user_message),
));
}
// Assistant chimes in
if let Some(assistant) = agents.iter().find(|a| a.role == AgentRole::Assistant) {
responses.push(ClassroomChatMessage::agent_message(
assistant,
"我来补充一下要点 📌",
));
}
// One student responds
if let Some(student) = agents.iter().find(|a| a.role == AgentRole::Student) {
responses.push(ClassroomChatMessage::agent_message(
student,
&format!("谢谢老师!我大概理解了{}", user_message),
));
}
responses
}
// ---------------------------------------------------------------------------
// LLM-powered response generation
// ---------------------------------------------------------------------------
async fn generate_llm_responses(
kernel_state: &State<'_, KernelState>,
agents: &[AgentProfile],
user_message: &str,
scene_context: Option<&str>,
history: &[ClassroomChatMessage],
) -> Result<Vec<ClassroomChatMessage>, String> {
let driver = {
let ks = kernel_state.lock().await;
ks.as_ref()
.map(|k| k.driver())
.ok_or_else(|| "Kernel not initialized".to_string())?
};
if !driver.is_configured() {
return Err("LLM driver not configured".to_string());
}
// Build the chat request for prompt generation (include history)
let chat_request = ClassroomChatRequest {
classroom_id: String::new(),
user_message: user_message.to_string(),
agents: agents.to_vec(),
scene_context: scene_context.map(|s| s.to_string()),
history: history.to_vec(),
};
let prompt = build_chat_prompt(&chat_request);
let request = CompletionRequest {
model: "default".to_string(),
system: Some("你是一个课堂多智能体讨论的协调器。".to_string()),
messages: vec![zclaw_types::Message::User {
content: prompt,
}],
..Default::default()
};
let response = driver.complete(request).await
.map_err(|e| format!("LLM call failed: {}", e))?;
// Extract text from response
let text = response.content.iter()
.filter_map(|block| match block {
zclaw_runtime::ContentBlock::Text { text } => Some(text.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("");
let responses = parse_chat_responses(&text, agents);
if responses.is_empty() {
return Err("LLM returned no parseable agent responses".to_string());
}
Ok(responses)
}

View File

@@ -0,0 +1,152 @@
//! Classroom export commands
//!
//! - `classroom_export` — export classroom as HTML, Markdown, or JSON
use serde::{Deserialize, Serialize};
use tauri::State;
use zclaw_kernel::generation::Classroom;
use super::ClassroomStore;
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClassroomExportRequest {
pub classroom_id: String,
pub format: String, // "html" | "markdown" | "json"
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClassroomExportResponse {
pub content: String,
pub filename: String,
pub mime_type: String,
}
// ---------------------------------------------------------------------------
// Command
// ---------------------------------------------------------------------------
#[tauri::command]
pub async fn classroom_export(
store: State<'_, ClassroomStore>,
request: ClassroomExportRequest,
) -> Result<ClassroomExportResponse, String> {
let classroom = {
let s = store.lock().await;
s.get(&request.classroom_id)
.cloned()
.ok_or_else(|| format!("Classroom '{}' not found", request.classroom_id))?
};
match request.format.as_str() {
"json" => export_json(&classroom),
"html" => export_html(&classroom),
"markdown" | "md" => export_markdown(&classroom),
_ => Err(format!("Unsupported export format: '{}'. Use html, markdown, or json.", request.format)),
}
}
// ---------------------------------------------------------------------------
// Exporters
// ---------------------------------------------------------------------------
fn export_json(classroom: &Classroom) -> Result<ClassroomExportResponse, String> {
let content = serde_json::to_string_pretty(classroom)
.map_err(|e| format!("JSON serialization failed: {}", e))?;
Ok(ClassroomExportResponse {
filename: format!("{}.json", sanitize_filename(&classroom.title)),
content,
mime_type: "application/json".to_string(),
})
}
fn export_html(classroom: &Classroom) -> Result<ClassroomExportResponse, String> {
let mut html = String::from(r#"<!DOCTYPE html><html lang="zh-CN"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width,initial-scale=1">"#);
html.push_str(&format!("<title>{}</title>", html_escape(&classroom.title)));
html.push_str(r#"<style>body{font-family:system-ui,sans-serif;max-width:800px;margin:0 auto;padding:2rem;color:#333}h1{color:#4F46E5}h2{color:#7C3AED;border-bottom:2px solid #E5E7EB;padding-bottom:0.5rem}.scene{margin:2rem 0;padding:1rem;border-left:4px solid #4F46E5;background:#F9FAFB}.quiz{border-left-color:#F59E0B}.discussion{border-left-color:#10B981}.agent{display:inline-flex;align-items:center;gap:0.5rem;margin:0.25rem;padding:0.25rem 0.75rem;border-radius:9999px;font-size:0.875rem;font-weight:500}</style></head><body>"#);
html.push_str(&format!("<h1>{}</h1>", html_escape(&classroom.title)));
html.push_str(&format!("<p>{}</p>", html_escape(&classroom.description)));
// Agents
html.push_str("<h2>课堂角色</h2><div>");
for agent in &classroom.agents {
html.push_str(&format!(
r#"<span class="agent" style="background:{};color:white">{} {}</span>"#,
agent.color, agent.avatar, html_escape(&agent.name)
));
}
html.push_str("</div>");
// Scenes
html.push_str("<h2>课程内容</h2>");
for scene in &classroom.scenes {
let type_class = match scene.content.scene_type {
zclaw_kernel::generation::SceneType::Quiz => "quiz",
zclaw_kernel::generation::SceneType::Discussion => "discussion",
_ => "",
};
html.push_str(&format!(
r#"<div class="scene {}"><h3>{}</h3><p>类型: {:?} | 时长: {}秒</p></div>"#,
type_class,
html_escape(&scene.content.title),
scene.content.scene_type,
scene.content.duration_seconds
));
}
html.push_str("</body></html>");
Ok(ClassroomExportResponse {
filename: format!("{}.html", sanitize_filename(&classroom.title)),
content: html,
mime_type: "text/html".to_string(),
})
}
fn export_markdown(classroom: &Classroom) -> Result<ClassroomExportResponse, String> {
let mut md = String::new();
md.push_str(&format!("# {}\n\n", &classroom.title));
md.push_str(&format!("{}\n\n", &classroom.description));
md.push_str("## 课堂角色\n\n");
for agent in &classroom.agents {
md.push_str(&format!("- {} **{}** ({:?})\n", agent.avatar, agent.name, agent.role));
}
md.push('\n');
md.push_str("## 课程内容\n\n");
for (i, scene) in classroom.scenes.iter().enumerate() {
md.push_str(&format!("### {}. {}\n\n", i + 1, scene.content.title));
md.push_str(&format!("- 类型: `{:?}`\n", scene.content.scene_type));
md.push_str(&format!("- 时长: {}\n\n", scene.content.duration_seconds));
}
Ok(ClassroomExportResponse {
filename: format!("{}.md", sanitize_filename(&classroom.title)),
content: md,
mime_type: "text/markdown".to_string(),
})
}
fn sanitize_filename(name: &str) -> String {
name.chars()
.map(|c| if c.is_alphanumeric() || c == '-' || c == '_' { c } else { '_' })
.collect::<String>()
.trim_matches('_')
.to_string()
}
fn html_escape(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
}

View File

@@ -0,0 +1,286 @@
//! Classroom generation commands
//!
//! - `classroom_generate` — start 4-stage pipeline, emit progress events
//! - `classroom_generation_progress` — query current progress
//! - `classroom_cancel_generation` — cancel active generation
//! - `classroom_get` — retrieve generated classroom data
//! - `classroom_list` — list all generated classrooms
use serde::{Deserialize, Serialize};
use tauri::{AppHandle, Emitter, State};
use zclaw_kernel::generation::{
Classroom, GenerationPipeline, GenerationRequest as KernelGenRequest, GenerationStage,
TeachingStyle, DifficultyLevel,
};
use super::{ClassroomStore, GenerationTasks};
use crate::kernel_commands::KernelState;
// ---------------------------------------------------------------------------
// Request / Response types
// ---------------------------------------------------------------------------
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClassroomGenerateRequest {
pub topic: String,
pub document: Option<String>,
pub style: Option<String>,
pub level: Option<String>,
pub target_duration_minutes: Option<u32>,
pub scene_count: Option<usize>,
pub custom_instructions: Option<String>,
pub language: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClassroomGenerateResponse {
pub classroom_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClassroomProgressResponse {
pub stage: String,
pub progress: u8,
pub activity: String,
pub items_progress: Option<(usize, usize)>,
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
fn parse_style(s: Option<&str>) -> TeachingStyle {
match s.unwrap_or("lecture") {
"discussion" => TeachingStyle::Discussion,
"pbl" => TeachingStyle::Pbl,
"flipped" => TeachingStyle::Flipped,
"socratic" => TeachingStyle::Socratic,
_ => TeachingStyle::Lecture,
}
}
fn parse_level(l: Option<&str>) -> DifficultyLevel {
match l.unwrap_or("intermediate") {
"beginner" => DifficultyLevel::Beginner,
"advanced" => DifficultyLevel::Advanced,
"expert" => DifficultyLevel::Expert,
_ => DifficultyLevel::Intermediate,
}
}
fn stage_name(stage: &GenerationStage) -> &'static str {
match stage {
GenerationStage::AgentProfiles => "agent_profiles",
GenerationStage::Outline => "outline",
GenerationStage::Scene => "scene",
GenerationStage::Complete => "complete",
}
}
// ---------------------------------------------------------------------------
// Commands
// ---------------------------------------------------------------------------
/// Start classroom generation (4-stage pipeline).
/// Progress events are emitted via `classroom:progress`.
/// Supports cancellation between stages by removing the task from GenerationTasks.
#[tauri::command]
pub async fn classroom_generate(
app: AppHandle,
store: State<'_, ClassroomStore>,
tasks: State<'_, GenerationTasks>,
kernel_state: State<'_, KernelState>,
request: ClassroomGenerateRequest,
) -> Result<ClassroomGenerateResponse, String> {
if request.topic.trim().is_empty() {
return Err("Topic is required".to_string());
}
let topic_clone = request.topic.clone();
let kernel_request = KernelGenRequest {
topic: request.topic.clone(),
document: request.document.clone(),
style: parse_style(request.style.as_deref()),
level: parse_level(request.level.as_deref()),
target_duration_minutes: request.target_duration_minutes.unwrap_or(30),
scene_count: request.scene_count,
custom_instructions: request.custom_instructions.clone(),
language: request.language.clone().or_else(|| Some("zh-CN".to_string())),
};
// Register generation task so cancellation can check it
{
use zclaw_kernel::generation::GenerationProgress;
let mut t = tasks.lock().await;
t.insert(topic_clone.clone(), GenerationProgress {
stage: zclaw_kernel::generation::GenerationStage::AgentProfiles,
progress: 0,
activity: "Starting generation...".to_string(),
items_progress: None,
eta_seconds: None,
});
}
// Get LLM driver from kernel if available, otherwise use placeholder mode
let pipeline = {
let ks = kernel_state.lock().await;
if let Some(kernel) = ks.as_ref() {
GenerationPipeline::with_driver(kernel.driver())
} else {
GenerationPipeline::new()
}
};
// Helper: check if cancelled
let is_cancelled = || {
let t = tasks.blocking_lock();
!t.contains_key(&topic_clone)
};
// Helper: emit progress event
let emit_progress = |stage: &str, progress: u8, activity: &str| {
let _ = app.emit("classroom:progress", serde_json::json!({
"topic": &topic_clone,
"stage": stage,
"progress": progress,
"activity": activity
}));
};
// ── Stage 0: Agent Profiles ──
emit_progress("agent_profiles", 5, "生成课堂角色...");
let agents = pipeline.generate_agent_profiles(&kernel_request).await;
emit_progress("agent_profiles", 25, "角色生成完成");
if is_cancelled() {
return Err("Generation cancelled".to_string());
}
// ── Stage 1: Outline ──
emit_progress("outline", 30, "分析主题,生成大纲...");
let outline = pipeline.generate_outline(&kernel_request).await
.map_err(|e| format!("Outline generation failed: {}", e))?;
emit_progress("outline", 50, &format!("大纲完成:{} 个场景", outline.len()));
if is_cancelled() {
return Err("Generation cancelled".to_string());
}
// ── Stage 2: Scenes (parallel) ──
emit_progress("scene", 55, &format!("并行生成 {} 个场景...", outline.len()));
let scenes = pipeline.generate_scenes(&outline).await
.map_err(|e| format!("Scene generation failed: {}", e))?;
if is_cancelled() {
return Err("Generation cancelled".to_string());
}
// ── Stage 3: Assemble ──
emit_progress("complete", 90, "组装课堂...");
// Build classroom directly (pipeline.build_classroom is private)
let total_duration: u32 = scenes.iter().map(|s| s.content.duration_seconds).sum();
let objectives = outline.iter()
.take(3)
.map(|item| format!("理解: {}", item.title))
.collect::<Vec<_>>();
let classroom_id = uuid::Uuid::new_v4().to_string();
let classroom = Classroom {
id: classroom_id.clone(),
title: format!("课堂: {}", kernel_request.topic),
description: format!("{:?} 风格课堂 — {}", kernel_request.style, kernel_request.topic),
topic: kernel_request.topic.clone(),
style: kernel_request.style,
level: kernel_request.level,
total_duration,
objectives,
scenes,
agents,
metadata: zclaw_kernel::generation::ClassroomMetadata {
generated_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis() as i64,
source_document: kernel_request.document.map(|_| "user_document".to_string()),
model: None,
version: "2.0.0".to_string(),
custom: serde_json::Map::new(),
},
};
// Store classroom
{
let mut s = store.lock().await;
s.insert(classroom_id.clone(), classroom);
}
// Clear generation task
{
let mut t = tasks.lock().await;
t.remove(&topic_clone);
}
// Emit completion
emit_progress("complete", 100, "课堂生成完成");
Ok(ClassroomGenerateResponse {
classroom_id,
})
}
/// Get current generation progress for a topic
#[tauri::command]
pub async fn classroom_generation_progress(
tasks: State<'_, GenerationTasks>,
topic: String,
) -> Result<ClassroomProgressResponse, String> {
let t = tasks.lock().await;
let progress = t.get(&topic);
Ok(ClassroomProgressResponse {
stage: progress.map(|p| stage_name(&p.stage).to_string()).unwrap_or_else(|| "none".to_string()),
progress: progress.map(|p| p.progress).unwrap_or(0),
activity: progress.map(|p| p.activity.clone()).unwrap_or_default(),
items_progress: progress.and_then(|p| p.items_progress),
})
}
/// Cancel an active generation
#[tauri::command]
pub async fn classroom_cancel_generation(
tasks: State<'_, GenerationTasks>,
topic: String,
) -> Result<(), String> {
let mut t = tasks.lock().await;
t.remove(&topic);
Ok(())
}
/// Retrieve a generated classroom by ID
#[tauri::command]
pub async fn classroom_get(
store: State<'_, ClassroomStore>,
classroom_id: String,
) -> Result<Classroom, String> {
let s = store.lock().await;
s.get(&classroom_id)
.cloned()
.ok_or_else(|| format!("Classroom '{}' not found", classroom_id))
}
/// List all generated classrooms (id + title only)
#[tauri::command]
pub async fn classroom_list(
store: State<'_, ClassroomStore>,
) -> Result<Vec<serde_json::Value>, String> {
let s = store.lock().await;
Ok(s.values().map(|c| serde_json::json!({
"id": c.id,
"title": c.title,
"topic": c.topic,
"totalDuration": c.total_duration,
"sceneCount": c.scenes.len(),
})).collect())
}

View File

@@ -0,0 +1,41 @@
//! Classroom generation and interaction commands
//!
//! Tauri commands for the OpenMAIC-style interactive classroom:
//! - Generate classroom (4-stage pipeline with progress events)
//! - Multi-agent chat
//! - Export (HTML/Markdown/JSON)
use std::sync::Arc;
use tokio::sync::Mutex;
use zclaw_kernel::generation::Classroom;
pub mod chat;
pub mod export;
pub mod generate;
// ---------------------------------------------------------------------------
// Shared state types
// ---------------------------------------------------------------------------
/// In-memory classroom store: classroom_id → Classroom
pub type ClassroomStore = Arc<Mutex<std::collections::HashMap<String, Classroom>>>;
/// Active generation tasks: topic → progress
pub type GenerationTasks = Arc<Mutex<std::collections::HashMap<String, zclaw_kernel::generation::GenerationProgress>>>;
// Re-export chat state type
// Re-export chat state type — used by lib.rs to construct managed state
#[allow(unused_imports)]
pub use chat::ChatStore;
// ---------------------------------------------------------------------------
// State constructors
// ---------------------------------------------------------------------------
pub fn create_classroom_state() -> ClassroomStore {
Arc::new(Mutex::new(std::collections::HashMap::new()))
}
pub fn create_generation_tasks() -> GenerationTasks {
Arc::new(Mutex::new(std::collections::HashMap::new()))
}

View File

@@ -258,11 +258,18 @@ impl AgentIdentityManager {
if !identity.instructions.is_empty() {
sections.push(identity.instructions.clone());
}
if !identity.user_profile.is_empty()
&& identity.user_profile != default_user_profile()
{
sections.push(format!("## 用户画像\n{}", identity.user_profile));
}
// NOTE: user_profile injection is intentionally disabled.
// The reflection engine may accumulate overly specific details from past
// conversations (e.g., "广东光华", "汕头玩具产业") into user_profile.
// These details then leak into every new conversation's system prompt,
// causing the model to think about old topics instead of the current query.
// Memory injection should only happen via MemoryMiddleware with relevance
// filtering, not unconditionally via user_profile.
// if !identity.user_profile.is_empty()
// && identity.user_profile != default_user_profile()
// {
// sections.push(format!("## 用户画像\n{}", identity.user_profile));
// }
if let Some(ctx) = memory_context {
sections.push(ctx.to_string());
}

View File

@@ -34,6 +34,7 @@ pub struct ChatResponse {
#[serde(rename_all = "camelCase", tag = "type")]
pub enum StreamChatEvent {
Delta { delta: String },
ThinkingDelta { delta: String },
ToolStart { name: String, input: serde_json::Value },
ToolEnd { name: String, output: serde_json::Value },
IterationStart { iteration: usize, max_iterations: usize },
@@ -218,6 +219,10 @@ pub async fn agent_chat_stream(
tracing::trace!("[agent_chat_stream] Delta: {} bytes", delta.len());
StreamChatEvent::Delta { delta: delta.clone() }
}
LoopEvent::ThinkingDelta(delta) => {
tracing::trace!("[agent_chat_stream] ThinkingDelta: {} bytes", delta.len());
StreamChatEvent::ThinkingDelta { delta: delta.clone() }
}
LoopEvent::ToolStart { name, input } => {
tracing::debug!("[agent_chat_stream] ToolStart: {}", name);
if name.starts_with("hand_") {

View File

@@ -249,3 +249,130 @@ pub async fn kernel_shutdown(
Ok(())
}
/// Apply SaaS-synced configuration to the Kernel config file.
///
/// Writes relevant config values (agent, llm categories) to the TOML config file.
/// The changes take effect on the next Kernel restart.
#[tauri::command]
pub async fn kernel_apply_saas_config(
configs: Vec<SaasConfigItem>,
) -> Result<u32, String> {
use std::io::Write;
let config_path = zclaw_kernel::config::KernelConfig::find_config_path()
.ok_or_else(|| "No config file path found".to_string())?;
// Read existing config or create empty
let existing = if config_path.exists() {
std::fs::read_to_string(&config_path).unwrap_or_default()
} else {
String::new()
};
let mut updated = existing;
let mut applied: u32 = 0;
for config in &configs {
// Only process kernel-relevant categories
if !matches!(config.category.as_str(), "agent" | "llm") {
continue;
}
// Write key=value to the [llm] or [agent] section
let section = &config.category;
let key = config.key.replace('.', "_");
let value = &config.value;
// Simple TOML patching: find or create section, update key
let section_header = format!("[{}]", section);
let line_to_set = format!("{} = {}", key, toml_quote_value(value));
if let Some(section_start) = updated.find(&section_header) {
// Section exists, find or add the key within it
let after_header = section_start + section_header.len();
let next_section = updated[after_header..].find("\n[")
.map(|i| after_header + i)
.unwrap_or(updated.len());
let section_content = &updated[after_header..next_section];
let key_prefix = format!("\n{} =", key);
let key_prefix_alt = format!("\n{}=", key);
if let Some(key_pos) = section_content.find(&key_prefix)
.or_else(|| section_content.find(&key_prefix_alt))
{
// Key exists, replace the line
let line_start = after_header + key_pos + 1; // skip \n
let line_end = updated[line_start..].find('\n')
.map(|i| line_start + i)
.unwrap_or(updated.len());
updated = format!(
"{}{}{}\n{}",
&updated[..line_start],
line_to_set,
if line_end < updated.len() { "" } else { "" },
&updated[line_end..]
);
// Remove the extra newline if line_end included one
updated = updated.replace(&format!("{}\n\n", line_to_set), &format!("{}\n", line_to_set));
} else {
// Key doesn't exist, append to section
updated.insert_str(next_section, format!("\n{}", line_to_set).as_str());
}
} else {
// Section doesn't exist, append it
updated = format!("{}\n{}\n{}\n", updated.trim_end(), section_header, line_to_set);
}
applied += 1;
}
if applied > 0 {
// Ensure parent directory exists
if let Some(parent) = config_path.parent() {
std::fs::create_dir_all(parent).map_err(|e| format!("Failed to create config dir: {}", e))?;
}
let mut file = std::fs::File::create(&config_path)
.map_err(|e| format!("Failed to write config: {}", e))?;
file.write_all(updated.as_bytes())
.map_err(|e| format!("Failed to write config: {}", e))?;
tracing::info!(
"[kernel_apply_saas_config] Applied {} config items to {:?} (restart required)",
applied,
config_path
);
}
Ok(applied)
}
/// Single config item from SaaS sync
#[derive(Debug, Clone, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SaasConfigItem {
pub category: String,
pub key: String,
pub value: String,
}
/// Quote a value for TOML format
fn toml_quote_value(value: &str) -> String {
// Try to parse as number or boolean
if value == "true" || value == "false" {
return value.to_string();
}
if let Ok(n) = value.parse::<i64>() {
return n.to_string();
}
if let Ok(n) = value.parse::<f64>() {
return n.to_string();
}
// Handle multi-line strings with TOML triple-quote syntax
if value.contains('\n') {
return format!("\"\"\"\n{}\"\"\"", value.replace('\\', "\\\\").replace("\"\"\"", "'\"'\"'\""));
}
// Default: quote as string
format!("\"{}\"", value.replace('\\', "\\\\").replace('"', "\\\""))
}

View File

@@ -34,6 +34,9 @@ mod kernel_commands;
// Pipeline commands (DSL-based workflows)
mod pipeline_commands;
// Classroom generation and interaction commands
mod classroom_commands;
// Gateway sub-modules (runtime, config, io, commands)
mod gateway;
@@ -99,6 +102,11 @@ pub fn run() {
// Initialize Pipeline state (DSL-based workflows)
let pipeline_state = pipeline_commands::create_pipeline_state();
// Initialize Classroom state (generation + chat)
let classroom_state = classroom_commands::create_classroom_state();
let classroom_chat_state = classroom_commands::chat::create_chat_state();
let classroom_gen_tasks = classroom_commands::create_generation_tasks();
tauri::Builder::default()
.plugin(tauri_plugin_opener::init())
.manage(browser_state)
@@ -110,11 +118,15 @@ pub fn run() {
.manage(scheduler_state)
.manage(kernel_commands::SessionStreamGuard::default())
.manage(pipeline_state)
.manage(classroom_state)
.manage(classroom_chat_state)
.manage(classroom_gen_tasks)
.invoke_handler(tauri::generate_handler![
// Internal ZCLAW Kernel commands (preferred)
kernel_commands::lifecycle::kernel_init,
kernel_commands::lifecycle::kernel_status,
kernel_commands::lifecycle::kernel_shutdown,
kernel_commands::lifecycle::kernel_apply_saas_config,
kernel_commands::agent::agent_create,
kernel_commands::agent::agent_list,
kernel_commands::agent::agent_get,
@@ -300,7 +312,16 @@ pub fn run() {
intelligence::identity::identity_get_snapshots,
intelligence::identity::identity_restore_snapshot,
intelligence::identity::identity_list_agents,
intelligence::identity::identity_delete_agent
intelligence::identity::identity_delete_agent,
// Classroom generation and interaction commands
classroom_commands::generate::classroom_generate,
classroom_commands::generate::classroom_generation_progress,
classroom_commands::generate::classroom_cancel_generation,
classroom_commands::generate::classroom_get,
classroom_commands::generate::classroom_list,
classroom_commands::chat::classroom_chat,
classroom_commands::chat::classroom_chat_history,
classroom_commands::export::classroom_export
])
.run(tauri::generate_context!())
.expect("error while running tauri application");

View File

@@ -29,6 +29,7 @@ import { useProposalNotifications, ProposalNotificationHandler } from './lib/use
import { useToast } from './components/ui/Toast';
import type { Clone } from './store/agentStore';
import { createLogger } from './lib/logger';
import { startOfflineMonitor } from './store/offlineStore';
const log = createLogger('App');
@@ -86,6 +87,8 @@ function App() {
useEffect(() => {
document.title = 'ZCLAW';
const stopOfflineMonitor = startOfflineMonitor();
return () => { stopOfflineMonitor(); };
}, []);
// Restore SaaS session from OS keyring on startup (before auth gate)
@@ -152,8 +155,11 @@ function App() {
let mounted = true;
const bootstrap = async () => {
// 未登录时不启动 bootstrap
if (!useSaaSStore.getState().isLoggedIn) return;
// 未登录时不启动 bootstrap,直接结束 loading
if (!useSaaSStore.getState().isLoggedIn) {
setBootstrapping(false);
return;
}
try {
// Step 1: Check and start local gateway in Tauri environment

View File

@@ -2,6 +2,7 @@ import { useState, useEffect, useRef, useCallback, useMemo, type MutableRefObjec
import { motion, AnimatePresence } from 'framer-motion';
import { List, type ListImperativeAPI } from 'react-window';
import { useChatStore, Message } from '../store/chatStore';
import { useArtifactStore } from '../store/chat/artifactStore';
import { useConnectionStore } from '../store/connectionStore';
import { useAgentStore } from '../store/agentStore';
import { useConfigStore } from '../store/configStore';
@@ -12,6 +13,8 @@ import { ArtifactPanel } from './ai/ArtifactPanel';
import { ToolCallChain } from './ai/ToolCallChain';
import { listItemVariants, defaultTransition, fadeInVariants } from '../lib/animations';
import { FirstConversationPrompt } from './FirstConversationPrompt';
import { ClassroomPlayer } from './classroom_player';
import { useClassroomStore } from '../store/classroomStore';
// MessageSearch temporarily removed during DeerFlow redesign
import { OfflineIndicator } from './OfflineIndicator';
import {
@@ -45,11 +48,14 @@ export function ChatArea() {
messages, currentAgent, isStreaming, isLoading, currentModel,
sendMessage: sendToGateway, setCurrentModel, initStreamListener,
newConversation, chatMode, setChatMode, suggestions,
artifacts, selectedArtifactId, artifactPanelOpen,
selectArtifact, setArtifactPanelOpen,
totalInputTokens, totalOutputTokens,
} = useChatStore();
const {
artifacts, selectedArtifactId, artifactPanelOpen,
selectArtifact, setArtifactPanelOpen,
} = useArtifactStore();
const connectionState = useConnectionStore((s) => s.connectionState);
const { activeClassroom, classroomOpen, closeClassroom, generating, progressPercent, progressActivity, error: classroomError, clearError: clearClassroomError } = useClassroomStore();
const clones = useAgentStore((s) => s.clones);
const models = useConfigStore((s) => s.models);
@@ -203,9 +209,76 @@ export function ChatArea() {
);
return (
<div className="relative h-full">
{/* Generation progress overlay */}
<AnimatePresence>
{generating && (
<motion.div
key="generation-overlay"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
className="absolute inset-0 z-40 bg-white/80 dark:bg-gray-900/80 backdrop-blur-sm flex items-center justify-center"
>
<div className="text-center space-y-4">
<div className="w-12 h-12 border-4 border-indigo-200 border-t-indigo-500 rounded-full animate-spin mx-auto" />
<div>
<p className="text-lg font-medium text-gray-900 dark:text-white">
...
</p>
<p className="text-sm text-gray-500 dark:text-gray-400 mt-1">
{progressActivity || '准备中...'}
</p>
</div>
{progressPercent > 0 && (
<div className="w-64 mx-auto">
<div className="h-2 bg-gray-200 dark:bg-gray-700 rounded-full overflow-hidden">
<div
className="h-full bg-indigo-500 rounded-full transition-all duration-500"
style={{ width: `${progressPercent}%` }}
/>
</div>
<p className="text-xs text-gray-400 mt-1">{progressPercent}%</p>
</div>
)}
<button
onClick={() => useClassroomStore.getState().cancelGeneration()}
className="px-4 py-2 text-sm text-gray-500 hover:text-gray-700 dark:hover:text-gray-300 border border-gray-300 dark:border-gray-600 rounded-lg"
>
</button>
</div>
</motion.div>
)}
</AnimatePresence>
{/* ClassroomPlayer overlay */}
<AnimatePresence>
{classroomOpen && activeClassroom && (
<motion.div
key="classroom-overlay"
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
exit={{ opacity: 0 }}
className="absolute inset-0 z-50 bg-white dark:bg-gray-900"
>
<ClassroomPlayer
onClose={closeClassroom}
/>
</motion.div>
)}
</AnimatePresence>
<ResizableChatLayout
chatPanel={
<div className="flex flex-col h-full">
{/* Classroom generation error banner */}
{classroomError && (
<div className="mx-4 mt-2 px-4 py-2 bg-red-50 dark:bg-red-900/20 border border-red-200 dark:border-red-800 rounded-lg flex items-center justify-between text-sm">
<span className="text-red-600 dark:text-red-400">: {classroomError}</span>
<button onClick={clearClassroomError} className="text-red-400 hover:text-red-600 ml-3 text-xs"></button>
</div>
)}
{/* Header — DeerFlow-style: minimal */}
<div className="h-14 border-b border-transparent flex items-center justify-between px-6 flex-shrink-0 bg-white dark:bg-gray-900">
<div className="flex items-center gap-2 text-sm text-gray-500">
@@ -298,6 +371,7 @@ export function ChatArea() {
getHeight={getHeight}
onHeightChange={setHeight}
messageRefs={messageRefs}
setInput={setInput}
/>
) : (
messages.map((message) => (
@@ -310,7 +384,7 @@ export function ChatArea() {
layout
transition={defaultTransition}
>
<MessageBubble message={message} />
<MessageBubble message={message} setInput={setInput} />
</motion.div>
))
)}
@@ -433,19 +507,16 @@ export function ChatArea() {
rightPanelOpen={artifactPanelOpen}
onRightPanelToggle={setArtifactPanelOpen}
/>
</div>
);
}
function MessageBubble({ message }: { message: Message }) {
// Tool messages are now absorbed into the assistant message's toolSteps chain.
// Legacy standalone tool messages (from older sessions) still render as before.
function MessageBubble({ message, setInput }: { message: Message; setInput: (text: string) => void }) {
if (message.role === 'tool') {
return null;
}
const isUser = message.role === 'user';
// 思考中状态streaming 且内容为空时显示思考指示器
const isThinking = message.streaming && !message.content;
// Download message as Markdown file
@@ -518,7 +589,20 @@ function MessageBubble({ message }: { message: Message }) {
: '...'}
</div>
{message.error && (
<p className="text-xs text-red-500 mt-2">{message.error}</p>
<div className="flex items-center gap-2 mt-2">
<p className="text-xs text-red-500">{message.error}</p>
<button
onClick={() => {
const text = typeof message.content === 'string' ? message.content : '';
if (text) {
setInput(text);
}
}}
className="text-xs px-2 py-0.5 rounded bg-red-100 dark:bg-red-900/30 text-red-600 dark:text-red-400 hover:bg-red-200 dark:hover:bg-red-900/50 transition-colors"
>
</button>
</div>
)}
{/* Download button for AI messages - show on hover */}
{!isUser && message.content && !message.streaming && (
@@ -543,6 +627,7 @@ interface VirtualizedMessageRowProps {
message: Message;
onHeightChange: (height: number) => void;
messageRefs: MutableRefObject<Map<string, HTMLDivElement>>;
setInput: (text: string) => void;
}
/**
@@ -553,6 +638,7 @@ function VirtualizedMessageRow({
message,
onHeightChange,
messageRefs,
setInput,
style,
ariaAttributes,
}: VirtualizedMessageRowProps & {
@@ -587,7 +673,7 @@ function VirtualizedMessageRow({
className="py-3"
{...ariaAttributes}
>
<MessageBubble message={message} />
<MessageBubble message={message} setInput={setInput} />
</div>
);
}
@@ -598,6 +684,7 @@ interface VirtualizedMessageListProps {
getHeight: (id: string, role: string) => number;
onHeightChange: (id: string, height: number) => void;
messageRefs: MutableRefObject<Map<string, HTMLDivElement>>;
setInput: (text: string) => void;
}
/**
@@ -610,6 +697,7 @@ function VirtualizedMessageList({
getHeight,
onHeightChange,
messageRefs,
setInput,
}: VirtualizedMessageListProps) {
// Row component for react-window v2
const RowComponent = (props: {
@@ -625,6 +713,7 @@ function VirtualizedMessageList({
message={messages[props.index]}
onHeightChange={(h) => onHeightChange(messages[props.index].id, h)}
messageRefs={messageRefs}
setInput={setInput}
style={props.style}
ariaAttributes={props.ariaAttributes}
/>

View File

@@ -67,6 +67,7 @@ interface ClassroomPreviewerProps {
data: ClassroomData;
onClose?: () => void;
onExport?: (format: 'pptx' | 'html' | 'pdf') => void;
onOpenFullPlayer?: () => void;
}
// === Sub-Components ===
@@ -271,6 +272,7 @@ function OutlinePanel({
export function ClassroomPreviewer({
data,
onExport,
onOpenFullPlayer,
}: ClassroomPreviewerProps) {
const [currentSceneIndex, setCurrentSceneIndex] = useState(0);
const [isPlaying, setIsPlaying] = useState(false);
@@ -398,6 +400,15 @@ export function ClassroomPreviewer({
</p>
</div>
<div className="flex items-center gap-2">
{onOpenFullPlayer && (
<button
onClick={onOpenFullPlayer}
className="flex items-center gap-1.5 px-3 py-1.5 text-sm bg-indigo-100 dark:bg-indigo-900/30 text-indigo-700 dark:text-indigo-300 rounded-md hover:bg-indigo-200 dark:hover:bg-indigo-900/50 transition-colors"
>
<Play className="w-4 h-4" />
</button>
)}
<button
onClick={() => handleExport('pptx')}
className="flex items-center gap-1.5 px-3 py-1.5 text-sm bg-orange-100 dark:bg-orange-900/30 text-orange-700 dark:text-orange-300 rounded-md hover:bg-orange-200 dark:hover:bg-orange-900/50 transition-colors"

View File

@@ -22,23 +22,26 @@ import {
} from '../lib/personality-presets';
import type { Clone } from '../store/agentStore';
import { useChatStore } from '../store/chatStore';
import { useClassroomStore } from '../store/classroomStore';
import { useHandStore } from '../store/handStore';
// Quick action chip definitions — DeerFlow-style colored pills
// handId maps to actual Hand names in the runtime
const QUICK_ACTIONS = [
{ key: 'surprise', label: '小惊喜', icon: Sparkles, color: 'text-orange-500' },
{ key: 'write', label: '写作', icon: PenLine, color: 'text-blue-500' },
{ key: 'research', label: '研究', icon: Microscope, color: 'text-purple-500' },
{ key: 'collect', label: '收集', icon: Layers, color: 'text-green-500' },
{ key: 'research', label: '研究', icon: Microscope, color: 'text-purple-500', handId: 'researcher' },
{ key: 'collect', label: '收集', icon: Layers, color: 'text-green-500', handId: 'collector' },
{ key: 'learn', label: '学习', icon: GraduationCap, color: 'text-indigo-500' },
];
// Pre-filled prompts for each quick action
// Pre-filled prompts for each quick action — tailored for target industries
const QUICK_ACTION_PROMPTS: Record<string, string> = {
surprise: '给我一个小惊喜吧!来点创意的',
write: '帮我写一篇文章,主题你来定',
research: '帮我做一个深度研究分析',
collect: '帮我收集整理一些有用的信息',
learn: '我想学点新东西,教我一些有趣的知识',
write: '帮我写一份关于"远程医疗行政管理优化方案"的提案大纲',
research: '帮我深度研究"2026年教育数字化转型趋势",包括政策、技术和实践三个维度',
collect: '帮我采集 5 个主流 AI 教育工具的产品信息,对比功能和价格',
learn: '我想了解汕头玩具产业 2026 年出口趋势,能帮我分析一下吗?',
};
interface FirstConversationPromptProps {
@@ -69,6 +72,41 @@ export function FirstConversationPrompt({
});
const handleQuickAction = (key: string) => {
if (key === 'learn') {
// Trigger classroom generation flow
const classroomStore = useClassroomStore.getState();
// Extract a clean topic from the prompt
const prompt = QUICK_ACTION_PROMPTS[key] || '';
const topic = prompt
.replace(/^[你我].*?(想了解|想学|了解|学习|分析|研究|探索)\s*/g, '')
.replace(/[,。?!].*$/g, '')
.replace(/^(能|帮|请|可不可以).*/g, '')
.trim() || '互动课堂';
classroomStore.startGeneration({
topic,
style: 'lecture',
level: 'intermediate',
language: 'zh-CN',
}).catch(() => {
// Error is already stored in classroomStore.error and displayed in ChatArea
});
return;
}
// Check if this action maps to a Hand
const actionDef = QUICK_ACTIONS.find((a) => a.key === key);
if (actionDef?.handId) {
const handStore = useHandStore.getState();
handStore.triggerHand(actionDef.handId, {
action: key === 'research' ? 'report' : 'collect',
query: { query: QUICK_ACTION_PROMPTS[key] || '' },
}).catch(() => {
// Fallback: fill prompt into input bar
onSelectSuggestion?.(QUICK_ACTION_PROMPTS[key] || '你好!');
});
return;
}
const prompt = QUICK_ACTION_PROMPTS[key] || '你好!';
onSelectSuggestion?.(prompt);
};

View File

@@ -25,6 +25,8 @@ import { PipelineRunResponse } from '../lib/pipeline-client';
import { useToast } from './ui/Toast';
import DOMPurify from 'dompurify';
import { ClassroomPreviewer, type ClassroomData } from './ClassroomPreviewer';
import { useClassroomStore } from '../store/classroomStore';
import { adaptToClassroom } from '../lib/classroom-adapter';
// === Types ===
@@ -286,6 +288,11 @@ export function PipelineResultPreview({
// Handle export
handleClassroomExport(format, classroomData);
}}
onOpenFullPlayer={() => {
const classroom = adaptToClassroom(classroomData);
useClassroomStore.getState().setActiveClassroom(classroom);
useClassroomStore.getState().openClassroom();
}}
/>
</div>
);

View File

@@ -18,6 +18,7 @@ import {
Filter,
X,
} from 'lucide-react';
import { PipelineResultPreview } from './PipelineResultPreview';
import {
PipelineClient,
PipelineInfo,
@@ -28,7 +29,7 @@ import {
formatInputType,
} from '../lib/pipeline-client';
import { useToast } from './ui/Toast';
import { PresentationContainer } from './presentation';
import { saasClient } from '../lib/saas-client';
// === Category Badge Component ===
@@ -117,64 +118,6 @@ function PipelineCard({ pipeline, onRun }: PipelineCardProps) {
);
}
// === Pipeline Result Modal ===
interface ResultModalProps {
result: PipelineRunResponse;
pipeline: PipelineInfo;
onClose: () => void;
}
function ResultModal({ result, pipeline, onClose }: ResultModalProps) {
return (
<div className="fixed inset-0 bg-black/50 flex items-center justify-center z-50">
<div className="bg-white dark:bg-gray-800 rounded-lg shadow-xl w-[90vw] max-w-4xl h-[85vh] flex flex-col mx-4">
{/* Header */}
<div className="flex items-center justify-between p-4 border-b border-gray-200 dark:border-gray-700">
<div className="flex items-center gap-3">
<span className="text-2xl">{pipeline.icon}</span>
<div>
<h2 className="text-lg font-semibold text-gray-900 dark:text-white">
{pipeline.displayName} -
</h2>
<p className="text-sm text-gray-500 dark:text-gray-400">
: {result.status === 'completed' ? '已完成' : '失败'}
</p>
</div>
</div>
<button
onClick={onClose}
className="p-1 hover:bg-gray-100 dark:hover:bg-gray-700 rounded"
>
<X className="w-5 h-5 text-gray-500" />
</button>
</div>
{/* Content */}
<div className="flex-1 overflow-hidden">
{result.outputs ? (
<PresentationContainer
data={result.outputs}
pipelineId={pipeline.id}
supportedTypes={['document', 'chart', 'quiz', 'slideshow']}
/>
) : result.error ? (
<div className="p-6 text-center text-red-500">
<XCircle className="w-8 h-8 mx-auto mb-2" />
<p>{result.error}</p>
</div>
) : (
<div className="p-6 text-center text-gray-500">
<Package className="w-8 h-8 mx-auto mb-2" />
<p></p>
</div>
)}
</div>
</div>
</div>
);
}
// === Pipeline Run Modal ===
interface RunModalProps {
@@ -489,6 +432,13 @@ export function PipelinesPanel() {
if (result.status === 'completed') {
toast('Pipeline 执行完成', 'success');
setRunResult({ result, pipeline: selectedPipeline! });
// Report pipeline execution to billing (fire-and-forget)
try {
if (saasClient.isAuthenticated()) {
saasClient.reportUsageFireAndForget('pipeline_runs');
}
} catch { /* billing reporting must never block */ }
} else {
toast(`Pipeline 执行失败: ${result.error}`, 'error');
}
@@ -602,11 +552,11 @@ export function PipelinesPanel() {
/>
)}
{/* Result Modal */}
{/* Result Preview */}
{runResult && (
<ResultModal
<PipelineResultPreview
result={runResult.result}
pipeline={runResult.pipeline}
pipelineId={runResult.pipeline.id}
onClose={() => setRunResult(null)}
/>
)}

View File

@@ -109,7 +109,7 @@ export function Conversation({ children, className = '' }: ConversationProps) {
<div
ref={containerRef}
onScroll={handleScroll}
className={`overflow-y-auto custom-scrollbar ${className}`}
className={`overflow-y-auto custom-scrollbar min-h-0 ${className}`}
>
{children}
</div>

View File

@@ -62,7 +62,7 @@ export function ResizableChatLayout({
if (!rightPanelOpen || !rightPanel) {
return (
<div className="flex-1 flex flex-col overflow-hidden relative">
<div className="h-full flex flex-col overflow-hidden relative">
{chatPanel}
<button
onClick={handleToggle}
@@ -76,7 +76,7 @@ export function ResizableChatLayout({
}
return (
<div className="flex-1 flex flex-col overflow-hidden">
<div className="h-full flex flex-col overflow-hidden">
<Group
orientation="horizontal"
onLayoutChanged={(layout) => savePanelSizes(layout)}

View File

@@ -0,0 +1,121 @@
/**
* AgentChat — Multi-agent chat panel for classroom interaction.
*
* Displays chat bubbles from different agents (teacher, assistant, students)
* with distinct colors and avatars. Users can send messages.
*/
import { useState, useRef, useEffect } from 'react';
import type { ClassroomChatMessage as ChatMessage, AgentProfile } from '../../types/classroom';
interface AgentChatProps {
messages: ChatMessage[];
agents: AgentProfile[];
loading: boolean;
onSend: (message: string) => Promise<void>;
}
export function AgentChat({ messages, loading, onSend }: AgentChatProps) {
const [input, setInput] = useState('');
const scrollRef = useRef<HTMLDivElement>(null);
// Auto-scroll to bottom
useEffect(() => {
if (scrollRef.current) {
scrollRef.current.scrollTop = scrollRef.current.scrollHeight;
}
}, [messages]);
const handleSend = async () => {
const trimmed = input.trim();
if (!trimmed || loading) return;
setInput('');
await onSend(trimmed);
};
const handleKeyDown = (e: React.KeyboardEvent) => {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
handleSend();
}
};
return (
<div className="flex flex-col w-80 border-l border-gray-200 dark:border-gray-700 bg-white dark:bg-gray-800">
{/* Header */}
<div className="px-3 py-2 border-b border-gray-200 dark:border-gray-700">
<h3 className="text-sm font-medium text-gray-700 dark:text-gray-300">
Classroom Chat
</h3>
</div>
{/* Messages */}
<div ref={scrollRef} className="flex-1 overflow-auto p-3 space-y-3">
{messages.length === 0 ? (
<div className="text-center text-xs text-gray-400 py-8">
Start a conversation with the classroom
</div>
) : (
messages.map((msg) => {
const isUser = msg.role === 'user';
return (
<div key={msg.id} className={`flex gap-2 ${isUser ? 'justify-end' : ''}`}>
{/* Avatar */}
{!isUser && (
<span
className="flex-shrink-0 w-7 h-7 rounded-full flex items-center justify-center text-xs"
style={{ backgroundColor: msg.color + '20' }}
>
{msg.agentAvatar}
</span>
)}
{/* Message bubble */}
<div className={`max-w-[200px] ${isUser ? 'text-right' : ''}`}>
{!isUser && (
<span className="text-xs font-medium" style={{ color: msg.color }}>
{msg.agentName}
</span>
)}
<div
className={`text-sm px-3 py-1.5 rounded-lg ${
isUser
? 'bg-indigo-500 text-white'
: 'bg-gray-100 dark:bg-gray-700 text-gray-800 dark:text-gray-200'
}`}
>
{msg.content}
</div>
</div>
</div>
);
})
)}
</div>
{/* Input */}
<div className="px-3 py-2 border-t border-gray-200 dark:border-gray-700">
<div className="flex gap-2">
<input
type="text"
value={input}
onChange={(e) => setInput(e.target.value)}
onKeyDown={handleKeyDown}
placeholder="Ask a question..."
disabled={loading}
className="flex-1 px-2 py-1.5 text-sm rounded border border-gray-300 dark:border-gray-600 bg-white dark:bg-gray-700 text-gray-900 dark:text-white focus:outline-none focus:ring-1 focus:ring-indigo-400 disabled:opacity-50"
/>
<button
onClick={handleSend}
disabled={loading || !input.trim()}
className="px-3 py-1.5 text-sm rounded bg-indigo-500 text-white disabled:opacity-50 hover:bg-indigo-600"
>
{loading ? '...' : 'Send'}
</button>
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,231 @@
/**
* ClassroomPlayer — Full-screen interactive classroom player.
*
* Layout: Notes sidebar | Main stage | Chat panel
* Top: Title + Agent avatars
* Bottom: Scene navigation + playback controls
*/
import { useState, useCallback, useEffect } from 'react';
import { invoke } from '@tauri-apps/api/core';
import { useClassroom } from '../../hooks/useClassroom';
import { SceneRenderer } from './SceneRenderer';
import { AgentChat } from './AgentChat';
import { NotesSidebar } from './NotesSidebar';
import { TtsPlayer } from './TtsPlayer';
import { Download } from 'lucide-react';
interface ClassroomPlayerProps {
onClose: () => void;
}
export function ClassroomPlayer({ onClose }: ClassroomPlayerProps) {
const {
activeClassroom,
chatMessages,
chatLoading,
sendChatMessage,
} = useClassroom();
const [currentSceneIndex, setCurrentSceneIndex] = useState(0);
const [sidebarOpen, setSidebarOpen] = useState(true);
const [chatOpen, setChatOpen] = useState(true);
const [exporting, setExporting] = useState(false);
const classroom = activeClassroom;
const scenes = classroom?.scenes ?? [];
const agents = classroom?.agents ?? [];
const currentScene = scenes[currentSceneIndex] ?? null;
// Navigate to next/prev scene
const goNext = useCallback(() => {
setCurrentSceneIndex((i) => Math.min(i + 1, scenes.length - 1));
}, [scenes.length]);
const goPrev = useCallback(() => {
setCurrentSceneIndex((i) => Math.max(i - 1, 0));
}, []);
// Keyboard shortcuts
useEffect(() => {
const handler = (e: KeyboardEvent) => {
if (e.key === 'ArrowRight') goNext();
else if (e.key === 'ArrowLeft') goPrev();
else if (e.key === 'Escape') onClose();
};
window.addEventListener('keydown', handler);
return () => window.removeEventListener('keydown', handler);
}, [goNext, goPrev, onClose]);
// Chat handler
const handleChatSend = useCallback(async (message: string) => {
const sceneContext = currentScene?.content.title;
await sendChatMessage(message, sceneContext);
}, [sendChatMessage, currentScene]);
// Export handler
const handleExport = useCallback(async (format: 'html' | 'markdown' | 'json') => {
if (!classroom) return;
setExporting(true);
try {
const result = await invoke<{ content: string; filename: string; mimeType: string }>(
'classroom_export',
{ request: { classroomId: classroom.id, format } }
);
// Download the exported file
const blob = new Blob([result.content], { type: result.mimeType });
const url = URL.createObjectURL(blob);
const a = document.createElement('a');
a.href = url;
a.download = result.filename;
a.click();
URL.revokeObjectURL(url);
} catch (e) {
console.error('Export failed:', e);
} finally {
setExporting(false);
}
}, [classroom]);
if (!classroom) {
return (
<div className="flex items-center justify-center h-full text-gray-500">
No classroom loaded
</div>
);
}
return (
<div className="flex flex-col h-full bg-gray-50 dark:bg-gray-900">
{/* Header */}
<header className="flex items-center justify-between px-4 py-2 border-b border-gray-200 dark:border-gray-700 bg-white dark:bg-gray-800">
<div className="flex items-center gap-3">
<button
onClick={onClose}
className="p-1 rounded hover:bg-gray-100 dark:hover:bg-gray-700"
aria-label="Close classroom"
>
</button>
<h1 className="text-lg font-semibold text-gray-900 dark:text-white truncate max-w-md">
{classroom.title}
</h1>
</div>
{/* Agent avatars */}
<div className="flex items-center gap-1">
{agents.map((agent) => (
<span
key={agent.id}
className="inline-flex items-center justify-center w-8 h-8 rounded-full text-sm"
style={{ backgroundColor: agent.color + '20', color: agent.color }}
title={agent.name}
>
{agent.avatar}
</span>
))}
</div>
<div className="flex items-center gap-2">
<button
onClick={() => setSidebarOpen(!sidebarOpen)}
className={`px-2 py-1 rounded text-xs ${sidebarOpen ? 'bg-indigo-100 text-indigo-700' : 'text-gray-500'}`}
>
Notes
</button>
<button
onClick={() => setChatOpen(!chatOpen)}
className={`px-2 py-1 rounded text-xs ${chatOpen ? 'bg-indigo-100 text-indigo-700' : 'text-gray-500'}`}
>
Chat
</button>
{/* Export dropdown */}
<div className="relative group">
<button
disabled={exporting}
className="px-2 py-1 rounded text-xs text-gray-500 hover:text-gray-700 flex items-center gap-1"
title="导出课堂"
>
<Download className="w-3.5 h-3.5" />
{exporting ? '...' : '导出'}
</button>
<div className="absolute right-0 top-full mt-1 bg-white dark:bg-gray-800 border border-gray-200 dark:border-gray-700 rounded shadow-lg hidden group-hover:block z-10">
<button onClick={() => handleExport('html')} className="block w-full text-left px-3 py-1.5 text-xs text-gray-700 dark:text-gray-300 hover:bg-gray-100 dark:hover:bg-gray-700">HTML</button>
<button onClick={() => handleExport('markdown')} className="block w-full text-left px-3 py-1.5 text-xs text-gray-700 dark:text-gray-300 hover:bg-gray-100 dark:hover:bg-gray-700">Markdown</button>
<button onClick={() => handleExport('json')} className="block w-full text-left px-3 py-1.5 text-xs text-gray-700 dark:text-gray-300 hover:bg-gray-100 dark:hover:bg-gray-700">JSON</button>
</div>
</div>
</div>
</header>
{/* Main content */}
<div className="flex flex-1 overflow-hidden">
{/* Notes sidebar */}
{sidebarOpen && (
<NotesSidebar
scenes={scenes}
currentIndex={currentSceneIndex}
onSelectScene={setCurrentSceneIndex}
/>
)}
{/* Main stage */}
<main className="flex-1 overflow-auto p-4">
{currentScene ? (
<SceneRenderer key={currentScene.id} scene={currentScene} agents={agents} />
) : (
<div className="flex items-center justify-center h-full text-gray-400">
No scenes available
</div>
)}
</main>
{/* Chat panel */}
{chatOpen && (
<AgentChat
messages={chatMessages}
agents={agents}
loading={chatLoading}
onSend={handleChatSend}
/>
)}
</div>
{/* Bottom navigation */}
<footer className="flex items-center justify-between px-4 py-2 border-t border-gray-200 dark:border-gray-700 bg-white dark:bg-gray-800">
<div className="flex items-center gap-2">
<button
onClick={goPrev}
disabled={currentSceneIndex === 0}
className="px-3 py-1 rounded text-sm bg-gray-100 dark:bg-gray-700 disabled:opacity-50"
>
Previous
</button>
<span className="text-sm text-gray-500">
{currentSceneIndex + 1} / {scenes.length}
</span>
<button
onClick={goNext}
disabled={currentSceneIndex >= scenes.length - 1}
className="px-3 py-1 rounded text-sm bg-gray-100 dark:bg-gray-700 disabled:opacity-50"
>
Next
</button>
</div>
{/* TTS + Scene info */}
<div className="flex items-center gap-3">
{currentScene?.content.notes && (
<TtsPlayer text={currentScene.content.notes} />
)}
<div className="text-xs text-gray-400">
{currentScene?.content.sceneType ?? ''}
{currentScene?.content.durationSeconds
? ` · ${Math.floor(currentScene.content.durationSeconds / 60)}:${String(currentScene.content.durationSeconds % 60).padStart(2, '0')}`
: ''}
</div>
</div>
</footer>
</div>
);
}

View File

@@ -0,0 +1,71 @@
/**
* NotesSidebar — Scene outline navigation + notes.
*
* Left panel showing all scenes as clickable items with notes.
*/
import type { GeneratedScene } from '../../types/classroom';
interface NotesSidebarProps {
scenes: GeneratedScene[];
currentIndex: number;
onSelectScene: (index: number) => void;
}
export function NotesSidebar({ scenes, currentIndex, onSelectScene }: NotesSidebarProps) {
return (
<div className="w-64 border-r border-gray-200 dark:border-gray-700 bg-white dark:bg-gray-800 overflow-auto">
<div className="px-3 py-2 border-b border-gray-200 dark:border-gray-700">
<h3 className="text-xs font-semibold text-gray-500 uppercase tracking-wider">
Outline
</h3>
</div>
<nav className="py-1">
{scenes.map((scene, i) => {
const isActive = i === currentIndex;
const typeColor = getTypeColor(scene.content.sceneType);
return (
<button
key={scene.id}
onClick={() => onSelectScene(i)}
className={`w-full text-left px-3 py-2 text-sm border-l-2 transition-colors ${
isActive
? 'border-indigo-500 bg-indigo-50 dark:bg-indigo-900/20'
: 'border-transparent hover:bg-gray-50 dark:hover:bg-gray-700/50'
}`}
>
<div className="flex items-center gap-2">
<span
className="inline-block w-1.5 h-1.5 rounded-full"
style={{ backgroundColor: typeColor }}
/>
<span className={`font-medium ${isActive ? 'text-indigo-700 dark:text-indigo-300' : 'text-gray-700 dark:text-gray-300'}`}>
{i + 1}. {scene.content.title}
</span>
</div>
{scene.content.notes && (
<p className="text-xs text-gray-400 mt-0.5 ml-3.5 line-clamp-2">
{scene.content.notes}
</p>
)}
</button>
);
})}
</nav>
</div>
);
}
function getTypeColor(type: string): string {
switch (type) {
case 'slide': return '#6366F1';
case 'quiz': return '#F59E0B';
case 'discussion': return '#10B981';
case 'interactive': return '#8B5CF6';
case 'pbl': return '#EF4444';
case 'media': return '#06B6D4';
default: return '#9CA3AF';
}
}

View File

@@ -0,0 +1,219 @@
/**
* SceneRenderer — Renders a single classroom scene.
*
* Supports scene types: slide, quiz, discussion, interactive, text, pbl, media.
* Executes scene actions (speech, whiteboard, quiz, discussion).
*/
import { useState, useEffect, useCallback } from 'react';
import type { GeneratedScene, SceneContent, SceneAction, AgentProfile } from '../../types/classroom';
interface SceneRendererProps {
scene: GeneratedScene;
agents: AgentProfile[];
autoPlay?: boolean;
}
export function SceneRenderer({ scene, agents, autoPlay = true }: SceneRendererProps) {
const { content } = scene;
const [actionIndex, setActionIndex] = useState(0);
const [isPlaying, setIsPlaying] = useState(autoPlay);
const [whiteboardItems, setWhiteboardItems] = useState<Array<{
type: string;
data: SceneAction;
}>>([]);
const actions = content.actions ?? [];
const currentAction = actions[actionIndex] ?? null;
// Auto-advance through actions
useEffect(() => {
if (!isPlaying || actions.length === 0) return;
if (actionIndex >= actions.length) {
setIsPlaying(false);
return;
}
const delay = getActionDelay(actions[actionIndex]);
const timer = setTimeout(() => {
processAction(actions[actionIndex]);
setActionIndex((i) => i + 1);
}, delay);
return () => clearTimeout(timer);
}, [actionIndex, isPlaying, actions]);
const processAction = useCallback((action: SceneAction) => {
switch (action.type) {
case 'whiteboard_draw_text':
case 'whiteboard_draw_shape':
case 'whiteboard_draw_chart':
case 'whiteboard_draw_latex':
setWhiteboardItems((prev) => [...prev, { type: action.type, data: action }]);
break;
case 'whiteboard_clear':
setWhiteboardItems([]);
break;
}
}, []);
// Render scene based on type
return (
<div className="flex flex-col h-full">
{/* Scene title */}
<div className="mb-4">
<h2 className="text-2xl font-bold text-gray-900 dark:text-white">
{content.title}
</h2>
{content.notes && (
<p className="text-sm text-gray-500 mt-1">{content.notes}</p>
)}
</div>
{/* Main content area */}
<div className="flex-1 flex gap-4 overflow-hidden">
{/* Content panel */}
<div className="flex-1 overflow-auto">
{renderContent(content)}
</div>
{/* Whiteboard area */}
{whiteboardItems.length > 0 && (
<div className="w-80 border border-gray-200 dark:border-gray-700 rounded-lg bg-white dark:bg-gray-800 p-2 overflow-auto">
<svg viewBox="0 0 800 600" className="w-full h-full">
{whiteboardItems.map((item, i) => (
<g key={i}>{renderWhiteboardItem(item)}</g>
))}
</svg>
</div>
)}
</div>
{/* Current action indicator */}
{currentAction && (
<div className="mt-4 p-3 rounded-lg bg-indigo-50 dark:bg-indigo-900/20 border border-indigo-100 dark:border-indigo-800">
{renderCurrentAction(currentAction, agents)}
</div>
)}
{/* Playback controls */}
<div className="flex items-center justify-center gap-2 mt-4">
<button
onClick={() => { setActionIndex(0); setWhiteboardItems([]); }}
className="px-2 py-1 text-xs rounded bg-gray-100 dark:bg-gray-700"
>
Restart
</button>
<button
onClick={() => setIsPlaying(!isPlaying)}
className="px-3 py-1 text-sm rounded bg-indigo-500 text-white"
>
{isPlaying ? 'Pause' : 'Play'}
</button>
<span className="text-xs text-gray-400">
Action {Math.min(actionIndex + 1, actions.length)} / {actions.length}
</span>
</div>
</div>
);
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
function getActionDelay(action: SceneAction): number {
switch (action.type) {
case 'speech': return 2000;
case 'whiteboard_draw_text': return 800;
case 'whiteboard_draw_shape': return 600;
case 'quiz_show': return 5000;
case 'discussion': return 10000;
default: return 1000;
}
}
function renderContent(content: SceneContent) {
const data = content.content;
if (!data || typeof data !== 'object') return null;
// Handle slide content
const keyPoints = data.key_points as string[] | undefined;
const description = data.description as string | undefined;
const slides = data.slides as Array<{ title: string; content: string }> | undefined;
return (
<div className="space-y-4">
{description && (
<p className="text-gray-700 dark:text-gray-300 leading-relaxed">{description}</p>
)}
{keyPoints && keyPoints.length > 0 && (
<ul className="space-y-2">
{keyPoints.map((point, i) => (
<li key={i} className="flex items-start gap-2">
<span className="text-indigo-500 mt-0.5"></span>
<span className="text-gray-700 dark:text-gray-300">{point}</span>
</li>
))}
</ul>
)}
{slides && slides.map((slide, i) => (
<div key={i} className="p-3 rounded border border-gray-200 dark:border-gray-700">
<h4 className="font-medium text-gray-900 dark:text-white">{slide.title}</h4>
<p className="text-sm text-gray-600 dark:text-gray-400 mt-1">{slide.content}</p>
</div>
))}
</div>
);
}
function renderCurrentAction(action: SceneAction, agents: AgentProfile[]) {
switch (action.type) {
case 'speech': {
const agent = agents.find(a => a.role === action.agentRole);
return (
<div className="flex items-start gap-2">
<span className="text-lg">{agent?.avatar ?? '💬'}</span>
<div>
<span className="text-xs font-medium text-gray-600">{agent?.name ?? action.agentRole}</span>
<p className="text-sm text-gray-700 dark:text-gray-300">{action.text}</p>
</div>
</div>
);
}
case 'quiz_show':
return <div className="text-sm text-amber-600">Quiz: {action.quizId}</div>;
case 'discussion':
return <div className="text-sm text-green-600">Discussion: {action.topic}</div>;
default:
return <div className="text-xs text-gray-400">{action.type}</div>;
}
}
function renderWhiteboardItem(item: { type: string; data: Record<string, unknown> }) {
switch (item.type) {
case 'whiteboard_draw_text': {
const d = item.data;
if ('text' in d && 'x' in d && 'y' in d) {
return (
<text x={typeof d.x === 'number' ? d.x : 100} y={typeof d.y === 'number' ? d.y : 100} fontSize={typeof d.fontSize === 'number' ? d.fontSize : 16} fill={typeof d.color === 'string' ? d.color : '#333'}>
{String(d.text ?? '')}
</text>
);
}
return null;
}
case 'whiteboard_draw_shape': {
const d = item.data as Record<string, unknown>;
const x = typeof d.x === 'number' ? d.x : 0;
const y = typeof d.y === 'number' ? d.y : 0;
const w = typeof d.width === 'number' ? d.width : 100;
const h = typeof d.height === 'number' ? d.height : 50;
const fill = typeof d.fill === 'string' ? d.fill : '#e5e5e5';
return (
<rect x={x} y={y} width={w} height={h} fill={fill} />
);
}
}
}

View File

@@ -0,0 +1,155 @@
/**
* TtsPlayer — Text-to-Speech playback controls for classroom narration.
*
* Uses the browser's built-in SpeechSynthesis API.
* Provides play/pause, speed, and volume controls.
*/
import { useState, useEffect, useCallback, useRef } from 'react';
import { Volume2, VolumeX, Pause, Play, Gauge } from 'lucide-react';
interface TtsPlayerProps {
text: string;
autoPlay?: boolean;
onEnd?: () => void;
}
export function TtsPlayer({ text, autoPlay = false, onEnd }: TtsPlayerProps) {
const [isPlaying, setIsPlaying] = useState(false);
const [isPaused, setIsPaused] = useState(false);
const [rate, setRate] = useState(1.0);
const [isMuted, setIsMuted] = useState(false);
const utteranceRef = useRef<SpeechSynthesisUtterance | null>(null);
const speak = useCallback(() => {
if (!text || typeof window === 'undefined') return;
window.speechSynthesis.cancel();
const utterance = new SpeechSynthesisUtterance(text);
utterance.lang = 'zh-CN';
utterance.rate = rate;
utterance.volume = isMuted ? 0 : 1;
utterance.onend = () => {
setIsPlaying(false);
setIsPaused(false);
onEnd?.();
};
utterance.onerror = () => {
setIsPlaying(false);
setIsPaused(false);
};
utteranceRef.current = utterance;
window.speechSynthesis.speak(utterance);
setIsPlaying(true);
setIsPaused(false);
}, [text, rate, isMuted, onEnd]);
const togglePlay = useCallback(() => {
if (isPlaying && !isPaused) {
window.speechSynthesis.pause();
setIsPaused(true);
} else if (isPaused) {
window.speechSynthesis.resume();
setIsPaused(false);
} else {
speak();
}
}, [isPlaying, isPaused, speak]);
const stop = useCallback(() => {
window.speechSynthesis.cancel();
setIsPlaying(false);
setIsPaused(false);
}, []);
// Auto-play when text changes
useEffect(() => {
if (autoPlay && text) {
speak();
}
return () => {
if (typeof window !== 'undefined') {
window.speechSynthesis.cancel();
}
};
}, [text, autoPlay, speak]);
// Cleanup on unmount
useEffect(() => {
return () => {
if (typeof window !== 'undefined') {
window.speechSynthesis.cancel();
}
};
}, []);
if (!text) return null;
return (
<div className="flex items-center gap-3 px-3 py-2 rounded-lg bg-gray-50 dark:bg-gray-800 border border-gray-200 dark:border-gray-700">
{/* Play/Pause button */}
<button
onClick={togglePlay}
className="w-8 h-8 flex items-center justify-center rounded-full bg-indigo-500 text-white hover:bg-indigo-600 transition-colors"
aria-label={isPlaying && !isPaused ? '暂停' : '播放'}
>
{isPlaying && !isPaused ? (
<Pause className="w-4 h-4" />
) : (
<Play className="w-4 h-4" />
)}
</button>
{/* Stop button */}
{isPlaying && (
<button
onClick={stop}
className="w-6 h-6 flex items-center justify-center rounded text-gray-500 hover:text-gray-700 dark:hover:text-gray-300"
aria-label="停止"
>
<VolumeX className="w-3.5 h-3.5" />
</button>
)}
{/* Speed control */}
<div className="flex items-center gap-1.5">
<Gauge className="w-3.5 h-3.5 text-gray-400" />
<select
value={rate}
onChange={(e) => setRate(Number(e.target.value))}
className="text-xs bg-transparent border-none text-gray-600 dark:text-gray-400 cursor-pointer"
>
<option value={0.5}>0.5x</option>
<option value={0.75}>0.75x</option>
<option value={1}>1x</option>
<option value={1.25}>1.25x</option>
<option value={1.5}>1.5x</option>
<option value={2}>2x</option>
</select>
</div>
{/* Mute toggle */}
<button
onClick={() => setIsMuted(!isMuted)}
className="text-gray-400 hover:text-gray-600 dark:hover:text-gray-300"
aria-label={isMuted ? '取消静音' : '静音'}
>
{isMuted ? (
<VolumeX className="w-4 h-4" />
) : (
<Volume2 className="w-4 h-4" />
)}
</button>
{/* Status indicator */}
{isPlaying && (
<span className="text-xs text-gray-400">
{isPaused ? '已暂停' : '朗读中...'}
</span>
)}
</div>
);
}

View File

@@ -0,0 +1,295 @@
/**
* WhiteboardCanvas — SVG-based whiteboard for classroom scene rendering.
*
* Supports incremental drawing operations:
* - Text (positioned labels)
* - Shapes (rectangles, circles, arrows)
* - Charts (bar/line/pie via simple SVG)
* - LaTeX (rendered as styled text blocks)
*/
import { useCallback } from 'react';
import type { SceneAction } from '../../types/classroom';
interface WhiteboardCanvasProps {
items: WhiteboardItem[];
width?: number;
height?: number;
}
export interface WhiteboardItem {
type: string;
data: SceneAction;
}
export function WhiteboardCanvas({
items,
width = 800,
height = 600,
}: WhiteboardCanvasProps) {
const renderItem = useCallback((item: WhiteboardItem, index: number) => {
switch (item.type) {
case 'whiteboard_draw_text':
return <TextItem key={index} data={item.data as TextDrawData} />;
case 'whiteboard_draw_shape':
return <ShapeItem key={index} data={item.data as ShapeDrawData} />;
case 'whiteboard_draw_chart':
return <ChartItem key={index} data={item.data as ChartDrawData} />;
case 'whiteboard_draw_latex':
return <LatexItem key={index} data={item.data as LatexDrawData} />;
default:
return null;
}
}, []);
return (
<div className="w-full h-full border border-gray-200 dark:border-gray-700 rounded-lg bg-white dark:bg-gray-900 overflow-auto">
<svg
viewBox={`0 0 ${width} ${height}`}
className="w-full h-full"
xmlns="http://www.w3.org/2000/svg"
>
{/* Grid background */}
<defs>
<pattern id="grid" width="40" height="40" patternUnits="userSpaceOnUse">
<path d="M 40 0 L 0 0 0 40" fill="none" stroke="#f0f0f0" strokeWidth="0.5" />
</pattern>
</defs>
<rect width={width} height={height} fill="url(#grid)" />
{/* Rendered items */}
{items.map((item, i) => renderItem(item, i))}
</svg>
</div>
);
}
// ---------------------------------------------------------------------------
// Sub-components
// ---------------------------------------------------------------------------
interface TextDrawData {
type: 'whiteboard_draw_text';
x: number;
y: number;
text: string;
fontSize?: number;
color?: string;
}
function TextItem({ data }: { data: TextDrawData }) {
return (
<text
x={data.x}
y={data.y}
fontSize={data.fontSize ?? 16}
fill={data.color ?? '#333333'}
fontFamily="system-ui, sans-serif"
>
{data.text}
</text>
);
}
interface ShapeDrawData {
type: 'whiteboard_draw_shape';
shape: string;
x: number;
y: number;
width: number;
height: number;
fill?: string;
}
function ShapeItem({ data }: { data: ShapeDrawData }) {
switch (data.shape) {
case 'circle':
return (
<ellipse
cx={data.x + data.width / 2}
cy={data.y + data.height / 2}
rx={data.width / 2}
ry={data.height / 2}
fill={data.fill ?? '#e5e7eb'}
stroke="#9ca3af"
strokeWidth={1}
/>
);
case 'arrow':
return (
<g>
<line
x1={data.x}
y1={data.y + data.height / 2}
x2={data.x + data.width}
y2={data.y + data.height / 2}
stroke={data.fill ?? '#6b7280'}
strokeWidth={2}
markerEnd="url(#arrowhead)"
/>
<defs>
<marker id="arrowhead" markerWidth="10" markerHeight="7" refX="10" refY="3.5" orient="auto">
<polygon points="0 0, 10 3.5, 0 7" fill={data.fill ?? '#6b7280'} />
</marker>
</defs>
</g>
);
default: // rectangle
return (
<rect
x={data.x}
y={data.y}
width={data.width}
height={data.height}
fill={data.fill ?? '#e5e7eb'}
stroke="#9ca3af"
strokeWidth={1}
rx={4}
/>
);
}
}
interface ChartDrawData {
type: 'whiteboard_draw_chart';
chartType: string;
data: Record<string, unknown>;
x: number;
y: number;
width: number;
height: number;
}
function ChartItem({ data }: { data: ChartDrawData }) {
const chartData = data.data;
const labels = (chartData?.labels as string[]) ?? [];
const values = (chartData?.values as number[]) ?? [];
if (labels.length === 0 || values.length === 0) return null;
switch (data.chartType) {
case 'bar':
return <BarChart data={data} labels={labels} values={values} />;
case 'line':
return <LineChart data={data} labels={labels} values={values} />;
default:
return <BarChart data={data} labels={labels} values={values} />;
}
}
function BarChart({ data, labels, values }: {
data: ChartDrawData;
labels: string[];
values: number[];
}) {
const maxVal = Math.max(...values, 1);
const barWidth = data.width / (labels.length * 2);
const chartHeight = data.height - 30;
return (
<g transform={`translate(${data.x}, ${data.y})`}>
{values.map((val, i) => {
const barHeight = (val / maxVal) * chartHeight;
return (
<g key={i}>
<rect
x={i * (barWidth * 2) + barWidth / 2}
y={chartHeight - barHeight}
width={barWidth}
height={barHeight}
fill="#6366f1"
rx={2}
/>
<text
x={i * (barWidth * 2) + barWidth}
y={data.height - 5}
textAnchor="middle"
fontSize={10}
fill="#666"
>
{labels[i]}
</text>
</g>
);
})}
</g>
);
}
function LineChart({ data, labels, values }: {
data: ChartDrawData;
labels: string[];
values: number[];
}) {
const maxVal = Math.max(...values, 1);
const chartHeight = data.height - 30;
const stepX = data.width / Math.max(labels.length - 1, 1);
const points = values.map((val, i) => {
const x = i * stepX;
const y = chartHeight - (val / maxVal) * chartHeight;
return `${x},${y}`;
}).join(' ');
return (
<g transform={`translate(${data.x}, ${data.y})`}>
<polyline
points={points}
fill="none"
stroke="#6366f1"
strokeWidth={2}
/>
{values.map((val, i) => {
const x = i * stepX;
const y = chartHeight - (val / maxVal) * chartHeight;
return (
<g key={i}>
<circle cx={x} cy={y} r={3} fill="#6366f1" />
<text
x={x}
y={data.height - 5}
textAnchor="middle"
fontSize={10}
fill="#666"
>
{labels[i]}
</text>
</g>
);
})}
</g>
);
}
interface LatexDrawData {
type: 'whiteboard_draw_latex';
latex: string;
x: number;
y: number;
}
function LatexItem({ data }: { data: LatexDrawData }) {
return (
<g transform={`translate(${data.x}, ${data.y})`}>
<rect
x={-4}
y={-20}
width={data.latex.length * 10 + 8}
height={28}
fill="#fef3c7"
stroke="#f59e0b"
strokeWidth={1}
rx={4}
/>
<text
x={0}
y={0}
fontSize={14}
fill="#92400e"
fontFamily="'Courier New', monospace"
>
{data.latex}
</text>
</g>
);
}

View File

@@ -0,0 +1,12 @@
/**
* Classroom Player Components
*
* Re-exports all classroom player components.
*/
export { ClassroomPlayer } from './ClassroomPlayer';
export { SceneRenderer } from './SceneRenderer';
export { AgentChat } from './AgentChat';
export { NotesSidebar } from './NotesSidebar';
export { WhiteboardCanvas } from './WhiteboardCanvas';
export { TtsPlayer } from './TtsPlayer';

View File

@@ -0,0 +1,76 @@
/**
* useClassroom — React hook wrapping the classroom store for component consumption.
*
* Provides a simplified interface for classroom generation and chat.
*/
import { useCallback } from 'react';
import {
useClassroomStore,
type GenerationRequest,
} from '../store/classroomStore';
import type {
Classroom,
ClassroomChatMessage,
} from '../types/classroom';
export interface UseClassroomReturn {
/** Is generation in progress */
generating: boolean;
/** Current generation stage name */
progressStage: string | null;
/** Progress percentage 0-100 */
progressPercent: number;
/** The active classroom */
activeClassroom: Classroom | null;
/** Chat messages for active classroom */
chatMessages: ClassroomChatMessage[];
/** Is a chat request loading */
chatLoading: boolean;
/** Error message, if any */
error: string | null;
/** Start classroom generation */
startGeneration: (request: GenerationRequest) => Promise<string>;
/** Cancel active generation */
cancelGeneration: () => void;
/** Send a chat message in the active classroom */
sendChatMessage: (message: string, sceneContext?: string) => Promise<void>;
/** Clear current error */
clearError: () => void;
}
/**
* Hook for classroom generation and multi-agent chat.
*
* Components should use this hook rather than accessing the store directly,
* to keep the rendering logic decoupled from state management.
*/
export function useClassroom(): UseClassroomReturn {
const {
generating,
progressStage,
progressPercent,
activeClassroom,
chatMessages,
chatLoading,
error,
startGeneration,
cancelGeneration,
sendChatMessage,
clearError,
} = useClassroomStore();
return {
generating,
progressStage,
progressPercent,
activeClassroom,
chatMessages,
chatLoading,
error,
startGeneration: useCallback((req: GenerationRequest) => startGeneration(req), [startGeneration]),
cancelGeneration: useCallback(() => cancelGeneration(), [cancelGeneration]),
sendChatMessage: useCallback((msg, ctx) => sendChatMessage(msg, ctx), [sendChatMessage]),
clearError: useCallback(() => clearError(), [clearError]),
};
}

View File

@@ -1,27 +1,5 @@
@import "tailwindcss";
/* Aurora gradient animation for welcome title (DeerFlow-inspired) */
@keyframes gradient-shift {
0%, 100% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
}
.aurora-title {
background: linear-gradient(
135deg,
#f97316 0%, /* orange-500 */
#ef4444 25%, /* red-500 */
#f97316 50%, /* orange-500 */
#fb923c 75%, /* orange-400 */
#f97316 100% /* orange-500 */
);
background-size: 200% 200%;
-webkit-background-clip: text;
background-clip: text;
-webkit-text-fill-color: transparent;
animation: gradient-shift 4s ease infinite;
}
:root {
/* Brand Colors - 中性灰色系 */
--color-primary: #374151; /* gray-700 */
@@ -154,3 +132,38 @@ textarea:focus-visible {
outline: none !important;
box-shadow: none !important;
}
/* === Accessibility: reduced motion === */
@media (prefers-reduced-motion: reduce) {
*, *::before, *::after {
animation-duration: 0.01ms !important;
animation-iteration-count: 1 !important;
transition-duration: 0.01ms !important;
scroll-behavior: auto !important;
}
}
/* === Responsive breakpoints for small windows/tablets === */
@media (max-width: 768px) {
/* Auto-collapse sidebar aside on narrow viewports */
aside.w-64 {
width: 0 !important;
min-width: 0 !important;
overflow: hidden;
border-right: none !important;
}
aside.w-64.sidebar-open {
width: 260px !important;
min-width: 260px !important;
position: fixed;
z-index: 50;
height: 100vh;
}
}
@media (max-width: 480px) {
.chat-bubble-assistant,
.chat-bubble-user {
max-width: 95% !important;
}
}

View File

@@ -3,6 +3,10 @@
*
* 为 ZCLAW 前端操作提供统一的审计日志记录功能。
* 记录关键操作Hand 触发、Agent 创建等)到本地存储。
*
* @reserved This module is reserved for future audit logging integration.
* It is not currently imported by any component. When audit logging is needed,
* import { logAudit, logAuditSuccess, logAuditFailure } from this module.
*/
import { createLogger } from './logger';

View File

@@ -0,0 +1,142 @@
/**
* Classroom Adapter
*
* Bridges the old ClassroomData type (ClassroomPreviewer) with the new
* Classroom type (ClassroomPlayer + Tauri backend).
*/
import type { Classroom, GeneratedScene } from '../types/classroom';
import { SceneType, TeachingStyle, DifficultyLevel } from '../types/classroom';
import type { ClassroomData, ClassroomScene } from '../components/ClassroomPreviewer';
// ---------------------------------------------------------------------------
// Old → New (ClassroomData → Classroom)
// ---------------------------------------------------------------------------
/**
* Convert a legacy ClassroomData to the new Classroom format.
* Used when opening ClassroomPlayer from Pipeline result previews.
*/
export function adaptToClassroom(data: ClassroomData): Classroom {
const scenes: GeneratedScene[] = data.scenes.map((scene, index) => ({
id: scene.id,
outlineId: `outline-${index}`,
content: {
title: scene.title,
sceneType: mapSceneType(scene.type),
content: {
heading: scene.content.heading ?? scene.title,
key_points: scene.content.bullets ?? [],
description: scene.content.explanation,
quiz: scene.content.quiz ?? undefined,
},
actions: [],
durationSeconds: scene.duration ?? 60,
notes: scene.narration,
},
order: index,
})) as GeneratedScene[];
return {
id: data.id,
title: data.title,
description: data.subject,
topic: data.subject,
style: TeachingStyle.Lecture,
level: mapDifficulty(data.difficulty),
totalDuration: data.duration * 60,
objectives: [],
scenes,
agents: [],
metadata: {
generatedAt: new Date(data.createdAt).getTime(),
version: '1.0',
custom: {},
},
};
}
// ---------------------------------------------------------------------------
// New → Old (Classroom → ClassroomData)
// ---------------------------------------------------------------------------
/**
* Convert a new Classroom to the legacy ClassroomData format.
* Used when rendering ClassroomPreviewer from new pipeline results.
*/
export function adaptToClassroomData(classroom: Classroom): ClassroomData {
const scenes: ClassroomScene[] = classroom.scenes.map((scene) => {
const data = scene.content.content as Record<string, unknown>;
return {
id: scene.id,
title: scene.content.title,
type: mapToLegacySceneType(scene.content.sceneType),
content: {
heading: (data?.heading as string) ?? scene.content.title,
bullets: (data?.key_points as string[]) ?? [],
explanation: (data?.description as string) ?? '',
quiz: (data?.quiz as ClassroomScene['content']['quiz']) ?? undefined,
},
narration: scene.content.notes,
duration: scene.content.durationSeconds,
};
});
return {
id: classroom.id,
title: classroom.title,
subject: classroom.topic,
difficulty: mapToLegacyDifficulty(classroom.level),
duration: Math.ceil(classroom.totalDuration / 60),
scenes,
outline: {
sections: classroom.scenes.map((scene) => ({
title: scene.content.title,
scenes: [scene.id],
})),
},
createdAt: new Date(classroom.metadata.generatedAt).toISOString(),
};
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
function mapSceneType(type: ClassroomScene['type']): SceneType {
switch (type) {
case 'title': return SceneType.Slide;
case 'content': return SceneType.Slide;
case 'quiz': return SceneType.Quiz;
case 'interactive': return SceneType.Interactive;
case 'summary': return SceneType.Text;
default: return SceneType.Slide;
}
}
function mapToLegacySceneType(sceneType: string): ClassroomScene['type'] {
switch (sceneType) {
case 'quiz': return 'quiz';
case 'interactive': return 'interactive';
case 'text': return 'summary';
default: return 'content';
}
}
function mapDifficulty(difficulty: string): DifficultyLevel {
switch (difficulty) {
case '初级': return DifficultyLevel.Beginner;
case '中级': return DifficultyLevel.Intermediate;
case '高级': return DifficultyLevel.Advanced;
default: return DifficultyLevel.Intermediate;
}
}
function mapToLegacyDifficulty(level: string): ClassroomData['difficulty'] {
switch (level) {
case 'beginner': return '初级';
case 'advanced': return '高级';
case 'expert': return '高级';
default: return '中级';
}
}

View File

@@ -56,12 +56,19 @@ function initErrorStore(): void {
errors: [],
addError: (error: AppError) => {
// Dedup: skip if same title+message already exists and undismissed
const isDuplicate = errorStore.errors.some(
(e) => !e.dismissed && e.title === error.title && e.message === error.message
);
if (isDuplicate) return;
const storedError: StoredError = {
...error,
dismissed: false,
reported: false,
};
errorStore.errors = [storedError, ...errorStore.errors];
// Cap at 50 errors to prevent unbounded growth
errorStore.errors = [storedError, ...errorStore.errors].slice(0, 50);
// Notify listeners
notifyErrorListeners(error);
},

View File

@@ -103,6 +103,12 @@ export function installChatMethods(ClientClass: { prototype: KernelClient }): vo
callbacks.onDelta(streamEvent.delta);
break;
case 'thinkingDelta':
if (callbacks.onThinkingDelta) {
callbacks.onThinkingDelta(streamEvent.delta);
}
break;
case 'tool_start':
log.debug('Tool started:', streamEvent.name, streamEvent.input);
if (callbacks.onTool) {

View File

@@ -5,8 +5,20 @@
*/
import { invoke } from '@tauri-apps/api/core';
import { listen, type UnlistenFn } from '@tauri-apps/api/event';
import { createLogger } from './logger';
import type { KernelClient } from './kernel-client';
const log = createLogger('KernelHands');
/** Payload emitted by the Rust backend on `hand-execution-complete` events. */
export interface HandExecutionCompletePayload {
approvalId: string;
handId: string;
success: boolean;
error?: string | null;
}
export function installHandMethods(ClientClass: { prototype: KernelClient }): void {
const proto = ClientClass.prototype as any;
@@ -92,7 +104,7 @@ export function installHandMethods(ClientClass: { prototype: KernelClient }): vo
*/
proto.getHandStatus = async function (this: KernelClient, name: string, runId: string): Promise<{ status: string; result?: unknown }> {
try {
return await invoke('hand_run_status', { handName: name, runId });
return await invoke('hand_run_status', { runId });
} catch (e) {
const { createLogger } = await import('./logger');
createLogger('KernelHands').debug('hand_run_status failed', { name, runId, error: e });
@@ -171,4 +183,26 @@ export function installHandMethods(ClientClass: { prototype: KernelClient }): vo
proto.respondToApproval = async function (this: KernelClient, approvalId: string, approved: boolean, reason?: string): Promise<void> {
return invoke('approval_respond', { id: approvalId, approved, reason });
};
// ─── Event Listeners ───
/**
* Listen for `hand-execution-complete` events emitted by the Rust backend
* after a hand finishes executing (both from direct trigger and approval flow).
*
* Returns an unlisten function for cleanup.
*/
proto.onHandExecutionComplete = async function (
this: KernelClient,
callback: (payload: HandExecutionCompletePayload) => void,
): Promise<UnlistenFn> {
const unlisten = await listen<HandExecutionCompletePayload>(
'hand-execution-complete',
(event) => {
log.debug('hand-execution-complete', event.payload);
callback(event.payload);
},
);
return unlisten;
};
}

View File

@@ -109,7 +109,11 @@ export function installSkillMethods(ClientClass: { prototype: KernelClient }): v
}> {
return invoke('skill_execute', {
id,
context: {},
context: {
agentId: '',
sessionId: '',
workingDir: '',
},
input: input || {},
});
};

Some files were not shown because too many files have changed in this diff Show More