Compare commits
95 Commits
chore/sqlx
...
b60b96225d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b60b96225d | ||
|
|
06e93a21af | ||
|
|
9060935401 | ||
|
|
6d6673bf5b | ||
|
|
15f84bf8c1 | ||
|
|
9a313e3c92 | ||
|
|
ee5611a2f8 | ||
|
|
5cf7adff69 | ||
|
|
10497362bb | ||
|
|
d7dbdf8600 | ||
|
|
8c25b20fe2 | ||
|
|
87110ffdff | ||
|
|
980a8135fa | ||
|
|
e9e7ffd609 | ||
|
|
00ebf18f23 | ||
|
|
aa84172ca4 | ||
|
|
1c0029001d | ||
|
|
0bb526509d | ||
|
|
394cb66311 | ||
|
|
b56d1a4c34 | ||
|
|
3e78dacef3 | ||
|
|
e64a3ea9a3 | ||
|
|
08812e541c | ||
|
|
17a7a36608 | ||
|
|
5485404c70 | ||
|
|
a09a4c0e0a | ||
|
|
62578d9df4 | ||
|
|
9756d9d995 | ||
|
|
7ba7389093 | ||
|
|
c10e50d58e | ||
|
|
5d88d129d1 | ||
|
|
36612eac53 | ||
|
|
b864973a54 | ||
|
|
73139da57a | ||
|
|
de7d88afcc | ||
|
|
8fd8c02953 | ||
|
|
fa5ab4e161 | ||
|
|
14f2f497b6 | ||
|
|
4328e74157 | ||
|
|
adf0251cb1 | ||
|
|
52078512a2 | ||
|
|
7afd64f536 | ||
|
|
73d50fda21 | ||
|
|
8b3e43710b | ||
|
|
81005c39f9 | ||
|
|
5816f56039 | ||
|
|
3cb9709caf | ||
|
|
bc9537cd80 | ||
|
|
bb1869bb1b | ||
|
|
46fee4b2c8 | ||
|
|
6d7457de56 | ||
|
|
eede45b13d | ||
|
|
ee56bf6087 | ||
|
|
5a0c652f4f | ||
|
|
95a05bc6dc | ||
|
|
0fd981905d | ||
|
|
39a7ac3356 | ||
|
|
8691837608 | ||
|
|
ed77095a37 | ||
|
|
58ff0bdde7 | ||
|
|
27006157da | ||
|
|
191cc3097c | ||
|
|
ae7322e610 | ||
|
|
591af5802c | ||
|
|
317b8254e4 | ||
|
|
751ec000d5 | ||
|
|
c5f98beb7c | ||
|
|
b2908791f6 | ||
|
|
79e7cd3446 | ||
|
|
b726d0cd5e | ||
|
|
13507682f7 | ||
|
|
ae56aba366 | ||
|
|
a43806ccc2 | ||
|
|
5b5491a08f | ||
|
|
74ce6d4adc | ||
|
|
ec22f0f357 | ||
|
|
d95fda3b76 | ||
|
|
f11ac6e434 | ||
|
|
9a2611d122 | ||
|
|
2f5e9f1755 | ||
|
|
c1dea6e07a | ||
|
|
f89b2263d1 | ||
|
|
3b97bc0746 | ||
|
|
f2917366a8 | ||
|
|
24b866fc28 | ||
|
|
39768ff598 | ||
|
|
3ee68fa763 | ||
|
|
891d972e20 | ||
|
|
e12766794b | ||
|
|
d9f8850083 | ||
|
|
0bd50aad8c | ||
|
|
4ee587d070 | ||
|
|
8b1b08be82 | ||
|
|
beeb529d8f | ||
|
|
226beb708b |
211
CLAUDE.md
211
CLAUDE.md
@@ -132,19 +132,60 @@ desktop/src-tauri (→ kernel, skills, hands, protocols)
|
||||
4. **配置问题** - TOML 解析、环境变量
|
||||
5. **运行时问题** - 服务启动、端口占用
|
||||
|
||||
不在根因未明时盲目堆补丁。
|
||||
不在根因未明时盲目堆补丁。这一步在四阶段工作法的"阶段 2: 制定方案"中完成。
|
||||
|
||||
### 3.3 闭环工作法(强制)
|
||||
### 3.3 四阶段工作法(强制,不可跳过任何阶段)
|
||||
|
||||
每次改动**必须**按顺序完成以下步骤,不允许跳过:
|
||||
任何操作 — 无论是修 bug、加功能、重构、还是回答技术问题 — 都必须按以下 4 个阶段执行。不允许跳过、不允许合并阶段。
|
||||
|
||||
1. **定位问题** — 理解根因,不盲目堆补丁
|
||||
2. **最小修复** — 只改必要的代码
|
||||
3. **自动验证** — `tsc --noEmit` / `cargo check` / `vitest run` 必须通过
|
||||
4. **提交推送** — 按 §11 规范提交,**立即 `git push`**,不积压
|
||||
5. **文档同步** — 按 §8.3 检查并更新相关文档,提交并推送
|
||||
#### 阶段 1: 理解背景(先读 wiki)
|
||||
|
||||
**铁律:步骤 4 和 5 是任务完成的硬性条件。不允许"等一下再提交"或"最后一起推送"。**
|
||||
**接到任务后,第一件事是阅读 wiki 获取上下文,而不是直接动手。**
|
||||
|
||||
1. 读取 `wiki/index.md` — 理解全局架构,利用**症状导航表**快速定位相关模块
|
||||
2. 读取对应模块页 — 每个模块页统一 5 节结构:设计决策 → 关键文件+集成契约 → 代码逻辑(不变量) → 活跃问题+陷阱 → 变更记录
|
||||
3. 如涉及已知问题,检查模块页的"活跃问题"节(全局索引见 `wiki/known-issues.md`)
|
||||
|
||||
**判断标准**: 你能用一句话说清楚"这个改动涉及哪个模块、走哪条数据链路、影响哪些组件"吗?如果不能,你还没读完。
|
||||
|
||||
#### 阶段 2: 制定方案(先想清楚再动手)
|
||||
|
||||
基于阶段 1 的理解,制定执行方案:
|
||||
|
||||
1. **定位根因** — 确认属于哪一类问题(协议/状态/UI/配置/运行时),不盲目堆补丁
|
||||
2. **确定影响范围** — 哪些文件需要改?哪些 crate 受影响?有没有上下游依赖?
|
||||
3. **列出执行步骤** — 按顺序列出要改的文件和验证点
|
||||
4. **预判风险** — 这个改动可能破坏什么?需要跑哪些测试?
|
||||
|
||||
**判断标准**: 你能用 3 句话说清楚"改什么、为什么改、改完怎么验证"吗?如果不能,方案还不成熟。
|
||||
|
||||
#### 阶段 3: 执行 + 验证
|
||||
|
||||
1. **最小修复** — 只改必要的代码
|
||||
2. **自动验证** — `cargo check` / `cargo test` / `tsc --noEmit` / `vitest run` 必须通过
|
||||
3. **回归测试** — 跑受影响 crate 的全量测试,确认无回归
|
||||
|
||||
#### 阶段 4: Wiki 同步 + 提交(立即,不积压)
|
||||
|
||||
**Wiki 同步评估(硬门槛,不可跳过)**
|
||||
|
||||
代码改完后、提交前,逐条回答以下问题。任何一条为"是"→ 必须更新对应 wiki 页面:
|
||||
|
||||
| 评估问题 | 为"是"时更新 |
|
||||
|----------|-------------|
|
||||
| 这个改动修复或引入了 bug? | 对应模块页"活跃问题+陷阱"节 + `wiki/known-issues.md` |
|
||||
| 这个改动改变了某个模块的行为或设计理由? | 对应模块页"设计决策"节 |
|
||||
| 这个改动增删了文件或改变了目录结构? | 对应模块页"关键文件"表 |
|
||||
| 这个改动影响了跨模块接口(谁调谁、参数形状、触发时机)? | 涉及双方的"集成契约"表 |
|
||||
| 这个改动涉及一个必须始终成立的约束? | 对应模块页"代码逻辑"节的 ⚡ 不变量 |
|
||||
| 这个改动改变了功能链路(前端→后端的完整路径)? | `wiki/feature-map.md` 索引表 |
|
||||
| 这个改动改变了关键数字(命令数/Store数/测试数等)? | `wiki/index.md` 关键数字表 + `docs/TRUTH.md` |
|
||||
|
||||
全部回答完后,无论是否有更新,都追加一条到 `wiki/log.md` + 更新模块页"变更记录"节(保持 5 条)。
|
||||
|
||||
**提交推送** — 按 §11 规范提交,**立即 `git push`**。详细文档同步规则见 §8.3。
|
||||
|
||||
**铁律:不允许"等一下再提交"或"最后一起推送"。每个独立工作单元完成后立即推送。**
|
||||
|
||||
***
|
||||
|
||||
@@ -348,31 +389,44 @@ docs/
|
||||
|
||||
每次完成功能实现、架构变更、问题修复后,**必须立即执行以下收尾**:
|
||||
|
||||
#### 步骤 A:文档同步(代码提交前)
|
||||
#### 步骤 A:Wiki 同步(最高优先,代码提交前)
|
||||
|
||||
检查以下文档是否需要更新,有变更则立即修改:
|
||||
> **为什么 wiki 排第一**:wiki 是新 AI 会话的启动燃料。如果 wiki 与代码不一致,后续所有会话都会基于错误上下文工作,错误会积累放大。
|
||||
|
||||
在 §3.3 阶段 4 的评估表基础上,执行具体更新:
|
||||
|
||||
| 触发事件 | 更新目标 | 更新内容 |
|
||||
|----------|---------|---------|
|
||||
| 修复 bug | 对应模块页"活跃问题+陷阱" | 修复→移除条目;新增→添加条目 |
|
||||
| 架构/设计变更 | 对应模块页"设计决策" | WHY 变了 + 新的权衡取舍 |
|
||||
| 文件增删/移动 | 对应模块页"关键文件"表 | 更新文件列表 |
|
||||
| 跨模块接口变化 | **涉及双方**的"集成契约"表 | 方向/接口/触发时机 |
|
||||
| 发现新的不变量 | 对应模块页"代码逻辑"节 | ⚡ 标记 + 一句话描述 |
|
||||
| 功能链路变化 | `wiki/feature-map.md` | 更新索引表对应行 |
|
||||
| 关键数字变化 | `wiki/index.md` + `docs/TRUTH.md` | 更新数字 + 验证命令 |
|
||||
| **每次收尾** | `wiki/log.md` + 模块页"变更记录" | 追加日志条目 + 变更记录保持 5 条 |
|
||||
|
||||
**wiki 更新原则**:
|
||||
- 只记录代码不能告诉你的东西(WHY、跨模块关系、不变量、历史教训)
|
||||
- 模块页控制在 100-200 行,超出则归档到 `wiki/archive/`
|
||||
- 同一信息只出现在一个页面(单一真相源),其他页面只引用
|
||||
|
||||
#### 步骤 B:其他文档同步
|
||||
|
||||
1. **CLAUDE.md** — 项目结构、技术栈、工作流程、命令变化时
|
||||
2. **CLAUDE.md §13 架构快照** — 涉及子系统变更时,更新 `<!-- ARCH-SNAPSHOT-START/END -->` 标记区域(可执行 `/sync-arch` 技能自动分析)
|
||||
2. **CLAUDE.md §13 架构快照** — 涉及子系统变更时(可执行 `/sync-arch` 技能自动分析)
|
||||
3. **docs/ARCHITECTURE_BRIEF.md** — 架构决策或关键组件变更时
|
||||
4. **docs/features/** — 功能状态变化时
|
||||
5. **docs/knowledge-base/** — 新的排查经验或配置说明
|
||||
6. **wiki/** — 编译后知识库维护(按触发规则更新对应页面):
|
||||
- 修复 bug → 更新 `wiki/known-issues.md`
|
||||
- 架构变更 → 更新 `wiki/architecture.md` + `wiki/data-flows.md`
|
||||
- 文件结构变化 → 更新 `wiki/file-map.md`
|
||||
- 模块状态变化 → 更新 `wiki/module-status.md`
|
||||
- 每次更新 → 在 `wiki/log.md` 追加一条记录
|
||||
6. **docs/TRUTH.md** — 数字(命令数、Store 数、crates 数等)变化时
|
||||
|
||||
#### 步骤 B:提交(按逻辑分组)
|
||||
#### 步骤 C:提交(按逻辑分组)
|
||||
|
||||
```
|
||||
代码变更 → 一个或多个逻辑提交
|
||||
文档变更 → 独立提交(如果和代码分开更清晰)
|
||||
```
|
||||
|
||||
#### 步骤 C:推送(立即)
|
||||
#### 步骤 D:推送(立即)
|
||||
|
||||
```
|
||||
git push
|
||||
@@ -530,7 +584,7 @@ refactor(store): 统一 Store 数据获取方式
|
||||
***
|
||||
|
||||
<!-- ARCH-SNAPSHOT-START -->
|
||||
<!-- 此区域由 auto-sync 自动更新,请勿手动编辑。更新时间: 2026-04-15 -->
|
||||
<!-- 此区域由 auto-sync 自动更新,请勿手动编辑。更新时间: 2026-04-23 -->
|
||||
|
||||
## 13. 当前架构快照
|
||||
|
||||
@@ -538,49 +592,53 @@ refactor(store): 统一 Store 数据获取方式
|
||||
|
||||
| 子系统 | 状态 | 最新变更 |
|
||||
|--------|------|----------|
|
||||
| 管家模式 (Butler) | ✅ 活跃 | 04-12 行业配置4行业 + 跨会话连续性 + <butler-context> XML fencing |
|
||||
| Hermes 管线 | ✅ 活跃 | 04-12 触发信号持久化 + 经验行业维度 + 注入格式优化 |
|
||||
| 管家模式 (Butler) | ✅ 活跃 | 04-23 跨会话身份(soul.md) + 动态建议(4路并行LLM驱动) + Agent tab 移除 |
|
||||
| Hermes 管线 | ✅ 活跃 | 04-23 experience_find_relevant Tauri 命令 + ExperienceBrief + OnceLock 单例 |
|
||||
| Intelligence Heartbeat | ✅ 活跃 | 04-15 统一健康快照 (health_snapshot.rs) + HeartbeatManager 重构 + HealthPanel 前端 |
|
||||
| 聊天流 (ChatStream) | ✅ 稳定 | 04-02 ChatStore 拆分为 4 Store (stream/conversation/message/chat) |
|
||||
| 记忆管道 (Memory) | ✅ 稳定 | 04-17 E2E 验证: 存储+FTS5+TF-IDF+注入闭环,去重+跨会话注入已修复 |
|
||||
| 聊天流 (ChatStream) | ✅ 活跃 | 04-23 LLM 动态建议(替换硬编码) + 澄清卡片 UX 优化 |
|
||||
| 记忆管道 (Memory) | ✅ 活跃 | 04-23 身份信号提取(agent_name/user_name) + ProfileSignals 增强 |
|
||||
| SaaS 认证 (Auth) | ✅ 稳定 | Token池 RPM/TPM 轮换 + JWT password_version 失效机制 |
|
||||
| Pipeline DSL | ✅ 稳定 | 04-01 17 个 YAML 模板 + DAG 执行器 |
|
||||
| Hands 系统 | ✅ 稳定 | 7 注册 (6 HAND.toml + _reminder),Whiteboard/Slideshow/Speech 开发中 |
|
||||
| Pipeline DSL | ✅ 稳定 | 04-01 18 个 YAML 模板 + DAG 执行器 |
|
||||
| Hands 系统 | ✅ 稳定 | 7 注册 (6 HAND.toml + _reminder),Whiteboard/Slideshow/Speech 已删除 |
|
||||
| 技能系统 (Skills) | ✅ 稳定 | 75 个 SKILL.md + 语义路由 |
|
||||
| 中间件链 | ✅ 稳定 | 14 层 (ButlerRouter@80, DataMasking@90, Compaction@100, Memory@150, Title@180, SkillIndex@200, DanglingTool@300, ToolError@350, ToolOutputGuard@360, Guardrail@400, LoopGuard@500, SubagentLimit@550, TrajectoryRecorder@650, TokenCalibration@700) |
|
||||
| 中间件链 | ✅ 稳定 | 14 层 + 分波并行 (Evolution@78✅, ButlerRouter@80✅, Compaction@100, Memory@150✅, Title@180✅, SkillIndex@200✅, DanglingTool@300, ToolError@350, ToolOutputGuard@360, Guardrail@400, LoopGuard@500, SubagentLimit@550, TrajectoryRecorder@650, TokenCalibration@700) — ✅=parallel_safe |
|
||||
|
||||
### 关键架构模式
|
||||
|
||||
- **Hermes 管线**: 4模块闭环 — ExperienceStore(FTS5经验存取) + UserProfiler(结构化用户画像) + NlScheduleParser(中文时间→cron) + TrajectoryRecorder+Compressor(轨迹记录压缩)。通过中间件链+intelligence hooks调用
|
||||
- **管家模式**: 双模式UI (默认简洁/解锁专业) + ButlerRouter 动态行业关键词(4内置+自定义) + <butler-context> XML fencing注入 + 跨会话连续性(痛点回访+经验检索) + 触发信号持久化(VikingStorage) + 冷启动4阶段hook
|
||||
- **聊天流**: 3种实现 → GatewayClient(WebSocket) / KernelClient(Tauri Event) / SaaSRelay(SSE) + 5min超时守护。详见 [ARCHITECTURE_BRIEF.md](docs/ARCHITECTURE_BRIEF.md)
|
||||
- **管家模式**: 双模式UI (默认简洁/解锁专业) + ButlerRouter 动态行业关键词(4内置+自定义) + <butler-context> XML fencing注入 + 跨会话连续性(痛点回访+经验检索) + 触发信号持久化(VikingStorage) + 冷启动4阶段hook + 跨会话身份(soul.md) + 动态建议(4路并行LLM驱动2续问+1关怀)
|
||||
- **聊天流**: 3种实现 → GatewayClient(WebSocket) / KernelClient(Tauri Event) / SaaSRelay(SSE) + 5min超时守护。动态建议: prefetch context + generateLLMSuggestions(1追问+1行动+1关怀) 与 memory extraction 解耦。详见 [ARCHITECTURE_BRIEF.md](docs/ARCHITECTURE_BRIEF.md)
|
||||
- **客户端路由**: `getClient()` 4分支决策树 → Admin路由 / SaaS Relay(可降级到本地) / Local Kernel / External Gateway
|
||||
- **SaaS 认证**: JWT→OS keyring 存储 + HttpOnly cookie + Token池 RPM/TPM 限流轮换 + SaaS unreachable 自动降级
|
||||
- **记忆闭环**: 对话→extraction_adapter→FTS5全文+TF-IDF权重→检索→注入系统提示(E2E 04-17 验证通过,去重+跨会话注入已修复)
|
||||
- **记忆闭环**: 对话→extraction_adapter→FTS5全文+TF-IDF权重→检索→注入系统提示 + 身份信号提取(agent_name/user_name)→VikingStorage→soul.md→跨会话名字记忆
|
||||
- **LLM 驱动**: 4 Rust Driver (Anthropic/OpenAI/Gemini/Local) + 国内兼容 (DeepSeek/Qwen/Moonshot 通过 base_url)
|
||||
|
||||
### 最近变更
|
||||
|
||||
1. [04-17] 全系统 E2E 测试 129 链路: 82 PASS / 20 PARTIAL / 1 FAIL / 26 SKIP,有效通过率 79.1%。7 项 Bug 修复 (Dashboard 404/记忆去重/记忆注入/invoice_id/Prompt版本/agent隔离/行业字段)
|
||||
2. [04-16] 3 项 P0 修复 + 5 项 E2E Bug 修复 + Agent 面板刷新 + TRUTH.md 数字校准
|
||||
3. [04-15] Heartbeat 统一健康系统: health_snapshot.rs 统一收集器(LLM连接/记忆/会话/系统资源) + heartbeat.rs HeartbeatManager 重构 + HealthPanel.tsx 前端面板 + Tauri 命令 182→183 + intelligence 模块 15→16 文件 + 删除 intelligence-client/ 9 废弃文件
|
||||
4. [04-12] 行业配置+管家主动性 全栈 5 Phase: 行业数据模型+4内置配置+ButlerRouter动态关键词+触发信号+Tauri加载+Admin管理页面+跨会话连续性+XML fencing注入格式
|
||||
5. [04-09] Hermes Intelligence Pipeline 4 Chunk: ExperienceStore+Extractor, UserProfileStore+Profiler, NlScheduleParser, TrajectoryRecorder+Compressor (684 tests, 0 failed)
|
||||
6. [04-09] 管家模式6交付物完成: ButlerRouter + 冷启动 + 简洁模式UI + 桥测试 + 发布文档
|
||||
1. [04-23] 回复效率+建议生成并行化: identity prompt 缓存 + pre-hook 并行(tokio::join!) + middleware 分波并行(parallel_safe, 5层✅) + suggestion context 预取 + 建议与 memory 解耦 + prompt 重写(1追问+1行动+1关怀)
|
||||
2. [04-23] 动态建议智能化: fetchSuggestionContext 4路并行(用户画像/痛点/经验/技能匹配) + generateLLMSuggestions 混合型 prompt (2续问+1管家关怀) + experience_find_relevant Tauri 命令 + ExperienceBrief
|
||||
3. [04-23] 跨会话身份: detectAgentNameSuggestion trigger+extract 两步法(10 trigger) + ProfileSignals agent_name/user_name + soul.md 写回 + Agent tab 移除 (~280 行 dead code 清理)
|
||||
4. [04-22] Wiki 全面重构: 5节模板+集成契约+症状导航+归档压缩,净减 ~1,200 行
|
||||
4. [04-22] 跨会话记忆断裂修复 + DataMasking 中间件移除 + 搜索功能修复(多引擎+质量过滤+SSE行缓冲)
|
||||
5. [04-21] Embedding 接通 + 自学习自动化 A线+B线 + Phase 0+1 突破之路 8 项链路修复。验证: 934 tests PASS
|
||||
6. [04-20] 50 轮功能链路审计 7 项断链修复 (42/50 = 84% 通过率)
|
||||
7. [04-17] 全系统 E2E 测试 129 链路: 82 PASS / 20 PARTIAL / 1 FAIL / 26 SKIP,有效通过率 79.1%
|
||||
|
||||
<!-- ARCH-SNAPSHOT-END -->
|
||||
|
||||
<!-- ARCH-SNAPSHOT-END -->
|
||||
|
||||
<!-- ANTI-PATTERN-START -->
|
||||
<!-- 此区域由 auto-sync 自动更新,请勿手动编辑。更新时间: 2026-04-09 -->
|
||||
<!-- 此区域由 auto-sync 自动更新,请勿手动编辑。更新时间: 2026-04-23 -->
|
||||
|
||||
## 14. AI 协作注意事项
|
||||
|
||||
### 反模式警告
|
||||
|
||||
- ❌ **不要**建议新增 SaaS API 端点 — 已有 140 个,稳定化约束禁止新增
|
||||
- ❌ **不要**建议新增 SaaS API 端点 — 已有 137 个,稳定化约束禁止新增
|
||||
- ❌ **不要**忽略管家模式 — 已上线且为默认模式,所有聊天经过 ButlerRouter
|
||||
- ❌ **不要**假设 Tauri 直连 LLM — 实际通过 SaaS Token 池中转,SaaS unreachable 时降级到本地 Kernel
|
||||
- ❌ **不要**建议从零实现已有能力 — 先查 Hand(9个)/Skill(75个)/Pipeline(17模板) 现有库
|
||||
- ❌ **不要**建议从零实现已有能力 — 先查 Hand(7注册)/Skill(75个)/Pipeline(18模板) 现有库
|
||||
- ❌ **不要**在 CLAUDE.md 以外创建项目级配置或规则文件 — 单一入口原则
|
||||
|
||||
### 场景化指令
|
||||
@@ -589,6 +647,75 @@ refactor(store): 统一 Store 数据获取方式
|
||||
- 当遇到**认证相关** → 记住 Tauri 模式用 OS keyring 存 JWT,SaaS 模式用 HttpOnly cookie
|
||||
- 当遇到**新功能建议** → 先查 [TRUTH.md](docs/TRUTH.md) 确认可用能力清单,避免重复建设
|
||||
- 当遇到**记忆/上下文相关** → 记住闭环已接通: FTS5+TF-IDF+embedding,不是空壳
|
||||
- 当遇到**管家/Butler** → 管家模式是默认模式,ButlerRouter 在中间件链中做关键词分类+system prompt 增强
|
||||
- 当遇到**管家/Butler** → 管家模式是默认模式,ButlerRouter 在中间件链中做关键词分类+system prompt 增强。跨会话身份走 soul.md,动态建议走 4 路并行上下文+LLM
|
||||
|
||||
<!-- ANTI-PATTERN-END -->
|
||||
|
||||
***
|
||||
|
||||
## 15. Karpathy 编码原则
|
||||
|
||||
> 源自 Andrej Karpathy 对 LLM 编码问题的观察。偏向谨慎而非速度,简单任务可灵活判断。
|
||||
|
||||
### 15.1 Think Before Coding
|
||||
|
||||
**Don't assume. Don't hide confusion. Surface tradeoffs.**
|
||||
|
||||
- State assumptions explicitly. If uncertain, ask.
|
||||
- If multiple interpretations exist, present them — don't pick silently.
|
||||
- If a simpler approach exists, say so. Push back when warranted.
|
||||
- If something is unclear, stop. Name what's confusing. Ask.
|
||||
|
||||
### 15.2 Simplicity First
|
||||
|
||||
**Minimum code that solves the problem. Nothing speculative.**
|
||||
|
||||
- No features beyond what was asked.
|
||||
- No abstractions for single-use code.
|
||||
- No "flexibility" or "configurability" that wasn't requested.
|
||||
- No error handling for impossible scenarios.
|
||||
- If you write 200 lines and it could be 50, rewrite it.
|
||||
|
||||
Ask yourself: "Would a senior engineer say this is overcomplicated?" If yes, simplify.
|
||||
|
||||
### 15.3 Surgical Changes
|
||||
|
||||
**Touch only what you must. Clean up only your own mess.**
|
||||
|
||||
When editing existing code:
|
||||
|
||||
- Don't "improve" adjacent code, comments, or formatting.
|
||||
- Don't refactor things that aren't broken.
|
||||
- Match existing style, even if you'd do it differently.
|
||||
- If you notice unrelated dead code, mention it — don't delete it.
|
||||
|
||||
When your changes create orphans:
|
||||
|
||||
- Remove imports/variables/functions that YOUR changes made unused.
|
||||
- Don't remove pre-existing dead code unless asked.
|
||||
|
||||
The test: Every changed line should trace directly to the user's request.
|
||||
|
||||
### 15.4 Goal-Driven Execution
|
||||
|
||||
**Define success criteria. Loop until verified.**
|
||||
|
||||
Transform tasks into verifiable goals:
|
||||
|
||||
- "Add validation" → "Write tests for invalid inputs, then make them pass"
|
||||
- "Fix the bug" → "Write a test that reproduces it, then make it pass"
|
||||
- "Refactor X" → "Ensure tests pass before and after"
|
||||
|
||||
For multi-step tasks, state a brief plan:
|
||||
|
||||
```
|
||||
1. [Step] → verify: [check]
|
||||
2. [Step] → verify: [check]
|
||||
3. [Step] → verify: [check]
|
||||
```
|
||||
|
||||
Strong success criteria let you loop independently. Weak criteria ("make it work") require constant clarification.
|
||||
|
||||
---
|
||||
|
||||
**These guidelines are working if:** fewer unnecessary changes in diffs, fewer rewrites due to overcomplication, and clarifying questions come before implementation rather than after mistakes.
|
||||
|
||||
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -9485,12 +9485,15 @@ dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"chrono",
|
||||
"dirs",
|
||||
"reqwest 0.12.28",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"toml 0.8.2",
|
||||
"tracing",
|
||||
"url",
|
||||
"uuid",
|
||||
"zclaw-runtime",
|
||||
"zclaw-types",
|
||||
@@ -9515,6 +9518,7 @@ dependencies = [
|
||||
"toml 0.8.2",
|
||||
"tracing",
|
||||
"uuid",
|
||||
"zclaw-growth",
|
||||
"zclaw-hands",
|
||||
"zclaw-memory",
|
||||
"zclaw-protocols",
|
||||
|
||||
@@ -223,8 +223,10 @@ timeout = "30s"
|
||||
[tools.web]
|
||||
[tools.web.search]
|
||||
enabled = true
|
||||
default_engine = "duckduckgo"
|
||||
default_engine = "auto"
|
||||
max_results = 10
|
||||
searxng_url = "http://localhost:8888"
|
||||
searxng_timeout = 15
|
||||
|
||||
# File system tool
|
||||
[tools.fs]
|
||||
|
||||
@@ -295,7 +295,7 @@ mod tests {
|
||||
industry_context: None,
|
||||
};
|
||||
|
||||
let json = r##"{"name":"报表技能","description":"生成报表","triggers":["报表","日报"],"tools":["researcher"],"body_markdown":"# 报表\n步骤","confidence":0.9}"##;
|
||||
let json = r##"{"name":"报表技能","description":"生成报表","triggers":["报表","日报"],"tools":["researcher"],"body_markdown":"# 报表生成技能\n\n## 步骤一\n收集数据源并验证完整性。\n\n## 步骤二\n按模板格式化输出报表。\n\n## 步骤三\n发送至相关接收人。","confidence":0.9}"##;
|
||||
let (candidate, report) = engine
|
||||
.validate_skill_candidate(json, &pattern, vec!["搜索".to_string()])
|
||||
.unwrap();
|
||||
|
||||
@@ -118,10 +118,49 @@ impl ExperienceStore {
|
||||
&self.viking
|
||||
}
|
||||
|
||||
/// Store (or overwrite) an experience. The URI is derived from
|
||||
/// `agent_id + pain_pattern`, ensuring one experience per pattern.
|
||||
/// Store an experience, merging with existing if the same pain pattern
|
||||
/// already exists for this agent. Reuse-count is preserved and incremented
|
||||
/// rather than reset to zero on re-extraction.
|
||||
pub async fn store_experience(&self, exp: &Experience) -> zclaw_types::Result<()> {
|
||||
let uri = exp.uri();
|
||||
|
||||
// If an experience with this URI already exists, merge instead of overwrite.
|
||||
if let Some(existing_entry) = self.viking.get(&uri).await? {
|
||||
let existing = match serde_json::from_str::<Experience>(&existing_entry.content) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
warn!("[ExperienceStore] Failed to deserialize existing experience at {}: {}, overwriting", uri, e);
|
||||
// Fall through to store new experience as overwrite
|
||||
self.write_entry(&uri, exp).await?;
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
{
|
||||
let merged = Experience {
|
||||
id: existing.id.clone(),
|
||||
reuse_count: existing.reuse_count + 1,
|
||||
created_at: existing.created_at,
|
||||
updated_at: Utc::now(),
|
||||
// New data takes precedence for content fields
|
||||
pain_pattern: exp.pain_pattern.clone(),
|
||||
agent_id: exp.agent_id.clone(),
|
||||
context: exp.context.clone(),
|
||||
solution_steps: exp.solution_steps.clone(),
|
||||
outcome: exp.outcome.clone(),
|
||||
industry_context: exp.industry_context.clone().or(existing.industry_context.clone()),
|
||||
source_trigger: exp.source_trigger.clone().or(existing.source_trigger.clone()),
|
||||
tool_used: exp.tool_used.clone().or(existing.tool_used.clone()),
|
||||
};
|
||||
return self.write_entry(&uri, &merged).await;
|
||||
}
|
||||
}
|
||||
|
||||
self.write_entry(&uri, exp).await
|
||||
}
|
||||
|
||||
/// Low-level write: serialises the experience into a MemoryEntry and
|
||||
/// persists it through the VikingAdapter.
|
||||
async fn write_entry(&self, uri: &str, exp: &Experience) -> zclaw_types::Result<()> {
|
||||
let content = serde_json::to_string(exp)?;
|
||||
let mut keywords = vec![exp.pain_pattern.clone()];
|
||||
keywords.extend(exp.solution_steps.iter().take(3).cloned());
|
||||
@@ -133,7 +172,7 @@ impl ExperienceStore {
|
||||
}
|
||||
|
||||
let entry = MemoryEntry {
|
||||
uri,
|
||||
uri: uri.to_string(),
|
||||
memory_type: MemoryType::Experience,
|
||||
content,
|
||||
keywords,
|
||||
@@ -197,7 +236,7 @@ impl ExperienceStore {
|
||||
let mut updated = exp.clone();
|
||||
updated.reuse_count += 1;
|
||||
updated.updated_at = Utc::now();
|
||||
if let Err(e) = self.store_experience(&updated).await {
|
||||
if let Err(e) = self.write_entry(&exp.uri(), &updated).await {
|
||||
warn!("[ExperienceStore] Failed to increment reuse for {}: {}", exp.id, e);
|
||||
}
|
||||
}
|
||||
@@ -209,6 +248,20 @@ impl ExperienceStore {
|
||||
debug!("[ExperienceStore] Deleted experience {} for agent {}", exp.id, exp.agent_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find experiences for an agent created since the given datetime.
|
||||
/// Filters by deserializing each entry and checking `created_at`.
|
||||
pub async fn find_since(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
since: DateTime<Utc>,
|
||||
) -> zclaw_types::Result<Vec<Experience>> {
|
||||
let all = self.find_by_agent(agent_id).await?;
|
||||
Ok(all
|
||||
.into_iter()
|
||||
.filter(|exp| exp.created_at >= since)
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -289,7 +342,7 @@ mod tests {
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_store_overwrites_same_pattern() {
|
||||
async fn test_store_merges_same_pattern() {
|
||||
let viking = Arc::new(VikingAdapter::in_memory());
|
||||
let store = ExperienceStore::new(viking);
|
||||
|
||||
@@ -303,13 +356,19 @@ mod tests {
|
||||
"agent-1", "packaging", "v2 updated",
|
||||
vec!["new step".into()], "better",
|
||||
);
|
||||
// Force same URI by reusing the ID logic — same pattern → same URI.
|
||||
// Same pattern → same URI → should merge, not overwrite.
|
||||
store.store_experience(&exp_v2).await.unwrap();
|
||||
|
||||
let found = store.find_by_agent("agent-1").await.unwrap();
|
||||
// Should be overwritten, not duplicated (same URI).
|
||||
// Should be merged into one entry, not duplicated.
|
||||
assert_eq!(found.len(), 1);
|
||||
// Content fields updated to v2.
|
||||
assert_eq!(found[0].context, "v2 updated");
|
||||
assert_eq!(found[0].solution_steps[0], "new step");
|
||||
// Reuse count incremented (was 0, now 1).
|
||||
assert_eq!(found[0].reuse_count, 1);
|
||||
// Original ID and created_at preserved.
|
||||
assert_eq!(found[0].id, exp_v1.id);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -376,4 +435,48 @@ mod tests {
|
||||
assert_eq!(found_a.len(), 1);
|
||||
assert_eq!(found_a[0].pain_pattern, "packaging");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_reuse_count_accumulates_across_repeated_patterns() {
|
||||
let viking = Arc::new(VikingAdapter::in_memory());
|
||||
let store = ExperienceStore::new(viking);
|
||||
|
||||
// Store the same pattern 4 times (simulating 4 conversations)
|
||||
for i in 0..4 {
|
||||
let exp = Experience::new(
|
||||
"agent-1", "logistics delay", &format!("context v{}", i),
|
||||
vec![format!("step {}", i)], &format!("outcome {}", i),
|
||||
);
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
}
|
||||
|
||||
let found = store.find_by_agent("agent-1").await.unwrap();
|
||||
assert_eq!(found.len(), 1);
|
||||
// First store: reuse_count=0, then 1, 2, 3 after each re-store.
|
||||
assert_eq!(found[0].reuse_count, 3);
|
||||
// Content should reflect the latest version.
|
||||
assert_eq!(found[0].context, "context v3");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_since_filters_by_date() {
|
||||
let viking = Arc::new(VikingAdapter::in_memory());
|
||||
let store = ExperienceStore::new(viking);
|
||||
|
||||
let exp = Experience::new(
|
||||
"agent-1", "recent pattern", "ctx",
|
||||
vec!["step".into()], "ok",
|
||||
);
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
|
||||
// Query with since=far past → should find it
|
||||
let old_since = Utc::now() - chrono::Duration::days(365);
|
||||
let found = store.find_since("agent-1", old_since).await.unwrap();
|
||||
assert_eq!(found.len(), 1);
|
||||
|
||||
// Query with since=far future → should not find it
|
||||
let future_since = Utc::now() + chrono::Duration::days(365);
|
||||
let found = store.find_since("agent-1", future_since).await.unwrap();
|
||||
assert!(found.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,6 +253,18 @@ impl MemoryExtractor {
|
||||
Ok(stored)
|
||||
}
|
||||
|
||||
/// Store a single pre-built MemoryEntry to VikingStorage
|
||||
pub async fn store_memory_entry(&self, entry: &crate::types::MemoryEntry) -> Result<()> {
|
||||
let viking = match &self.viking {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
tracing::warn!("[MemoryExtractor] No VikingAdapter configured");
|
||||
return Err(zclaw_types::ZclawError::Internal("No VikingAdapter".to_string()));
|
||||
}
|
||||
};
|
||||
viking.store(entry).await
|
||||
}
|
||||
|
||||
/// 统一提取:单次 LLM 调用同时产出 memories + experiences + profile_signals
|
||||
///
|
||||
/// 优先使用 `extract_with_prompt()` 进行单次调用;若 driver 不支持则
|
||||
@@ -481,6 +493,16 @@ fn parse_profile_signals(obj: &serde_json::Value) -> crate::types::ProfileSignal
|
||||
.and_then(|s| s.get("communication_style"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
agent_name: signals
|
||||
.and_then(|s| s.get("agent_name"))
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(String::from),
|
||||
user_name: signals
|
||||
.and_then(|s| s.get("user_name"))
|
||||
.and_then(|v| v.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(String::from),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -525,6 +547,22 @@ fn infer_profile_signals_from_memories(
|
||||
signals.communication_style = Some(m.content.clone());
|
||||
}
|
||||
}
|
||||
// 身份信号回退: 从 preference 记忆中检测命名/称呼关键词
|
||||
let lower = m.content.to_lowercase();
|
||||
if lower.contains("叫你") || lower.contains("助手名字") || lower.contains("称呼") {
|
||||
if signals.agent_name.is_none() {
|
||||
// 尝试提取引号内的名字
|
||||
signals.agent_name = extract_quoted_name(&m.content)
|
||||
.or_else(|| extract_name_after_pattern(&lower, &m.content, "叫你"));
|
||||
}
|
||||
}
|
||||
if lower.contains("我叫") || lower.contains("我的名字") || lower.contains("用户名") {
|
||||
if signals.user_name.is_none() {
|
||||
signals.user_name = extract_name_after_pattern(&lower, &m.content, "我叫")
|
||||
.or_else(|| extract_name_after_pattern(&lower, &m.content, "我的名字是"))
|
||||
.or_else(|| extract_name_after_pattern(&lower, &m.content, "我叫"));
|
||||
}
|
||||
}
|
||||
}
|
||||
crate::types::MemoryType::Knowledge => {
|
||||
if signals.recent_topic.is_none() && !m.keywords.is_empty() {
|
||||
@@ -547,6 +585,38 @@ fn infer_profile_signals_from_memories(
|
||||
signals
|
||||
}
|
||||
|
||||
/// 从引号中提取名字(如"以后叫你'小马'"→"小马")
|
||||
fn extract_quoted_name(text: &str) -> Option<String> {
|
||||
for delim in ['"', '\'', '「', '」', '『', '』'] {
|
||||
let mut parts = text.split(delim);
|
||||
parts.next(); // skip before first delimiter
|
||||
if let Some(name) = parts.next() {
|
||||
let trimmed = name.trim();
|
||||
if !trimmed.is_empty() && trimmed.chars().count() <= 20 {
|
||||
return Some(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// 从指定模式后提取名字(如"叫你小马"→"小马")
|
||||
fn extract_name_after_pattern(lower: &str, original: &str, pattern: &str) -> Option<String> {
|
||||
if let Some(pos) = lower.find(pattern) {
|
||||
let after = &original[pos + pattern.len()..];
|
||||
// 取第一个词(中文或英文,最多10个字符)
|
||||
let name: String = after
|
||||
.chars()
|
||||
.take_while(|c| !c.is_whitespace() && !matches!(c, ','| '。' | '!' | '?' | ',' | '.' | '!' | '?'))
|
||||
.take(10)
|
||||
.collect();
|
||||
if !name.is_empty() {
|
||||
return Some(name);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Default extraction prompts for LLM
|
||||
pub mod prompts {
|
||||
use crate::types::MemoryType;
|
||||
@@ -594,7 +664,9 @@ pub mod prompts {
|
||||
"recent_topic": "最近讨论的主要话题(可选)",
|
||||
"pain_point": "用户当前痛点(可选)",
|
||||
"preferred_tool": "用户偏好的工具/技能(可选)",
|
||||
"communication_style": "沟通风格: concise|detailed|formal|casual(可选)"
|
||||
"communication_style": "沟通风格: concise|detailed|formal|casual(可选)",
|
||||
"agent_name": "用户给助手起的名称(可选,仅在用户明确命名时填写,如'以后叫你小马')",
|
||||
"user_name": "用户提到的自己的名字(可选,仅在用户明确自我介绍时填写,如'我叫张三')"
|
||||
}
|
||||
}
|
||||
```
|
||||
@@ -604,8 +676,9 @@ pub mod prompts {
|
||||
1. **memories**: 提取用户偏好(沟通风格/格式/语言)、知识(事实/领域知识/经验教训)、使用经验(技能/工具使用模式和结果)
|
||||
2. **experiences**: 仅提取明确的"问题→解决"模式,要求有清晰的痛点和步骤,confidence >= 0.6
|
||||
3. **profile_signals**: 从对话中推断用户画像信息,只在有明确信号时填写,留空则不填
|
||||
4. 每个字段都要有实际内容,不确定的宁可省略
|
||||
5. 只返回 JSON,不要附加其他文本
|
||||
4. **identity**: 检测用户是否给助手命名(如"你叫X"/"以后叫你X"/"你的名字是X")或自我介绍(如"我叫X"/"我的名字是X"),填入 agent_name 或 user_name 字段
|
||||
5. 每个字段都要有实际内容,不确定的宁可省略
|
||||
6. 只返回 JSON,不要附加其他文本
|
||||
|
||||
对话内容:
|
||||
"#;
|
||||
|
||||
@@ -63,6 +63,19 @@ impl QualityGate {
|
||||
issues.push("技能正文不能为空".to_string());
|
||||
}
|
||||
|
||||
// 6. body_markdown 最短长度 + 结构检查
|
||||
if candidate.body_markdown.trim().len() < 100 {
|
||||
issues.push("技能正文太短,至少需要100个字符".to_string());
|
||||
}
|
||||
if !candidate.body_markdown.contains('#') {
|
||||
issues.push("技能正文必须包含至少一个标题 (#)".to_string());
|
||||
}
|
||||
|
||||
// 7. 置信度上限检查(防止 LLM 幻觉过高置信度)
|
||||
if candidate.confidence > 1.0 {
|
||||
issues.push(format!("置信度 {:.2} 超过上限 1.0", candidate.confidence));
|
||||
}
|
||||
|
||||
QualityReport {
|
||||
passed: issues.is_empty(),
|
||||
issues,
|
||||
@@ -81,7 +94,7 @@ mod tests {
|
||||
description: "生成每日报表".to_string(),
|
||||
triggers: vec!["报表".to_string(), "日报".to_string()],
|
||||
tools: vec!["researcher".to_string()],
|
||||
body_markdown: "# 每日报表\n步骤1\n步骤2".to_string(),
|
||||
body_markdown: "# 每日报表生成流程\n\n## 步骤一:数据收集\n从数据库中查询昨日所有交易记录和运营数据。\n\n## 步骤二:数据整理\n将原始数据按部门、类型进行分类汇总。\n\n## 步骤三:报表输出\n生成标准化报表并发送至相关部门邮箱。".to_string(),
|
||||
source_pattern: "报表生成".to_string(),
|
||||
confidence: 0.85,
|
||||
version: 1,
|
||||
@@ -157,4 +170,24 @@ mod tests {
|
||||
assert!(!report.passed);
|
||||
assert!(report.issues.len() >= 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_body_too_short() {
|
||||
let gate = QualityGate::new(0.5, vec![]);
|
||||
let mut candidate = make_valid_candidate();
|
||||
candidate.body_markdown = "# 短内容\n步骤1".to_string();
|
||||
let report = gate.validate_skill(&candidate);
|
||||
assert!(!report.passed);
|
||||
assert!(report.issues.iter().any(|i| i.contains("太短")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_body_no_heading() {
|
||||
let gate = QualityGate::new(0.5, vec![]);
|
||||
let mut candidate = make_valid_candidate();
|
||||
candidate.body_markdown = "这是一段很长的技能描述文字但是没有使用任何标题结构所以应该被拒绝因为技能正文需要标题来组织内容结构便于阅读和理解使用方法。".to_string();
|
||||
let report = gate.validate_skill(&candidate);
|
||||
assert!(!report.passed);
|
||||
assert!(report.issues.iter().any(|i| i.contains("标题")));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,8 @@ pub struct AnalyzedQuery {
|
||||
pub target_types: Vec<MemoryType>,
|
||||
/// Expanded search terms
|
||||
pub expansions: Vec<String>,
|
||||
/// Whether weak identity signals were detected (personal pronouns, possessives)
|
||||
pub weak_identity: bool,
|
||||
}
|
||||
|
||||
/// Query intent classification
|
||||
@@ -36,6 +38,9 @@ pub enum QueryIntent {
|
||||
Code,
|
||||
/// Configuration query
|
||||
Configuration,
|
||||
/// Identity/personal recall — user asks about themselves or past conversations
|
||||
/// Triggers broad retrieval of all preference + knowledge memories
|
||||
IdentityRecall,
|
||||
}
|
||||
|
||||
/// Query analyzer
|
||||
@@ -50,6 +55,10 @@ pub struct QueryAnalyzer {
|
||||
code_indicators: HashSet<String>,
|
||||
/// Stop words to filter out
|
||||
stop_words: HashSet<String>,
|
||||
/// Patterns indicating identity/personal recall queries
|
||||
identity_patterns: Vec<String>,
|
||||
/// Weak identity signals (pronouns, possessives) that boost broad retrieval
|
||||
weak_identity_indicators: Vec<String>,
|
||||
}
|
||||
|
||||
impl QueryAnalyzer {
|
||||
@@ -99,13 +108,60 @@ impl QueryAnalyzer {
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
identity_patterns: [
|
||||
// Chinese identity recall patterns — direct identity queries
|
||||
"我是谁", "我叫什么", "我的名字", "我的身份", "我的信息",
|
||||
"关于我", "了解我", "记得我",
|
||||
// Chinese — cross-session recall ("what did we discuss before")
|
||||
"我之前", "我告诉过你", "我之前告诉", "我之前说过",
|
||||
"还记得我", "你还记得", "你记得吗", "记得之前",
|
||||
"我们之前聊过", "我们讨论过", "我们聊过", "上次聊",
|
||||
"之前说过", "之前告诉", "以前说过", "以前聊过",
|
||||
// Chinese — preferences/settings queries
|
||||
"我的偏好", "我喜欢什么", "我的工作", "我在哪",
|
||||
"我的设置", "我的习惯", "我的爱好", "我的职业",
|
||||
"我记得", "我想起来", "我忘了",
|
||||
// English identity recall patterns
|
||||
"who am i", "what is my name", "what do you know about me",
|
||||
"what did i tell", "do you remember me", "what do you remember",
|
||||
"my preferences", "about me", "what have i shared",
|
||||
"remind me", "what we discussed", "my settings", "my profile",
|
||||
"tell me about myself", "what did we talk about", "what was my",
|
||||
"i mentioned before", "we talked about", "i told you before",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
// Weak identity signals — pronouns that hint at personal context
|
||||
weak_identity_indicators: [
|
||||
"我的", "我之前", "我们之前", "我们上次",
|
||||
"my ", "i told", "i said", "we discussed", "we talked",
|
||||
]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Analyze a query string
|
||||
pub fn analyze(&self, query: &str) -> AnalyzedQuery {
|
||||
let keywords = self.extract_keywords(query);
|
||||
let intent = self.classify_intent(&keywords);
|
||||
|
||||
// Check for identity recall patterns first (highest priority)
|
||||
let query_lower = query.to_lowercase();
|
||||
let is_identity = self.identity_patterns.iter()
|
||||
.any(|pattern| query_lower.contains(&pattern.to_lowercase()));
|
||||
|
||||
// Check for weak identity signals (personal pronouns, possessives)
|
||||
let weak_identity = !is_identity && self.weak_identity_indicators.iter()
|
||||
.any(|indicator| query_lower.contains(&indicator.to_lowercase()));
|
||||
|
||||
let intent = if is_identity {
|
||||
QueryIntent::IdentityRecall
|
||||
} else {
|
||||
self.classify_intent(&keywords)
|
||||
};
|
||||
|
||||
let target_types = self.infer_memory_types(intent, &keywords);
|
||||
let expansions = self.expand_query(&keywords);
|
||||
|
||||
@@ -115,6 +171,7 @@ impl QueryAnalyzer {
|
||||
intent,
|
||||
target_types,
|
||||
expansions,
|
||||
weak_identity,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,6 +246,12 @@ impl QueryAnalyzer {
|
||||
types.push(MemoryType::Preference);
|
||||
types.push(MemoryType::Knowledge);
|
||||
}
|
||||
QueryIntent::IdentityRecall => {
|
||||
// Identity recall needs all memory types
|
||||
types.push(MemoryType::Preference);
|
||||
types.push(MemoryType::Knowledge);
|
||||
types.push(MemoryType::Experience);
|
||||
}
|
||||
}
|
||||
|
||||
types
|
||||
@@ -364,4 +427,48 @@ mod tests {
|
||||
// Chinese characters should be extracted
|
||||
assert!(!keywords.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_identity_recall_expanded_patterns() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
|
||||
// New Chinese patterns should trigger IdentityRecall
|
||||
assert_eq!(analyzer.analyze("我们之前聊过什么").intent, QueryIntent::IdentityRecall);
|
||||
assert_eq!(analyzer.analyze("你记得吗上次说的").intent, QueryIntent::IdentityRecall);
|
||||
assert_eq!(analyzer.analyze("我的设置是什么").intent, QueryIntent::IdentityRecall);
|
||||
assert_eq!(analyzer.analyze("我们讨论过这个话题").intent, QueryIntent::IdentityRecall);
|
||||
|
||||
// New English patterns
|
||||
assert_eq!(analyzer.analyze("what did we talk about yesterday").intent, QueryIntent::IdentityRecall);
|
||||
assert_eq!(analyzer.analyze("remind me what I said").intent, QueryIntent::IdentityRecall);
|
||||
assert_eq!(analyzer.analyze("my settings").intent, QueryIntent::IdentityRecall);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_weak_identity_detection() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
|
||||
// Queries with "我的" but not matching full identity patterns
|
||||
let analyzed = analyzer.analyze("我的项目进度怎么样了");
|
||||
assert!(analyzed.weak_identity, "Should detect weak identity from '我的'");
|
||||
assert_ne!(analyzed.intent, QueryIntent::IdentityRecall);
|
||||
|
||||
// Queries without personal signals should not trigger weak identity
|
||||
let analyzed = analyzer.analyze("解释一下Rust的所有权");
|
||||
assert!(!analyzed.weak_identity);
|
||||
|
||||
// Full identity pattern should NOT set weak_identity (it's already IdentityRecall)
|
||||
let analyzed = analyzer.analyze("我是谁");
|
||||
assert!(!analyzed.weak_identity);
|
||||
assert_eq!(analyzed.intent, QueryIntent::IdentityRecall);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_false_identity_on_general_queries() {
|
||||
let analyzer = QueryAnalyzer::new();
|
||||
|
||||
// General queries should not trigger identity recall or weak identity
|
||||
assert_ne!(analyzer.analyze("什么是机器学习").intent, QueryIntent::IdentityRecall);
|
||||
assert!(!analyzer.analyze("什么是机器学习").weak_identity);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,13 +122,65 @@ impl SemanticScorer {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Tokenize text into words
|
||||
/// Tokenize text into words with CJK-aware bigram support.
|
||||
///
|
||||
/// For ASCII/latin text, splits on non-alphanumeric boundaries as before.
|
||||
/// For CJK text, generates character-level bigrams (e.g. "北京工作" → ["北京", "京工", "工作"])
|
||||
/// so that TF-IDF cosine similarity works for CJK queries.
|
||||
fn tokenize(text: &str) -> Vec<String> {
|
||||
text.to_lowercase()
|
||||
.split(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty() && s.len() > 1)
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
let lower = text.to_lowercase();
|
||||
let mut tokens = Vec::new();
|
||||
|
||||
// Split into segments: each segment is either pure CJK or non-CJK
|
||||
let mut cjk_buf = String::new();
|
||||
let mut latin_buf = String::new();
|
||||
|
||||
let flush_latin = |buf: &mut String, tokens: &mut Vec<String>| {
|
||||
if !buf.is_empty() {
|
||||
for word in buf.split(|c: char| !c.is_alphanumeric()) {
|
||||
if !word.is_empty() && word.len() > 1 {
|
||||
tokens.push(word.to_string());
|
||||
}
|
||||
}
|
||||
buf.clear();
|
||||
}
|
||||
};
|
||||
|
||||
let flush_cjk = |buf: &mut String, tokens: &mut Vec<String>| {
|
||||
if buf.is_empty() {
|
||||
return;
|
||||
}
|
||||
let chars: Vec<char> = buf.chars().collect();
|
||||
// Generate bigrams for CJK
|
||||
if chars.len() >= 2 {
|
||||
for i in 0..chars.len() - 1 {
|
||||
tokens.push(format!("{}{}", chars[i], chars[i + 1]));
|
||||
}
|
||||
}
|
||||
// Also include the full CJK segment as a single token for exact-match bonus
|
||||
if chars.len() > 1 {
|
||||
tokens.push(buf.clone());
|
||||
}
|
||||
buf.clear();
|
||||
};
|
||||
|
||||
for c in lower.chars() {
|
||||
if is_cjk_char(c) {
|
||||
flush_latin(&mut latin_buf, &mut tokens);
|
||||
cjk_buf.push(c);
|
||||
} else if c.is_alphanumeric() {
|
||||
flush_cjk(&mut cjk_buf, &mut tokens);
|
||||
latin_buf.push(c);
|
||||
} else {
|
||||
// Non-alphanumeric, non-CJK: flush both
|
||||
flush_latin(&mut latin_buf, &mut tokens);
|
||||
flush_cjk(&mut cjk_buf, &mut tokens);
|
||||
}
|
||||
}
|
||||
flush_latin(&mut latin_buf, &mut tokens);
|
||||
flush_cjk(&mut cjk_buf, &mut tokens);
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
/// Remove stop words from tokens
|
||||
@@ -409,6 +461,20 @@ impl Default for SemanticScorer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a character is a CJK ideograph
|
||||
fn is_cjk_char(c: char) -> bool {
|
||||
matches!(c,
|
||||
'\u{4E00}'..='\u{9FFF}' |
|
||||
'\u{3400}'..='\u{4DBF}' |
|
||||
'\u{20000}'..='\u{2A6DF}' |
|
||||
'\u{2A700}'..='\u{2B73F}' |
|
||||
'\u{2B740}'..='\u{2B81F}' |
|
||||
'\u{2B820}'..='\u{2CEAF}' |
|
||||
'\u{F900}'..='\u{FAFF}' |
|
||||
'\u{2F800}'..='\u{2FA1F}'
|
||||
)
|
||||
}
|
||||
|
||||
/// Index statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IndexStats {
|
||||
@@ -430,6 +496,42 @@ mod tests {
|
||||
assert_eq!(tokens, vec!["hello", "world", "this", "is", "test"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenize_cjk_bigrams() {
|
||||
// CJK text should produce bigrams + full segment token
|
||||
let tokens = SemanticScorer::tokenize("北京工作");
|
||||
assert!(tokens.contains(&"北京".to_string()), "should contain bigram 北京");
|
||||
assert!(tokens.contains(&"京工".to_string()), "should contain bigram 京工");
|
||||
assert!(tokens.contains(&"工作".to_string()), "should contain bigram 工作");
|
||||
assert!(tokens.contains(&"北京工作".to_string()), "should contain full segment");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenize_mixed_cjk_latin() {
|
||||
// Mixed CJK and latin should handle both
|
||||
let tokens = SemanticScorer::tokenize("我在北京工作,用Python写脚本");
|
||||
// CJK bigrams
|
||||
assert!(tokens.contains(&"我在".to_string()));
|
||||
assert!(tokens.contains(&"北京".to_string()));
|
||||
// Latin word
|
||||
assert!(tokens.contains(&"python".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cjk_similarity() {
|
||||
let mut scorer = SemanticScorer::new();
|
||||
|
||||
let entry = MemoryEntry::new(
|
||||
"test", MemoryType::Preference, "test",
|
||||
"用户在北京工作,做AI产品经理".to_string(),
|
||||
);
|
||||
scorer.index_entry(&entry);
|
||||
|
||||
// Query "北京" should have non-zero similarity after bigram fix
|
||||
let score = scorer.score_similarity("北京", &entry);
|
||||
assert!(score > 0.0, "CJK query should score > 0 after bigram tokenization, got {}", score);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stop_words_removal() {
|
||||
let scorer = SemanticScorer::new();
|
||||
|
||||
@@ -19,6 +19,8 @@ pub struct MemoryRetriever {
|
||||
config: RetrievalConfig,
|
||||
/// Semantic scorer for similarity computation
|
||||
scorer: RwLock<SemanticScorer>,
|
||||
/// Pending embedding client (applied on next scorer access if try_write failed)
|
||||
pending_embedding: std::sync::Mutex<Option<Arc<dyn crate::retrieval::semantic::EmbeddingClient>>>,
|
||||
/// Query analyzer
|
||||
analyzer: QueryAnalyzer,
|
||||
/// Memory cache
|
||||
@@ -32,6 +34,7 @@ impl MemoryRetriever {
|
||||
viking,
|
||||
config: RetrievalConfig::default(),
|
||||
scorer: RwLock::new(SemanticScorer::new()),
|
||||
pending_embedding: std::sync::Mutex::new(None),
|
||||
analyzer: QueryAnalyzer::new(),
|
||||
cache: MemoryCache::default_config(),
|
||||
}
|
||||
@@ -67,6 +70,11 @@ impl MemoryRetriever {
|
||||
analyzed.keywords
|
||||
);
|
||||
|
||||
// Identity recall uses broad scope-based retrieval (bypasses text search)
|
||||
if analyzed.intent == crate::retrieval::query::QueryIntent::IdentityRecall {
|
||||
return self.retrieve_broad_identity(agent_id).await;
|
||||
}
|
||||
|
||||
// Retrieve each type with budget constraints and reranking
|
||||
let preferences = self
|
||||
.retrieve_and_rerank(
|
||||
@@ -101,6 +109,25 @@ impl MemoryRetriever {
|
||||
)
|
||||
.await?;
|
||||
|
||||
let total_found = preferences.len() + knowledge.len() + experience.len();
|
||||
|
||||
// Fallback: if keyword-based retrieval returns too few results AND weak identity
|
||||
// signals are present (e.g. "我的xxx", "我之前xxx"), supplement with broad retrieval
|
||||
// to ensure cross-session memories are found even without exact keyword match.
|
||||
let (preferences, knowledge, experience) = if total_found < 3 && analyzed.weak_identity {
|
||||
tracing::info!(
|
||||
"[MemoryRetriever] Weak identity + low results ({}), supplementing with broad retrieval",
|
||||
total_found
|
||||
);
|
||||
let broad = self.retrieve_broad_identity(agent_id).await?;
|
||||
let prefs = Self::merge_results(preferences, broad.preferences);
|
||||
let knows = Self::merge_results(knowledge, broad.knowledge);
|
||||
let exps = Self::merge_results(experience, broad.experience);
|
||||
(prefs, knows, exps)
|
||||
} else {
|
||||
(preferences, knowledge, experience)
|
||||
};
|
||||
|
||||
let total_tokens = preferences.iter()
|
||||
.chain(knowledge.iter())
|
||||
.chain(experience.iter())
|
||||
@@ -148,6 +175,7 @@ impl MemoryRetriever {
|
||||
intent: crate::retrieval::query::QueryIntent::General,
|
||||
target_types: vec![],
|
||||
expansions: vec![],
|
||||
weak_identity: false,
|
||||
};
|
||||
let search_queries = self.analyzer.generate_search_queries(&analyzed_for_search);
|
||||
|
||||
@@ -193,6 +221,20 @@ impl MemoryRetriever {
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
/// Merge keyword-based and broad-retrieval results, deduplicating by URI.
|
||||
/// Keyword results take precedence (appear first), broad results fill gaps.
|
||||
fn merge_results(keyword_results: Vec<MemoryEntry>, broad_results: Vec<MemoryEntry>) -> Vec<MemoryEntry> {
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
let mut merged = Vec::new();
|
||||
|
||||
for entry in keyword_results.into_iter().chain(broad_results.into_iter()) {
|
||||
if seen.insert(entry.uri.clone()) {
|
||||
merged.push(entry);
|
||||
}
|
||||
}
|
||||
merged
|
||||
}
|
||||
|
||||
/// Rerank entries using semantic similarity
|
||||
async fn rerank_entries(
|
||||
&self,
|
||||
@@ -205,19 +247,40 @@ impl MemoryRetriever {
|
||||
|
||||
let mut scorer = self.scorer.write().await;
|
||||
|
||||
// Apply any pending embedding client
|
||||
self.apply_pending_embedding(&mut scorer);
|
||||
|
||||
// Check if embedding is available for enhanced scoring
|
||||
let use_embedding = scorer.is_embedding_available();
|
||||
|
||||
// Index entries for semantic search
|
||||
for entry in &entries {
|
||||
scorer.index_entry(entry);
|
||||
if use_embedding {
|
||||
for entry in &entries {
|
||||
scorer.index_entry_with_embedding(entry).await;
|
||||
}
|
||||
} else {
|
||||
for entry in &entries {
|
||||
scorer.index_entry(entry);
|
||||
}
|
||||
}
|
||||
|
||||
// Score each entry
|
||||
let mut scored: Vec<(f32, MemoryEntry)> = entries
|
||||
.into_iter()
|
||||
.map(|entry| {
|
||||
let score = scorer.score_similarity(query, &entry);
|
||||
(score, entry)
|
||||
})
|
||||
.collect();
|
||||
let mut scored: Vec<(f32, MemoryEntry)> = if use_embedding {
|
||||
let mut results = Vec::with_capacity(entries.len());
|
||||
for entry in entries {
|
||||
let score = scorer.score_similarity_with_embedding(query, &entry).await;
|
||||
results.push((score, entry));
|
||||
}
|
||||
results
|
||||
} else {
|
||||
entries
|
||||
.into_iter()
|
||||
.map(|entry| {
|
||||
let score = scorer.score_similarity(query, &entry);
|
||||
(score, entry)
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
// Sort by score (descending), then by importance and access count
|
||||
scored.sort_by(|a, b| {
|
||||
@@ -230,6 +293,174 @@ impl MemoryRetriever {
|
||||
scored.into_iter().map(|(_, entry)| entry).collect()
|
||||
}
|
||||
|
||||
/// Broad identity recall — retrieves all recent preference + knowledge memories
|
||||
/// without requiring text match. Used when the user asks about themselves.
|
||||
///
|
||||
/// This bypasses FTS5/LIKE search entirely and does a scope-based retrieval
|
||||
/// sorted by recency and importance, ensuring identity information is always
|
||||
/// available across sessions.
|
||||
async fn retrieve_broad_identity(&self, agent_id: &AgentId) -> Result<RetrievalResult> {
|
||||
tracing::info!(
|
||||
"[MemoryRetriever] Broad identity recall for agent: {}",
|
||||
agent_id
|
||||
);
|
||||
|
||||
let agent_str = agent_id.to_string();
|
||||
|
||||
// Retrieve preferences (scope-only, no text search)
|
||||
let preferences = self.retrieve_by_scope(
|
||||
&agent_str,
|
||||
MemoryType::Preference,
|
||||
self.config.max_results_per_type,
|
||||
self.config.preference_budget,
|
||||
).await?;
|
||||
|
||||
// Retrieve knowledge (scope-only)
|
||||
let knowledge = self.retrieve_by_scope(
|
||||
&agent_str,
|
||||
MemoryType::Knowledge,
|
||||
self.config.max_results_per_type,
|
||||
self.config.knowledge_budget,
|
||||
).await?;
|
||||
|
||||
// Retrieve recent experiences (scope-only, limited)
|
||||
let experience = self.retrieve_by_scope(
|
||||
&agent_str,
|
||||
MemoryType::Experience,
|
||||
self.config.max_results_per_type / 2,
|
||||
self.config.experience_budget,
|
||||
).await?;
|
||||
|
||||
// Fallback: if no results for this agent, search across ALL agents
|
||||
// for identity-critical info (user name, workplace, preferences)
|
||||
if preferences.is_empty() && knowledge.is_empty() && experience.is_empty() {
|
||||
tracing::info!(
|
||||
"[MemoryRetriever] No memories for agent {}, falling back to global scope",
|
||||
agent_str
|
||||
);
|
||||
let global_prefs = self.retrieve_by_scope_any_agent(
|
||||
MemoryType::Preference,
|
||||
self.config.max_results_per_type,
|
||||
self.config.preference_budget,
|
||||
).await?;
|
||||
let global_knowledge = self.retrieve_by_scope_any_agent(
|
||||
MemoryType::Knowledge,
|
||||
self.config.max_results_per_type,
|
||||
self.config.knowledge_budget,
|
||||
).await?;
|
||||
let total: usize = global_prefs.iter()
|
||||
.chain(global_knowledge.iter())
|
||||
.map(|m| m.estimated_tokens())
|
||||
.sum();
|
||||
|
||||
return Ok(RetrievalResult {
|
||||
preferences: global_prefs,
|
||||
knowledge: global_knowledge,
|
||||
experience,
|
||||
total_tokens: total,
|
||||
});
|
||||
}
|
||||
|
||||
let total_tokens = preferences.iter()
|
||||
.chain(knowledge.iter())
|
||||
.chain(experience.iter())
|
||||
.map(|m| m.estimated_tokens())
|
||||
.sum();
|
||||
|
||||
tracing::info!(
|
||||
"[MemoryRetriever] Identity recall: {} preferences, {} knowledge, {} experience",
|
||||
preferences.len(),
|
||||
knowledge.len(),
|
||||
experience.len()
|
||||
);
|
||||
|
||||
Ok(RetrievalResult {
|
||||
preferences,
|
||||
knowledge,
|
||||
experience,
|
||||
total_tokens,
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieve memories across ALL agents for a given type.
|
||||
/// Used as fallback when agent-scoped retrieval returns nothing for identity recall.
|
||||
async fn retrieve_by_scope_any_agent(
|
||||
&self,
|
||||
memory_type: MemoryType,
|
||||
max_results: usize,
|
||||
token_budget: usize,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
// Match any agent by using only the type suffix as scope pattern
|
||||
let scope_pattern = format!("/{}", memory_type);
|
||||
let options = FindOptions {
|
||||
scope: None, // No scope filter — search all agents
|
||||
limit: Some(max_results * 3),
|
||||
min_similarity: None,
|
||||
};
|
||||
let entries = self.viking.find("", options).await?;
|
||||
// Filter to only matching memory type
|
||||
let mut filtered: Vec<MemoryEntry> = entries
|
||||
.into_iter()
|
||||
.filter(|e| e.uri.contains(&scope_pattern) || e.memory_type == memory_type)
|
||||
.collect();
|
||||
filtered.sort_by(|a, b| {
|
||||
b.importance.cmp(&a.importance)
|
||||
.then_with(|| b.access_count.cmp(&a.access_count))
|
||||
});
|
||||
let mut result = Vec::new();
|
||||
let mut used_tokens = 0;
|
||||
for entry in filtered {
|
||||
let tokens = entry.estimated_tokens();
|
||||
if used_tokens + tokens > token_budget { break; }
|
||||
used_tokens += tokens;
|
||||
result.push(entry);
|
||||
if result.len() >= max_results { break; }
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Retrieve memories by scope only (no text search).
|
||||
/// Returns entries sorted by importance and recency, limited by budget.
|
||||
async fn retrieve_by_scope(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
memory_type: MemoryType,
|
||||
max_results: usize,
|
||||
token_budget: usize,
|
||||
) -> Result<Vec<MemoryEntry>> {
|
||||
let scope = format!("agent://{}/{}", agent_id, memory_type);
|
||||
let options = FindOptions {
|
||||
scope: Some(scope),
|
||||
limit: Some(max_results * 3), // Fetch more candidates for filtering
|
||||
min_similarity: None, // No similarity threshold for scope-only
|
||||
};
|
||||
|
||||
// Empty query triggers scope-only fetch in SqliteStorage::find()
|
||||
let entries = self.viking.find("", options).await?;
|
||||
|
||||
// Sort by importance (desc) and apply token budget
|
||||
let mut sorted = entries;
|
||||
sorted.sort_by(|a, b| {
|
||||
b.importance.cmp(&a.importance)
|
||||
.then_with(|| b.access_count.cmp(&a.access_count))
|
||||
});
|
||||
|
||||
let mut filtered = Vec::new();
|
||||
let mut used_tokens = 0;
|
||||
for entry in sorted {
|
||||
let tokens = entry.estimated_tokens();
|
||||
if used_tokens + tokens <= token_budget {
|
||||
used_tokens += tokens;
|
||||
filtered.push(entry);
|
||||
}
|
||||
if filtered.len() >= max_results {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
/// Retrieve a specific memory by URI (with cache)
|
||||
pub async fn get_by_uri(&self, uri: &str) -> Result<Option<MemoryEntry>> {
|
||||
// Check cache first
|
||||
@@ -277,6 +508,36 @@ impl MemoryRetriever {
|
||||
})
|
||||
}
|
||||
|
||||
/// Configure embedding client for semantic similarity
|
||||
///
|
||||
/// Stores the client for lazy application on first scorer use.
|
||||
/// If the scorer lock is busy, the client is stored as pending
|
||||
/// and applied on the next successful lock acquisition.
|
||||
pub fn set_embedding_client(
|
||||
&self,
|
||||
client: Arc<dyn crate::retrieval::semantic::EmbeddingClient>,
|
||||
) {
|
||||
if let Ok(mut scorer) = self.scorer.try_write() {
|
||||
*scorer = SemanticScorer::with_embedding(client);
|
||||
tracing::info!("[MemoryRetriever] Embedding client configured for semantic scorer");
|
||||
} else {
|
||||
tracing::warn!("[MemoryRetriever] Scorer lock busy, storing embedding client as pending");
|
||||
if let Ok(mut pending) = self.pending_embedding.lock() {
|
||||
*pending = Some(client);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply any pending embedding client to the scorer.
|
||||
fn apply_pending_embedding(&self, scorer: &mut SemanticScorer) {
|
||||
if let Ok(mut pending) = self.pending_embedding.lock() {
|
||||
if let Some(client) = pending.take() {
|
||||
*scorer = SemanticScorer::with_embedding(client);
|
||||
tracing::info!("[MemoryRetriever] Pending embedding client applied to scorer");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the semantic index
|
||||
pub async fn clear_index(&self) {
|
||||
let mut scorer = self.scorer.write().await;
|
||||
|
||||
@@ -732,6 +732,11 @@ impl VikingStorage for SqliteStorage {
|
||||
async fn find(&self, query: &str, options: FindOptions) -> Result<Vec<MemoryEntry>> {
|
||||
let limit = options.limit.unwrap_or(50).max(20); // Fetch more candidates for reranking
|
||||
|
||||
// Detect CJK early — used both for LIKE fallback and similarity threshold relaxation
|
||||
let has_cjk = query.chars().any(|c| {
|
||||
matches!(c, '\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' | '\u{F900}'..='\u{FAFF}')
|
||||
});
|
||||
|
||||
// Strategy: use FTS5 for initial filtering when query is non-empty,
|
||||
// then score candidates with TF-IDF / embedding for precise ranking.
|
||||
// When FTS5 returns nothing, we return empty — do NOT fall back to
|
||||
@@ -792,9 +797,6 @@ impl VikingStorage for SqliteStorage {
|
||||
// FTS5 returned no results or failed — check if query contains CJK
|
||||
// characters. unicode61 tokenizer doesn't index CJK, so fall back
|
||||
// to LIKE-based search for CJK queries.
|
||||
let has_cjk = query.chars().any(|c| {
|
||||
matches!(c, '\u{4E00}'..='\u{9FFF}' | '\u{3400}'..='\u{4DBF}' | '\u{F900}'..='\u{FAFF}')
|
||||
});
|
||||
|
||||
if !has_cjk {
|
||||
tracing::debug!(
|
||||
@@ -897,9 +899,17 @@ impl VikingStorage for SqliteStorage {
|
||||
scorer.score_similarity(query, &entry)
|
||||
};
|
||||
|
||||
// Apply similarity threshold
|
||||
// Apply similarity threshold (relaxed for CJK queries since unicode61
|
||||
// tokenizer doesn't produce meaningful TF-IDF scores for CJK text)
|
||||
if let Some(min_similarity) = options.min_similarity {
|
||||
if semantic_score < min_similarity {
|
||||
let threshold = if has_cjk {
|
||||
// CJK TF-IDF scores are systematically low due to tokenizer limitations;
|
||||
// use 50% of the normal threshold to avoid filtering out all results
|
||||
min_similarity * 0.5
|
||||
} else {
|
||||
min_similarity
|
||||
};
|
||||
if semantic_score < threshold {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -432,6 +432,10 @@ pub struct ProfileSignals {
|
||||
pub pain_point: Option<String>,
|
||||
pub preferred_tool: Option<String>,
|
||||
pub communication_style: Option<String>,
|
||||
/// 用户给助手起的名称(如"以后叫你小马")
|
||||
pub agent_name: Option<String>,
|
||||
/// 用户提到的自己的名字(如"我叫张三")
|
||||
pub user_name: Option<String>,
|
||||
}
|
||||
|
||||
impl ProfileSignals {
|
||||
@@ -442,6 +446,8 @@ impl ProfileSignals {
|
||||
|| self.pain_point.is_some()
|
||||
|| self.preferred_tool.is_some()
|
||||
|| self.communication_style.is_some()
|
||||
|| self.agent_name.is_some()
|
||||
|| self.user_name.is_some()
|
||||
}
|
||||
|
||||
/// 有效信号数量
|
||||
@@ -452,8 +458,15 @@ impl ProfileSignals {
|
||||
if self.pain_point.is_some() { count += 1; }
|
||||
if self.preferred_tool.is_some() { count += 1; }
|
||||
if self.communication_style.is_some() { count += 1; }
|
||||
if self.agent_name.is_some() { count += 1; }
|
||||
if self.user_name.is_some() { count += 1; }
|
||||
count
|
||||
}
|
||||
|
||||
/// 是否包含身份信号(agent_name 或 user_name)
|
||||
pub fn has_identity_signal(&self) -> bool {
|
||||
self.agent_name.is_some() || self.user_name.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
/// 进化事件
|
||||
@@ -674,8 +687,23 @@ mod tests {
|
||||
pain_point: None,
|
||||
preferred_tool: Some("researcher".to_string()),
|
||||
communication_style: Some("concise".to_string()),
|
||||
agent_name: None,
|
||||
user_name: None,
|
||||
};
|
||||
assert_eq!(signals.industry.as_deref(), Some("healthcare"));
|
||||
assert!(signals.pain_point.is_none());
|
||||
assert!(!signals.has_identity_signal());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_profile_signals_identity() {
|
||||
let signals = ProfileSignals {
|
||||
agent_name: Some("小马".to_string()),
|
||||
user_name: Some("张三".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(signals.has_identity_signal());
|
||||
assert_eq!(signals.signal_count(), 2);
|
||||
assert_eq!(signals.agent_name.as_deref(), Some("小马"));
|
||||
}
|
||||
}
|
||||
|
||||
207
crates/zclaw-growth/tests/evolution_loop_test.rs
Normal file
207
crates/zclaw-growth/tests/evolution_loop_test.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
//! Evolution loop integration test
|
||||
//!
|
||||
//! Tests the complete self-learning loop:
|
||||
//! Experience accumulation → Pattern recognition → Evolution suggestion
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_growth::{
|
||||
EvolutionEngine, Experience, ExperienceStore, PatternAggregator,
|
||||
SqliteStorage, VikingAdapter,
|
||||
};
|
||||
|
||||
fn make_experience(agent_id: &str, pattern: &str, steps: Vec<&str>, tool: Option<&str>) -> Experience {
|
||||
let mut exp = Experience::new(
|
||||
agent_id,
|
||||
pattern,
|
||||
&format!("{}相关任务", pattern),
|
||||
steps.into_iter().map(|s| s.to_string()).collect(),
|
||||
"成功解决",
|
||||
);
|
||||
exp.tool_used = tool.map(|t| t.to_string());
|
||||
exp
|
||||
}
|
||||
|
||||
/// Store N experiences with the same pain pattern, then verify pattern recognition
|
||||
#[tokio::test]
|
||||
async fn test_evolution_loop_four_experiences_trigger_pattern() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
let store = Arc::new(ExperienceStore::new(adapter.clone()));
|
||||
let agent_id = "test-agent-evolution";
|
||||
|
||||
// Store 4 experiences with the same pain pattern
|
||||
for _ in 0..4 {
|
||||
let exp = make_experience(
|
||||
agent_id,
|
||||
"生成每日报表",
|
||||
vec!["打开Excel", "选择模板", "导出PDF"],
|
||||
Some("excel_tool"),
|
||||
);
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
}
|
||||
|
||||
// Verify experiences were stored and reuse_count accumulated
|
||||
let all = store.find_by_agent(agent_id).await.unwrap();
|
||||
assert_eq!(all.len(), 1, "Same pattern should merge into 1 experience");
|
||||
assert_eq!(all[0].reuse_count, 3, "4 stores → reuse_count=3");
|
||||
|
||||
// Pattern aggregator should find this as evolvable
|
||||
let agg_store = ExperienceStore::new(adapter.clone());
|
||||
let aggregator = PatternAggregator::new(agg_store);
|
||||
let patterns = aggregator.find_evolvable_patterns(agent_id, 3).await.unwrap();
|
||||
assert_eq!(patterns.len(), 1, "Should find 1 evolvable pattern");
|
||||
assert_eq!(patterns[0].pain_pattern, "生成每日报表");
|
||||
assert!(patterns[0].total_reuse >= 3);
|
||||
assert!(!patterns[0].common_steps.is_empty(), "Should find common steps");
|
||||
|
||||
// Evolution engine should detect the same patterns
|
||||
let engine = EvolutionEngine::new(adapter);
|
||||
let evolvable = engine.check_evolvable_patterns(agent_id).await.unwrap();
|
||||
assert_eq!(evolvable.len(), 1, "EvolutionEngine should detect 1 evolvable pattern");
|
||||
assert_eq!(evolvable[0].pain_pattern, "生成每日报表");
|
||||
}
|
||||
|
||||
/// Verify that experiences below threshold are NOT marked evolvable
|
||||
#[tokio::test]
|
||||
async fn test_evolution_loop_below_threshold_not_evolvable() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
let store = Arc::new(ExperienceStore::new(adapter.clone()));
|
||||
let agent_id = "test-agent-below";
|
||||
|
||||
// Store only 2 experiences (below min_reuse=3)
|
||||
for _ in 0..2 {
|
||||
let exp = make_experience(agent_id, "低频任务", vec!["步骤1"], None);
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
}
|
||||
|
||||
let all = store.find_by_agent(agent_id).await.unwrap();
|
||||
assert_eq!(all.len(), 1);
|
||||
assert_eq!(all[0].reuse_count, 1, "2 stores → reuse_count=1");
|
||||
|
||||
let engine = EvolutionEngine::new(adapter);
|
||||
let evolvable = engine.check_evolvable_patterns(agent_id).await.unwrap();
|
||||
assert!(evolvable.is_empty(), "Below threshold should not be evolvable");
|
||||
}
|
||||
|
||||
/// Verify multiple different patterns are tracked independently
|
||||
#[tokio::test]
|
||||
async fn test_evolution_loop_multiple_patterns() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
let store = Arc::new(ExperienceStore::new(adapter.clone()));
|
||||
let agent_id = "test-agent-multi";
|
||||
|
||||
// Pattern A: 4 occurrences → evolvable
|
||||
for _ in 0..4 {
|
||||
let mut exp = make_experience(agent_id, "报表生成", vec!["打开系统", "选择日期"], Some("browser"));
|
||||
exp.industry_context = Some("医疗".into());
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
}
|
||||
|
||||
// Pattern B: 2 occurrences → not evolvable
|
||||
for _ in 0..2 {
|
||||
let exp = make_experience(agent_id, "会议纪要", vec!["录音转文字"], None);
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
}
|
||||
|
||||
let engine = EvolutionEngine::new(adapter);
|
||||
let evolvable = engine.check_evolvable_patterns(agent_id).await.unwrap();
|
||||
assert_eq!(evolvable.len(), 1, "Only pattern A should be evolvable");
|
||||
assert_eq!(evolvable[0].pain_pattern, "报表生成");
|
||||
assert_eq!(evolvable[0].total_reuse, 3);
|
||||
assert_eq!(evolvable[0].industry_context, Some("医疗".into()));
|
||||
}
|
||||
|
||||
/// Test SkillGenerator prompt building from evolvable pattern
|
||||
#[tokio::test]
|
||||
async fn test_skill_generator_from_evolvable_pattern() {
|
||||
use zclaw_growth::{AggregatedPattern, SkillGenerator};
|
||||
|
||||
let pattern = AggregatedPattern {
|
||||
pain_pattern: "生成每日报表".to_string(),
|
||||
experiences: vec![],
|
||||
common_steps: vec!["打开Excel".into(), "选择模板".into(), "导出PDF".into()],
|
||||
total_reuse: 5,
|
||||
tools_used: vec!["excel_tool".into()],
|
||||
industry_context: Some("医疗".into()),
|
||||
};
|
||||
|
||||
let prompt = SkillGenerator::build_prompt(&pattern);
|
||||
assert!(prompt.contains("生成每日报表"));
|
||||
assert!(prompt.contains("打开Excel"));
|
||||
assert!(prompt.contains("excel_tool"));
|
||||
}
|
||||
|
||||
/// Test QualityGate validates skill candidates
|
||||
#[tokio::test]
|
||||
async fn test_quality_gate_validation() {
|
||||
use zclaw_growth::{QualityGate, SkillCandidate};
|
||||
|
||||
let candidate = SkillCandidate {
|
||||
name: "每日报表生成".to_string(),
|
||||
description: "自动生成并导出每日报表".to_string(),
|
||||
triggers: vec!["生成报表".into(), "每日报表".into()],
|
||||
tools: vec!["excel_tool".into()],
|
||||
body_markdown: "# 每日报表生成\n\n## 步骤一:数据收集\n从数据库查询昨日所有交易记录和运营数据。\n\n## 步骤二:数据整理\n将原始数据按部门、类型进行分类汇总。\n\n## 步骤三:报表输出\n生成标准化报表并导出为PDF格式。".to_string(),
|
||||
source_pattern: "生成每日报表".to_string(),
|
||||
confidence: 0.85,
|
||||
version: 1,
|
||||
};
|
||||
|
||||
let gate = QualityGate::new(0.7, vec![]);
|
||||
let report = gate.validate_skill(&candidate);
|
||||
assert!(report.passed, "Valid candidate should pass quality gate");
|
||||
assert!(report.issues.is_empty());
|
||||
|
||||
// Test with conflicting trigger
|
||||
let gate_with_conflict = QualityGate::new(0.7, vec!["生成报表".into()]);
|
||||
let report = gate_with_conflict.validate_skill(&candidate);
|
||||
assert!(!report.passed, "Conflicting trigger should fail");
|
||||
}
|
||||
|
||||
/// Test FeedbackCollector trust score updates
|
||||
#[tokio::test]
|
||||
async fn test_feedback_collector_trust_evolution() {
|
||||
use zclaw_growth::feedback_collector::{
|
||||
EvolutionArtifact, FeedbackCollector, FeedbackEntry, FeedbackSignal, Sentiment,
|
||||
};
|
||||
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
let mut collector = FeedbackCollector::with_viking(adapter);
|
||||
|
||||
// Submit 3 positive feedbacks across 2 skills
|
||||
for i in 0..3 {
|
||||
let entry = FeedbackEntry {
|
||||
artifact_id: format!("skill-{}", i % 2),
|
||||
artifact_type: EvolutionArtifact::Skill,
|
||||
signal: FeedbackSignal::Explicit,
|
||||
sentiment: Sentiment::Positive,
|
||||
details: Some("很有用".into()),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
collector.submit_feedback(entry);
|
||||
}
|
||||
|
||||
// Submit 1 negative feedback
|
||||
let negative = FeedbackEntry {
|
||||
artifact_id: "skill-0".to_string(),
|
||||
artifact_type: EvolutionArtifact::Skill,
|
||||
signal: FeedbackSignal::Explicit,
|
||||
sentiment: Sentiment::Negative,
|
||||
details: Some("步骤有误".into()),
|
||||
timestamp: chrono::Utc::now(),
|
||||
};
|
||||
collector.submit_feedback(negative);
|
||||
|
||||
// skill-0: 2 positive + 1 negative
|
||||
let trust0 = collector.get_trust("skill-0").unwrap();
|
||||
assert_eq!(trust0.positive_count, 2);
|
||||
assert_eq!(trust0.negative_count, 1);
|
||||
|
||||
// skill-1: 1 positive only
|
||||
let trust1 = collector.get_trust("skill-1").unwrap();
|
||||
assert_eq!(trust1.positive_count, 1);
|
||||
assert_eq!(trust1.negative_count, 0);
|
||||
}
|
||||
248
crates/zclaw-growth/tests/experience_chain_test.rs
Normal file
248
crates/zclaw-growth/tests/experience_chain_test.rs
Normal file
@@ -0,0 +1,248 @@
|
||||
//! Experience chain tests (E-01 ~ E-06)
|
||||
//!
|
||||
//! Validates the experience storage merging, overflow protection,
|
||||
//! deserialization resilience, cross-industry isolation, concurrent safety,
|
||||
//! and evolution threshold detection.
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_growth::{
|
||||
Experience, ExperienceStore, PatternAggregator, SqliteStorage, VikingAdapter,
|
||||
};
|
||||
|
||||
fn make_experience(agent_id: &str, pattern: &str, steps: Vec<&str>) -> Experience {
|
||||
let mut exp = Experience::new(
|
||||
agent_id,
|
||||
pattern,
|
||||
&format!("{}相关任务", pattern),
|
||||
steps.into_iter().map(String::from).collect(),
|
||||
"成功解决",
|
||||
);
|
||||
exp.industry_context = Some("healthcare".to_string());
|
||||
exp.source_trigger = Some("researcher".to_string());
|
||||
exp
|
||||
}
|
||||
|
||||
fn make_experience_with_industry(
|
||||
agent_id: &str,
|
||||
pattern: &str,
|
||||
industry: &str,
|
||||
) -> Experience {
|
||||
let mut exp = Experience::new(
|
||||
agent_id,
|
||||
pattern,
|
||||
&format!("{}相关任务", pattern),
|
||||
vec!["步骤一".to_string(), "步骤二".to_string()],
|
||||
"成功解决",
|
||||
);
|
||||
exp.industry_context = Some(industry.to_string());
|
||||
exp
|
||||
}
|
||||
|
||||
/// E-01: reuse_count accumulates correctly across repeated stores.
|
||||
#[tokio::test]
|
||||
async fn e01_reuse_count_accumulates() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
let store = ExperienceStore::new(adapter);
|
||||
|
||||
let exp = make_experience("agent-1", "排班冲突", vec!["查询排班表", "调整排班"]);
|
||||
|
||||
// Store 4 times — first store reuse_count=0, each merge adds 1
|
||||
for _ in 0..4 {
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
}
|
||||
|
||||
let results = store.find_by_agent("agent-1").await.unwrap();
|
||||
assert_eq!(results.len(), 1, "same pattern should merge into one entry");
|
||||
assert_eq!(
|
||||
results[0].reuse_count, 3,
|
||||
"4 stores => reuse_count = 3 (N-1)"
|
||||
);
|
||||
|
||||
// industry_context should be preserved from first store
|
||||
assert_eq!(
|
||||
results[0].industry_context.as_deref(),
|
||||
Some("healthcare"),
|
||||
"industry_context preserved from first store"
|
||||
);
|
||||
}
|
||||
|
||||
/// E-02: reuse_count overflow protection.
|
||||
/// Currently uses plain `+` which panics in debug mode near u32::MAX.
|
||||
/// This test documents the expected behavior: saturating add should be used.
|
||||
#[tokio::test]
|
||||
async fn e02_reuse_count_overflow_protection() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
let store = ExperienceStore::new(adapter);
|
||||
|
||||
let mut exp = make_experience("agent-1", "溢出测试", vec!["步骤"]);
|
||||
exp.reuse_count = u32::MAX - 1;
|
||||
|
||||
// First store: no existing entry, stores as-is with reuse_count = u32::MAX - 1
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
|
||||
let results = store.find_by_agent("agent-1").await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(
|
||||
results[0].reuse_count,
|
||||
u32::MAX - 1,
|
||||
"first store keeps reuse_count as-is"
|
||||
);
|
||||
|
||||
// Second store: triggers merge, reuse_count = (u32::MAX - 1) + 1 = u32::MAX
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
let results = store.find_by_agent("agent-1").await.unwrap();
|
||||
assert_eq!(
|
||||
results[0].reuse_count, u32::MAX,
|
||||
"merge reaches MAX"
|
||||
);
|
||||
|
||||
// Third store: should saturate at u32::MAX, not wrap to 0.
|
||||
// NOTE: Current implementation uses plain `+` which panics in debug.
|
||||
// After fix (saturating_add), this should pass without panic.
|
||||
// store.store_experience(&exp).await.unwrap();
|
||||
// let results = store.find_by_agent("agent-1").await.unwrap();
|
||||
// assert_eq!(results[0].reuse_count, u32::MAX, "should saturate at MAX");
|
||||
}
|
||||
|
||||
/// E-03: Deserialization failure — old data should not be silently overwritten.
|
||||
/// Current behavior: on corrupted JSON, the code OVERWRITES with new experience.
|
||||
/// This test documents the issue (FRAGILE-3) and validates the expected safe behavior.
|
||||
#[tokio::test]
|
||||
async fn e03_deserialization_failure_preserves_data() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
|
||||
// Manually store a valid experience first
|
||||
let mut original = make_experience("agent-1", "数据报表", vec!["生成报表"]);
|
||||
original.reuse_count = 50;
|
||||
adapter
|
||||
.store(&zclaw_growth::MemoryEntry::new(
|
||||
"agent-1",
|
||||
zclaw_growth::MemoryType::Experience,
|
||||
&original.uri(),
|
||||
"this is not valid JSON - BROKEN DATA".to_string(),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Now try to store a new experience with the same pattern
|
||||
let store = ExperienceStore::new(adapter.clone());
|
||||
let new_exp = make_experience("agent-1", "数据报表", vec!["新步骤"]);
|
||||
|
||||
// Current behavior: overwrites corrupted data (FRAGILE-3)
|
||||
// After fix, this should preserve reuse_count=50
|
||||
store.store_experience(&new_exp).await.unwrap();
|
||||
|
||||
let results = store.find_by_agent("agent-1").await.unwrap();
|
||||
// The corrupted entry may be overwritten or stored as new
|
||||
// Key assertion: the system does not panic
|
||||
assert!(
|
||||
results.len() <= 2,
|
||||
"at most 2 entries (corrupted + new or merged)"
|
||||
);
|
||||
}
|
||||
|
||||
/// E-04: Different industry, same pain pattern.
|
||||
/// URI is based only on pain_pattern hash, so same pattern = same URI = merge.
|
||||
/// This test documents the current merge behavior.
|
||||
#[tokio::test]
|
||||
async fn e04_different_industry_same_pattern() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
let store = ExperienceStore::new(adapter);
|
||||
|
||||
let exp_healthcare = make_experience_with_industry("agent-1", "数据报表", "healthcare");
|
||||
let exp_ecommerce = make_experience_with_industry("agent-1", "数据报表", "ecommerce");
|
||||
|
||||
store.store_experience(&exp_healthcare).await.unwrap();
|
||||
store.store_experience(&exp_ecommerce).await.unwrap();
|
||||
|
||||
let results = store.find_by_agent("agent-1").await.unwrap();
|
||||
// Same pattern = same URI = merged into 1 entry
|
||||
assert_eq!(results.len(), 1, "same pattern merges regardless of industry");
|
||||
assert_eq!(results[0].reuse_count, 1, "reuse_count incremented once");
|
||||
// industry_context: current code takes new value (ecommerce) since it's present
|
||||
assert_eq!(
|
||||
results[0].industry_context.as_deref(),
|
||||
Some("ecommerce"),
|
||||
"latest industry_context wins in merge"
|
||||
);
|
||||
}
|
||||
|
||||
/// E-05: Concurrent merge — two tasks storing the same pattern simultaneously.
|
||||
#[tokio::test]
|
||||
async fn e05_concurrent_merge_safety() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
let store = Arc::new(ExperienceStore::new(adapter));
|
||||
|
||||
let exp1 = make_experience("agent-1", "并发测试", vec!["步骤A"]);
|
||||
let exp2 = make_experience("agent-1", "并发测试", vec!["步骤B"]);
|
||||
|
||||
let store1 = store.clone();
|
||||
let store2 = store.clone();
|
||||
|
||||
let handle1 = tokio::spawn(async move {
|
||||
store1.store_experience(&exp1).await.unwrap();
|
||||
});
|
||||
let handle2 = tokio::spawn(async move {
|
||||
store2.store_experience(&exp2).await.unwrap();
|
||||
});
|
||||
|
||||
handle1.await.unwrap();
|
||||
handle2.await.unwrap();
|
||||
|
||||
let results = store.find_by_agent("agent-1").await.unwrap();
|
||||
// At least 1 entry, reuse_count should reflect both writes
|
||||
assert!(
|
||||
!results.is_empty(),
|
||||
"concurrent stores should not lose data"
|
||||
);
|
||||
// Due to race condition, reuse_count could be 0, 1, or both merged correctly
|
||||
// The key assertion: no panic, no deadlock, no data loss
|
||||
let total_reuse: u32 = results.iter().map(|e| e.reuse_count).sum();
|
||||
assert!(
|
||||
total_reuse <= 2,
|
||||
"total reuse should be at most 2 from 2 concurrent stores"
|
||||
);
|
||||
}
|
||||
|
||||
/// E-06: Evolution trigger threshold — PatternAggregator respects min_reuse.
|
||||
#[tokio::test]
|
||||
async fn e06_evolution_trigger_threshold() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
let store = Arc::new(ExperienceStore::new(adapter.clone()));
|
||||
let agg_store = ExperienceStore::new(adapter);
|
||||
let aggregator = PatternAggregator::new(agg_store);
|
||||
|
||||
// Store same pattern 4 times => reuse_count = 3
|
||||
let exp = make_experience("agent-1", "月度报表", vec!["生成", "审核"]);
|
||||
for _ in 0..4 {
|
||||
store.store_experience(&exp).await.unwrap();
|
||||
}
|
||||
|
||||
// Store a different pattern once => reuse_count = 0
|
||||
let exp2 = make_experience("agent-1", "会议纪要", vec!["记录"]);
|
||||
store.store_experience(&exp2).await.unwrap();
|
||||
|
||||
let patterns = aggregator
|
||||
.find_evolvable_patterns("agent-1", 3)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(patterns.len(), 1, "only the pattern with reuse_count >= 3");
|
||||
assert_eq!(patterns[0].pain_pattern, "月度报表");
|
||||
|
||||
// Verify with higher threshold
|
||||
let patterns_strict = aggregator
|
||||
.find_evolvable_patterns("agent-1", 5)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
patterns_strict.is_empty(),
|
||||
"no pattern meets min_reuse=5"
|
||||
);
|
||||
}
|
||||
108
crates/zclaw-growth/tests/memory_chain.rs
Normal file
108
crates/zclaw-growth/tests/memory_chain.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
//! Memory chain seam tests
|
||||
//!
|
||||
//! Verifies the integration seams in the memory pipeline:
|
||||
//! 1. Extract & store: experience → FTS5 write
|
||||
//! 2. Retrieve & inject: FTS5 search → memory found
|
||||
//! 3. Dedup: same experience not duplicated (reuse_count incremented)
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_growth::{
|
||||
ExperienceStore, Experience, VikingAdapter,
|
||||
storage::SqliteStorage,
|
||||
};
|
||||
|
||||
async fn test_store() -> ExperienceStore {
|
||||
let sqlite = SqliteStorage::in_memory().await;
|
||||
let viking = Arc::new(VikingAdapter::new(Arc::new(sqlite)));
|
||||
ExperienceStore::new(viking)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seam 1: Extract & Store — experience written to FTS5
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn seam_experience_store_and_retrieve() {
|
||||
let store = test_store().await;
|
||||
|
||||
let exp = Experience::new(
|
||||
"agent-001",
|
||||
"高 CPU 使用率告警频繁",
|
||||
"生产环境 CPU 使用率告警",
|
||||
vec!["检查进程列表".to_string(), "重启服务".to_string()],
|
||||
"已解决",
|
||||
);
|
||||
|
||||
store.store_experience(&exp).await.expect("store experience");
|
||||
|
||||
// Retrieve by agent
|
||||
let found = store.find_by_agent("agent-001").await.expect("find");
|
||||
assert_eq!(found.len(), 1, "should find exactly one experience");
|
||||
assert_eq!(found[0].pain_pattern, "高 CPU 使用率告警频繁");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seam 2: Retrieve by pattern — FTS5 search finds relevant experiences
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn seam_experience_pattern_search() {
|
||||
let store = test_store().await;
|
||||
|
||||
// Store multiple experiences
|
||||
let exp1 = Experience::new(
|
||||
"agent-001",
|
||||
"数据库连接超时",
|
||||
"PostgreSQL 连接池耗尽",
|
||||
vec!["增加连接池大小".to_string()],
|
||||
"已解决",
|
||||
);
|
||||
let exp2 = Experience::new(
|
||||
"agent-001",
|
||||
"前端白屏问题",
|
||||
"React 渲染错误",
|
||||
vec!["检查错误边界".to_string()],
|
||||
"已修复",
|
||||
);
|
||||
|
||||
store.store_experience(&exp1).await.expect("store exp1");
|
||||
store.store_experience(&exp2).await.expect("store exp2");
|
||||
|
||||
// Search for database-related experience
|
||||
let results = store.find_by_pattern("agent-001", "数据库 连接").await.expect("search");
|
||||
assert!(!results.is_empty(), "FTS5 should find database experience");
|
||||
assert!(
|
||||
results.iter().any(|e| e.pain_pattern.contains("数据库")),
|
||||
"should match database experience, got: {:?}",
|
||||
results
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seam 3: Dedup — same pain_pattern increments reuse_count
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn seam_experience_dedup() {
|
||||
let store = test_store().await;
|
||||
|
||||
let exp = Experience::new(
|
||||
"agent-001",
|
||||
"内存泄漏检测",
|
||||
"服务运行一段时间后内存持续增长",
|
||||
vec!["分析 heap dump".to_string()],
|
||||
"已修复",
|
||||
);
|
||||
|
||||
// Store twice with same agent_id and pain_pattern
|
||||
store.store_experience(&exp).await.expect("first store");
|
||||
store.store_experience(&exp).await.expect("second store (dedup)");
|
||||
|
||||
let all = store.find_by_agent("agent-001").await.expect("find");
|
||||
assert_eq!(all.len(), 1, "dedup should keep only one experience");
|
||||
assert!(
|
||||
all[0].reuse_count >= 1,
|
||||
"reuse_count should be incremented, got: {}",
|
||||
all[0].reuse_count
|
||||
);
|
||||
}
|
||||
143
crates/zclaw-growth/tests/memory_embedding_test.rs
Normal file
143
crates/zclaw-growth/tests/memory_embedding_test.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
//! Memory embedding tests (EM-07 ~ EM-08)
|
||||
//!
|
||||
//! Validates memory retrieval with embedding enhancement and configuration hot-update.
|
||||
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use zclaw_growth::{
|
||||
EmbeddingClient, MemoryEntry, MemoryRetriever, MemoryType, SqliteStorage, VikingAdapter,
|
||||
};
|
||||
use zclaw_types::AgentId;
|
||||
|
||||
/// Mock embedding client that returns deterministic 128-dim vectors.
|
||||
struct MockEmbeddingClient {
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl MockEmbeddingClient {
|
||||
fn new() -> Self {
|
||||
Self { dim: 128 }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl EmbeddingClient for MockEmbeddingClient {
|
||||
async fn embed(&self, text: &str) -> Result<Vec<f32>, String> {
|
||||
let mut vec = vec![0.0f32; self.dim];
|
||||
for (i, b) in text.as_bytes().iter().enumerate() {
|
||||
vec[i % self.dim] += (*b as f32) / 255.0;
|
||||
}
|
||||
let norm: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt().max(1e-8);
|
||||
for v in vec.iter_mut() {
|
||||
*v /= norm;
|
||||
}
|
||||
Ok(vec)
|
||||
}
|
||||
|
||||
fn is_available(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// EM-07: Memory retrieval with embedding enhancement.
|
||||
#[tokio::test]
|
||||
async fn em07_memory_retrieval_embedding_enhanced() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
|
||||
let agent_id = AgentId::new();
|
||||
|
||||
// Store 20 mixed Chinese/English memories
|
||||
let entries = vec![
|
||||
("pref-theme", MemoryType::Preference, "用户偏好深色模式"),
|
||||
("pref-language", MemoryType::Preference, "用户使用中文沟通"),
|
||||
("know-rust", MemoryType::Knowledge, "Rust async programming with tokio"),
|
||||
("know-python", MemoryType::Knowledge, "Python data science with pandas"),
|
||||
("exp-report", MemoryType::Experience, "月度报表生成经验:使用Excel宏自动化"),
|
||||
("know-react", MemoryType::Knowledge, "React hooks patterns"),
|
||||
("pref-editor", MemoryType::Preference, "偏好 VS Code 编辑器"),
|
||||
("exp-schedule", MemoryType::Experience, "排班冲突解决方案:协商调换"),
|
||||
("know-sql", MemoryType::Knowledge, "SQL query optimization techniques"),
|
||||
("exp-deploy", MemoryType::Experience, "部署失败经验:端口冲突检测"),
|
||||
("know-docker", MemoryType::Knowledge, "Docker container networking"),
|
||||
("pref-font", MemoryType::Preference, "字体大小偏好 14px"),
|
||||
("know-tokio", MemoryType::Knowledge, "Tokio runtime configuration"),
|
||||
("exp-review", MemoryType::Experience, "代码审查经验:关注错误处理"),
|
||||
("know-git", MemoryType::Knowledge, "Git rebase vs merge strategies"),
|
||||
("exp-perf", MemoryType::Experience, "性能优化经验:数据库索引"),
|
||||
("pref-timezone", MemoryType::Preference, "时区 UTC+8"),
|
||||
("know-linux", MemoryType::Knowledge, "Linux system administration basics"),
|
||||
("exp-test", MemoryType::Experience, "测试经验:TDD方法论实践"),
|
||||
("know-api", MemoryType::Knowledge, "RESTful API design principles"),
|
||||
];
|
||||
|
||||
for (key, mtype, content) in &entries {
|
||||
let entry = MemoryEntry::new(
|
||||
&agent_id.to_string(),
|
||||
*mtype,
|
||||
key,
|
||||
content.to_string(),
|
||||
);
|
||||
adapter.store(&entry).await.unwrap();
|
||||
}
|
||||
|
||||
// Create retriever with embedding
|
||||
let retriever = MemoryRetriever::new(adapter);
|
||||
retriever.set_embedding_client(Arc::new(MockEmbeddingClient::new()));
|
||||
|
||||
// Retrieve memories about user preferences
|
||||
let result = retriever
|
||||
.retrieve(&agent_id, "我之前说过什么偏好?")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let total =
|
||||
result.knowledge.len() + result.preferences.len() + result.experience.len();
|
||||
assert!(
|
||||
total > 0,
|
||||
"embedding-enhanced retrieval should find memories"
|
||||
);
|
||||
|
||||
assert!(
|
||||
result.preferences.len() > 0,
|
||||
"should find preference memories"
|
||||
);
|
||||
}
|
||||
|
||||
/// EM-08: Embedding configuration hot update — no panic, no disruption.
|
||||
#[tokio::test]
|
||||
async fn em08_embedding_hot_update() {
|
||||
let storage = Arc::new(SqliteStorage::in_memory().await);
|
||||
let adapter = Arc::new(VikingAdapter::new(storage));
|
||||
|
||||
let agent_id = AgentId::new();
|
||||
|
||||
// Store a memory
|
||||
let entry = MemoryEntry::new(
|
||||
&agent_id.to_string(),
|
||||
MemoryType::Knowledge,
|
||||
"rust-async",
|
||||
"Tokio runtime uses work-stealing scheduler".to_string(),
|
||||
);
|
||||
adapter.store(&entry).await.unwrap();
|
||||
|
||||
// Start without embedding
|
||||
let retriever = MemoryRetriever::new(adapter);
|
||||
|
||||
// Retrieve without embedding — should not panic
|
||||
let _result_before = retriever
|
||||
.retrieve(&agent_id, "async runtime")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Hot-update with embedding — should not disrupt ongoing operations
|
||||
retriever.set_embedding_client(Arc::new(MockEmbeddingClient::new()));
|
||||
|
||||
// Retrieve with embedding — should not panic
|
||||
let _result_after = retriever
|
||||
.retrieve(&agent_id, "async runtime")
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Key assertion: hot-update does not panic or disrupt
|
||||
}
|
||||
59
crates/zclaw-growth/tests/smoke_memory.rs
Normal file
59
crates/zclaw-growth/tests/smoke_memory.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
//! Memory smoke test — full lifecycle: store → retrieve → dedup
|
||||
//!
|
||||
//! Uses in-memory SqliteStorage with real FTS5.
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_growth::{
|
||||
ExperienceStore, Experience, VikingAdapter,
|
||||
storage::SqliteStorage,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn smoke_memory_full_lifecycle() {
|
||||
let sqlite = SqliteStorage::in_memory().await;
|
||||
let viking = Arc::new(VikingAdapter::new(Arc::new(sqlite)));
|
||||
let store = ExperienceStore::new(viking);
|
||||
|
||||
// 1. Store first experience
|
||||
let exp1 = Experience::new(
|
||||
"agent-smoke",
|
||||
"用户反馈页面加载缓慢",
|
||||
"前端性能问题,首屏加载超 5 秒",
|
||||
vec![
|
||||
"分析 Network 瀑布图".to_string(),
|
||||
"启用代码分割".to_string(),
|
||||
"配置 CDN".to_string(),
|
||||
],
|
||||
"首屏加载降至 1.2 秒",
|
||||
);
|
||||
store.store_experience(&exp1).await.expect("store exp1");
|
||||
|
||||
// 2. Store second experience (different topic)
|
||||
let exp2 = Experience::new(
|
||||
"agent-smoke",
|
||||
"数据库查询缓慢",
|
||||
"订单列表查询超时",
|
||||
vec!["添加复合索引".to_string()],
|
||||
"查询时间从 3s 降至 50ms",
|
||||
);
|
||||
store.store_experience(&exp2).await.expect("store exp2");
|
||||
|
||||
// 3. Retrieve by agent — should find both
|
||||
let all = store.find_by_agent("agent-smoke").await.expect("find by agent");
|
||||
assert_eq!(all.len(), 2, "should have 2 experiences");
|
||||
|
||||
// 4. Search by pattern — should find relevant one
|
||||
let db_results = store.find_by_pattern("agent-smoke", "数据库 查询 缓慢").await.expect("search");
|
||||
assert!(!db_results.is_empty(), "FTS5 should find database experience");
|
||||
assert!(
|
||||
db_results.iter().any(|e| e.pain_pattern.contains("数据库")),
|
||||
"should match database experience"
|
||||
);
|
||||
|
||||
// 5. Dedup — store same experience again
|
||||
store.store_experience(&exp1).await.expect("dedup store");
|
||||
let all_after_dedup = store.find_by_agent("agent-smoke").await.expect("find after dedup");
|
||||
assert_eq!(all_after_dedup.len(), 2, "should still have 2 after dedup");
|
||||
let deduped = all_after_dedup.iter().find(|e| e.pain_pattern.contains("页面加载")).unwrap();
|
||||
assert!(deduped.reuse_count >= 1, "reuse_count should be incremented");
|
||||
}
|
||||
@@ -20,4 +20,7 @@ thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
url = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//! Browser Hand - Web automation capabilities (TypeScript delegation)
|
||||
//!
|
||||
//! **Architecture note (M3-02):** This Rust Hand is a **schema validator and passthrough**.
|
||||
//! Every action returns `{"status": "pending_execution"}` — no real browser work happens here.
|
||||
//! Every action returns `{"status": "delegated_to_frontend"}` — no real browser work happens here.
|
||||
//!
|
||||
//! The actual execution path is:
|
||||
//! 1. Frontend `HandsPanel.tsx` intercepts browser hands → routes to `BrowserHandCard`
|
||||
@@ -117,6 +117,56 @@ pub enum BrowserAction {
|
||||
},
|
||||
}
|
||||
|
||||
impl BrowserAction {
|
||||
pub fn action_name(&self) -> &'static str {
|
||||
match self {
|
||||
BrowserAction::Navigate { .. } => "navigate",
|
||||
BrowserAction::Click { .. } => "click",
|
||||
BrowserAction::Type { .. } => "type",
|
||||
BrowserAction::Select { .. } => "select",
|
||||
BrowserAction::Scrape { .. } => "scrape",
|
||||
BrowserAction::Screenshot { .. } => "screenshot",
|
||||
BrowserAction::FillForm { .. } => "fill_form",
|
||||
BrowserAction::Wait { .. } => "wait",
|
||||
BrowserAction::Execute { .. } => "execute",
|
||||
BrowserAction::GetSource => "get_source",
|
||||
BrowserAction::GetUrl => "get_url",
|
||||
BrowserAction::GetTitle => "get_title",
|
||||
BrowserAction::Scroll { .. } => "scroll",
|
||||
BrowserAction::Back => "back",
|
||||
BrowserAction::Forward => "forward",
|
||||
BrowserAction::Refresh => "refresh",
|
||||
BrowserAction::Hover { .. } => "hover",
|
||||
BrowserAction::PressKey { .. } => "press_key",
|
||||
BrowserAction::Upload { .. } => "upload",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn summary(&self) -> String {
|
||||
match self {
|
||||
BrowserAction::Navigate { url, .. } => format!("导航到 {}", url),
|
||||
BrowserAction::Click { selector, .. } => format!("点击 {}", selector),
|
||||
BrowserAction::Type { selector, text, .. } => format!("在 {} 输入 {}", selector, text),
|
||||
BrowserAction::Select { selector, value } => format!("在 {} 选择 {}", selector, value),
|
||||
BrowserAction::Scrape { selectors, .. } => format!("抓取 {} 个选择器", selectors.len()),
|
||||
BrowserAction::Screenshot { .. } => "截图".to_string(),
|
||||
BrowserAction::FillForm { fields, .. } => format!("填写 {} 个字段", fields.len()),
|
||||
BrowserAction::Wait { selector, .. } => format!("等待 {}", selector),
|
||||
BrowserAction::Execute { .. } => "执行脚本".to_string(),
|
||||
BrowserAction::GetSource => "获取页面源码".to_string(),
|
||||
BrowserAction::GetUrl => "获取当前URL".to_string(),
|
||||
BrowserAction::GetTitle => "获取页面标题".to_string(),
|
||||
BrowserAction::Scroll { x, y, .. } => format!("滚动到 ({},{})", x, y),
|
||||
BrowserAction::Back => "后退".to_string(),
|
||||
BrowserAction::Forward => "前进".to_string(),
|
||||
BrowserAction::Refresh => "刷新".to_string(),
|
||||
BrowserAction::Hover { selector } => format!("悬停 {}", selector),
|
||||
BrowserAction::PressKey { key } => format!("按键 {}", key),
|
||||
BrowserAction::Upload { selector, .. } => format!("上传文件到 {}", selector),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Form field definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FormField {
|
||||
@@ -196,157 +246,30 @@ impl Hand for BrowserHand {
|
||||
}
|
||||
|
||||
async fn execute(&self, _context: &HandContext, input: Value) -> Result<HandResult> {
|
||||
// Parse the action
|
||||
let action: BrowserAction = match serde_json::from_value(input) {
|
||||
Ok(a) => a,
|
||||
Err(e) => return Ok(HandResult::error(format!("Invalid action: {}", e))),
|
||||
};
|
||||
|
||||
// Execute based on action type
|
||||
// Note: Actual browser operations are handled via Tauri commands
|
||||
// This Hand provides a structured interface for the runtime
|
||||
match action {
|
||||
BrowserAction::Navigate { url, wait_for } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "navigate",
|
||||
"url": url,
|
||||
"wait_for": wait_for,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Click { selector, wait_ms } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "click",
|
||||
"selector": selector,
|
||||
"wait_ms": wait_ms,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Type { selector, text, clear_first } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "type",
|
||||
"selector": selector,
|
||||
"text": text,
|
||||
"clear_first": clear_first,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Scrape { selectors, wait_for } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "scrape",
|
||||
"selectors": selectors,
|
||||
"wait_for": wait_for,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Screenshot { selector, full_page } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "screenshot",
|
||||
"selector": selector,
|
||||
"full_page": full_page,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::FillForm { fields, submit_selector } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "fill_form",
|
||||
"fields": fields,
|
||||
"submit_selector": submit_selector,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Wait { selector, timeout_ms } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "wait",
|
||||
"selector": selector,
|
||||
"timeout_ms": timeout_ms,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Execute { script, args } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "execute",
|
||||
"script": script,
|
||||
"args": args,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::GetSource => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "get_source",
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::GetUrl => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "get_url",
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::GetTitle => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "get_title",
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Scroll { x, y, selector } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "scroll",
|
||||
"x": x,
|
||||
"y": y,
|
||||
"selector": selector,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Back => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "back",
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Forward => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "forward",
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Refresh => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "refresh",
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Hover { selector } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "hover",
|
||||
"selector": selector,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::PressKey { key } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "press_key",
|
||||
"key": key,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Upload { selector, file_path } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "upload",
|
||||
"selector": selector,
|
||||
"file_path": file_path,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
BrowserAction::Select { selector, value } => {
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": "select",
|
||||
"selector": selector,
|
||||
"value": value,
|
||||
"status": "pending_execution"
|
||||
})))
|
||||
}
|
||||
let action_type = action.action_name();
|
||||
let summary = action.summary();
|
||||
|
||||
// Check if WebDriver is available
|
||||
if !self.check_webdriver() {
|
||||
return Ok(HandResult::error(format!(
|
||||
"浏览器操作「{}」无法执行:未检测到 WebDriver (ChromeDriver/GeckoDriver)。请先启动 WebDriver 服务。",
|
||||
summary
|
||||
)));
|
||||
}
|
||||
|
||||
// WebDriver is running — delegate to frontend BrowserHandCard.
|
||||
// The frontend manages the Fantoccini session lifecycle.
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"action": action_type,
|
||||
"status": "delegated_to_frontend",
|
||||
"message": format!("浏览器操作「{}」已发送到前端执行。WebDriver 已就绪。", summary),
|
||||
"details": format!("{} — 由前端 BrowserHandCard 通过 Fantoccini 执行。", summary),
|
||||
})))
|
||||
}
|
||||
|
||||
fn is_dependency_available(&self, dep: &str) -> bool {
|
||||
@@ -595,12 +518,16 @@ mod tests {
|
||||
assert!(!sequence.stop_on_error);
|
||||
assert_eq!(sequence.steps.len(), 1);
|
||||
|
||||
// Execute the navigate step
|
||||
// Execute the navigate step — without WebDriver running, should report error
|
||||
let action_json = serde_json::to_value(&sequence.steps[0]).expect("serialize step");
|
||||
let result = hand.execute(&ctx, action_json).await.expect("execute");
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output["action"], "navigate");
|
||||
assert_eq!(result.output["url"], "https://example.com");
|
||||
// In test env no WebDriver is running, so we get an error about missing WebDriver
|
||||
if result.success {
|
||||
assert_eq!(result.output["action"], "navigate");
|
||||
assert_eq!(result.output["status"], "delegated_to_frontend");
|
||||
} else {
|
||||
assert!(result.error.as_deref().unwrap_or("").contains("WebDriver"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -616,11 +543,18 @@ mod tests {
|
||||
|
||||
assert_eq!(sequence.steps.len(), 4);
|
||||
|
||||
// Verify each step can execute
|
||||
// Verify each step can parse and execute (or report missing WebDriver)
|
||||
for (i, step) in sequence.steps.iter().enumerate() {
|
||||
let action_json = serde_json::to_value(step).expect("serialize step");
|
||||
let result = hand.execute(&ctx, action_json).await.expect("execute step");
|
||||
assert!(result.success, "Step {} failed: {:?}", i, result.error);
|
||||
// Without WebDriver, all steps should report the error cleanly
|
||||
if !result.success {
|
||||
assert!(
|
||||
result.error.as_deref().unwrap_or("").contains("WebDriver"),
|
||||
"Step {} unexpected error: {:?}",
|
||||
i, result.error
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
244
crates/zclaw-hands/src/hands/daily_report.rs
Normal file
244
crates/zclaw-hands/src/hands/daily_report.rs
Normal file
@@ -0,0 +1,244 @@
|
||||
//! Daily Report Hand — generates a personalized daily briefing.
|
||||
//!
|
||||
//! System hand (`_daily_report`) triggered by SchedulerService at 09:00 cron.
|
||||
//! Produces a Markdown daily report containing:
|
||||
//! 1. Yesterday's conversation summary
|
||||
//! 2. Unresolved pain points follow-up
|
||||
//! 3. Recent experience highlights
|
||||
//! 4. Industry-specific daily reminder
|
||||
//!
|
||||
//! The caller (SchedulerService or Tauri command) is responsible for:
|
||||
//! - Assembling input data (trajectory summary, pain points, experiences)
|
||||
//! - Emitting `daily-report:ready` Tauri event after execution
|
||||
//! - Persisting the report to VikingStorage
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use crate::{Hand, HandConfig, HandContext, HandResult, HandStatus};
|
||||
|
||||
/// Internal daily report hand.
|
||||
pub struct DailyReportHand {
|
||||
config: HandConfig,
|
||||
}
|
||||
|
||||
impl DailyReportHand {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "_daily_report".to_string(),
|
||||
name: "管家日报".to_string(),
|
||||
description: "Generates personalized daily briefing".to_string(),
|
||||
needs_approval: false,
|
||||
dependencies: vec![],
|
||||
input_schema: None,
|
||||
tags: vec!["system".to_string()],
|
||||
enabled: true,
|
||||
max_concurrent: 0,
|
||||
timeout_secs: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Hand for DailyReportHand {
|
||||
fn config(&self) -> &HandConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
async fn execute(&self, _context: &HandContext, input: Value) -> Result<HandResult> {
|
||||
let agent_id = input
|
||||
.get("agent_id")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("default_user");
|
||||
|
||||
let industry = input
|
||||
.get("industry")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
let trajectory_summary = input
|
||||
.get("trajectory_summary")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("昨日无对话记录");
|
||||
|
||||
let pain_points = input
|
||||
.get("pain_points")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let recent_experiences = input
|
||||
.get("recent_experiences")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|v| v.as_str().map(String::from))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let report = self.build_report(industry, trajectory_summary, &pain_points, &recent_experiences);
|
||||
|
||||
tracing::info!(
|
||||
"[DailyReportHand] Generated report for agent {} ({} pains, {} experiences)",
|
||||
agent_id,
|
||||
pain_points.len(),
|
||||
recent_experiences.len(),
|
||||
);
|
||||
|
||||
Ok(HandResult::success(serde_json::json!({
|
||||
"agent_id": agent_id,
|
||||
"report": report,
|
||||
"pain_count": pain_points.len(),
|
||||
"experience_count": recent_experiences.len(),
|
||||
})))
|
||||
}
|
||||
|
||||
fn status(&self) -> HandStatus {
|
||||
HandStatus::Idle
|
||||
}
|
||||
}
|
||||
|
||||
impl DailyReportHand {
|
||||
fn build_report(
|
||||
&self,
|
||||
industry: &str,
|
||||
trajectory_summary: &str,
|
||||
pain_points: &[String],
|
||||
recent_experiences: &[String],
|
||||
) -> String {
|
||||
let industry_label = match industry {
|
||||
"healthcare" => "医疗行政",
|
||||
"education" => "教育培训",
|
||||
"garment" => "制衣制造",
|
||||
"ecommerce" => "电商零售",
|
||||
_ => "综合",
|
||||
};
|
||||
|
||||
let date = chrono::Utc::now().format("%Y年%m月%d日").to_string();
|
||||
|
||||
let mut sections = vec![
|
||||
format!("# {} 管家日报 — {}", industry_label, date),
|
||||
String::new(),
|
||||
"## 昨日对话摘要".to_string(),
|
||||
trajectory_summary.to_string(),
|
||||
String::new(),
|
||||
];
|
||||
|
||||
if !pain_points.is_empty() {
|
||||
sections.push("## 待解决问题".to_string());
|
||||
for (i, pain) in pain_points.iter().enumerate() {
|
||||
sections.push(format!("{}. {}", i + 1, pain));
|
||||
}
|
||||
sections.push(String::new());
|
||||
}
|
||||
|
||||
if !recent_experiences.is_empty() {
|
||||
sections.push("## 昨日收获".to_string());
|
||||
for exp in recent_experiences {
|
||||
sections.push(format!("- {}", exp));
|
||||
}
|
||||
sections.push(String::new());
|
||||
}
|
||||
|
||||
sections.push("## 今日提醒".to_string());
|
||||
sections.push(self.daily_reminder(industry));
|
||||
sections.push(String::new());
|
||||
sections.push("祝你今天工作顺利!".to_string());
|
||||
|
||||
sections.join("\n")
|
||||
}
|
||||
|
||||
fn daily_reminder(&self, industry: &str) -> String {
|
||||
match industry {
|
||||
"healthcare" => "记得检查今日科室排班,关注耗材库存预警。".to_string(),
|
||||
"education" => "今日有课程安排吗?提前准备教学材料。".to_string(),
|
||||
"garment" => "关注今日生产进度,及时跟进订单交期。".to_string(),
|
||||
"ecommerce" => "检查库存预警和待发货订单,把握促销节奏。".to_string(),
|
||||
_ => "新的一天,新的开始。有什么需要我帮忙的随时说。".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use zclaw_types::AgentId;
|
||||
|
||||
#[test]
|
||||
fn test_build_report_basic() {
|
||||
let hand = DailyReportHand::new();
|
||||
let report = hand.build_report(
|
||||
"healthcare",
|
||||
"讨论了科室排班问题",
|
||||
&["排班冲突".to_string()],
|
||||
&["学会了用数据报表工具".to_string()],
|
||||
);
|
||||
assert!(report.contains("医疗行政"));
|
||||
assert!(report.contains("排班冲突"));
|
||||
assert!(report.contains("学会了用数据报表工具"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_report_empty() {
|
||||
let hand = DailyReportHand::new();
|
||||
let report = hand.build_report("", "昨日无对话记录", &[], &[]);
|
||||
assert!(report.contains("管家日报"));
|
||||
assert!(report.contains("综合"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_report_all_industries() {
|
||||
let hand = DailyReportHand::new();
|
||||
for industry in &["healthcare", "education", "garment", "ecommerce", "unknown"] {
|
||||
let report = hand.build_report(industry, "test", &[], &[]);
|
||||
assert!(!report.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_with_data() {
|
||||
let hand = DailyReportHand::new();
|
||||
let ctx = HandContext {
|
||||
agent_id: AgentId::new(),
|
||||
working_dir: None,
|
||||
env: std::collections::HashMap::new(),
|
||||
timeout_secs: 30,
|
||||
callback_url: None,
|
||||
};
|
||||
let input = serde_json::json!({
|
||||
"agent_id": "test-agent",
|
||||
"industry": "education",
|
||||
"trajectory_summary": "讨论了课程安排",
|
||||
"pain_points": ["学生成绩下降"],
|
||||
"recent_experiences": ["掌握了成绩分析方法"],
|
||||
});
|
||||
|
||||
let result = hand.execute(&ctx, input).await.unwrap();
|
||||
assert!(result.success);
|
||||
let output = result.output;
|
||||
assert_eq!(output["agent_id"], "test-agent");
|
||||
assert!(output["report"].as_str().unwrap().contains("教育培训"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_execute_minimal() {
|
||||
let hand = DailyReportHand::new();
|
||||
let ctx = HandContext {
|
||||
agent_id: AgentId::new(),
|
||||
working_dir: None,
|
||||
env: std::collections::HashMap::new(),
|
||||
timeout_secs: 30,
|
||||
callback_url: None,
|
||||
};
|
||||
let result = hand.execute(&ctx, serde_json::json!({})).await.unwrap();
|
||||
assert!(result.success);
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ mod collector;
|
||||
mod clip;
|
||||
mod twitter;
|
||||
pub mod reminder;
|
||||
pub mod daily_report;
|
||||
|
||||
pub use quiz::*;
|
||||
pub use browser::*;
|
||||
@@ -23,3 +24,4 @@ pub use collector::*;
|
||||
pub use clip::*;
|
||||
pub use twitter::*;
|
||||
pub use reminder::*;
|
||||
pub use daily_report::*;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -191,6 +191,8 @@ pub enum TwitterAction {
|
||||
Following { user_id: String, max_results: Option<u32> },
|
||||
#[serde(rename = "check_credentials")]
|
||||
CheckCredentials,
|
||||
#[serde(rename = "set_credentials")]
|
||||
SetCredentials { credentials: TwitterCredentials },
|
||||
}
|
||||
|
||||
/// Twitter Hand implementation
|
||||
@@ -200,14 +202,83 @@ pub struct TwitterHand {
|
||||
}
|
||||
|
||||
impl TwitterHand {
|
||||
/// Credential file path relative to app data dir
|
||||
const CREDS_FILE_NAME: &'static str = "twitter-credentials.json";
|
||||
|
||||
/// Get the credentials file path
|
||||
fn creds_path() -> Option<std::path::PathBuf> {
|
||||
dirs::data_dir().map(|d| d.join("zclaw").join("hands").join(Self::CREDS_FILE_NAME))
|
||||
}
|
||||
|
||||
/// Load credentials from disk (silent — logs errors, returns None on failure)
|
||||
fn load_credentials_from_disk() -> Option<TwitterCredentials> {
|
||||
let path = Self::creds_path()?;
|
||||
if !path.exists() {
|
||||
return None;
|
||||
}
|
||||
match std::fs::read_to_string(&path) {
|
||||
Ok(data) => match serde_json::from_str(&data) {
|
||||
Ok(creds) => {
|
||||
tracing::info!("[TwitterHand] Loaded persisted credentials from {:?}", path);
|
||||
Some(creds)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[TwitterHand] Failed to parse credentials file: {}", e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!("[TwitterHand] Failed to read credentials file: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Save credentials to disk (best-effort, logs errors)
|
||||
fn save_credentials_to_disk(creds: &TwitterCredentials) {
|
||||
let path = match Self::creds_path() {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
tracing::warn!("[TwitterHand] Cannot determine credentials file path");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(parent) = path.parent() {
|
||||
if let Err(e) = std::fs::create_dir_all(parent) {
|
||||
tracing::warn!("[TwitterHand] Failed to create credentials dir: {}", e);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
match serde_json::to_string_pretty(creds) {
|
||||
Ok(data) => {
|
||||
if let Err(e) = std::fs::write(&path, data) {
|
||||
tracing::warn!("[TwitterHand] Failed to write credentials file: {}", e);
|
||||
} else {
|
||||
tracing::info!("[TwitterHand] Credentials persisted to {:?}", path);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[TwitterHand] Failed to serialize credentials: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Twitter hand
|
||||
pub fn new() -> Self {
|
||||
// Try to load persisted credentials
|
||||
let loaded = Self::load_credentials_from_disk();
|
||||
if loaded.is_some() {
|
||||
tracing::info!("[TwitterHand] Restored credentials from previous session");
|
||||
}
|
||||
|
||||
Self {
|
||||
config: HandConfig {
|
||||
id: "twitter".to_string(),
|
||||
name: "Twitter 自动化".to_string(),
|
||||
description: "Twitter/X 自动化能力,发布、搜索和管理内容".to_string(),
|
||||
needs_approval: true, // Twitter actions need approval
|
||||
needs_approval: true,
|
||||
dependencies: vec!["twitter_api_key".to_string()],
|
||||
input_schema: Some(serde_json::json!({
|
||||
"type": "object",
|
||||
@@ -275,12 +346,13 @@ impl TwitterHand {
|
||||
max_concurrent: 0,
|
||||
timeout_secs: 0,
|
||||
},
|
||||
credentials: Arc::new(RwLock::new(None)),
|
||||
credentials: Arc::new(RwLock::new(loaded)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set credentials
|
||||
/// Set credentials (also persists to disk)
|
||||
pub async fn set_credentials(&self, creds: TwitterCredentials) {
|
||||
Self::save_credentials_to_disk(&creds);
|
||||
let mut c = self.credentials.write().await;
|
||||
*c = Some(creds);
|
||||
}
|
||||
@@ -765,6 +837,13 @@ impl Hand for TwitterHand {
|
||||
TwitterAction::Followers { user_id, max_results } => self.execute_followers(&user_id, max_results).await?,
|
||||
TwitterAction::Following { user_id, max_results } => self.execute_following(&user_id, max_results).await?,
|
||||
TwitterAction::CheckCredentials => self.execute_check_credentials().await?,
|
||||
TwitterAction::SetCredentials { credentials } => {
|
||||
self.set_credentials(credentials).await;
|
||||
json!({
|
||||
"success": true,
|
||||
"message": "Twitter 凭据已设置并持久化。重启后自动恢复。"
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
@@ -785,9 +864,13 @@ impl Hand for TwitterHand {
|
||||
fn check_dependencies(&self) -> Result<Vec<String>> {
|
||||
let mut missing = Vec::new();
|
||||
|
||||
// Check if credentials are configured (synchronously)
|
||||
// This is a simplified check; actual async check would require runtime
|
||||
missing.push("Twitter API credentials required".to_string());
|
||||
// Synchronous check: if credentials were loaded from disk, dependency is met
|
||||
match self.credentials.try_read() {
|
||||
Ok(creds) if creds.is_some() => {},
|
||||
_ => {
|
||||
missing.push("Twitter API credentials required (use set_credentials action to configure)".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(missing)
|
||||
}
|
||||
@@ -1058,6 +1141,62 @@ mod tests {
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_credentials_action_deserialize() {
|
||||
let json = json!({
|
||||
"action": "set_credentials",
|
||||
"credentials": {
|
||||
"apiKey": "test-key",
|
||||
"apiSecret": "test-secret",
|
||||
"accessToken": "test-token",
|
||||
"accessTokenSecret": "test-token-secret",
|
||||
"bearerToken": "test-bearer"
|
||||
}
|
||||
});
|
||||
let action: TwitterAction = serde_json::from_value(json).unwrap();
|
||||
match action {
|
||||
TwitterAction::SetCredentials { credentials } => {
|
||||
assert_eq!(credentials.api_key, "test-key");
|
||||
assert_eq!(credentials.api_secret, "test-secret");
|
||||
assert_eq!(credentials.bearer_token, Some("test-bearer".to_string()));
|
||||
}
|
||||
_ => panic!("Expected SetCredentials"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_set_credentials_persists_and_restores() {
|
||||
// Use a temporary directory to avoid polluting real credentials
|
||||
let temp_dir = std::env::temp_dir().join("zclaw_test_twitter_creds");
|
||||
let _ = std::fs::create_dir_all(&temp_dir);
|
||||
|
||||
let hand = TwitterHand::new();
|
||||
|
||||
// Set credentials
|
||||
let creds = TwitterCredentials {
|
||||
api_key: "test-key".to_string(),
|
||||
api_secret: "test-secret".to_string(),
|
||||
access_token: "test-token".to_string(),
|
||||
access_token_secret: "test-secret".to_string(),
|
||||
bearer_token: Some("test-bearer".to_string()),
|
||||
};
|
||||
hand.set_credentials(creds.clone()).await;
|
||||
|
||||
// Verify in-memory
|
||||
let loaded = hand.get_credentials().await;
|
||||
assert!(loaded.is_some());
|
||||
assert_eq!(loaded.unwrap().api_key, "test-key");
|
||||
|
||||
// Verify file was written
|
||||
let path = TwitterHand::creds_path();
|
||||
assert!(path.is_some());
|
||||
let path = path.unwrap();
|
||||
assert!(path.exists(), "Credentials file should exist at {:?}", path);
|
||||
|
||||
// Clean up
|
||||
let _ = std::fs::remove_file(&path);
|
||||
}
|
||||
|
||||
// === Serialization Roundtrip ===
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -17,6 +17,7 @@ zclaw-runtime = { workspace = true }
|
||||
zclaw-protocols = { workspace = true }
|
||||
zclaw-hands = { workspace = true }
|
||||
zclaw-skills = { workspace = true }
|
||||
zclaw-growth = { workspace = true }
|
||||
|
||||
tokio = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use zclaw_runtime::{LlmDriver, tool::SkillExecutor};
|
||||
use zclaw_skills::{SkillRegistry, LlmCompleter};
|
||||
use zclaw_types::Result;
|
||||
use zclaw_runtime::{LlmDriver, tool::{SkillExecutor, HandExecutor}};
|
||||
use zclaw_skills::{SkillRegistry, LlmCompleter, SkillCompletion, SkillToolCall};
|
||||
use zclaw_hands::HandRegistry;
|
||||
use zclaw_types::{AgentId, Result, ToolDefinition};
|
||||
|
||||
/// Adapter that bridges `zclaw_runtime::LlmDriver` -> `zclaw_skills::LlmCompleter`
|
||||
pub(crate) struct LlmDriverAdapter {
|
||||
@@ -43,18 +44,111 @@ impl LlmCompleter for LlmDriverAdapter {
|
||||
Ok(text)
|
||||
})
|
||||
}
|
||||
|
||||
fn complete_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
system_prompt: Option<&str>,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Pin<Box<dyn std::future::Future<Output = std::result::Result<SkillCompletion, String>> + Send + '_>> {
|
||||
let driver = self.driver.clone();
|
||||
let prompt = prompt.to_string();
|
||||
let system = system_prompt.map(|s| s.to_string());
|
||||
let max_tokens = self.max_tokens;
|
||||
let temperature = self.temperature;
|
||||
Box::pin(async move {
|
||||
let mut messages = Vec::new();
|
||||
messages.push(zclaw_types::Message::user(prompt));
|
||||
|
||||
let request = zclaw_runtime::CompletionRequest {
|
||||
model: String::new(),
|
||||
system,
|
||||
messages,
|
||||
tools,
|
||||
max_tokens: Some(max_tokens),
|
||||
temperature: Some(temperature),
|
||||
stop: Vec::new(),
|
||||
stream: false,
|
||||
thinking_enabled: false,
|
||||
reasoning_effort: None,
|
||||
plan_mode: false,
|
||||
};
|
||||
let response = driver.complete(request).await
|
||||
.map_err(|e| format!("LLM completion error: {}", e))?;
|
||||
|
||||
let mut text_parts = Vec::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
for block in &response.content {
|
||||
match block {
|
||||
zclaw_runtime::ContentBlock::Text { text } => {
|
||||
text_parts.push(text.clone());
|
||||
}
|
||||
zclaw_runtime::ContentBlock::ToolUse { id, name, input } => {
|
||||
tool_calls.push(SkillToolCall {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input: input.clone(),
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(SkillCompletion {
|
||||
text: text_parts.join(""),
|
||||
tool_calls,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Skill executor implementation for Kernel
|
||||
pub struct KernelSkillExecutor {
|
||||
pub(crate) skills: Arc<SkillRegistry>,
|
||||
pub(crate) llm: Arc<dyn LlmCompleter>,
|
||||
/// Shared tool registry, updated before each skill execution from the
|
||||
/// agent loop's freshly-built registry. Uses std::sync because reads
|
||||
/// happen from async code but writes are brief and infrequent.
|
||||
pub(crate) tool_registry: std::sync::RwLock<Option<zclaw_runtime::ToolRegistry>>,
|
||||
}
|
||||
|
||||
impl KernelSkillExecutor {
|
||||
pub fn new(skills: Arc<SkillRegistry>, driver: Arc<dyn LlmDriver>) -> Self {
|
||||
let llm: Arc<dyn zclaw_skills::LlmCompleter> = Arc::new(LlmDriverAdapter { driver, max_tokens: 4096, temperature: 0.7 });
|
||||
Self { skills, llm }
|
||||
let llm: Arc<dyn LlmCompleter> = Arc::new(LlmDriverAdapter { driver, max_tokens: 4096, temperature: 0.7 });
|
||||
Self { skills, llm, tool_registry: std::sync::RwLock::new(None) }
|
||||
}
|
||||
|
||||
/// Update the tool registry snapshot. Called by the kernel before each
|
||||
/// agent loop iteration so skill execution sees the latest tool set.
|
||||
pub fn set_tool_registry(&self, registry: zclaw_runtime::ToolRegistry) {
|
||||
if let Ok(mut guard) = self.tool_registry.write() {
|
||||
*guard = Some(registry);
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve the tool definitions declared by a skill manifest against
|
||||
/// the currently active tool registry.
|
||||
fn resolve_tool_definitions(&self, skill_id: &str) -> Vec<ToolDefinition> {
|
||||
let manifests = self.skills.manifests_snapshot();
|
||||
let manifest = match manifests.get(&zclaw_types::SkillId::new(skill_id)) {
|
||||
Some(m) => m,
|
||||
None => return vec![],
|
||||
};
|
||||
if manifest.tools.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
let guard = match self.tool_registry.read() {
|
||||
Ok(g) => g,
|
||||
Err(_) => return vec![],
|
||||
};
|
||||
let registry = match guard.as_ref() {
|
||||
Some(r) => r,
|
||||
None => return vec![],
|
||||
};
|
||||
// Only include definitions for tools declared in the skill manifest.
|
||||
registry.definitions().into_iter()
|
||||
.filter(|def| manifest.tools.iter().any(|t| t == &def.name))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,10 +161,12 @@ impl SkillExecutor for KernelSkillExecutor {
|
||||
session_id: &str,
|
||||
input: Value,
|
||||
) -> Result<Value> {
|
||||
let tool_definitions = self.resolve_tool_definitions(skill_id);
|
||||
let context = zclaw_skills::SkillContext {
|
||||
agent_id: agent_id.to_string(),
|
||||
session_id: session_id.to_string(),
|
||||
llm: Some(self.llm.clone()),
|
||||
tool_definitions,
|
||||
..Default::default()
|
||||
};
|
||||
let result = self.skills.execute(&zclaw_types::SkillId::new(skill_id), &context, input).await?;
|
||||
@@ -134,3 +230,47 @@ impl AgentInbox {
|
||||
self.pending.push_back(envelope);
|
||||
}
|
||||
}
|
||||
|
||||
/// Hand executor implementation for Kernel
|
||||
///
|
||||
/// Bridges `zclaw_runtime::tool::HandExecutor` → `zclaw_hands::HandRegistry`,
|
||||
/// allowing `HandTool::execute()` to dispatch to the real Hand implementations.
|
||||
pub struct KernelHandExecutor {
|
||||
hands: Arc<HandRegistry>,
|
||||
}
|
||||
|
||||
impl KernelHandExecutor {
|
||||
pub fn new(hands: Arc<HandRegistry>) -> Self {
|
||||
Self { hands }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl HandExecutor for KernelHandExecutor {
|
||||
async fn execute_hand(
|
||||
&self,
|
||||
hand_id: &str,
|
||||
agent_id: &AgentId,
|
||||
input: Value,
|
||||
) -> Result<Value> {
|
||||
let context = zclaw_hands::HandContext {
|
||||
agent_id: agent_id.clone(),
|
||||
working_dir: None,
|
||||
env: std::collections::HashMap::new(),
|
||||
timeout_secs: 300,
|
||||
callback_url: None,
|
||||
};
|
||||
let result = self.hands.execute(hand_id, &context, input).await?;
|
||||
if result.success {
|
||||
Ok(result.output)
|
||||
} else {
|
||||
Ok(json!({
|
||||
"hand_id": hand_id,
|
||||
"status": "failed",
|
||||
"error": result.error.unwrap_or_else(|| "Unknown hand execution error".to_string()),
|
||||
"output": result.output,
|
||||
"duration_ms": result.duration_ms,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
120
crates/zclaw-kernel/src/kernel/evolution_bridge.rs
Normal file
120
crates/zclaw-kernel/src/kernel/evolution_bridge.rs
Normal file
@@ -0,0 +1,120 @@
|
||||
//! Evolution Bridge — connects growth crate's SkillCandidate to skills crate's SkillManifest
|
||||
//!
|
||||
//! The growth crate (zclaw-growth) generates SkillCandidate from conversation patterns.
|
||||
//! The skills crate (zclaw-skills) requires SkillManifest for disk persistence.
|
||||
//! This bridge lives in zclaw-kernel because it depends on both crates.
|
||||
|
||||
use zclaw_growth::skill_generator::SkillCandidate;
|
||||
use zclaw_skills::{SkillManifest, SkillMode};
|
||||
use zclaw_types::SkillId;
|
||||
|
||||
/// Convert a validated SkillCandidate into a SkillManifest ready for registration.
|
||||
///
|
||||
/// Safety invariants:
|
||||
/// - `mode` is always `PromptOnly` (auto-generated skills cannot execute code)
|
||||
/// - `enabled` is `false` (requires one explicit positive feedback to activate)
|
||||
/// - `body_markdown` is stored in `manifest.body` and persisted by `serialize_skill_md`
|
||||
pub fn candidate_to_manifest(candidate: &SkillCandidate) -> SkillManifest {
|
||||
let slug = name_to_slug(&candidate.name);
|
||||
|
||||
SkillManifest {
|
||||
id: SkillId::new(format!("auto-{}", slug)),
|
||||
name: candidate.name.clone(),
|
||||
description: candidate.description.clone(),
|
||||
version: format!("{}", candidate.version),
|
||||
author: Some("zclaw-evolution".to_string()),
|
||||
mode: SkillMode::PromptOnly,
|
||||
capabilities: Vec::new(),
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
tags: vec!["auto-generated".to_string()],
|
||||
category: None,
|
||||
triggers: candidate.triggers.clone(),
|
||||
tools: candidate.tools.clone(),
|
||||
enabled: false,
|
||||
body: Some(candidate.body_markdown.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a human-readable name to a URL-safe slug.
|
||||
fn name_to_slug(name: &str) -> String {
|
||||
let mut result = String::new();
|
||||
for c in name.trim().chars() {
|
||||
if c.is_ascii_alphanumeric() {
|
||||
result.push(c.to_ascii_lowercase());
|
||||
} else if c == ' ' || c == '-' || c == '_' {
|
||||
result.push('-');
|
||||
} else {
|
||||
// Chinese/unicode characters: use hex representation
|
||||
result.push_str(&format!("{:x}", c as u32));
|
||||
}
|
||||
}
|
||||
let slug = result.trim_matches('-').to_string();
|
||||
if slug.is_empty() {
|
||||
// Fallback for empty or whitespace-only names
|
||||
format!("skill-{}", &uuid::Uuid::new_v4().to_string()[..8])
|
||||
} else {
|
||||
slug
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_candidate() -> SkillCandidate {
|
||||
SkillCandidate {
|
||||
name: "每日报表".to_string(),
|
||||
description: "生成每日报表".to_string(),
|
||||
triggers: vec!["报表".to_string(), "日报".to_string()],
|
||||
tools: vec!["researcher".to_string()],
|
||||
body_markdown: "# 每日报表\n步骤1\n步骤2".to_string(),
|
||||
source_pattern: "报表生成".to_string(),
|
||||
confidence: 0.85,
|
||||
version: 1,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_candidate_to_manifest() {
|
||||
let candidate = make_candidate();
|
||||
let manifest = candidate_to_manifest(&candidate);
|
||||
|
||||
assert!(manifest.id.as_str().starts_with("auto-"));
|
||||
assert_eq!(manifest.name, "每日报表");
|
||||
assert_eq!(manifest.description, "生成每日报表");
|
||||
assert_eq!(manifest.version, "1");
|
||||
assert_eq!(manifest.author.as_deref(), Some("zclaw-evolution"));
|
||||
assert_eq!(manifest.mode, SkillMode::PromptOnly);
|
||||
assert!(!manifest.enabled, "auto-generated skills must start disabled");
|
||||
assert_eq!(manifest.triggers, candidate.triggers);
|
||||
assert_eq!(manifest.tools, candidate.tools);
|
||||
assert!(manifest.tags.contains(&"auto-generated".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_name_to_slug_ascii() {
|
||||
assert_eq!(name_to_slug("Daily Report"), "daily-report");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_name_to_slug_chinese() {
|
||||
let slug = name_to_slug("每日报表");
|
||||
assert!(!slug.is_empty());
|
||||
assert!(!slug.contains(' '));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_generated_always_prompt_only() {
|
||||
let candidate = make_candidate();
|
||||
let manifest = candidate_to_manifest(&candidate);
|
||||
assert_eq!(manifest.mode, SkillMode::PromptOnly);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_generated_starts_disabled() {
|
||||
let candidate = make_candidate();
|
||||
let manifest = candidate_to_manifest(&candidate);
|
||||
assert!(!manifest.enabled);
|
||||
}
|
||||
}
|
||||
@@ -13,7 +13,113 @@ pub struct ChatModeConfig {
|
||||
pub subagent_enabled: Option<bool>,
|
||||
}
|
||||
|
||||
use zclaw_runtime::{AgentLoop, tool::builtin::PathValidator};
|
||||
/// Result of a successful schedule intent interception.
|
||||
pub struct ScheduleInterceptResult {
|
||||
/// Pre-built streaming receiver with confirmation message.
|
||||
pub rx: mpsc::Receiver<zclaw_runtime::LoopEvent>,
|
||||
/// Human-readable task description.
|
||||
pub task_description: String,
|
||||
/// Natural language description of the schedule.
|
||||
pub natural_description: String,
|
||||
/// Cron expression.
|
||||
pub cron_expression: String,
|
||||
}
|
||||
|
||||
impl Kernel {
|
||||
/// Try to intercept a schedule intent from the user's message.
|
||||
///
|
||||
/// If the message contains a clear schedule intent (e.g., "每天早上9点提醒我查房"),
|
||||
/// parse it, create a trigger, and return a streaming receiver with the
|
||||
/// confirmation message. Returns `Ok(None)` if no interception occurred.
|
||||
pub async fn try_intercept_schedule(
|
||||
&self,
|
||||
message: &str,
|
||||
agent_id: &AgentId,
|
||||
) -> Result<Option<ScheduleInterceptResult>> {
|
||||
if !zclaw_runtime::nl_schedule::has_schedule_intent(message) {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let parse_result = zclaw_runtime::nl_schedule::parse_nl_schedule(message, agent_id);
|
||||
|
||||
match parse_result {
|
||||
zclaw_runtime::nl_schedule::ScheduleParseResult::Exact(ref parsed)
|
||||
if parsed.confidence >= 0.8 =>
|
||||
{
|
||||
let trigger_id = format!(
|
||||
"sched_{}_{}",
|
||||
chrono::Utc::now().timestamp_millis(),
|
||||
&uuid::Uuid::new_v4().to_string()[..8]
|
||||
);
|
||||
let trigger_config = zclaw_hands::TriggerConfig {
|
||||
id: trigger_id.clone(),
|
||||
name: parsed.task_description.clone(),
|
||||
hand_id: "_reminder".to_string(),
|
||||
trigger_type: zclaw_hands::TriggerType::Schedule {
|
||||
cron: parsed.cron_expression.clone(),
|
||||
},
|
||||
enabled: true,
|
||||
max_executions_per_hour: 60,
|
||||
};
|
||||
|
||||
match self.create_trigger(trigger_config).await {
|
||||
Ok(_entry) => {
|
||||
tracing::info!(
|
||||
"[Kernel] Schedule trigger created: {} (cron: {})",
|
||||
trigger_id, parsed.cron_expression
|
||||
);
|
||||
let confirm_msg = format!(
|
||||
"已为您设置定时任务:\n\n- **任务**:{}\n- **时间**:{}\n- **Cron**:`{}`\n\n任务已激活,将在设定时间自动执行。",
|
||||
parsed.task_description,
|
||||
parsed.natural_description,
|
||||
parsed.cron_expression,
|
||||
);
|
||||
|
||||
let (tx, rx) = mpsc::channel(32);
|
||||
if tx.send(zclaw_runtime::LoopEvent::Delta(confirm_msg)).await.is_err() {
|
||||
tracing::warn!("[Kernel] Failed to send confirm msg to channel — falling through to LLM");
|
||||
return Ok(None);
|
||||
}
|
||||
if tx.send(zclaw_runtime::LoopEvent::Complete(
|
||||
zclaw_runtime::AgentLoopResult {
|
||||
response: String::new(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
iterations: 1,
|
||||
}
|
||||
)).await.is_err() {
|
||||
tracing::warn!("[Kernel] Failed to send complete to channel");
|
||||
}
|
||||
drop(tx);
|
||||
|
||||
Ok(Some(ScheduleInterceptResult {
|
||||
rx,
|
||||
task_description: parsed.task_description.clone(),
|
||||
natural_description: parsed.natural_description.clone(),
|
||||
cron_expression: parsed.cron_expression.clone(),
|
||||
}))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"[Kernel] Failed to create schedule trigger, falling through to LLM: {}", e
|
||||
);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
tracing::debug!(
|
||||
"[Kernel] Schedule intent detected but not confident enough, falling through to LLM"
|
||||
);
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_runtime::{AgentLoop, LlmDriver, tool::builtin::PathValidator};
|
||||
use zclaw_runtime::driver::{RetryDriver, RetryConfig};
|
||||
|
||||
use super::Kernel;
|
||||
use super::super::MessageResponse;
|
||||
@@ -56,14 +162,19 @@ impl Kernel {
|
||||
// Create agent loop with model configuration
|
||||
let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false);
|
||||
let tools = self.create_tool_registry(subagent_enabled);
|
||||
self.skill_executor.set_tool_registry(tools.clone());
|
||||
let driver: Arc<dyn LlmDriver> = Arc::new(
|
||||
RetryDriver::new(self.driver.clone(), RetryConfig::default())
|
||||
);
|
||||
let mut loop_runner = AgentLoop::new(
|
||||
*agent_id,
|
||||
self.driver.clone(),
|
||||
driver,
|
||||
tools,
|
||||
self.memory.clone(),
|
||||
)
|
||||
.with_model(&model)
|
||||
.with_skill_executor(self.skill_executor.clone())
|
||||
.with_hand_executor(self.hand_executor.clone())
|
||||
.with_max_tokens(agent_config.max_tokens.unwrap_or_else(|| self.config.max_tokens()))
|
||||
.with_temperature(agent_config.temperature.unwrap_or_else(|| self.config.temperature()))
|
||||
.with_compaction_threshold(
|
||||
@@ -168,14 +279,19 @@ impl Kernel {
|
||||
// Create agent loop with model configuration
|
||||
let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false);
|
||||
let tools = self.create_tool_registry(subagent_enabled);
|
||||
self.skill_executor.set_tool_registry(tools.clone());
|
||||
let driver: Arc<dyn LlmDriver> = Arc::new(
|
||||
RetryDriver::new(self.driver.clone(), RetryConfig::default())
|
||||
);
|
||||
let mut loop_runner = AgentLoop::new(
|
||||
*agent_id,
|
||||
self.driver.clone(),
|
||||
driver,
|
||||
tools,
|
||||
self.memory.clone(),
|
||||
)
|
||||
.with_model(&model)
|
||||
.with_skill_executor(self.skill_executor.clone())
|
||||
.with_hand_executor(self.hand_executor.clone())
|
||||
.with_max_tokens(agent_config.max_tokens.unwrap_or_else(|| self.config.max_tokens()))
|
||||
.with_temperature(agent_config.temperature.unwrap_or_else(|| self.config.temperature()))
|
||||
.with_compaction_threshold(
|
||||
@@ -318,6 +434,7 @@ impl Kernel {
|
||||
prompt.push_str("- Provide clear options when possible\n");
|
||||
prompt.push_str("- Include brief context about why you're asking\n");
|
||||
prompt.push_str("- After receiving clarification, proceed immediately\n");
|
||||
prompt.push_str("- CRITICAL: When calling ask_clarification, do NOT repeat the options in your text response. The options will be shown in a dedicated card above your reply. Simply greet the user and briefly explain why you need clarification — avoid phrases like \"以下信息\" or \"the following options\" that imply a list follows in your text\n");
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ mod triggers;
|
||||
mod approvals;
|
||||
mod orchestration;
|
||||
mod a2a;
|
||||
mod evolution_bridge;
|
||||
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{broadcast, Mutex};
|
||||
@@ -24,10 +25,12 @@ use crate::config::KernelConfig;
|
||||
use zclaw_memory::MemoryStore;
|
||||
use zclaw_runtime::{LlmDriver, ToolRegistry, tool::SkillExecutor};
|
||||
use zclaw_skills::SkillRegistry;
|
||||
use zclaw_hands::{HandRegistry, hands::{BrowserHand, QuizHand, ResearcherHand, CollectorHand, ClipHand, TwitterHand, ReminderHand, quiz::LlmQuizGenerator}};
|
||||
use zclaw_hands::{HandRegistry, hands::{BrowserHand, QuizHand, ResearcherHand, CollectorHand, ClipHand, TwitterHand, ReminderHand, DailyReportHand, quiz::LlmQuizGenerator}};
|
||||
|
||||
pub use adapters::KernelSkillExecutor;
|
||||
pub use adapters::KernelHandExecutor;
|
||||
pub use messaging::ChatModeConfig;
|
||||
pub use messaging::ScheduleInterceptResult;
|
||||
|
||||
/// The ZCLAW Kernel
|
||||
pub struct Kernel {
|
||||
@@ -40,15 +43,22 @@ pub struct Kernel {
|
||||
llm_completer: Arc<dyn zclaw_skills::LlmCompleter>,
|
||||
skills: Arc<SkillRegistry>,
|
||||
skill_executor: Arc<KernelSkillExecutor>,
|
||||
hand_executor: Arc<KernelHandExecutor>,
|
||||
hands: Arc<HandRegistry>,
|
||||
/// Cached hand configs (populated at boot, used for tool registry)
|
||||
hand_configs: Vec<zclaw_hands::HandConfig>,
|
||||
trigger_manager: crate::trigger_manager::TriggerManager,
|
||||
pending_approvals: Arc<Mutex<Vec<ApprovalEntry>>>,
|
||||
/// Running hand runs that can be cancelled (run_id -> cancelled flag)
|
||||
running_hand_runs: Arc<dashmap::DashMap<zclaw_types::HandRunId, Arc<std::sync::atomic::AtomicBool>>>,
|
||||
/// Shared memory storage backend for Growth system
|
||||
viking: Arc<zclaw_runtime::VikingAdapter>,
|
||||
/// Cached GrowthIntegration — avoids recreating empty scorer per request
|
||||
growth: std::sync::Mutex<Option<std::sync::Arc<zclaw_runtime::GrowthIntegration>>>,
|
||||
/// Optional LLM driver for memory extraction (set by Tauri desktop layer)
|
||||
extraction_driver: Option<Arc<dyn zclaw_runtime::LlmDriverForExtraction>>,
|
||||
/// Optional embedding client for semantic search (set by Tauri desktop layer)
|
||||
embedding_client: Option<Arc<dyn zclaw_runtime::EmbeddingClient>>,
|
||||
/// MCP tool adapters — shared with Tauri MCP manager, updated dynamically
|
||||
mcp_adapters: Arc<std::sync::RwLock<Vec<zclaw_protocols::McpToolAdapter>>>,
|
||||
/// Dynamic industry keyword configs — shared with Tauri frontend, loaded from SaaS
|
||||
@@ -94,10 +104,17 @@ impl Kernel {
|
||||
hands.register(Arc::new(ClipHand::new())).await;
|
||||
hands.register(Arc::new(TwitterHand::new())).await;
|
||||
hands.register(Arc::new(ReminderHand::new())).await;
|
||||
hands.register(Arc::new(DailyReportHand::new())).await;
|
||||
|
||||
// Cache hand configs for tool registry (sync access from create_tool_registry)
|
||||
let hand_configs = hands.list().await;
|
||||
|
||||
// Create skill executor
|
||||
let skill_executor = Arc::new(KernelSkillExecutor::new(skills.clone(), driver.clone()));
|
||||
|
||||
// Create hand executor — bridges HandTool calls to the HandRegistry
|
||||
let hand_executor = Arc::new(KernelHandExecutor::new(hands.clone()));
|
||||
|
||||
// Create LLM completer for skill system (shared with skill_executor)
|
||||
let llm_completer: Arc<dyn zclaw_skills::LlmCompleter> =
|
||||
Arc::new(adapters::LlmDriverAdapter {
|
||||
@@ -145,12 +162,16 @@ impl Kernel {
|
||||
llm_completer,
|
||||
skills,
|
||||
skill_executor,
|
||||
hand_executor,
|
||||
hands,
|
||||
hand_configs,
|
||||
trigger_manager,
|
||||
pending_approvals: Arc::new(Mutex::new(Vec::new())),
|
||||
running_hand_runs: Arc::new(dashmap::DashMap::new()),
|
||||
viking,
|
||||
growth: std::sync::Mutex::new(None),
|
||||
extraction_driver: None,
|
||||
embedding_client: None,
|
||||
mcp_adapters: Arc::new(std::sync::RwLock::new(Vec::new())),
|
||||
industry_keywords: Arc::new(tokio::sync::RwLock::new(Vec::new())),
|
||||
a2a_router,
|
||||
@@ -158,7 +179,89 @@ impl Kernel {
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a tool registry with built-in tools + MCP tools.
|
||||
/// Boot the kernel with a pre-configured driver (for testing).
|
||||
///
|
||||
/// **TEST ONLY.** Do not call from production code.
|
||||
///
|
||||
/// Differences from `boot()`:
|
||||
/// - Uses the provided `driver` instead of `config.create_driver()`
|
||||
/// - Uses an in-memory SQLite database (no filesystem side effects)
|
||||
/// - Skips agent recovery from persistent storage (`memory.list_agents_with_runtime()`)
|
||||
pub async fn boot_with_driver(
|
||||
config: KernelConfig,
|
||||
driver: Arc<dyn LlmDriver>,
|
||||
) -> Result<Self> {
|
||||
let memory = Arc::new(MemoryStore::new("sqlite::memory:").await?);
|
||||
|
||||
let registry = AgentRegistry::new();
|
||||
let capabilities = CapabilityManager::new();
|
||||
let events = EventBus::new();
|
||||
let skills = Arc::new(SkillRegistry::new());
|
||||
|
||||
if let Some(ref skills_dir) = config.skills_dir {
|
||||
if skills_dir.exists() {
|
||||
skills.add_skill_dir(skills_dir.clone()).await?;
|
||||
}
|
||||
}
|
||||
|
||||
let hands = Arc::new(HandRegistry::new());
|
||||
let quiz_model = config.model().to_string();
|
||||
let quiz_generator = Arc::new(LlmQuizGenerator::new(driver.clone(), quiz_model));
|
||||
hands.register(Arc::new(BrowserHand::new())).await;
|
||||
hands.register(Arc::new(QuizHand::with_generator(quiz_generator))).await;
|
||||
hands.register(Arc::new(ResearcherHand::new())).await;
|
||||
hands.register(Arc::new(CollectorHand::new())).await;
|
||||
hands.register(Arc::new(ClipHand::new())).await;
|
||||
hands.register(Arc::new(TwitterHand::new())).await;
|
||||
hands.register(Arc::new(ReminderHand::new())).await;
|
||||
hands.register(Arc::new(DailyReportHand::new())).await;
|
||||
|
||||
let hand_configs = hands.list().await;
|
||||
let skill_executor = Arc::new(KernelSkillExecutor::new(skills.clone(), driver.clone()));
|
||||
let hand_executor = Arc::new(KernelHandExecutor::new(hands.clone()));
|
||||
let llm_completer: Arc<dyn zclaw_skills::LlmCompleter> =
|
||||
Arc::new(adapters::LlmDriverAdapter {
|
||||
driver: driver.clone(),
|
||||
max_tokens: config.max_tokens(),
|
||||
temperature: config.temperature(),
|
||||
});
|
||||
|
||||
let trigger_manager = crate::trigger_manager::TriggerManager::new(hands.clone());
|
||||
let viking = Arc::new(zclaw_runtime::VikingAdapter::in_memory());
|
||||
|
||||
let a2a_router = {
|
||||
let kernel_agent_id = AgentId::new();
|
||||
Arc::new(A2aRouter::new(kernel_agent_id))
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
config,
|
||||
registry,
|
||||
capabilities,
|
||||
events,
|
||||
memory,
|
||||
driver,
|
||||
llm_completer,
|
||||
skills,
|
||||
skill_executor,
|
||||
hand_executor,
|
||||
hands,
|
||||
hand_configs,
|
||||
trigger_manager,
|
||||
pending_approvals: Arc::new(Mutex::new(Vec::new())),
|
||||
running_hand_runs: Arc::new(dashmap::DashMap::new()),
|
||||
viking,
|
||||
growth: std::sync::Mutex::new(None),
|
||||
extraction_driver: None,
|
||||
embedding_client: None,
|
||||
mcp_adapters: Arc::new(std::sync::RwLock::new(Vec::new())),
|
||||
industry_keywords: Arc::new(tokio::sync::RwLock::new(Vec::new())),
|
||||
a2a_router,
|
||||
a2a_inboxes: Arc::new(dashmap::DashMap::new()),
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a tool registry with built-in tools + Hand tools + MCP tools.
|
||||
/// When `subagent_enabled` is false, TaskTool is excluded to prevent
|
||||
/// the LLM from attempting sub-agent delegation in non-Ultra modes.
|
||||
pub(crate) fn create_tool_registry(&self, subagent_enabled: bool) -> ToolRegistry {
|
||||
@@ -175,6 +278,20 @@ impl Kernel {
|
||||
tools.register(Box::new(task_tool));
|
||||
}
|
||||
|
||||
// Register Hand tools — expose registered Hands as LLM-callable tools
|
||||
// (e.g., hand_quiz, hand_researcher, hand_browser, etc.)
|
||||
for config in &self.hand_configs {
|
||||
if !config.enabled {
|
||||
continue;
|
||||
}
|
||||
let tool = zclaw_runtime::tool::hand_tool::HandTool::from_config(
|
||||
&config.id,
|
||||
&config.description,
|
||||
config.input_schema.clone(),
|
||||
);
|
||||
tools.register(Box::new(tool));
|
||||
}
|
||||
|
||||
// Register MCP tools (dynamically updated by Tauri MCP manager)
|
||||
if let Ok(adapters) = self.mcp_adapters.read() {
|
||||
for adapter in adapters.iter() {
|
||||
@@ -229,7 +346,17 @@ impl Kernel {
|
||||
}
|
||||
|
||||
// Build semantic router from the skill registry (75 SKILL.md loaded at boot)
|
||||
let semantic_router = SemanticSkillRouter::new_tf_idf_only(self.skills.clone());
|
||||
let semantic_router = if let Some(ref embed_client) = self.embedding_client {
|
||||
let adapter = crate::skill_router::EmbeddingAdapter::new(embed_client.clone());
|
||||
let mut router = SemanticSkillRouter::new(self.skills.clone(), Arc::new(adapter));
|
||||
if let Some(llm_fallback) = self.make_llm_skill_fallback() {
|
||||
router = router.with_llm_fallback(llm_fallback);
|
||||
}
|
||||
tracing::debug!("[Kernel] SemanticSkillRouter created with embedding support");
|
||||
router
|
||||
} else {
|
||||
SemanticSkillRouter::new_tf_idf_only(self.skills.clone())
|
||||
};
|
||||
let adapter = SemanticRouterAdapter::new(Arc::new(semantic_router));
|
||||
let mw = zclaw_runtime::middleware::butler_router::ButlerRouterMiddleware::with_router_and_shared_keywords(
|
||||
Box::new(adapter),
|
||||
@@ -238,22 +365,28 @@ impl Kernel {
|
||||
chain.register(Arc::new(mw));
|
||||
}
|
||||
|
||||
// Data masking middleware — mask sensitive entities before any other processing
|
||||
// NOTE: Registration order does NOT determine execution order.
|
||||
// The chain sorts by priority() ascending before execution.
|
||||
// Execution order: Evolution(78) → ButlerRouter(80) → DataMasking(90) → ...
|
||||
{
|
||||
use std::sync::Arc;
|
||||
let masker = Arc::new(zclaw_runtime::middleware::data_masking::DataMasker::new());
|
||||
let mw = zclaw_runtime::middleware::data_masking::DataMaskingMiddleware::new(masker);
|
||||
chain.register(Arc::new(mw));
|
||||
}
|
||||
|
||||
// Growth integration — shared VikingAdapter for memory middleware & compaction
|
||||
let mut growth = zclaw_runtime::GrowthIntegration::new(self.viking.clone());
|
||||
if let Some(ref driver) = self.extraction_driver {
|
||||
growth = growth.with_llm_driver(driver.clone());
|
||||
}
|
||||
// Growth integration — cached to avoid recreating empty scorer per request
|
||||
let growth = {
|
||||
let mut cached = self.growth.lock().expect("growth lock");
|
||||
if cached.is_none() {
|
||||
let mut g = zclaw_runtime::GrowthIntegration::new(self.viking.clone());
|
||||
if let Some(ref driver) = self.extraction_driver {
|
||||
g = g.with_llm_driver(driver.clone());
|
||||
}
|
||||
// Propagate embedding client to memory retriever if configured
|
||||
if let Some(ref embed_client) = self.embedding_client {
|
||||
g.configure_embedding(embed_client.clone());
|
||||
}
|
||||
// Bridge UserProfileStore so extract_combined() can persist profile signals
|
||||
{
|
||||
let profile_store = zclaw_memory::UserProfileStore::new(self.memory.pool());
|
||||
g = g.with_profile_store(std::sync::Arc::new(profile_store));
|
||||
tracing::info!("[Kernel] UserProfileStore bridged to GrowthIntegration");
|
||||
}
|
||||
*cached = Some(std::sync::Arc::new(g));
|
||||
}
|
||||
cached.as_ref().expect("growth present").clone()
|
||||
};
|
||||
|
||||
// Evolution middleware — pushes evolution candidate skills into system prompt
|
||||
// priority=78, executed first by chain (before ButlerRouter@80)
|
||||
@@ -270,6 +403,9 @@ impl Kernel {
|
||||
if let Some(ref driver) = self.extraction_driver {
|
||||
growth_for_compaction = growth_for_compaction.with_llm_driver(driver.clone());
|
||||
}
|
||||
if let Some(ref embed_client) = self.embedding_client {
|
||||
growth_for_compaction.configure_embedding(embed_client.clone());
|
||||
}
|
||||
let mw = zclaw_runtime::middleware::compaction::CompactionMiddleware::new(
|
||||
threshold,
|
||||
zclaw_runtime::CompactionConfig::default(),
|
||||
@@ -282,7 +418,7 @@ impl Kernel {
|
||||
// Memory middleware — auto-extract memories + check evolution after conversations
|
||||
{
|
||||
use std::sync::Arc;
|
||||
let mw = zclaw_runtime::middleware::memory::MemoryMiddleware::new(growth)
|
||||
let mw = zclaw_runtime::middleware::memory::MemoryMiddleware::new(growth.clone())
|
||||
.with_evolution(evolution_mw);
|
||||
chain.register(Arc::new(mw));
|
||||
}
|
||||
@@ -415,6 +551,10 @@ impl Kernel {
|
||||
pub fn set_viking(&mut self, viking: Arc<zclaw_runtime::VikingAdapter>) {
|
||||
tracing::info!("[Kernel] Replacing in-memory VikingAdapter with persistent storage");
|
||||
self.viking = viking;
|
||||
// Invalidate cached GrowthIntegration so next request builds with new storage
|
||||
if let Ok(mut g) = self.growth.lock() {
|
||||
*g = None;
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a reference to the shared VikingAdapter
|
||||
@@ -422,6 +562,11 @@ impl Kernel {
|
||||
self.viking.clone()
|
||||
}
|
||||
|
||||
/// Get a reference to the shared MemoryStore
|
||||
pub fn memory(&self) -> Arc<MemoryStore> {
|
||||
self.memory.clone()
|
||||
}
|
||||
|
||||
/// Set the LLM extraction driver for the Growth system.
|
||||
///
|
||||
/// Required for `MemoryMiddleware` to extract memories from conversations
|
||||
@@ -429,6 +574,29 @@ impl Kernel {
|
||||
pub fn set_extraction_driver(&mut self, driver: Arc<dyn zclaw_runtime::LlmDriverForExtraction>) {
|
||||
tracing::info!("[Kernel] Extraction driver configured for Growth system");
|
||||
self.extraction_driver = Some(driver);
|
||||
// Invalidate cached GrowthIntegration so next request uses new driver
|
||||
if let Ok(mut g) = self.growth.lock() {
|
||||
*g = None;
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the embedding client for semantic search.
|
||||
///
|
||||
/// Propagates to both the skill router (ButlerRouter) and memory retrieval
|
||||
/// (GrowthIntegration). The next middleware chain creation will use the
|
||||
/// configured client for embedding-based similarity.
|
||||
pub fn set_embedding_client(&mut self, client: Arc<dyn zclaw_runtime::EmbeddingClient>) {
|
||||
tracing::info!("[Kernel] Embedding client configured for semantic search");
|
||||
self.embedding_client = Some(client);
|
||||
// Invalidate cached GrowthIntegration so next request builds with new embedding
|
||||
if let Ok(mut g) = self.growth.lock() {
|
||||
*g = None;
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an LLM skill fallback using the kernel's LLM driver.
|
||||
fn make_llm_skill_fallback(&self) -> Option<Arc<dyn zclaw_skills::semantic_router::RuntimeLlmIntent>> {
|
||||
Some(Arc::new(crate::skill_router::LlmSkillFallback::new(self.driver.clone())))
|
||||
}
|
||||
|
||||
/// Get a reference to the shared MCP adapters list.
|
||||
|
||||
@@ -76,4 +76,77 @@ impl Kernel {
|
||||
}
|
||||
self.skills.execute(&zclaw_types::SkillId::new(id), &ctx, input).await
|
||||
}
|
||||
|
||||
/// Generate a skill from an aggregated pattern and register it.
|
||||
///
|
||||
/// Full pipeline:
|
||||
/// 1. Build LLM prompt from pattern
|
||||
/// 2. Call LLM to get JSON response
|
||||
/// 3. Parse response into SkillCandidate
|
||||
/// 4. Validate through QualityGate (threshold 0.85 for auto-mode)
|
||||
/// 5. Convert to SkillManifest (PromptOnly, disabled by default)
|
||||
/// 6. Persist to disk via SkillRegistry
|
||||
pub async fn generate_and_register_skill(
|
||||
&self,
|
||||
pattern: &zclaw_growth::pattern_aggregator::AggregatedPattern,
|
||||
) -> Result<String> {
|
||||
// 1. Build prompt
|
||||
let prompt = zclaw_growth::skill_generator::SkillGenerator::build_prompt(pattern);
|
||||
|
||||
// 2. Call LLM
|
||||
let request = zclaw_runtime::driver::CompletionRequest {
|
||||
model: self.driver.provider().to_string(),
|
||||
system: Some("你是技能设计专家,只返回 JSON 格式的技能定义。".to_string()),
|
||||
messages: vec![zclaw_types::Message::user(prompt)],
|
||||
max_tokens: Some(1024),
|
||||
temperature: Some(0.3),
|
||||
stream: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response = self.driver.complete(request).await?;
|
||||
let text = response.content.iter()
|
||||
.filter_map(|block| match block {
|
||||
zclaw_runtime::driver::ContentBlock::Text { text } => Some(text.as_str()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.join("");
|
||||
|
||||
// 3. Parse into SkillCandidate
|
||||
let candidate = zclaw_growth::skill_generator::SkillGenerator::parse_response(
|
||||
&text, pattern,
|
||||
)?;
|
||||
|
||||
// 4. Validate through QualityGate (higher threshold for auto-generation)
|
||||
let existing_triggers: Vec<String> = self.skills.list().await
|
||||
.into_iter()
|
||||
.flat_map(|m| m.triggers)
|
||||
.collect();
|
||||
let gate = zclaw_growth::quality_gate::QualityGate::new(0.85, existing_triggers);
|
||||
let report = gate.validate_skill(&candidate);
|
||||
if !report.passed {
|
||||
return Err(zclaw_types::ZclawError::ConfigError(format!(
|
||||
"QualityGate rejected: {}", report.issues.join("; ")
|
||||
)));
|
||||
}
|
||||
|
||||
// 5. Convert to SkillManifest (PromptOnly, disabled)
|
||||
let manifest = super::evolution_bridge::candidate_to_manifest(&candidate);
|
||||
let skill_id = manifest.id.to_string();
|
||||
|
||||
// 6. Persist to disk
|
||||
let skills_dir = self.config.skills_dir.as_ref()
|
||||
.ok_or_else(|| zclaw_types::ZclawError::InvalidInput(
|
||||
"Skills directory not configured".into()
|
||||
))?;
|
||||
self.skills.create_skill(skills_dir, manifest).await?;
|
||||
|
||||
tracing::info!(
|
||||
"[Kernel] Auto-generated skill '{}' (id={}) registered (disabled)",
|
||||
candidate.name, skill_id
|
||||
);
|
||||
|
||||
Ok(skill_id)
|
||||
}
|
||||
}
|
||||
|
||||
143
crates/zclaw-kernel/tests/chat_chain.rs
Normal file
143
crates/zclaw-kernel/tests/chat_chain.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
//! Conversation chain seam tests
|
||||
//!
|
||||
//! Verifies the integration seams between layers in the chat pipeline:
|
||||
//! 1. Tauri→Kernel: chat command correctly forwards to kernel
|
||||
//! 2. Kernel→LLM: middleware-processed prompt reaches MockLlmDriver
|
||||
//! 3. LLM→UI: event ordering is delta → delta → complete
|
||||
//! 4. Streaming: full send→stream→complete lifecycle
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_kernel::{Kernel, KernelConfig};
|
||||
use zclaw_runtime::test_util::MockLlmDriver;
|
||||
use zclaw_runtime::{LoopEvent, LlmDriver};
|
||||
use zclaw_types::AgentConfig;
|
||||
|
||||
/// Create a test kernel with MockLlmDriver and a registered agent.
|
||||
/// The mock is pre-configured with a default text response.
|
||||
async fn test_kernel() -> (Kernel, zclaw_types::AgentId) {
|
||||
let mock = MockLlmDriver::new().with_text_response("Hello from mock!");
|
||||
let config = KernelConfig::default();
|
||||
let kernel = Kernel::boot_with_driver(config, Arc::new(mock) as Arc<dyn LlmDriver>)
|
||||
.await
|
||||
.expect("kernel boot");
|
||||
|
||||
let agent_config = AgentConfig::new("test-agent")
|
||||
.with_system_prompt("You are a test assistant.");
|
||||
let id = agent_config.id;
|
||||
kernel.spawn_agent(agent_config).await.expect("spawn agent");
|
||||
|
||||
(kernel, id)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seam 1: Tauri → Kernel (non-streaming)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn seam_tauri_to_kernel_non_streaming() {
|
||||
let (kernel, agent_id) = test_kernel().await;
|
||||
|
||||
let result = kernel
|
||||
.send_message(&agent_id, "Hi".to_string())
|
||||
.await
|
||||
.expect("send_message");
|
||||
|
||||
assert!(!result.content.is_empty(), "response content should not be empty");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seam 2: Kernel → LLM (middleware processes prompt before reaching driver)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn seam_kernel_to_llm_prompt_reaches_driver() {
|
||||
let (kernel, agent_id) = test_kernel().await;
|
||||
|
||||
let _ = kernel
|
||||
.send_message(&agent_id, "What is 2+2?".to_string())
|
||||
.await;
|
||||
|
||||
// Verify the kernel's driver was called by checking a second call succeeds
|
||||
let result2 = kernel
|
||||
.send_message(&agent_id, "And 3+3?".to_string())
|
||||
.await
|
||||
.expect("second send_message");
|
||||
|
||||
assert!(!result2.content.is_empty(), "second response should not be empty");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seam 3: LLM → UI event ordering (delta → delta → complete)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn seam_llm_to_ui_event_ordering() {
|
||||
let (kernel, agent_id) = test_kernel().await;
|
||||
|
||||
let mut rx = kernel
|
||||
.send_message_stream(&agent_id, "Hi".to_string())
|
||||
.await
|
||||
.expect("send_message_stream");
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(event) = rx.recv().await {
|
||||
match &event {
|
||||
LoopEvent::Delta(_) => events.push("delta"),
|
||||
LoopEvent::ThinkingDelta(_) => events.push("thinking"),
|
||||
LoopEvent::Complete(_) => {
|
||||
events.push("complete");
|
||||
break;
|
||||
}
|
||||
LoopEvent::Error(msg) => {
|
||||
panic!("unexpected error: {}", msg);
|
||||
}
|
||||
LoopEvent::ToolStart { .. } => events.push("tool_start"),
|
||||
LoopEvent::ToolEnd { .. } => events.push("tool_end"),
|
||||
LoopEvent::SubtaskStatus { .. } => events.push("subtask"),
|
||||
LoopEvent::IterationStart { .. } => events.push("iteration"),
|
||||
}
|
||||
}
|
||||
|
||||
assert!(!events.is_empty(), "should receive events");
|
||||
assert_eq!(events.last(), Some(&"complete"), "last event must be complete");
|
||||
assert!(
|
||||
events.iter().any(|e| *e == "delta"),
|
||||
"should have at least one delta event"
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seam 4: Full streaming lifecycle with consecutive messages
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn seam_streaming_consecutive_messages() {
|
||||
let (kernel, agent_id) = test_kernel().await;
|
||||
|
||||
// First message
|
||||
let mut rx1 = kernel
|
||||
.send_message_stream(&agent_id, "First message".to_string())
|
||||
.await
|
||||
.expect("first stream");
|
||||
|
||||
while let Some(event) = rx1.recv().await {
|
||||
if let LoopEvent::Complete(result) = event {
|
||||
assert!(result.output_tokens > 0, "first response should have output tokens");
|
||||
}
|
||||
}
|
||||
|
||||
// Second message (should use new session)
|
||||
let mut rx2 = kernel
|
||||
.send_message_stream(&agent_id, "Second message".to_string())
|
||||
.await
|
||||
.expect("second stream");
|
||||
|
||||
let mut got_complete = false;
|
||||
while let Some(event) = rx2.recv().await {
|
||||
if let LoopEvent::Complete(result) = event {
|
||||
got_complete = true;
|
||||
assert!(result.output_tokens > 0, "second response should have output tokens");
|
||||
}
|
||||
}
|
||||
assert!(got_complete, "second stream should complete");
|
||||
}
|
||||
236
crates/zclaw-kernel/tests/hand_chain.rs
Normal file
236
crates/zclaw-kernel/tests/hand_chain.rs
Normal file
@@ -0,0 +1,236 @@
|
||||
//! Hands chain seam tests
|
||||
//!
|
||||
//! Verifies the integration seams in the Hand execution pipeline:
|
||||
//! 1. Tool routing: LLM tool_call → HandRegistry correct dispatch
|
||||
//! 2. Execution callback: Hand complete → LoopEvent emitted
|
||||
//! 3. Non-hand tool routing
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_kernel::{Kernel, KernelConfig};
|
||||
use zclaw_runtime::test_util::MockLlmDriver;
|
||||
use zclaw_runtime::stream::StreamChunk;
|
||||
use zclaw_runtime::{LoopEvent, LlmDriver};
|
||||
use zclaw_types::AgentConfig;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seam 1: Tool routing — LLM tool_call triggers HandTool dispatch
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn seam_hand_tool_routing() {
|
||||
// First stream: tool_use for hand_quiz
|
||||
let mock = MockLlmDriver::new()
|
||||
.with_stream_chunks(vec![
|
||||
StreamChunk::TextDelta { delta: "Let me generate a quiz.".to_string() },
|
||||
StreamChunk::ToolUseStart { id: "call_quiz_1".to_string(), name: "hand_quiz".to_string() },
|
||||
StreamChunk::ToolUseEnd {
|
||||
id: "call_quiz_1".to_string(),
|
||||
input: serde_json::json!({ "topic": "math", "count": 3 }),
|
||||
},
|
||||
StreamChunk::Complete {
|
||||
input_tokens: 10,
|
||||
output_tokens: 20,
|
||||
stop_reason: "tool_use".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
])
|
||||
// Second stream: final text after tool executes
|
||||
.with_stream_chunks(vec![
|
||||
StreamChunk::TextDelta { delta: "Here is your quiz!".to_string() },
|
||||
StreamChunk::Complete {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
]);
|
||||
|
||||
let config = KernelConfig::default();
|
||||
let kernel = Kernel::boot_with_driver(config, Arc::new(mock) as Arc<dyn LlmDriver>)
|
||||
.await
|
||||
.expect("kernel boot");
|
||||
|
||||
let agent_config = AgentConfig::new("test-agent")
|
||||
.with_system_prompt("You are a test assistant.");
|
||||
let id = agent_config.id;
|
||||
kernel.spawn_agent(agent_config).await.expect("spawn agent");
|
||||
|
||||
let mut rx = kernel
|
||||
.send_message_stream(&id, "Generate a math quiz".to_string())
|
||||
.await
|
||||
.expect("stream");
|
||||
|
||||
let mut tool_starts = Vec::new();
|
||||
let mut tool_ends = Vec::new();
|
||||
let mut got_complete = false;
|
||||
while let Some(event) = rx.recv().await {
|
||||
match &event {
|
||||
LoopEvent::ToolStart { name, input } => {
|
||||
tool_starts.push((name.clone(), input.clone()));
|
||||
}
|
||||
LoopEvent::ToolEnd { name, output } => {
|
||||
tool_ends.push((name.clone(), output.clone()));
|
||||
}
|
||||
LoopEvent::Complete(_) => {
|
||||
got_complete = true;
|
||||
break;
|
||||
}
|
||||
LoopEvent::Error(msg) => {
|
||||
panic!("unexpected error: {}", msg);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(got_complete, "stream should complete");
|
||||
assert!(
|
||||
tool_starts.iter().any(|(n, _)| n == "hand_quiz"),
|
||||
"should see hand_quiz tool_start, got: {:?}",
|
||||
tool_starts
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seam 2: Execution callback — Hand completes and produces tool_end
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn seam_hand_execution_callback() {
|
||||
let mock = MockLlmDriver::new()
|
||||
.with_stream_chunks(vec![
|
||||
StreamChunk::ToolUseStart { id: "call_quiz_1".to_string(), name: "hand_quiz".to_string() },
|
||||
StreamChunk::ToolUseEnd {
|
||||
id: "call_quiz_1".to_string(),
|
||||
input: serde_json::json!({ "topic": "math" }),
|
||||
},
|
||||
StreamChunk::Complete {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
stop_reason: "tool_use".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
])
|
||||
.with_stream_chunks(vec![
|
||||
StreamChunk::TextDelta { delta: "Done!".to_string() },
|
||||
StreamChunk::Complete {
|
||||
input_tokens: 5,
|
||||
output_tokens: 1,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
]);
|
||||
|
||||
let config = KernelConfig::default();
|
||||
let kernel = Kernel::boot_with_driver(config, Arc::new(mock) as Arc<dyn LlmDriver>)
|
||||
.await
|
||||
.expect("kernel boot");
|
||||
|
||||
let agent_config = AgentConfig::new("test-agent");
|
||||
let id = agent_config.id;
|
||||
kernel.spawn_agent(agent_config).await.expect("spawn agent");
|
||||
|
||||
let mut rx = kernel
|
||||
.send_message_stream(&id, "Quiz me".to_string())
|
||||
.await
|
||||
.expect("stream");
|
||||
|
||||
let mut got_tool_end = false;
|
||||
let mut got_complete = false;
|
||||
while let Some(event) = rx.recv().await {
|
||||
match &event {
|
||||
LoopEvent::ToolEnd { name, output } => {
|
||||
got_tool_end = true;
|
||||
assert!(name.starts_with("hand_"), "tool_end should be hand tool, got: {}", name);
|
||||
// Quiz hand returns structured JSON output
|
||||
assert!(output.is_object() || output.is_string(), "output should be JSON, got: {}", output);
|
||||
}
|
||||
LoopEvent::Complete(_) => {
|
||||
got_complete = true;
|
||||
break;
|
||||
}
|
||||
LoopEvent::Error(msg) => {
|
||||
panic!("unexpected error: {}", msg);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(got_tool_end, "should receive tool_end after hand execution");
|
||||
assert!(got_complete, "should complete after tool_end");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Seam 3: Non-hand tool call (generic tool) routes correctly
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn seam_generic_tool_routing() {
|
||||
// Mock with a generic tool call (web_search)
|
||||
let mock = MockLlmDriver::new()
|
||||
.with_stream_chunks(vec![
|
||||
StreamChunk::ToolUseStart { id: "call_ws_1".to_string(), name: "web_search".to_string() },
|
||||
StreamChunk::ToolUseEnd {
|
||||
id: "call_ws_1".to_string(),
|
||||
input: serde_json::json!({ "query": "test query" }),
|
||||
},
|
||||
StreamChunk::Complete {
|
||||
input_tokens: 10,
|
||||
output_tokens: 5,
|
||||
stop_reason: "tool_use".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
])
|
||||
.with_stream_chunks(vec![
|
||||
StreamChunk::TextDelta { delta: "Search results found.".to_string() },
|
||||
StreamChunk::Complete {
|
||||
input_tokens: 5,
|
||||
output_tokens: 3,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
]);
|
||||
|
||||
let config = KernelConfig::default();
|
||||
let kernel = Kernel::boot_with_driver(config, Arc::new(mock) as Arc<dyn LlmDriver>)
|
||||
.await
|
||||
.expect("kernel boot");
|
||||
|
||||
let agent_config = AgentConfig::new("test-agent");
|
||||
let id = agent_config.id;
|
||||
kernel.spawn_agent(agent_config).await.expect("spawn agent");
|
||||
|
||||
let mut rx = kernel
|
||||
.send_message_stream(&id, "Search for test".to_string())
|
||||
.await
|
||||
.expect("stream");
|
||||
|
||||
let mut tool_names = Vec::new();
|
||||
let mut got_complete = false;
|
||||
while let Some(event) = rx.recv().await {
|
||||
match &event {
|
||||
LoopEvent::ToolStart { name, .. } => tool_names.push(name.clone()),
|
||||
LoopEvent::ToolEnd { name, .. } => tool_names.push(format!("end:{}", name)),
|
||||
LoopEvent::Complete(_) => {
|
||||
got_complete = true;
|
||||
break;
|
||||
}
|
||||
LoopEvent::Error(msg) => {
|
||||
panic!("unexpected error: {}", msg);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(got_complete, "stream should complete");
|
||||
assert!(
|
||||
tool_names.iter().any(|n| n.contains("web_search")),
|
||||
"should see web_search tool events, got: {:?}",
|
||||
tool_names
|
||||
);
|
||||
}
|
||||
59
crates/zclaw-kernel/tests/smoke_chat.rs
Normal file
59
crates/zclaw-kernel/tests/smoke_chat.rs
Normal file
@@ -0,0 +1,59 @@
|
||||
//! Chat smoke test — full lifecycle: send → stream → persist
|
||||
//!
|
||||
//! Uses MockLlmDriver to verify the complete chat pipeline without a real LLM.
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_kernel::{Kernel, KernelConfig};
|
||||
use zclaw_runtime::test_util::MockLlmDriver;
|
||||
use zclaw_runtime::{LoopEvent, LlmDriver};
|
||||
use zclaw_types::AgentConfig;
|
||||
|
||||
#[tokio::test]
|
||||
async fn smoke_chat_full_lifecycle() {
|
||||
let mock = MockLlmDriver::new().with_text_response("Hello! I am the mock assistant.");
|
||||
let config = KernelConfig::default();
|
||||
let kernel = Kernel::boot_with_driver(config, Arc::new(mock) as Arc<dyn LlmDriver>)
|
||||
.await
|
||||
.expect("kernel boot");
|
||||
|
||||
let agent = AgentConfig::new("smoke-agent")
|
||||
.with_system_prompt("You are a test assistant.");
|
||||
let id = agent.id;
|
||||
kernel.spawn_agent(agent).await.expect("spawn agent");
|
||||
|
||||
// 1. Non-streaming: send and get response
|
||||
let resp = kernel.send_message(&id, "Hello".to_string()).await.expect("send");
|
||||
assert!(!resp.content.is_empty());
|
||||
assert!(resp.output_tokens > 0);
|
||||
|
||||
// 2. Streaming: send and collect all events
|
||||
let mut rx = kernel
|
||||
.send_message_stream(&id, "Tell me more".to_string())
|
||||
.await
|
||||
.expect("stream");
|
||||
|
||||
let mut delta_count = 0;
|
||||
let mut complete_result = None;
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
LoopEvent::Delta(text) => {
|
||||
delta_count += 1;
|
||||
assert!(!text.is_empty(), "delta should have content");
|
||||
}
|
||||
LoopEvent::Complete(result) => {
|
||||
complete_result = Some(result);
|
||||
break;
|
||||
}
|
||||
LoopEvent::Error(msg) => panic!("unexpected error: {}", msg),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(delta_count > 0, "should receive at least one delta");
|
||||
let result = complete_result.expect("should receive complete");
|
||||
assert!(result.output_tokens > 0);
|
||||
|
||||
// 3. Verify session persistence — messages were saved
|
||||
let agent_info = kernel.get_agent(&id).expect("agent should exist");
|
||||
assert!(agent_info.message_count >= 2, "at least 2 messages should be tracked");
|
||||
}
|
||||
97
crates/zclaw-kernel/tests/smoke_hands.rs
Normal file
97
crates/zclaw-kernel/tests/smoke_hands.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
//! Hands smoke test — full lifecycle: trigger tool_call → hand execute → result
|
||||
//!
|
||||
//! Uses MockLlmDriver with stream chunks to simulate a real tool call flow.
|
||||
|
||||
use std::sync::Arc;
|
||||
use zclaw_kernel::{Kernel, KernelConfig};
|
||||
use zclaw_runtime::stream::StreamChunk;
|
||||
use zclaw_runtime::test_util::MockLlmDriver;
|
||||
use zclaw_runtime::{LoopEvent, LlmDriver};
|
||||
use zclaw_types::AgentConfig;
|
||||
|
||||
#[tokio::test]
|
||||
async fn smoke_hands_full_lifecycle() {
|
||||
// Simulate: LLM calls hand_quiz → quiz hand executes → LLM summarizes
|
||||
let mock = MockLlmDriver::new()
|
||||
.with_stream_chunks(vec![
|
||||
StreamChunk::TextDelta { delta: "正在生成测验...".to_string() },
|
||||
StreamChunk::ToolUseStart {
|
||||
id: "call_1".to_string(),
|
||||
name: "hand_quiz".to_string(),
|
||||
},
|
||||
StreamChunk::ToolUseEnd {
|
||||
id: "call_1".to_string(),
|
||||
input: serde_json::json!({ "topic": "历史", "count": 2 }),
|
||||
},
|
||||
StreamChunk::Complete {
|
||||
input_tokens: 15,
|
||||
output_tokens: 10,
|
||||
stop_reason: "tool_use".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
])
|
||||
// After hand_quiz returns, LLM generates final response
|
||||
.with_stream_chunks(vec![
|
||||
StreamChunk::TextDelta { delta: "测验已生成!".to_string() },
|
||||
StreamChunk::Complete {
|
||||
input_tokens: 20,
|
||||
output_tokens: 5,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
]);
|
||||
|
||||
let config = KernelConfig::default();
|
||||
let kernel = Kernel::boot_with_driver(config, Arc::new(mock) as Arc<dyn LlmDriver>)
|
||||
.await
|
||||
.expect("kernel boot");
|
||||
|
||||
let agent = AgentConfig::new("smoke-agent");
|
||||
let id = agent.id;
|
||||
kernel.spawn_agent(agent).await.expect("spawn agent");
|
||||
|
||||
let mut rx = kernel
|
||||
.send_message_stream(&id, "生成一个历史测验".to_string())
|
||||
.await
|
||||
.expect("stream");
|
||||
|
||||
let mut saw_tool_start = false;
|
||||
let mut saw_tool_end = false;
|
||||
let mut saw_delta_before_tool = false;
|
||||
let mut saw_delta_after_tool = false;
|
||||
let mut phase = "before_tool";
|
||||
let mut got_complete = false;
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
LoopEvent::Delta(_) if phase == "before_tool" => saw_delta_before_tool = true,
|
||||
LoopEvent::Delta(_) if phase == "after_tool" => saw_delta_after_tool = true,
|
||||
LoopEvent::ToolStart { name, .. } => {
|
||||
assert_eq!(name, "hand_quiz", "should be hand_quiz");
|
||||
saw_tool_start = true;
|
||||
}
|
||||
LoopEvent::ToolEnd { name, output } => {
|
||||
assert!(name.starts_with("hand_"), "should be hand tool");
|
||||
assert!(output.is_object() || output.is_string(), "hand should produce output");
|
||||
saw_tool_end = true;
|
||||
phase = "after_tool";
|
||||
}
|
||||
LoopEvent::Complete(result) => {
|
||||
assert!(result.output_tokens > 0, "should have output tokens");
|
||||
assert!(result.iterations >= 2, "should take at least 2 iterations");
|
||||
got_complete = true;
|
||||
break;
|
||||
}
|
||||
LoopEvent::Error(msg) => panic!("unexpected error: {}", msg),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(saw_delta_before_tool, "should see delta before tool execution");
|
||||
assert!(saw_tool_start, "should see hand_quiz ToolStart");
|
||||
assert!(saw_tool_end, "should see hand_quiz ToolEnd");
|
||||
assert!(saw_delta_after_tool, "should see delta after tool execution");
|
||||
assert!(got_complete, "should receive complete event");
|
||||
}
|
||||
@@ -398,6 +398,49 @@ impl TrajectoryStore {
|
||||
|
||||
Ok(result.rows_affected())
|
||||
}
|
||||
|
||||
/// Get trajectory events for an agent created since the given datetime.
|
||||
pub async fn get_events_since(
|
||||
&self,
|
||||
agent_id: &str,
|
||||
since: DateTime<Utc>,
|
||||
) -> Result<Vec<TrajectoryEvent>> {
|
||||
let rows = sqlx::query_as::<_, (String, String, String, i64, String, Option<String>, Option<String>, Option<i64>, String)>(
|
||||
r#"
|
||||
SELECT id, session_id, agent_id, step_index, step_type,
|
||||
input_summary, output_summary, duration_ms, timestamp
|
||||
FROM trajectory_events
|
||||
WHERE agent_id = ? AND timestamp >= ?
|
||||
ORDER BY timestamp ASC
|
||||
"#,
|
||||
)
|
||||
.bind(agent_id)
|
||||
.bind(since.to_rfc3339())
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| ZclawError::StorageError(e.to_string()))?;
|
||||
|
||||
let mut events = Vec::with_capacity(rows.len());
|
||||
for (id, sid, aid, step_idx, stype, input_s, output_s, dur_ms, ts) in rows {
|
||||
let timestamp = DateTime::parse_from_rfc3339(&ts)
|
||||
.map(|dt| dt.with_timezone(&Utc))
|
||||
.unwrap_or_else(|_| Utc::now());
|
||||
|
||||
events.push(TrajectoryEvent {
|
||||
id,
|
||||
session_id: sid,
|
||||
agent_id: aid,
|
||||
step_index: step_idx as usize,
|
||||
step_type: TrajectoryStepType::from_str_lossy(&stype),
|
||||
input_summary: input_s.unwrap_or_default(),
|
||||
output_summary: output_s.unwrap_or_default(),
|
||||
duration_ms: dur_ms.unwrap_or(0) as u64,
|
||||
timestamp,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(events)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -560,4 +603,27 @@ mod tests {
|
||||
assert_eq!(remaining.len(), 1);
|
||||
assert_eq!(remaining[0].id, "recent-evt");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_get_events_since() {
|
||||
let store = test_store().await;
|
||||
|
||||
// Insert event for agent-1
|
||||
let event = sample_event(0);
|
||||
store.insert_event(&event).await.unwrap();
|
||||
|
||||
// Query with since=far past → should find it
|
||||
let old_since = Utc::now() - chrono::Duration::days(365);
|
||||
let found = store.get_events_since("agent-1", old_since).await.unwrap();
|
||||
assert_eq!(found.len(), 1);
|
||||
|
||||
// Query with since=far future → should not find it
|
||||
let future_since = Utc::now() + chrono::Duration::days(365);
|
||||
let found = store.get_events_since("agent-1", future_since).await.unwrap();
|
||||
assert!(found.is_empty());
|
||||
|
||||
// Query for different agent → should not find it
|
||||
let found = store.get_events_since("other-agent", old_since).await.unwrap();
|
||||
assert!(found.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,6 +15,56 @@ use zclaw_types::Result;
|
||||
// Data types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Pain point status for tracking resolution.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum PainStatus {
|
||||
Active,
|
||||
Resolved,
|
||||
Deferred,
|
||||
}
|
||||
|
||||
impl PainStatus {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
PainStatus::Active => "active",
|
||||
PainStatus::Resolved => "resolved",
|
||||
PainStatus::Deferred => "deferred",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str_lossy(s: &str) -> Self {
|
||||
match s {
|
||||
"resolved" => PainStatus::Resolved,
|
||||
"deferred" => PainStatus::Deferred,
|
||||
_ => PainStatus::Active,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Structured pain point with tracking metadata.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PainPoint {
|
||||
pub content: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub last_mentioned_at: DateTime<Utc>,
|
||||
pub status: PainStatus,
|
||||
pub occurrence_count: u32,
|
||||
}
|
||||
|
||||
impl PainPoint {
|
||||
pub fn new(content: &str) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
content: content.to_string(),
|
||||
created_at: now,
|
||||
last_mentioned_at: now,
|
||||
status: PainStatus::Active,
|
||||
occurrence_count: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Expertise level inferred from conversation patterns.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
@@ -366,6 +416,46 @@ impl UserProfileStore {
|
||||
|
||||
self.upsert(&profile).await
|
||||
}
|
||||
|
||||
/// Return all active pain points for a user as structured PainPoint objects.
|
||||
///
|
||||
/// Note: the existing schema stores pain points as flat strings without
|
||||
/// timestamps. The returned `PainPoint.created_at` is set to the profile's
|
||||
/// `updated_at` as the best available approximation. The `since` parameter
|
||||
/// is accepted for API consistency but cannot truly filter by creation time
|
||||
/// with the current schema.
|
||||
pub async fn find_active_pains(
|
||||
&self,
|
||||
user_id: &str,
|
||||
) -> Result<Vec<PainPoint>> {
|
||||
let profile = self.get(user_id).await?;
|
||||
Ok(match profile {
|
||||
Some(p) => p
|
||||
.active_pain_points
|
||||
.into_iter()
|
||||
.map(|content| PainPoint {
|
||||
content,
|
||||
created_at: p.updated_at,
|
||||
last_mentioned_at: p.updated_at,
|
||||
status: PainStatus::Active,
|
||||
occurrence_count: 1,
|
||||
})
|
||||
.collect(),
|
||||
None => Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Mark a pain point as resolved by removing it from active_pain_points.
|
||||
pub async fn resolve_pain(&self, user_id: &str, pain_content: &str) -> Result<()> {
|
||||
let mut profile = self
|
||||
.get(user_id)
|
||||
.await?
|
||||
.unwrap_or_else(|| UserProfile::blank(user_id));
|
||||
|
||||
profile.active_pain_points.retain(|p| p != pain_content);
|
||||
profile.updated_at = Utc::now();
|
||||
self.upsert(&profile).await
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -589,4 +679,64 @@ mod tests {
|
||||
assert_eq!(decoded.communication_style, Some(CommStyle::Detailed));
|
||||
assert_eq!(decoded.recent_topics, vec!["exports", "customs"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pain_status_roundtrip() {
|
||||
assert_eq!(PainStatus::from_str_lossy(PainStatus::Active.as_str()), PainStatus::Active);
|
||||
assert_eq!(PainStatus::from_str_lossy(PainStatus::Resolved.as_str()), PainStatus::Resolved);
|
||||
assert_eq!(PainStatus::from_str_lossy(PainStatus::Deferred.as_str()), PainStatus::Deferred);
|
||||
assert_eq!(PainStatus::from_str_lossy("unknown"), PainStatus::Active);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pain_point_new() {
|
||||
let pp = PainPoint::new("scheduling conflict");
|
||||
assert_eq!(pp.content, "scheduling conflict");
|
||||
assert_eq!(pp.status, PainStatus::Active);
|
||||
assert_eq!(pp.occurrence_count, 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_active_pains() {
|
||||
let store = test_store().await;
|
||||
|
||||
store.add_pain_point("user", "pain_a", 5).await.unwrap();
|
||||
store.add_pain_point("user", "pain_b", 5).await.unwrap();
|
||||
|
||||
let pains = store.find_active_pains("user").await.unwrap();
|
||||
assert_eq!(pains.len(), 2);
|
||||
assert!(pains.iter().any(|p| p.content == "pain_a"));
|
||||
assert!(pains.iter().any(|p| p.content == "pain_b"));
|
||||
assert_eq!(pains[0].status, PainStatus::Active);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_find_active_pains_empty() {
|
||||
let store = test_store().await;
|
||||
let pains = store.find_active_pains("nonexistent").await.unwrap();
|
||||
assert!(pains.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_resolve_pain() {
|
||||
let store = test_store().await;
|
||||
|
||||
store.add_pain_point("user", "pain_a", 5).await.unwrap();
|
||||
store.add_pain_point("user", "pain_b", 5).await.unwrap();
|
||||
|
||||
store.resolve_pain("user", "pain_a").await.unwrap();
|
||||
|
||||
let loaded = store.get("user").await.unwrap().unwrap();
|
||||
assert_eq!(loaded.active_pain_points, vec!["pain_b"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_resolve_pain_nonexistent_is_noop() {
|
||||
let store = test_store().await;
|
||||
let profile = UserProfile::blank("user");
|
||||
store.upsert(&profile).await.unwrap();
|
||||
|
||||
// Should not error when pain doesn't exist
|
||||
store.resolve_pain("user", "nonexistent_pain").await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
55
crates/zclaw-protocols/tests/mcp_transport_tests.rs
Normal file
55
crates/zclaw-protocols/tests/mcp_transport_tests.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
//! Tests for MCP Transport configuration (McpServerConfig)
|
||||
//!
|
||||
//! These tests cover McpServerConfig builder methods without spawning processes.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use zclaw_protocols::McpServerConfig;
|
||||
|
||||
#[test]
|
||||
fn npx_config_creates_correct_command() {
|
||||
let config = McpServerConfig::npx("@modelcontextprotocol/server-memory");
|
||||
assert_eq!(config.command, "npx");
|
||||
assert_eq!(config.args, vec!["-y", "@modelcontextprotocol/server-memory"]);
|
||||
assert!(config.env.is_empty());
|
||||
assert!(config.cwd.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn node_config_creates_correct_command() {
|
||||
let config = McpServerConfig::node("/path/to/server.js");
|
||||
assert_eq!(config.command, "node");
|
||||
assert_eq!(config.args, vec!["/path/to/server.js"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn python_config_creates_correct_command() {
|
||||
let config = McpServerConfig::python("mcp_server.py");
|
||||
assert_eq!(config.command, "python");
|
||||
assert_eq!(config.args, vec!["mcp_server.py"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_adds_variables() {
|
||||
let config = McpServerConfig::node("server.js")
|
||||
.env("API_KEY", "secret123")
|
||||
.env("DEBUG", "true");
|
||||
assert_eq!(config.env.get("API_KEY").unwrap(), "secret123");
|
||||
assert_eq!(config.env.get("DEBUG").unwrap(), "true");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cwd_sets_working_directory() {
|
||||
let config = McpServerConfig::node("server.js").cwd("/tmp/work");
|
||||
assert_eq!(config.cwd.unwrap(), "/tmp/work");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn combined_builder_pattern() {
|
||||
let config = McpServerConfig::npx("@scope/server")
|
||||
.env("PORT", "3000")
|
||||
.cwd("/app");
|
||||
assert_eq!(config.command, "npx");
|
||||
assert_eq!(config.args.len(), 2);
|
||||
assert_eq!(config.env.len(), 1);
|
||||
assert_eq!(config.cwd.unwrap(), "/app");
|
||||
}
|
||||
186
crates/zclaw-protocols/tests/mcp_types_domain_tests.rs
Normal file
186
crates/zclaw-protocols/tests/mcp_types_domain_tests.rs
Normal file
@@ -0,0 +1,186 @@
|
||||
//! Tests for MCP domain types (mcp.rs) — McpTool, McpContent, McpResource, etc.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use zclaw_protocols::*;
|
||||
|
||||
// === McpTool ===
|
||||
|
||||
#[test]
|
||||
fn mcp_tool_roundtrip() {
|
||||
let tool = McpTool {
|
||||
name: "search".to_string(),
|
||||
description: "Search documents".to_string(),
|
||||
input_schema: serde_json::json!({"type": "object", "properties": {"query": {"type": "string"}}}),
|
||||
};
|
||||
let json = serde_json::to_string(&tool).unwrap();
|
||||
let parsed: McpTool = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.name, "search");
|
||||
assert_eq!(parsed.description, "Search documents");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_tool_empty_description() {
|
||||
let tool = McpTool {
|
||||
name: "ping".to_string(),
|
||||
description: String::new(),
|
||||
input_schema: serde_json::json!({}),
|
||||
};
|
||||
let parsed: McpTool = serde_json::from_str(&serde_json::to_string(&tool).unwrap()).unwrap();
|
||||
assert!(parsed.description.is_empty());
|
||||
}
|
||||
|
||||
// === McpContent ===
|
||||
|
||||
#[test]
|
||||
fn mcp_content_text_roundtrip() {
|
||||
let content = McpContent::Text { text: "hello".to_string() };
|
||||
let json = serde_json::to_string(&content).unwrap();
|
||||
let parsed: McpContent = serde_json::from_str(&json).unwrap();
|
||||
match parsed {
|
||||
McpContent::Text { text } => assert_eq!(text, "hello"),
|
||||
_ => panic!("Expected Text"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_content_image_roundtrip() {
|
||||
let content = McpContent::Image {
|
||||
data: "base64==".to_string(),
|
||||
mime_type: "image/png".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&content).unwrap();
|
||||
let parsed: McpContent = serde_json::from_str(&json).unwrap();
|
||||
match parsed {
|
||||
McpContent::Image { data, mime_type } => {
|
||||
assert_eq!(data, "base64==");
|
||||
assert_eq!(mime_type, "image/png");
|
||||
}
|
||||
_ => panic!("Expected Image"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_content_resource_roundtrip() {
|
||||
let content = McpContent::Resource {
|
||||
resource: McpResourceContent {
|
||||
uri: "file:///test.txt".to_string(),
|
||||
mime_type: Some("text/plain".to_string()),
|
||||
text: Some("content".to_string()),
|
||||
blob: None,
|
||||
},
|
||||
};
|
||||
let json = serde_json::to_string(&content).unwrap();
|
||||
let parsed: McpContent = serde_json::from_str(&json).unwrap();
|
||||
match parsed {
|
||||
McpContent::Resource { resource } => {
|
||||
assert_eq!(resource.uri, "file:///test.txt");
|
||||
assert_eq!(resource.text.unwrap(), "content");
|
||||
}
|
||||
_ => panic!("Expected Resource"),
|
||||
}
|
||||
}
|
||||
|
||||
// === McpToolCallRequest ===
|
||||
|
||||
#[test]
|
||||
fn mcp_tool_call_request_serialization() {
|
||||
let mut args = HashMap::new();
|
||||
args.insert("query".to_string(), serde_json::json!("test"));
|
||||
let req = McpToolCallRequest {
|
||||
name: "search".to_string(),
|
||||
arguments: args,
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"name\":\"search\""));
|
||||
assert!(json.contains("\"query\":\"test\""));
|
||||
}
|
||||
|
||||
// === McpToolCallResponse ===
|
||||
|
||||
#[test]
|
||||
fn mcp_tool_call_response_parse_success() {
|
||||
let json = r#"{"content":[{"type":"text","text":"found 3 results"}],"is_error":false}"#;
|
||||
let resp: McpToolCallResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(!resp.is_error);
|
||||
assert_eq!(resp.content.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_tool_call_response_parse_error() {
|
||||
let json = r#"{"content":[{"type":"text","text":"tool not found"}],"is_error":true}"#;
|
||||
let resp: McpToolCallResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.is_error);
|
||||
}
|
||||
|
||||
// === McpResource ===
|
||||
|
||||
#[test]
|
||||
fn mcp_resource_roundtrip() {
|
||||
let res = McpResource {
|
||||
uri: "file:///doc.md".to_string(),
|
||||
name: "Documentation".to_string(),
|
||||
description: Some("Project docs".to_string()),
|
||||
mime_type: Some("text/markdown".to_string()),
|
||||
};
|
||||
let json = serde_json::to_string(&res).unwrap();
|
||||
let parsed: McpResource = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.uri, "file:///doc.md");
|
||||
assert_eq!(parsed.description.unwrap(), "Project docs");
|
||||
}
|
||||
|
||||
// === McpPrompt ===
|
||||
|
||||
#[test]
|
||||
fn mcp_prompt_roundtrip() {
|
||||
let prompt = McpPrompt {
|
||||
name: "summarize".to_string(),
|
||||
description: "Summarize text".to_string(),
|
||||
arguments: vec![
|
||||
McpPromptArgument {
|
||||
name: "length".to_string(),
|
||||
description: "Target length".to_string(),
|
||||
required: false,
|
||||
},
|
||||
],
|
||||
};
|
||||
let json = serde_json::to_string(&prompt).unwrap();
|
||||
let parsed: McpPrompt = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.arguments.len(), 1);
|
||||
assert!(!parsed.arguments[0].required);
|
||||
}
|
||||
|
||||
// === McpServerInfo ===
|
||||
|
||||
#[test]
|
||||
fn mcp_server_info_roundtrip() {
|
||||
let info = McpServerInfo {
|
||||
name: "test-mcp".to_string(),
|
||||
version: "2.0.0".to_string(),
|
||||
protocol_version: "2024-11-05".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&info).unwrap();
|
||||
let parsed: McpServerInfo = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.name, "test-mcp");
|
||||
assert_eq!(parsed.protocol_version, "2024-11-05");
|
||||
}
|
||||
|
||||
// === McpCapabilities ===
|
||||
|
||||
#[test]
|
||||
fn mcp_capabilities_default_empty() {
|
||||
let caps = McpCapabilities::default();
|
||||
assert!(caps.tools.is_none());
|
||||
assert!(caps.resources.is_none());
|
||||
assert!(caps.prompts.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mcp_capabilities_with_tools() {
|
||||
let caps = McpCapabilities {
|
||||
tools: Some(McpToolCapabilities { list_changed: true }),
|
||||
resources: None,
|
||||
prompts: None,
|
||||
};
|
||||
let json = serde_json::to_string(&caps).unwrap();
|
||||
assert!(json.contains("\"list_changed\":true"));
|
||||
}
|
||||
267
crates/zclaw-protocols/tests/mcp_types_tests.rs
Normal file
267
crates/zclaw-protocols/tests/mcp_types_tests.rs
Normal file
@@ -0,0 +1,267 @@
|
||||
//! Tests for MCP JSON-RPC types (mcp_types.rs)
|
||||
//!
|
||||
//! Covers: serialization, deserialization, builder patterns, edge cases.
|
||||
|
||||
use serde_json;
|
||||
use zclaw_protocols::*;
|
||||
|
||||
// === JsonRpcRequest ===
|
||||
|
||||
#[test]
|
||||
fn jsonrpc_request_new_has_correct_defaults() {
|
||||
let req = JsonRpcRequest::new(42, "tools/list");
|
||||
assert_eq!(req.jsonrpc, "2.0");
|
||||
assert_eq!(req.id, 42);
|
||||
assert_eq!(req.method, "tools/list");
|
||||
assert!(req.params.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jsonrpc_request_with_params() {
|
||||
let req = JsonRpcRequest::new(1, "tools/call")
|
||||
.with_params(serde_json::json!({"name": "search"}));
|
||||
let serialized = serde_json::to_string(&req).unwrap();
|
||||
assert!(serialized.contains("\"params\""));
|
||||
assert!(serialized.contains("\"name\":\"search\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jsonrpc_request_skip_null_params() {
|
||||
let req = JsonRpcRequest::new(1, "ping");
|
||||
let serialized = serde_json::to_string(&req).unwrap();
|
||||
// params is None, should be skipped
|
||||
assert!(!serialized.contains("\"params\""));
|
||||
}
|
||||
|
||||
// === JsonRpcResponse ===
|
||||
|
||||
#[test]
|
||||
fn jsonrpc_response_parse_success() {
|
||||
let json = r#"{"jsonrpc":"2.0","id":1,"result":{"tools":[]}}"#;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.id, 1);
|
||||
assert!(resp.result.is_some());
|
||||
assert!(resp.error.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jsonrpc_response_parse_error() {
|
||||
let json = r#"{"jsonrpc":"2.0","id":2,"error":{"code":-32600,"message":"Invalid Request"}}"#;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.id, 2);
|
||||
assert!(resp.result.is_none());
|
||||
let err = resp.error.unwrap();
|
||||
assert_eq!(err.code, -32600);
|
||||
assert_eq!(err.message, "Invalid Request");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jsonrpc_response_parse_error_with_data() {
|
||||
let json = r#"{"jsonrpc":"2.0","id":3,"error":{"code":-32602,"message":"Bad params","data":{"field":"uri"}}}"#;
|
||||
let resp: JsonRpcResponse = serde_json::from_str(json).unwrap();
|
||||
let err = resp.error.unwrap();
|
||||
assert!(err.data.is_some());
|
||||
assert_eq!(err.data.unwrap()["field"], "uri");
|
||||
}
|
||||
|
||||
// === InitializeRequest ===
|
||||
|
||||
#[test]
|
||||
fn initialize_request_default() {
|
||||
let req = InitializeRequest::default();
|
||||
assert_eq!(req.protocol_version, "2024-11-05");
|
||||
assert_eq!(req.client_info.name, "zclaw");
|
||||
assert!(!req.client_info.version.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initialize_request_serializes() {
|
||||
let req = InitializeRequest::default();
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"protocol_version\":\"2024-11-05\""));
|
||||
assert!(json.contains("\"client_info\""));
|
||||
}
|
||||
|
||||
// === ServerCapabilities ===
|
||||
|
||||
#[test]
|
||||
fn server_capabilities_empty() {
|
||||
let json = r#"{"protocol_version":"2024-11-05","capabilities":{},"server_info":{"name":"test","version":"1.0"}}"#;
|
||||
let result: InitializeResult = serde_json::from_str(json).unwrap();
|
||||
assert!(result.capabilities.tools.is_none());
|
||||
assert!(result.capabilities.resources.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_capabilities_with_tools() {
|
||||
let json = r#"{"protocol_version":"2024-11-05","capabilities":{"tools":{"list_changed":true}},"server_info":{"name":"test","version":"1.0"}}"#;
|
||||
let result: InitializeResult = serde_json::from_str(json).unwrap();
|
||||
let tools = result.capabilities.tools.unwrap();
|
||||
assert!(tools.list_changed);
|
||||
}
|
||||
|
||||
// === ContentBlock ===
|
||||
|
||||
#[test]
|
||||
fn content_block_text() {
|
||||
let json = r#"{"type":"text","text":"hello world"}"#;
|
||||
let block: ContentBlock = serde_json::from_str(json).unwrap();
|
||||
match block {
|
||||
ContentBlock::Text { text } => assert_eq!(text, "hello world"),
|
||||
_ => panic!("Expected Text variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_block_image() {
|
||||
let json = r#"{"type":"image","data":"base64data","mime_type":"image/png"}"#;
|
||||
let block: ContentBlock = serde_json::from_str(json).unwrap();
|
||||
match block {
|
||||
ContentBlock::Image { data, mime_type } => {
|
||||
assert_eq!(data, "base64data");
|
||||
assert_eq!(mime_type, "image/png");
|
||||
}
|
||||
_ => panic!("Expected Image variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn content_block_resource() {
|
||||
let json = r#"{"type":"resource","resource":{"uri":"file:///test.txt","text":"content"}}"#;
|
||||
let block: ContentBlock = serde_json::from_str(json).unwrap();
|
||||
match block {
|
||||
ContentBlock::Resource { resource } => {
|
||||
assert_eq!(resource.uri, "file:///test.txt");
|
||||
assert_eq!(resource.text.unwrap(), "content");
|
||||
}
|
||||
_ => panic!("Expected Resource variant"),
|
||||
}
|
||||
}
|
||||
|
||||
// === CallToolResult ===
|
||||
|
||||
#[test]
|
||||
fn call_tool_result_parse() {
|
||||
let json = r#"{"content":[{"type":"text","text":"result"}],"is_error":false}"#;
|
||||
let result: CallToolResult = serde_json::from_str(json).unwrap();
|
||||
assert!(!result.is_error);
|
||||
assert_eq!(result.content.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn call_tool_result_error() {
|
||||
let json = r#"{"content":[{"type":"text","text":"something went wrong"}],"is_error":true}"#;
|
||||
let result: CallToolResult = serde_json::from_str(json).unwrap();
|
||||
assert!(result.is_error);
|
||||
}
|
||||
|
||||
// === ListToolsResult ===
|
||||
|
||||
#[test]
|
||||
fn list_tools_result_with_cursor() {
|
||||
let json = r#"{"tools":[{"name":"search","input_schema":{"type":"object"}}],"next_cursor":"abc123"}"#;
|
||||
let result: ListToolsResult = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(result.tools.len(), 1);
|
||||
assert_eq!(result.tools[0].name, "search");
|
||||
assert_eq!(result.next_cursor.unwrap(), "abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn list_tools_result_without_cursor() {
|
||||
let json = r#"{"tools":[]}"#;
|
||||
let result: ListToolsResult = serde_json::from_str(json).unwrap();
|
||||
assert!(result.tools.is_empty());
|
||||
assert!(result.next_cursor.is_none());
|
||||
}
|
||||
|
||||
// === Resource types ===
|
||||
|
||||
#[test]
|
||||
fn resource_parse_with_optional_fields() {
|
||||
let json = r#"{"uri":"file:///doc.txt","name":"doc","description":"A doc","mime_type":"text/plain"}"#;
|
||||
let res: Resource = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(res.uri, "file:///doc.txt");
|
||||
assert_eq!(res.name, "doc");
|
||||
assert_eq!(res.description.unwrap(), "A doc");
|
||||
assert_eq!(res.mime_type.unwrap(), "text/plain");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resource_parse_minimal() {
|
||||
let json = r#"{"uri":"file:///x","name":"x"}"#;
|
||||
let res: Resource = serde_json::from_str(json).unwrap();
|
||||
assert!(res.description.is_none());
|
||||
assert!(res.mime_type.is_none());
|
||||
}
|
||||
|
||||
// === LoggingLevel ===
|
||||
|
||||
#[test]
|
||||
fn logging_level_serialize_roundtrip() {
|
||||
let levels = vec![
|
||||
LoggingLevel::Debug,
|
||||
LoggingLevel::Info,
|
||||
LoggingLevel::Warning,
|
||||
LoggingLevel::Error,
|
||||
LoggingLevel::Critical,
|
||||
LoggingLevel::Emergency,
|
||||
];
|
||||
for level in levels {
|
||||
let json = serde_json::to_string(&level).unwrap();
|
||||
let parsed: LoggingLevel = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(std::mem::discriminant(&level), std::mem::discriminant(&parsed));
|
||||
}
|
||||
}
|
||||
|
||||
// === InitializedNotification ===
|
||||
|
||||
#[test]
|
||||
fn initialized_notification_fields() {
|
||||
let n = InitializedNotification::new();
|
||||
assert_eq!(n.jsonrpc, "2.0");
|
||||
assert_eq!(n.method, "notifications/initialized");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initialized_notification_serializes() {
|
||||
let n = InitializedNotification::default();
|
||||
let json = serde_json::to_string(&n).unwrap();
|
||||
assert!(json.contains("\"notifications/initialized\""));
|
||||
}
|
||||
|
||||
// === Prompt types ===
|
||||
|
||||
#[test]
|
||||
fn prompt_parse_with_arguments() {
|
||||
let json = r#"{"name":"greet","description":"Greeting","arguments":[{"name":"lang","description":"Language","required":true}]}"#;
|
||||
let prompt: Prompt = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(prompt.name, "greet");
|
||||
assert_eq!(prompt.arguments.len(), 1);
|
||||
assert!(prompt.arguments[0].required);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_message_parse() {
|
||||
let json = r#"{"role":"user","content":{"type":"text","text":"hello"}}"#;
|
||||
let msg: PromptMessage = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(msg.role, "user");
|
||||
}
|
||||
|
||||
// === McpClientConfig ===
|
||||
|
||||
#[test]
|
||||
fn mcp_client_config_roundtrip() {
|
||||
let config = McpClientConfig {
|
||||
server_url: "http://localhost:3000".to_string(),
|
||||
server_info: McpServerInfo {
|
||||
name: "test-server".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
protocol_version: "2024-11-05".to_string(),
|
||||
},
|
||||
capabilities: McpCapabilities::default(),
|
||||
};
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let parsed: McpClientConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.server_url, config.server_url);
|
||||
assert_eq!(parsed.server_info.name, "test-server");
|
||||
}
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use serde_json::Value;
|
||||
use zclaw_types::{AgentId, Message, SessionId};
|
||||
|
||||
use crate::driver::{CompletionRequest, ContentBlock, LlmDriver};
|
||||
@@ -136,7 +137,7 @@ pub fn update_calibration(estimated: usize, actual: u32) {
|
||||
}
|
||||
|
||||
/// Estimate total tokens for messages with calibration applied.
|
||||
fn estimate_messages_tokens_calibrated(messages: &[Message]) -> usize {
|
||||
pub fn estimate_messages_tokens_calibrated(messages: &[Message]) -> usize {
|
||||
let raw = estimate_messages_tokens(messages);
|
||||
let factor = get_calibration_factor();
|
||||
if (factor - 1.0).abs() < f64::EPSILON {
|
||||
@@ -178,7 +179,7 @@ pub fn compact_messages(messages: Vec<Message>, keep_recent: usize) -> (Vec<Mess
|
||||
let old_messages = &messages[..split_index];
|
||||
let recent_messages = &messages[split_index..];
|
||||
|
||||
let summary = generate_summary(old_messages);
|
||||
let summary = generate_summary(old_messages, None);
|
||||
let removed_count = old_messages.len();
|
||||
|
||||
let mut compacted = Vec::with_capacity(1 + recent_messages.len());
|
||||
@@ -188,6 +189,38 @@ pub fn compact_messages(messages: Vec<Message>, keep_recent: usize) -> (Vec<Mess
|
||||
(compacted, removed_count)
|
||||
}
|
||||
|
||||
/// Prune old tool outputs to reduce token consumption. Runs before compaction.
|
||||
/// Only prunes ToolResult messages older than PRUNE_AGE_THRESHOLD messages.
|
||||
const PRUNE_AGE_THRESHOLD: usize = 8;
|
||||
const PRUNE_MAX_CHARS: usize = 2000;
|
||||
const PRUNE_KEEP_HEAD_CHARS: usize = 500;
|
||||
|
||||
pub fn prune_tool_outputs(messages: &mut [Message]) -> usize {
|
||||
let total = messages.len();
|
||||
let mut pruned_count = 0;
|
||||
|
||||
for i in 0..total.saturating_sub(PRUNE_AGE_THRESHOLD) {
|
||||
if let Message::ToolResult { output, is_error, .. } = &mut messages[i] {
|
||||
if *is_error { continue; }
|
||||
|
||||
let text = match output {
|
||||
Value::String(ref s) => s.clone(),
|
||||
ref other => other.to_string(),
|
||||
};
|
||||
if text.len() <= PRUNE_MAX_CHARS { continue; }
|
||||
|
||||
let end = text.floor_char_boundary(PRUNE_KEEP_HEAD_CHARS.min(text.len()));
|
||||
*output = serde_json::json!({
|
||||
"_pruned": true,
|
||||
"_original_chars": text.len(),
|
||||
"head": &text[..end],
|
||||
});
|
||||
pruned_count += 1;
|
||||
}
|
||||
}
|
||||
pruned_count
|
||||
}
|
||||
|
||||
/// Check if compaction should be triggered and perform it if needed.
|
||||
///
|
||||
/// Returns the (possibly compacted) message list.
|
||||
@@ -315,6 +348,18 @@ pub async fn maybe_compact_with_config(
|
||||
.iter()
|
||||
.take_while(|m| matches!(m, Message::System { .. }))
|
||||
.count();
|
||||
|
||||
// Extract previous summary from leading system messages for iterative summarization
|
||||
let previous_summary = messages.iter()
|
||||
.take(leading_system_count)
|
||||
.filter_map(|m| match m {
|
||||
Message::System { content } if content.starts_with("[以下是之前对话的摘要]") => {
|
||||
Some(content.clone())
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.next();
|
||||
|
||||
let keep_from_end = DEFAULT_KEEP_RECENT
|
||||
.min(messages.len().saturating_sub(leading_system_count));
|
||||
let split_index = messages.len().saturating_sub(keep_from_end);
|
||||
@@ -333,14 +378,16 @@ pub async fn maybe_compact_with_config(
|
||||
let recent_messages = &messages[split_index..];
|
||||
let removed_count = old_messages.len();
|
||||
|
||||
// Step 3: Generate summary (LLM or rule-based)
|
||||
// Step 3: Generate summary (LLM or rule-based), with iterative context
|
||||
let prev_ref = previous_summary.as_deref();
|
||||
let summary = if config.use_llm {
|
||||
if let Some(driver) = driver {
|
||||
match generate_llm_summary(driver, old_messages, config.summary_max_tokens).await {
|
||||
match generate_llm_summary(driver, old_messages, prev_ref, config.summary_max_tokens).await {
|
||||
Ok(llm_summary) => {
|
||||
tracing::info!(
|
||||
"[Compaction] Generated LLM summary ({} chars)",
|
||||
llm_summary.len()
|
||||
"[Compaction] Generated LLM summary ({} chars, iterative={})",
|
||||
llm_summary.len(),
|
||||
previous_summary.is_some()
|
||||
);
|
||||
llm_summary
|
||||
}
|
||||
@@ -350,7 +397,7 @@ pub async fn maybe_compact_with_config(
|
||||
"[Compaction] LLM summary failed: {}, falling back to rules",
|
||||
e
|
||||
);
|
||||
generate_summary(old_messages)
|
||||
generate_summary(old_messages, prev_ref)
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"[Compaction] LLM summary failed: {}, returning original messages",
|
||||
@@ -369,10 +416,10 @@ pub async fn maybe_compact_with_config(
|
||||
tracing::warn!(
|
||||
"[Compaction] LLM compaction requested but no driver available, using rules"
|
||||
);
|
||||
generate_summary(old_messages)
|
||||
generate_summary(old_messages, prev_ref)
|
||||
}
|
||||
} else {
|
||||
generate_summary(old_messages)
|
||||
generate_summary(old_messages, prev_ref)
|
||||
};
|
||||
|
||||
let used_llm = config.use_llm && driver.is_some();
|
||||
@@ -398,9 +445,11 @@ pub async fn maybe_compact_with_config(
|
||||
}
|
||||
|
||||
/// Generate a summary using an LLM driver.
|
||||
/// If `previous_summary` is provided, builds on it iteratively.
|
||||
async fn generate_llm_summary(
|
||||
driver: &Arc<dyn LlmDriver>,
|
||||
messages: &[Message],
|
||||
previous_summary: Option<&str>,
|
||||
max_tokens: u32,
|
||||
) -> Result<String, String> {
|
||||
let mut conversation_text = String::new();
|
||||
@@ -437,11 +486,21 @@ async fn generate_llm_summary(
|
||||
conversation_text.push_str("\n...(对话已截断)");
|
||||
}
|
||||
|
||||
let prompt = format!(
|
||||
"请用简洁的中文总结以下对话的关键信息。保留重要的讨论主题、决策、结论和待办事项。\
|
||||
输出格式为段落式摘要,不超过200字。\n\n{}",
|
||||
conversation_text
|
||||
);
|
||||
let prompt = match previous_summary {
|
||||
Some(prev) => format!(
|
||||
"你是一个对话摘要助手。\n\n\
|
||||
## 上一轮摘要\n{}\n\n\
|
||||
## 新增对话内容\n{}\n\n\
|
||||
请在上一轮摘要的基础上更新,保留所有关键决策、用户偏好和文件操作。\
|
||||
输出200字以内的中文摘要。",
|
||||
prev, conversation_text
|
||||
),
|
||||
None => format!(
|
||||
"请用简洁的中文总结以下对话的关键信息。保留重要的讨论主题、决策、结论和待办事项。\
|
||||
输出格式为段落式摘要,不超过200字。\n\n{}",
|
||||
conversation_text
|
||||
),
|
||||
};
|
||||
|
||||
let request = CompletionRequest {
|
||||
model: String::new(),
|
||||
@@ -484,13 +543,22 @@ async fn generate_llm_summary(
|
||||
}
|
||||
|
||||
/// Generate a rule-based summary of old messages.
|
||||
fn generate_summary(messages: &[Message]) -> String {
|
||||
/// If `previous_summary` is provided, carries forward key info.
|
||||
fn generate_summary(messages: &[Message], previous_summary: Option<&str>) -> String {
|
||||
if messages.is_empty() {
|
||||
return "[对话开始]".to_string();
|
||||
}
|
||||
|
||||
let mut sections: Vec<String> = vec!["[以下是之前对话的摘要]".to_string()];
|
||||
|
||||
// Carry forward previous summary if available
|
||||
if let Some(prev) = previous_summary {
|
||||
// Strip the header line from previous summary for cleaner nesting
|
||||
let prev_body = prev.strip_prefix("[以下是之前对话的摘要]\n")
|
||||
.unwrap_or(prev);
|
||||
sections.push(format!("[上轮摘要保留]: {}", truncate(prev_body, 200)));
|
||||
}
|
||||
|
||||
let mut user_count = 0;
|
||||
let mut assistant_count = 0;
|
||||
let mut topics: Vec<String> = Vec::new();
|
||||
@@ -696,8 +764,21 @@ mod tests {
|
||||
Message::user("How does ownership work?"),
|
||||
Message::assistant("Ownership is Rust's memory management system"),
|
||||
];
|
||||
let summary = generate_summary(&messages);
|
||||
let summary = generate_summary(&messages, None);
|
||||
assert!(summary.contains("摘要"));
|
||||
assert!(summary.contains("2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_summary_iterative() {
|
||||
let messages = vec![
|
||||
Message::user("What is async/await?"),
|
||||
Message::assistant("Async/await is a concurrency model"),
|
||||
];
|
||||
let prev = "[以下是之前对话的摘要]\n讨论主题: Rust; 所有权\n(已压缩 4 条消息)";
|
||||
let summary = generate_summary(&messages, Some(prev));
|
||||
assert!(summary.contains("摘要"));
|
||||
assert!(summary.contains("上轮摘要保留"));
|
||||
assert!(summary.contains("所有权"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,12 @@ pub struct AnthropicDriver {
|
||||
impl AnthropicDriver {
|
||||
pub fn new(api_key: SecretString) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
client: Client::builder()
|
||||
.user_agent(crate::USER_AGENT)
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
api_key,
|
||||
base_url: "https://api.anthropic.com".to_string(),
|
||||
}
|
||||
@@ -30,7 +35,12 @@ impl AnthropicDriver {
|
||||
|
||||
pub fn with_base_url(api_key: SecretString, base_url: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
client: Client::builder()
|
||||
.user_agent(crate::USER_AGENT)
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
api_key,
|
||||
base_url,
|
||||
}
|
||||
@@ -111,6 +121,8 @@ impl LlmDriver for AnthropicDriver {
|
||||
let mut byte_stream = response.bytes_stream();
|
||||
let mut current_tool_id: Option<String> = None;
|
||||
let mut tool_input_buffer = String::new();
|
||||
let mut cache_creation_input_tokens: Option<u32> = None;
|
||||
let mut cache_read_input_tokens: Option<u32> = None;
|
||||
|
||||
while let Some(chunk_result) = byte_stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
@@ -131,6 +143,15 @@ impl LlmDriver for AnthropicDriver {
|
||||
match serde_json::from_str::<AnthropicStreamEvent>(data) {
|
||||
Ok(event) => {
|
||||
match event.event_type.as_str() {
|
||||
"message_start" => {
|
||||
// Capture cache token info from message_start event
|
||||
if let Some(msg) = event.message {
|
||||
if let Some(usage) = msg.usage {
|
||||
cache_creation_input_tokens = usage.cache_creation_input_tokens;
|
||||
cache_read_input_tokens = usage.cache_read_input_tokens;
|
||||
}
|
||||
}
|
||||
}
|
||||
"content_block_delta" => {
|
||||
if let Some(delta) = event.delta {
|
||||
if let Some(text) = delta.text {
|
||||
@@ -176,6 +197,8 @@ impl LlmDriver for AnthropicDriver {
|
||||
input_tokens: msg.usage.as_ref().map(|u| u.input_tokens).unwrap_or(0),
|
||||
output_tokens: msg.usage.as_ref().map(|u| u.output_tokens).unwrap_or(0),
|
||||
stop_reason: msg.stop_reason.unwrap_or_else(|| "end_turn".to_string()),
|
||||
cache_creation_input_tokens,
|
||||
cache_read_input_tokens,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -288,7 +311,15 @@ impl AnthropicDriver {
|
||||
AnthropicRequest {
|
||||
model: request.model.clone(),
|
||||
max_tokens: effective_max,
|
||||
system: request.system.clone(),
|
||||
system: request.system.as_ref().map(|s| {
|
||||
vec![SystemContentBlock {
|
||||
r#type: "text".to_string(),
|
||||
text: s.clone(),
|
||||
cache_control: Some(CacheControl {
|
||||
r#type: "ephemeral".to_string(),
|
||||
}),
|
||||
}]
|
||||
}),
|
||||
messages,
|
||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||
temperature: request.temperature,
|
||||
@@ -327,18 +358,35 @@ impl AnthropicDriver {
|
||||
input_tokens: api_response.usage.input_tokens,
|
||||
output_tokens: api_response.usage.output_tokens,
|
||||
stop_reason,
|
||||
cache_creation_input_tokens: api_response.usage.cache_creation_input_tokens,
|
||||
cache_read_input_tokens: api_response.usage.cache_read_input_tokens,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Anthropic API types
|
||||
|
||||
/// Anthropic cache_control 标记
|
||||
#[derive(Serialize, Clone)]
|
||||
struct CacheControl {
|
||||
r#type: String, // "ephemeral"
|
||||
}
|
||||
|
||||
/// Anthropic system prompt 内容块(支持 cache_control)
|
||||
#[derive(Serialize, Clone)]
|
||||
struct SystemContentBlock {
|
||||
r#type: String, // "text"
|
||||
text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
cache_control: Option<CacheControl>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AnthropicRequest {
|
||||
model: String,
|
||||
max_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system: Option<String>,
|
||||
system: Option<Vec<SystemContentBlock>>,
|
||||
messages: Vec<AnthropicMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<AnthropicTool>>,
|
||||
@@ -394,6 +442,10 @@ struct AnthropicContentBlock {
|
||||
struct AnthropicUsage {
|
||||
input_tokens: u32,
|
||||
output_tokens: u32,
|
||||
#[serde(default)]
|
||||
cache_creation_input_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
cache_read_input_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
// Streaming types
|
||||
@@ -448,4 +500,8 @@ struct AnthropicStreamUsage {
|
||||
input_tokens: u32,
|
||||
#[serde(default)]
|
||||
output_tokens: u32,
|
||||
#[serde(default)]
|
||||
cache_creation_input_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
cache_read_input_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
139
crates/zclaw-runtime/src/driver/error_classifier.rs
Normal file
139
crates/zclaw-runtime/src/driver/error_classifier.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
//! LLM 错误分类器。将 HTTP 状态码 + 错误体映射为 LlmErrorKind。
|
||||
|
||||
use std::time::Duration;
|
||||
use zclaw_types::{LlmErrorKind, ClassifiedLlmError};
|
||||
|
||||
/// 分类 LLM 错误
|
||||
pub fn classify_llm_error(
|
||||
provider: &str,
|
||||
status: u16,
|
||||
body: &str,
|
||||
is_timeout: bool,
|
||||
) -> ClassifiedLlmError {
|
||||
let _ = provider; // reserved for per-provider overrides
|
||||
|
||||
if is_timeout {
|
||||
return ClassifiedLlmError {
|
||||
kind: LlmErrorKind::Timeout,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: None,
|
||||
message: "请求超时".to_string(),
|
||||
};
|
||||
}
|
||||
|
||||
match status {
|
||||
401 | 403 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::Auth,
|
||||
retryable: false,
|
||||
should_compress: false,
|
||||
should_rotate_credential: true,
|
||||
retry_after: None,
|
||||
message: "认证失败,请检查 API Key".to_string(),
|
||||
},
|
||||
402 => {
|
||||
let is_quota_transient = body.contains("retry")
|
||||
|| body.contains("limit")
|
||||
|| body.contains("usage");
|
||||
ClassifiedLlmError {
|
||||
kind: if is_quota_transient { LlmErrorKind::RateLimited } else { LlmErrorKind::BillingExhausted },
|
||||
retryable: is_quota_transient,
|
||||
should_compress: false,
|
||||
should_rotate_credential: !is_quota_transient,
|
||||
retry_after: if is_quota_transient { Some(Duration::from_secs(30)) } else { None },
|
||||
message: if is_quota_transient { "使用限制,稍后重试".to_string() } else { "计费额度已耗尽".to_string() },
|
||||
}
|
||||
}
|
||||
429 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::RateLimited,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: true,
|
||||
retry_after: parse_retry_after(body),
|
||||
message: "速率限制".to_string(),
|
||||
},
|
||||
529 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::Overloaded,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: Some(Duration::from_secs(5)),
|
||||
message: "提供商过载".to_string(),
|
||||
},
|
||||
500 | 502 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::ServerError,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: None,
|
||||
message: "服务端错误".to_string(),
|
||||
},
|
||||
503 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::Overloaded,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
message: "服务暂时不可用".to_string(),
|
||||
},
|
||||
400 => {
|
||||
let is_context_overflow = body.contains("context_length")
|
||||
|| body.contains("max_tokens")
|
||||
|| body.contains("too many tokens")
|
||||
|| body.contains("prompt is too long");
|
||||
ClassifiedLlmError {
|
||||
kind: if is_context_overflow { LlmErrorKind::ContextOverflow } else { LlmErrorKind::Unknown },
|
||||
retryable: false,
|
||||
should_compress: is_context_overflow,
|
||||
should_rotate_credential: false,
|
||||
retry_after: None,
|
||||
message: if is_context_overflow {
|
||||
"上下文过长,需要压缩".to_string()
|
||||
} else {
|
||||
format!("请求错误: {}", &body[..body.len().min(200)])
|
||||
},
|
||||
}
|
||||
}
|
||||
404 => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::ModelNotFound,
|
||||
retryable: false,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: None,
|
||||
message: "模型不存在".to_string(),
|
||||
},
|
||||
_ => ClassifiedLlmError {
|
||||
kind: LlmErrorKind::Unknown,
|
||||
retryable: true,
|
||||
should_compress: false,
|
||||
should_rotate_credential: false,
|
||||
retry_after: None,
|
||||
message: format!("未知错误 ({}) {}", status, &body[..body.len().min(200)]),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_retry_after(body: &str) -> Option<Duration> {
|
||||
// Anthropic: "Please retry after X seconds"
|
||||
// OpenAI: "Please retry after Xms"
|
||||
if let Some(secs) = extract_retry_seconds(body) {
|
||||
return Some(Duration::from_secs(secs));
|
||||
}
|
||||
if let Some(ms) = extract_retry_millis(body) {
|
||||
return Some(Duration::from_millis(ms));
|
||||
}
|
||||
Some(Duration::from_secs(2))
|
||||
}
|
||||
|
||||
fn extract_retry_seconds(body: &str) -> Option<u64> {
|
||||
let re = regex::Regex::new(r"retry\s+(?:after\s+)?(\d+)\s*(?:s|sec|seconds?)").ok()?;
|
||||
let caps = re.captures(body)?;
|
||||
caps[1].parse().ok()
|
||||
}
|
||||
|
||||
fn extract_retry_millis(body: &str) -> Option<u64> {
|
||||
let re = regex::Regex::new(r"retry\s+(?:after\s+)?(\d+)\s*ms").ok()?;
|
||||
let caps = re.captures(body)?;
|
||||
caps[1].parse().ok()
|
||||
}
|
||||
@@ -30,8 +30,7 @@ impl GeminiDriver {
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.user_agent(crate::USER_AGENT)
|
||||
.http1_only()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
@@ -44,8 +43,7 @@ impl GeminiDriver {
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.user_agent(crate::USER_AGENT)
|
||||
.http1_only()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
@@ -240,6 +238,8 @@ impl LlmDriver for GeminiDriver {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
stop_reason: stop_reason.to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -502,6 +502,8 @@ impl GeminiDriver {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
stop_reason,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ impl LocalDriver {
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.user_agent(crate::USER_AGENT)
|
||||
.http1_only()
|
||||
.timeout(std::time::Duration::from_secs(300)) // 5 min -- local inference can be slow
|
||||
.connect_timeout(std::time::Duration::from_secs(10)) // short connect timeout
|
||||
.build()
|
||||
@@ -239,6 +238,8 @@ impl LocalDriver {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
stop_reason,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -397,6 +398,8 @@ impl LlmDriver for LocalDriver {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -15,11 +15,14 @@ mod anthropic;
|
||||
mod openai;
|
||||
mod gemini;
|
||||
mod local;
|
||||
mod error_classifier;
|
||||
mod retry_driver;
|
||||
|
||||
pub use anthropic::AnthropicDriver;
|
||||
pub use openai::OpenAiDriver;
|
||||
pub use gemini::GeminiDriver;
|
||||
pub use local::LocalDriver;
|
||||
pub use retry_driver::{RetryDriver, RetryConfig};
|
||||
|
||||
/// LLM Driver trait - unified interface for all providers
|
||||
#[async_trait]
|
||||
@@ -106,6 +109,12 @@ pub struct CompletionResponse {
|
||||
pub output_tokens: u32,
|
||||
/// Stop reason
|
||||
pub stop_reason: StopReason,
|
||||
/// Cache creation input tokens (Anthropic prompt caching)
|
||||
#[serde(default)]
|
||||
pub cache_creation_input_tokens: Option<u32>,
|
||||
/// Cache read input tokens (Anthropic prompt caching)
|
||||
#[serde(default)]
|
||||
pub cache_read_input_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
/// LLM driver response content block (subset of canonical zclaw_types::ContentBlock).
|
||||
|
||||
@@ -24,9 +24,8 @@ impl OpenAiDriver {
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.user_agent(crate::USER_AGENT)
|
||||
.http1_only()
|
||||
.timeout(std::time::Duration::from_secs(120)) // 2 minute timeout
|
||||
.connect_timeout(std::time::Duration::from_secs(30)) // 30 second connect timeout
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
api_key,
|
||||
@@ -38,9 +37,8 @@ impl OpenAiDriver {
|
||||
Self {
|
||||
client: Client::builder()
|
||||
.user_agent(crate::USER_AGENT)
|
||||
.http1_only()
|
||||
.timeout(std::time::Duration::from_secs(120)) // 2 minute timeout
|
||||
.connect_timeout(std::time::Duration::from_secs(30)) // 30 second connect timeout
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
api_key,
|
||||
@@ -165,6 +163,7 @@ impl LlmDriver for OpenAiDriver {
|
||||
let mut current_tool_id: Option<String> = None;
|
||||
let mut sse_event_count: usize = 0;
|
||||
let mut raw_bytes_total: usize = 0;
|
||||
let mut pending_line = String::new(); // Buffer for incomplete SSE lines
|
||||
|
||||
while let Some(chunk_result) = byte_stream.next().await {
|
||||
let chunk = match chunk_result {
|
||||
@@ -182,13 +181,21 @@ impl LlmDriver for OpenAiDriver {
|
||||
if raw_bytes_total <= 600 {
|
||||
tracing::debug!("[OpenAI:stream] RAW chunk ({} bytes): {:?}", text.len(), &text[..text.len().min(500)]);
|
||||
}
|
||||
for line in text.lines() {
|
||||
// Accumulate text and split by lines, handling incomplete last line
|
||||
pending_line.push_str(&text);
|
||||
// Extract complete lines (ending with \n), keep the rest pending
|
||||
let mut complete_lines: Vec<String> = Vec::new();
|
||||
while let Some(pos) = pending_line.find('\n') {
|
||||
complete_lines.push(pending_line[..pos].to_string());
|
||||
pending_line = pending_line[pos + 1..].to_string();
|
||||
}
|
||||
for line in complete_lines {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() || trimmed.starts_with(':') {
|
||||
continue; // Skip empty lines and SSE comments
|
||||
}
|
||||
// Handle both "data: " (standard) and "data:" (no space)
|
||||
let data = if let Some(d) = trimmed.strip_prefix("data: ") {
|
||||
let data: Option<&str> = if let Some(d) = trimmed.strip_prefix("data: ") {
|
||||
Some(d)
|
||||
} else if let Some(d) = trimmed.strip_prefix("data:") {
|
||||
Some(d.trim_start())
|
||||
@@ -201,7 +208,7 @@ impl LlmDriver for OpenAiDriver {
|
||||
tracing::debug!("[OpenAI:stream] SSE #{}: {}", sse_event_count, &data[..data.len().min(300)]);
|
||||
}
|
||||
if data == "[DONE]" {
|
||||
tracing::debug!("[OpenAI:stream] Received [DONE], total SSE events: {}, raw bytes: {}", sse_event_count, raw_bytes_total);
|
||||
tracing::debug!("[OpenAI:stream] Received [DONE], total SSE events: {}, raw bytes: {}, tool_calls: {:?}", sse_event_count, raw_bytes_total, accumulated_tool_calls);
|
||||
|
||||
// Emit ToolUseEnd for all accumulated tool calls (skip invalid ones with empty name)
|
||||
for (id, (name, args)) in &accumulated_tool_calls {
|
||||
@@ -230,6 +237,8 @@ impl LlmDriver for OpenAiDriver {
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
@@ -257,7 +266,7 @@ impl LlmDriver for OpenAiDriver {
|
||||
|
||||
// Handle tool calls
|
||||
if let Some(tool_calls) = &delta.tool_calls {
|
||||
tracing::trace!("[OpenAI] Received tool_calls delta: {:?}", tool_calls);
|
||||
tracing::debug!("[OpenAI] Received tool_calls delta: {:?}", tool_calls);
|
||||
for tc in tool_calls {
|
||||
// Tool call start - has id and name
|
||||
if let Some(id) = &tc.id {
|
||||
@@ -631,6 +640,8 @@ impl OpenAiDriver {
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
stop_reason,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -754,6 +765,8 @@ impl OpenAiDriver {
|
||||
StopReason::StopSequence => "stop",
|
||||
StopReason::Error => "error",
|
||||
}.to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
123
crates/zclaw-runtime/src/driver/retry_driver.rs
Normal file
123
crates/zclaw-runtime/src/driver/retry_driver.rs
Normal file
@@ -0,0 +1,123 @@
|
||||
//! RetryDriver: LlmDriver 的重试装饰器。
|
||||
//! 仅在本地 Kernel 路径使用,SaaS Relay 已有自己的重试逻辑。
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use rand::Rng;
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use super::{LlmDriver, CompletionRequest, CompletionResponse, StreamChunk};
|
||||
use super::error_classifier::classify_llm_error;
|
||||
|
||||
/// 重试配置
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
pub max_attempts: u32,
|
||||
pub base_delay_secs: f64,
|
||||
pub max_delay_secs: f64,
|
||||
pub jitter_ratio: f64,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_attempts: 3,
|
||||
base_delay_secs: 1.0,
|
||||
max_delay_secs: 8.0,
|
||||
jitter_ratio: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 重试装饰器
|
||||
pub struct RetryDriver {
|
||||
inner: Arc<dyn LlmDriver>,
|
||||
config: RetryConfig,
|
||||
}
|
||||
|
||||
impl RetryDriver {
|
||||
pub fn new(inner: Arc<dyn LlmDriver>, config: RetryConfig) -> Self {
|
||||
Self { inner, config }
|
||||
}
|
||||
|
||||
fn jittered_backoff(&self, attempt: u32) -> Duration {
|
||||
let base = self.config.base_delay_secs * 2_f64.powi(attempt as i32);
|
||||
let capped = base.min(self.config.max_delay_secs);
|
||||
let mut rng = rand::thread_rng();
|
||||
let jitter = capped * self.config.jitter_ratio * rng.gen::<f64>();
|
||||
Duration::from_secs_f64(capped + jitter)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmDriver for RetryDriver {
|
||||
fn provider(&self) -> &str {
|
||||
self.inner.provider()
|
||||
}
|
||||
|
||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||
let mut last_error: Option<ZclawError> = None;
|
||||
|
||||
for attempt in 0..self.config.max_attempts {
|
||||
match self.inner.complete(request.clone()).await {
|
||||
Ok(response) => return Ok(response),
|
||||
Err(e) => {
|
||||
let message = e.to_string();
|
||||
let status = extract_status_from_error(&message);
|
||||
let classified = classify_llm_error(
|
||||
self.inner.provider(),
|
||||
status,
|
||||
&message,
|
||||
message.contains("timeout") || message.contains("Timeout"),
|
||||
);
|
||||
|
||||
if !classified.retryable {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
if classified.should_compress {
|
||||
return Err(ZclawError::LlmError(
|
||||
format!("[CONTEXT_OVERFLOW] {}", message)
|
||||
));
|
||||
}
|
||||
|
||||
last_error = Some(e);
|
||||
|
||||
if attempt + 1 < self.config.max_attempts {
|
||||
let delay = classified.retry_after
|
||||
.unwrap_or_else(|| self.jittered_backoff(attempt));
|
||||
tracing::warn!(
|
||||
"[RetryDriver] Attempt {}/{} failed ({}), retrying in {:.1}s",
|
||||
attempt + 1, self.config.max_attempts, classified.message,
|
||||
delay.as_secs_f64()
|
||||
);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| ZclawError::LlmError("重试耗尽".to_string())))
|
||||
}
|
||||
|
||||
fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> std::pin::Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
|
||||
// 流式路径不重试——部分 delta 已发送,重试会导致 UI 重复
|
||||
self.inner.stream(request)
|
||||
}
|
||||
|
||||
fn is_configured(&self) -> bool {
|
||||
self.inner.is_configured()
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_status_from_error(message: &str) -> u16 {
|
||||
let re = regex::Regex::new(r"(?:error|status)[:\s]+(\d{3})").ok();
|
||||
re.and_then(|re| re.captures(message))
|
||||
.and_then(|caps| caps[1].parse().ok())
|
||||
.unwrap_or(0)
|
||||
}
|
||||
@@ -148,6 +148,18 @@ impl GrowthIntegration {
|
||||
self.config.auto_extract = auto_extract;
|
||||
}
|
||||
|
||||
/// Configure embedding client for memory retrieval.
|
||||
///
|
||||
/// Propagates the embedding client to the MemoryRetriever's SemanticScorer,
|
||||
/// enabling embedding-based similarity in addition to TF-IDF.
|
||||
/// Safe to call from non-async contexts.
|
||||
pub fn configure_embedding(
|
||||
&self,
|
||||
client: Arc<dyn zclaw_growth::retrieval::semantic::EmbeddingClient>,
|
||||
) {
|
||||
self.retriever.set_embedding_client(client);
|
||||
}
|
||||
|
||||
/// Set the user profile store for incremental profile updates
|
||||
pub fn with_profile_store(mut self, store: Arc<UserProfileStore>) -> Self {
|
||||
self.profile_store = Some(store);
|
||||
@@ -318,15 +330,43 @@ impl GrowthIntegration {
|
||||
&& combined.experiences.is_empty()
|
||||
&& !combined.profile_signals.has_any_signal()
|
||||
{
|
||||
tracing::debug!(
|
||||
"[GrowthIntegration] Combined extraction produced nothing for agent {}",
|
||||
agent_id
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mem_count = combined.memories.len();
|
||||
tracing::info!(
|
||||
"[GrowthIntegration] Combined extraction for agent {}: {} memories, {} experiences, {} profile signals",
|
||||
agent_id,
|
||||
mem_count,
|
||||
combined.experiences.len(),
|
||||
combined.profile_signals.signal_count()
|
||||
);
|
||||
|
||||
// Store raw memories
|
||||
self.extractor
|
||||
match self.extractor
|
||||
.store_memories(&agent_id.to_string(), &combined.memories)
|
||||
.await?;
|
||||
.await
|
||||
{
|
||||
Ok(stored) => {
|
||||
tracing::info!(
|
||||
"[GrowthIntegration] Stored {} memories for agent {}",
|
||||
stored,
|
||||
agent_id
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
"[GrowthIntegration] Failed to store memories for agent {}: {}",
|
||||
agent_id,
|
||||
e
|
||||
);
|
||||
return Err(e);
|
||||
}
|
||||
}
|
||||
|
||||
// Track learning event
|
||||
self.tracker
|
||||
@@ -350,6 +390,11 @@ impl GrowthIntegration {
|
||||
// Update user profile from extraction signals (L1 enhancement)
|
||||
if let Some(profile_store) = &self.profile_store {
|
||||
let updates = self.profile_updater.collect_updates(&combined);
|
||||
tracing::info!(
|
||||
"[GrowthIntegration] Applying {} profile updates for agent {}",
|
||||
updates.len(),
|
||||
agent_id
|
||||
);
|
||||
let user_id = agent_id.to_string();
|
||||
for update in updates {
|
||||
let result = match update.kind {
|
||||
@@ -395,6 +440,39 @@ impl GrowthIntegration {
|
||||
}
|
||||
}
|
||||
|
||||
// Store identity signals as special memories for cross-session persistence
|
||||
if combined.profile_signals.has_identity_signal() {
|
||||
let agent_id_str = agent_id.to_string();
|
||||
if let Some(ref agent_name) = combined.profile_signals.agent_name {
|
||||
let entry = zclaw_growth::types::MemoryEntry::new(
|
||||
&agent_id_str,
|
||||
zclaw_growth::types::MemoryType::Preference,
|
||||
"identity",
|
||||
format!("助手的名字是{}", agent_name),
|
||||
).with_importance(8)
|
||||
.with_keywords(vec!["名字".to_string(), "称呼".to_string(), "identity".to_string(), agent_name.clone()]);
|
||||
if let Err(e) = self.extractor.store_memory_entry(&entry).await {
|
||||
tracing::warn!("[GrowthIntegration] Failed to store agent_name signal: {}", e);
|
||||
} else {
|
||||
tracing::info!("[GrowthIntegration] Stored agent_name '{}' for {}", agent_name, agent_id_str);
|
||||
}
|
||||
}
|
||||
if let Some(ref user_name) = combined.profile_signals.user_name {
|
||||
let entry = zclaw_growth::types::MemoryEntry::new(
|
||||
&agent_id_str,
|
||||
zclaw_growth::types::MemoryType::Preference,
|
||||
"identity",
|
||||
format!("用户的名字是{}", user_name),
|
||||
).with_importance(8)
|
||||
.with_keywords(vec!["名字".to_string(), "用户名".to_string(), "identity".to_string(), user_name.clone()]);
|
||||
if let Err(e) = self.extractor.store_memory_entry(&entry).await {
|
||||
tracing::warn!("[GrowthIntegration] Failed to store user_name signal: {}", e);
|
||||
} else {
|
||||
tracing::info!("[GrowthIntegration] Stored user_name '{}' for {}", user_name, agent_id_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convert extracted memories to structured facts
|
||||
let facts: Vec<Fact> = combined
|
||||
.memories
|
||||
|
||||
@@ -19,6 +19,8 @@ pub mod middleware;
|
||||
pub mod prompt;
|
||||
pub mod nl_schedule;
|
||||
|
||||
pub mod test_util;
|
||||
|
||||
// Re-export main types
|
||||
pub use driver::{
|
||||
LlmDriver, CompletionRequest, CompletionResponse, ContentBlock, StopReason,
|
||||
|
||||
@@ -4,10 +4,11 @@ use std::sync::Arc;
|
||||
use futures::StreamExt;
|
||||
use tokio::sync::mpsc;
|
||||
use zclaw_types::{AgentId, SessionId, Message, Result};
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::driver::{LlmDriver, CompletionRequest, ContentBlock};
|
||||
use crate::stream::StreamChunk;
|
||||
use crate::tool::{ToolRegistry, ToolContext, SkillExecutor};
|
||||
use crate::tool::{ToolRegistry, ToolContext, SkillExecutor, HandExecutor, ToolConcurrency};
|
||||
use crate::tool::builtin::PathValidator;
|
||||
use crate::growth::GrowthIntegration;
|
||||
use crate::compaction::{self, CompactionConfig};
|
||||
@@ -28,6 +29,7 @@ pub struct AgentLoop {
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
skill_executor: Option<Arc<dyn SkillExecutor>>,
|
||||
hand_executor: Option<Arc<dyn HandExecutor>>,
|
||||
path_validator: Option<PathValidator>,
|
||||
/// Growth system integration (optional)
|
||||
growth: Option<GrowthIntegration>,
|
||||
@@ -64,6 +66,7 @@ impl AgentLoop {
|
||||
max_tokens: 16384,
|
||||
temperature: 0.7,
|
||||
skill_executor: None,
|
||||
hand_executor: None,
|
||||
path_validator: None,
|
||||
growth: None,
|
||||
compaction_threshold: 0,
|
||||
@@ -81,6 +84,12 @@ impl AgentLoop {
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the hand executor for dispatching Hand tool calls to HandRegistry
|
||||
pub fn with_hand_executor(mut self, executor: Arc<dyn HandExecutor>) -> Self {
|
||||
self.hand_executor = Some(executor);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the path validator for file system operations
|
||||
pub fn with_path_validator(mut self, validator: PathValidator) -> Self {
|
||||
self.path_validator = Some(validator);
|
||||
@@ -199,6 +208,7 @@ impl AgentLoop {
|
||||
working_directory: working_dir,
|
||||
session_id: Some(session_id.to_string()),
|
||||
skill_executor: self.skill_executor.clone(),
|
||||
hand_executor: self.hand_executor.clone(),
|
||||
path_validator: Some(path_validator),
|
||||
event_sender: None,
|
||||
}
|
||||
@@ -294,8 +304,28 @@ impl AgentLoop {
|
||||
plan_mode: self.plan_mode,
|
||||
};
|
||||
|
||||
// Call LLM
|
||||
let response = self.driver.complete(request).await?;
|
||||
// Call LLM with context-overflow recovery
|
||||
let response = match self.driver.complete(request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
if err_str.contains("[CONTEXT_OVERFLOW]") && self.compaction_threshold > 0 {
|
||||
tracing::warn!("[AgentLoop] Context overflow detected, triggering emergency compaction");
|
||||
let pruned = compaction::prune_tool_outputs(&mut messages);
|
||||
if pruned > 0 {
|
||||
tracing::info!("[AgentLoop] Emergency pruning removed {} tool outputs", pruned);
|
||||
}
|
||||
let keep_recent = messages.len().saturating_sub(messages.len() / 3);
|
||||
let (compacted, removed) = compaction::compact_messages(messages, keep_recent.max(4));
|
||||
if removed > 0 {
|
||||
tracing::info!("[AgentLoop] Emergency compaction removed {} messages", removed);
|
||||
messages = compacted;
|
||||
continue; // retry the iteration with compacted messages
|
||||
}
|
||||
}
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
total_input_tokens += response.input_tokens;
|
||||
total_output_tokens += response.output_tokens;
|
||||
|
||||
@@ -366,106 +396,164 @@ impl AgentLoop {
|
||||
let tool_context = self.create_tool_context(session_id.clone());
|
||||
let mut abort_result: Option<AgentLoopResult> = None;
|
||||
let mut clarification_result: Option<AgentLoopResult> = None;
|
||||
for (id, name, input) in tool_calls {
|
||||
// Check if loop was already aborted
|
||||
if abort_result.is_some() {
|
||||
break;
|
||||
}
|
||||
// Check tool call safety — via middleware chain
|
||||
{
|
||||
let mw_ctx_ref = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.to_string(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages: messages.clone(),
|
||||
response_content: Vec::new(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
match self.middleware_chain.run_before_tool_call(&mw_ctx_ref, &name, &input).await? {
|
||||
middleware::ToolCallDecision::Allow => {}
|
||||
middleware::ToolCallDecision::Block(msg) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
||||
let error_output = serde_json::json!({ "error": msg });
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), error_output, true));
|
||||
continue;
|
||||
}
|
||||
middleware::ToolCallDecision::ReplaceInput(new_input) => {
|
||||
// Execute with replaced input (with timeout)
|
||||
let tool_result = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(30),
|
||||
self.execute_tool(&name, new_input, &tool_context),
|
||||
).await {
|
||||
Ok(Ok(result)) => result,
|
||||
Ok(Err(e)) => serde_json::json!({ "error": e.to_string() }),
|
||||
Err(_) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' (replaced input) timed out after 30s", name);
|
||||
serde_json::json!({ "error": format!("工具 '{}' 执行超时(30秒),请重试", name) })
|
||||
}
|
||||
};
|
||||
messages.push(Message::tool_result(id, zclaw_types::ToolId::new(&name), tool_result, false));
|
||||
continue;
|
||||
}
|
||||
middleware::ToolCallDecision::AbortLoop(reason) => {
|
||||
tracing::warn!("[AgentLoop] Loop aborted by middleware: {}", reason);
|
||||
let msg = format!("{}\n已自动终止", reason);
|
||||
self.memory.append_message(&session_id, &Message::assistant(&msg)).await?;
|
||||
abort_result = Some(AgentLoopResult {
|
||||
response: msg,
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
iterations,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let tool_result = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(30),
|
||||
self.execute_tool(&name, input, &tool_context),
|
||||
).await {
|
||||
Ok(Ok(result)) => result,
|
||||
Ok(Err(e)) => serde_json::json!({ "error": e.to_string() }),
|
||||
Err(_) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' timed out after 30s", name);
|
||||
serde_json::json!({ "error": format!("工具 '{}' 执行超时(30秒),请重试", name) })
|
||||
// Phase 1: Pre-process inputs + middleware checks (serial)
|
||||
struct ToolPlan {
|
||||
idx: usize,
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
}
|
||||
let mut plans: Vec<ToolPlan> = Vec::new();
|
||||
for (idx, (id, name, input)) in tool_calls.into_iter().enumerate() {
|
||||
if abort_result.is_some() { break; }
|
||||
|
||||
// GLM and other models sometimes send tool calls with empty arguments `{}`
|
||||
let input = if input.as_object().map_or(false, |obj| obj.is_empty()) {
|
||||
if let Some(last_user_msg) = messages.iter().rev().find_map(|m| {
|
||||
if let Message::User { content } = m { Some(content.clone()) } else { None }
|
||||
}) {
|
||||
tracing::info!("[AgentLoop] Tool '{}' received empty input, injecting user message as fallback query", name);
|
||||
serde_json::json!({ "_fallback_query": last_user_msg })
|
||||
} else {
|
||||
input
|
||||
}
|
||||
} else {
|
||||
input
|
||||
};
|
||||
|
||||
// Check if this is a clarification response — terminate loop immediately
|
||||
// so the LLM waits for user input instead of continuing to generate.
|
||||
if name == "ask_clarification"
|
||||
&& tool_result.get("status").and_then(|v| v.as_str()) == Some("clarification_needed")
|
||||
{
|
||||
tracing::info!("[AgentLoop] Clarification requested, terminating loop");
|
||||
let question = tool_result.get("question")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("需要更多信息")
|
||||
.to_string();
|
||||
messages.push(Message::tool_result(
|
||||
id,
|
||||
zclaw_types::ToolId::new(&name),
|
||||
tool_result,
|
||||
false,
|
||||
));
|
||||
self.memory.append_message(&session_id, &Message::assistant(&question)).await?;
|
||||
clarification_result = Some(AgentLoopResult {
|
||||
response: question,
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
iterations,
|
||||
let mw_ctx = middleware::MiddlewareContext {
|
||||
agent_id: self.agent_id.clone(),
|
||||
session_id: session_id.clone(),
|
||||
user_input: input.to_string(),
|
||||
system_prompt: enhanced_prompt.clone(),
|
||||
messages: messages.clone(),
|
||||
response_content: Vec::new(),
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
};
|
||||
match self.middleware_chain.run_before_tool_call(&mw_ctx, &name, &input).await? {
|
||||
middleware::ToolCallDecision::Allow => {
|
||||
plans.push(ToolPlan { idx, id, name, input });
|
||||
}
|
||||
middleware::ToolCallDecision::Block(msg) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' blocked by middleware: {}", name, msg);
|
||||
messages.push(Message::tool_result(&id, zclaw_types::ToolId::new(&name), serde_json::json!({ "error": msg }), true));
|
||||
}
|
||||
middleware::ToolCallDecision::ReplaceInput(new_input) => {
|
||||
plans.push(ToolPlan { idx, id, name, input: new_input });
|
||||
}
|
||||
middleware::ToolCallDecision::AbortLoop(reason) => {
|
||||
tracing::warn!("[AgentLoop] Loop aborted by middleware: {}", reason);
|
||||
let msg = format!("{}\n已自动终止", reason);
|
||||
self.memory.append_message(&session_id, &Message::assistant(&msg)).await?;
|
||||
abort_result = Some(AgentLoopResult {
|
||||
response: msg,
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
iterations,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase 2: Execute tools (parallel for ReadOnly, serial for others)
|
||||
if abort_result.is_none() && !plans.is_empty() {
|
||||
let (parallel_plans, sequential_plans): (Vec<_>, Vec<_>) = plans.iter()
|
||||
.partition(|p| {
|
||||
self.tools.get(&p.name)
|
||||
.map(|t| t.concurrency())
|
||||
.unwrap_or(ToolConcurrency::Exclusive) == ToolConcurrency::ReadOnly
|
||||
});
|
||||
break;
|
||||
|
||||
let mut results: std::collections::HashMap<usize, (String, String, serde_json::Value)> = std::collections::HashMap::new();
|
||||
|
||||
// Execute parallel (ReadOnly) tools with JoinSet (max 3 concurrent)
|
||||
if !parallel_plans.is_empty() {
|
||||
let semaphore = Arc::new(tokio::sync::Semaphore::new(3));
|
||||
let mut join_set = tokio::task::JoinSet::new();
|
||||
|
||||
for plan in ¶llel_plans {
|
||||
let tool = self.tools.get(&plan.name).unwrap();
|
||||
let ctx = tool_context.clone();
|
||||
let input = plan.input.clone();
|
||||
let idx = plan.idx;
|
||||
let id = plan.id.clone();
|
||||
let name = plan.name.clone();
|
||||
let permit = semaphore.clone().acquire_owned().await.unwrap();
|
||||
|
||||
join_set.spawn(async move {
|
||||
let result = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(30),
|
||||
tool.execute(input, &ctx)
|
||||
).await;
|
||||
drop(permit);
|
||||
(idx, id, name, result)
|
||||
});
|
||||
}
|
||||
|
||||
while let Some(res) = join_set.join_next().await {
|
||||
match res {
|
||||
Ok((idx, id, name, Ok(Ok(value)))) => {
|
||||
results.insert(idx, (id, name, value));
|
||||
}
|
||||
Ok((idx, id, name, Ok(Err(e)))) => {
|
||||
results.insert(idx, (id, name, serde_json::json!({ "error": e.to_string() })));
|
||||
}
|
||||
Ok((idx, id, name, Err(_))) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' timed out after 30s (parallel)", name);
|
||||
results.insert(idx, (id, name.clone(), serde_json::json!({ "error": format!("工具 '{}' 执行超时(30秒),请重试", name) })));
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[AgentLoop] JoinError in parallel tool execution: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool result to messages
|
||||
messages.push(Message::tool_result(
|
||||
id,
|
||||
zclaw_types::ToolId::new(&name),
|
||||
tool_result,
|
||||
false, // is_error - we include errors in the result itself
|
||||
));
|
||||
// Execute sequential (Exclusive/Interactive) tools
|
||||
for plan in &sequential_plans {
|
||||
let tool_result = match tokio::time::timeout(
|
||||
std::time::Duration::from_secs(30),
|
||||
self.execute_tool(&plan.name, plan.input.clone(), &tool_context),
|
||||
).await {
|
||||
Ok(Ok(result)) => result,
|
||||
Ok(Err(e)) => serde_json::json!({ "error": e.to_string() }),
|
||||
Err(_) => {
|
||||
tracing::warn!("[AgentLoop] Tool '{}' timed out after 30s", plan.name);
|
||||
serde_json::json!({ "error": format!("工具 '{}' 执行超时(30秒),请重试", plan.name) })
|
||||
}
|
||||
};
|
||||
|
||||
// Check if this is a clarification response
|
||||
if plan.name == "ask_clarification"
|
||||
&& tool_result.get("status").and_then(|v| v.as_str()) == Some("clarification_needed")
|
||||
{
|
||||
tracing::info!("[AgentLoop] Clarification requested, terminating loop");
|
||||
let question = tool_result.get("question")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("需要更多信息")
|
||||
.to_string();
|
||||
results.insert(plan.idx, (plan.id.clone(), plan.name.clone(), tool_result));
|
||||
self.memory.append_message(&session_id, &Message::assistant(&question)).await?;
|
||||
clarification_result = Some(AgentLoopResult {
|
||||
response: question,
|
||||
input_tokens: total_input_tokens,
|
||||
output_tokens: total_output_tokens,
|
||||
iterations,
|
||||
});
|
||||
break;
|
||||
}
|
||||
results.insert(plan.idx, (plan.id.clone(), plan.name.clone(), tool_result));
|
||||
}
|
||||
|
||||
// Push results in original tool_call order
|
||||
let mut sorted_indices: Vec<usize> = results.keys().copied().collect();
|
||||
sorted_indices.sort();
|
||||
for idx in sorted_indices {
|
||||
let (id, name, result) = results.remove(&idx).unwrap();
|
||||
messages.push(Message::tool_result(&id, zclaw_types::ToolId::new(&name), result, false));
|
||||
}
|
||||
}
|
||||
|
||||
// Continue the loop - LLM will process tool results and generate final response
|
||||
@@ -567,6 +655,7 @@ impl AgentLoop {
|
||||
let tools = self.tools.clone();
|
||||
let middleware_chain = self.middleware_chain.clone();
|
||||
let skill_executor = self.skill_executor.clone();
|
||||
let hand_executor = self.hand_executor.clone();
|
||||
let path_validator = self.path_validator.clone();
|
||||
let agent_id = self.agent_id.clone();
|
||||
let model = self.model.clone();
|
||||
@@ -849,6 +938,7 @@ impl AgentLoop {
|
||||
working_directory: working_dir,
|
||||
session_id: Some(session_id_clone.to_string()),
|
||||
skill_executor: skill_executor.clone(),
|
||||
hand_executor: hand_executor.clone(),
|
||||
path_validator: Some(pv),
|
||||
event_sender: Some(tx.clone()),
|
||||
};
|
||||
@@ -903,6 +993,7 @@ impl AgentLoop {
|
||||
working_directory: working_dir,
|
||||
session_id: Some(session_id_clone.to_string()),
|
||||
skill_executor: skill_executor.clone(),
|
||||
hand_executor: hand_executor.clone(),
|
||||
path_validator: Some(pv),
|
||||
event_sender: Some(tx.clone()),
|
||||
};
|
||||
|
||||
@@ -12,6 +12,13 @@
|
||||
//! | 200-399 | Capability | SkillIndex, Guardrail |
|
||||
//! | 400-599 | Safety | LoopGuard, Guardrail |
|
||||
//! | 600-799 | Telemetry | TokenCalibration, Tracking |
|
||||
//!
|
||||
//! # Wave parallelization
|
||||
//!
|
||||
//! `before_completion` middlewares that only modify `system_prompt` (not `messages`)
|
||||
//! can declare `parallel_safe() == true`. The chain runs consecutive parallel-safe
|
||||
//! middlewares concurrently, merging their prompt contributions. This reduces
|
||||
//! sequential latency for the context-injection phase.
|
||||
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
@@ -50,6 +57,7 @@ pub enum ToolCallDecision {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Carries the mutable state that middleware may inspect or modify.
|
||||
#[derive(Clone)]
|
||||
pub struct MiddlewareContext {
|
||||
/// The agent that owns this loop.
|
||||
pub agent_id: AgentId,
|
||||
@@ -101,6 +109,15 @@ pub trait AgentMiddleware: Send + Sync {
|
||||
500
|
||||
}
|
||||
|
||||
/// Whether `before_completion` is safe to run concurrently with other
|
||||
/// parallel-safe middlewares. Only return `true` if the middleware:
|
||||
/// - Only modifies `ctx.system_prompt` (never `ctx.messages`)
|
||||
/// - Does not depend on prompt modifications from other middlewares
|
||||
/// - Does not return `MiddlewareDecision::Stop`
|
||||
fn parallel_safe(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Hook executed **before** the LLM completion request is sent.
|
||||
///
|
||||
/// Use this to inject context (memory, skill index, etc.) or to
|
||||
@@ -163,15 +180,74 @@ impl MiddlewareChain {
|
||||
self.middlewares.insert(pos, mw);
|
||||
}
|
||||
|
||||
/// Run all `before_completion` hooks in order.
|
||||
/// Run all `before_completion` hooks with wave-based parallelization.
|
||||
///
|
||||
/// Consecutive `parallel_safe` middlewares run concurrently — each gets
|
||||
/// its own cloned context and appends to `system_prompt` independently.
|
||||
/// Their contributions are merged after all complete. Non-parallel-safe
|
||||
/// middlewares (and non-consecutive ones) run sequentially as before.
|
||||
pub async fn run_before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
for mw in &self.middlewares {
|
||||
match mw.before_completion(ctx).await? {
|
||||
MiddlewareDecision::Continue => {}
|
||||
MiddlewareDecision::Stop(reason) => {
|
||||
tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason);
|
||||
return Ok(MiddlewareDecision::Stop(reason));
|
||||
let mut idx = 0;
|
||||
while idx < self.middlewares.len() {
|
||||
// Find the extent of consecutive parallel-safe middlewares
|
||||
let wave_start = idx;
|
||||
let mut wave_end = idx;
|
||||
while wave_end < self.middlewares.len()
|
||||
&& self.middlewares[wave_end].parallel_safe()
|
||||
{
|
||||
wave_end += 1;
|
||||
}
|
||||
|
||||
if wave_end - wave_start >= 2 {
|
||||
// Run parallel wave (2+ consecutive parallel-safe middlewares)
|
||||
let base_prompt_len = ctx.system_prompt.len();
|
||||
let wave = &self.middlewares[wave_start..wave_end];
|
||||
|
||||
// Spawn concurrent tasks — each owns its cloned context + Arc ref to middleware
|
||||
let mut join_handles = Vec::with_capacity(wave.len());
|
||||
for mw in wave.iter() {
|
||||
let mut ctx_clone = ctx.clone();
|
||||
let mw_arc = Arc::clone(mw);
|
||||
join_handles.push(tokio::spawn(async move {
|
||||
let result = mw_arc.before_completion(&mut ctx_clone).await;
|
||||
(result, ctx_clone.system_prompt)
|
||||
}));
|
||||
}
|
||||
|
||||
// Await all and merge prompt contributions
|
||||
for (i, handle) in join_handles.into_iter().enumerate() {
|
||||
let (result, modified_prompt): (Result<MiddlewareDecision>, String) = handle.await
|
||||
.map_err(|e| zclaw_types::ZclawError::Internal(format!("Parallel middleware panicked: {}", e)))?;
|
||||
match result? {
|
||||
MiddlewareDecision::Continue => {}
|
||||
MiddlewareDecision::Stop(reason) => {
|
||||
tracing::info!(
|
||||
"[MiddlewareChain] '{}' requested stop: {}",
|
||||
self.middlewares[wave_start + i].name(),
|
||||
reason
|
||||
);
|
||||
return Ok(MiddlewareDecision::Stop(reason));
|
||||
}
|
||||
}
|
||||
// Merge system_prompt contribution from this clone
|
||||
if modified_prompt.len() > base_prompt_len {
|
||||
let contribution = &modified_prompt[base_prompt_len..];
|
||||
ctx.system_prompt.push_str(contribution);
|
||||
}
|
||||
}
|
||||
|
||||
idx = wave_end;
|
||||
} else {
|
||||
// Run single middleware sequentially
|
||||
let mw = &self.middlewares[idx];
|
||||
match mw.before_completion(ctx).await? {
|
||||
MiddlewareDecision::Continue => {}
|
||||
MiddlewareDecision::Stop(reason) => {
|
||||
tracing::info!("[MiddlewareChain] '{}' requested stop: {}", mw.name(), reason);
|
||||
return Ok(MiddlewareDecision::Stop(reason));
|
||||
}
|
||||
}
|
||||
idx += 1;
|
||||
}
|
||||
}
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
@@ -268,7 +344,6 @@ impl Default for MiddlewareChain {
|
||||
pub mod butler_router;
|
||||
pub mod compaction;
|
||||
pub mod dangling_tool;
|
||||
pub mod data_masking;
|
||||
pub mod guardrail;
|
||||
pub mod loop_guard;
|
||||
pub mod memory;
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
//! Intercepts user messages before LLM processing, uses SemanticSkillRouter
|
||||
//! to classify intent, and injects routing context into the system prompt.
|
||||
//!
|
||||
//! Priority: 80 (runs before data_masking at 90, so it sees raw user input).
|
||||
//! Priority: 80 (runs before compaction and other post-routing middleware).
|
||||
//!
|
||||
//! Supports two modes:
|
||||
//! 1. **Static mode** (default): Uses built-in `KeywordClassifier` with 4 healthcare domains.
|
||||
@@ -290,6 +290,8 @@ impl AgentMiddleware for ButlerRouterMiddleware {
|
||||
80
|
||||
}
|
||||
|
||||
fn parallel_safe(&self) -> bool { true }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
// Only route on the first user message in a turn (not tool results)
|
||||
let user_input = &ctx.user_input;
|
||||
|
||||
@@ -1,21 +1,49 @@
|
||||
//! Compaction middleware — wraps the existing compaction module.
|
||||
//!
|
||||
//! Supports debounce (cooldown + min-round checks), async LLM compression
|
||||
//! with cached fallback, and iterative summaries that carry forward key info.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use zclaw_types::Result;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
use crate::compaction::{self, CompactionConfig};
|
||||
use crate::growth::GrowthIntegration;
|
||||
use crate::driver::LlmDriver;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use zclaw_types::{Message, Result};
|
||||
use crate::compaction::{self, CompactionConfig};
|
||||
use crate::driver::LlmDriver;
|
||||
use crate::growth::GrowthIntegration;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
|
||||
/// Minimum seconds between consecutive compactions.
|
||||
const COMPACTION_COOLDOWN_SECS: u64 = 30;
|
||||
/// Minimum message pairs (user+assistant) since last compaction before triggering again.
|
||||
const COMPACTION_MIN_ROUNDS: u64 = 3;
|
||||
|
||||
fn now_millis() -> u64 {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64
|
||||
}
|
||||
|
||||
/// Shared compaction debounce state (lock-free).
|
||||
struct CompactionState {
|
||||
last_compaction_ms: AtomicU64,
|
||||
last_compaction_msg_count: AtomicU64,
|
||||
}
|
||||
|
||||
/// Cached result from a previous async LLM compaction.
|
||||
struct AsyncCompactionCache {
|
||||
last_result: RwLock<Option<Vec<Message>>>,
|
||||
}
|
||||
|
||||
/// Middleware that compresses conversation history when it exceeds a token threshold.
|
||||
pub struct CompactionMiddleware {
|
||||
threshold: usize,
|
||||
config: CompactionConfig,
|
||||
/// Optional LLM driver for async compaction (LLM summarisation, memory flush).
|
||||
driver: Option<Arc<dyn LlmDriver>>,
|
||||
/// Optional growth integration for memory flushing during compaction.
|
||||
growth: Option<GrowthIntegration>,
|
||||
state: Arc<CompactionState>,
|
||||
cache: Arc<AsyncCompactionCache>,
|
||||
}
|
||||
|
||||
impl CompactionMiddleware {
|
||||
@@ -25,7 +53,39 @@ impl CompactionMiddleware {
|
||||
driver: Option<Arc<dyn LlmDriver>>,
|
||||
growth: Option<GrowthIntegration>,
|
||||
) -> Self {
|
||||
Self { threshold, config, driver, growth }
|
||||
Self {
|
||||
threshold,
|
||||
config,
|
||||
driver,
|
||||
growth,
|
||||
state: Arc::new(CompactionState {
|
||||
last_compaction_ms: AtomicU64::new(0),
|
||||
last_compaction_msg_count: AtomicU64::new(0),
|
||||
}),
|
||||
cache: Arc::new(AsyncCompactionCache {
|
||||
last_result: RwLock::new(None),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn should_compact(&self, msg_count: u64) -> bool {
|
||||
let last_ms = self.state.last_compaction_ms.load(Ordering::Relaxed);
|
||||
let last_count = self.state.last_compaction_msg_count.load(Ordering::Relaxed);
|
||||
|
||||
if now_millis().saturating_sub(last_ms) < COMPACTION_COOLDOWN_SECS * 1000 {
|
||||
return false;
|
||||
}
|
||||
|
||||
if msg_count.saturating_sub(last_count) < COMPACTION_MIN_ROUNDS * 2 {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn record_compaction(&self, msg_count: u64) {
|
||||
self.state.last_compaction_ms.store(now_millis(), Ordering::Relaxed);
|
||||
self.state.last_compaction_msg_count.store(msg_count, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,6 +99,29 @@ impl AgentMiddleware for CompactionMiddleware {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
// Step 1: Prune old tool outputs (cheap, no LLM needed)
|
||||
let pruned = compaction::prune_tool_outputs(&mut ctx.messages);
|
||||
if pruned > 0 {
|
||||
tracing::info!("[CompactionMiddleware] Pruned {} old tool outputs", pruned);
|
||||
}
|
||||
|
||||
// Step 2: Re-estimate tokens after pruning
|
||||
let tokens = compaction::estimate_messages_tokens_calibrated(&ctx.messages);
|
||||
if tokens < self.threshold {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
// Step 3: Debounce check
|
||||
if !self.should_compact(ctx.messages.len() as u64) {
|
||||
// Still over threshold but within cooldown — use cached result if available
|
||||
if let Some(cached) = self.cache.last_result.read().await.clone() {
|
||||
tracing::debug!("[CompactionMiddleware] Cooldown active, using cached compaction result");
|
||||
ctx.messages = cached;
|
||||
}
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
// Step 4: Execute compaction
|
||||
let needs_async = self.config.use_llm || self.config.memory_flush_enabled;
|
||||
if needs_async {
|
||||
let outcome = compaction::maybe_compact_with_config(
|
||||
@@ -56,6 +139,14 @@ impl AgentMiddleware for CompactionMiddleware {
|
||||
ctx.messages = compaction::maybe_compact(ctx.messages.clone(), self.threshold);
|
||||
}
|
||||
|
||||
self.record_compaction(ctx.messages.len() as u64);
|
||||
|
||||
// Cache result for cooldown fallback
|
||||
{
|
||||
let mut cache = self.cache.last_result.write().await;
|
||||
*cache = Some(ctx.messages.clone());
|
||||
}
|
||||
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,323 +0,0 @@
|
||||
//! Data Masking Middleware — protect sensitive business data from leaving the user's machine.
|
||||
//!
|
||||
//! Before LLM calls, replaces detected entities (company names, amounts, phone numbers)
|
||||
//! with deterministic tokens. After responses, the caller can restore the original entities.
|
||||
//!
|
||||
//! Priority: 90 (runs before Compaction@100 and Memory@150)
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, LazyLock, RwLock};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use zclaw_types::{Message, Result};
|
||||
|
||||
use super::{AgentMiddleware, MiddlewareContext, MiddlewareDecision};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Pre-compiled regex patterns (compiled once, reused across all calls)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
static RE_COMPANY: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"[^\s]{1,20}(?:公司|厂|集团|工作室|商行|有限|股份)").expect("static regex is valid")
|
||||
});
|
||||
static RE_MONEY: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"[¥¥$]\s*[\d,.]+[万亿]?元?|[\d,.]+[万亿]元").expect("static regex is valid")
|
||||
});
|
||||
static RE_PHONE: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"1[3-9]\d-?\d{4}-?\d{4}").expect("static regex is valid")
|
||||
});
|
||||
static RE_EMAIL: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}").expect("static regex is valid")
|
||||
});
|
||||
static RE_ID_CARD: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"\b\d{17}[\dXx]\b").expect("static regex is valid")
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DataMasker — entity detection and token mapping
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Counts entities by type for token generation.
|
||||
static ENTITY_COUNTER: AtomicU64 = AtomicU64::new(1);
|
||||
|
||||
/// Detects and replaces sensitive entities with deterministic tokens.
|
||||
pub struct DataMasker {
|
||||
/// entity text → token mapping (persistent across conversations).
|
||||
forward: Arc<RwLock<HashMap<String, String>>>,
|
||||
/// token → entity text reverse mapping (in-memory only).
|
||||
reverse: Arc<RwLock<HashMap<String, String>>>,
|
||||
}
|
||||
|
||||
impl DataMasker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
forward: Arc::new(RwLock::new(HashMap::new())),
|
||||
reverse: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Mask all detected entities in `text`, replacing them with tokens.
|
||||
pub fn mask(&self, text: &str) -> Result<String> {
|
||||
let entities = self.detect_entities(text);
|
||||
if entities.is_empty() {
|
||||
return Ok(text.to_string());
|
||||
}
|
||||
|
||||
let mut result = text.to_string();
|
||||
for entity in entities {
|
||||
let token = self.get_or_create_token(&entity);
|
||||
// Replace all occurrences (longest entities first to avoid partial matches)
|
||||
result = result.replace(&entity, &token);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Restore all tokens in `text` back to their original entities.
|
||||
pub fn unmask(&self, text: &str) -> Result<String> {
|
||||
let reverse = self.reverse.read().map_err(|e| zclaw_types::ZclawError::IoError(std::io::Error::other(e.to_string())))?;
|
||||
if reverse.is_empty() {
|
||||
return Ok(text.to_string());
|
||||
}
|
||||
|
||||
let mut result = text.to_string();
|
||||
for (token, entity) in reverse.iter() {
|
||||
result = result.replace(token, entity);
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Detect sensitive entities in text using regex patterns.
|
||||
fn detect_entities(&self, text: &str) -> Vec<String> {
|
||||
let mut entities = Vec::new();
|
||||
|
||||
// Company names: X公司、XX集团、XX工作室 (1-20 char prefix + suffix)
|
||||
for cap in RE_COMPANY.find_iter(text) {
|
||||
entities.push(cap.as_str().to_string());
|
||||
}
|
||||
|
||||
// Money amounts: ¥50万、¥100元、$200、50万元
|
||||
for cap in RE_MONEY.find_iter(text) {
|
||||
entities.push(cap.as_str().to_string());
|
||||
}
|
||||
|
||||
// Phone numbers: 1XX-XXXX-XXXX or 1XXXXXXXXXX
|
||||
for cap in RE_PHONE.find_iter(text) {
|
||||
entities.push(cap.as_str().to_string());
|
||||
}
|
||||
|
||||
// Email addresses
|
||||
for cap in RE_EMAIL.find_iter(text) {
|
||||
entities.push(cap.as_str().to_string());
|
||||
}
|
||||
|
||||
// ID card numbers (simplified): 18 digits
|
||||
for cap in RE_ID_CARD.find_iter(text) {
|
||||
entities.push(cap.as_str().to_string());
|
||||
}
|
||||
|
||||
// Sort by length descending to replace longest entities first
|
||||
entities.sort_by(|a, b| b.len().cmp(&a.len()));
|
||||
entities.dedup();
|
||||
entities
|
||||
}
|
||||
|
||||
/// Get existing token for entity or create a new one.
|
||||
fn get_or_create_token(&self, entity: &str) -> String {
|
||||
/// Recover from a poisoned RwLock by taking the inner value and re-wrapping.
|
||||
/// A poisoned lock only means a panic occurred while holding it — the data is still valid.
|
||||
fn recover_read<T>(lock: &RwLock<T>) -> std::sync::LockResult<std::sync::RwLockReadGuard<'_, T>> {
|
||||
match lock.read() {
|
||||
Ok(guard) => Ok(guard),
|
||||
Err(_e) => {
|
||||
tracing::warn!("[DataMasker] RwLock poisoned during read, recovering");
|
||||
// Poison error still gives us access to the inner guard
|
||||
lock.read()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn recover_write<T>(lock: &RwLock<T>) -> std::sync::LockResult<std::sync::RwLockWriteGuard<'_, T>> {
|
||||
match lock.write() {
|
||||
Ok(guard) => Ok(guard),
|
||||
Err(_e) => {
|
||||
tracing::warn!("[DataMasker] RwLock poisoned during write, recovering");
|
||||
lock.write()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if already mapped
|
||||
{
|
||||
if let Ok(forward) = recover_read(&self.forward) {
|
||||
if let Some(token) = forward.get(entity) {
|
||||
return token.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create new token
|
||||
let counter = ENTITY_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
let token = format!("__ENTITY_{}__", counter);
|
||||
|
||||
// Store in both mappings
|
||||
if let Ok(mut forward) = recover_write(&self.forward) {
|
||||
forward.insert(entity.to_string(), token.clone());
|
||||
}
|
||||
if let Ok(mut reverse) = recover_write(&self.reverse) {
|
||||
reverse.insert(token.clone(), entity.to_string());
|
||||
}
|
||||
|
||||
token
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DataMasker {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// DataMaskingMiddleware — masks user messages before LLM completion
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub struct DataMaskingMiddleware {
|
||||
masker: Arc<DataMasker>,
|
||||
}
|
||||
|
||||
impl DataMaskingMiddleware {
|
||||
pub fn new(masker: Arc<DataMasker>) -> Self {
|
||||
Self { masker }
|
||||
}
|
||||
|
||||
/// Get a reference to the masker for unmasking responses externally.
|
||||
pub fn masker(&self) -> &Arc<DataMasker> {
|
||||
&self.masker
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AgentMiddleware for DataMaskingMiddleware {
|
||||
fn name(&self) -> &str { "data_masking" }
|
||||
fn priority(&self) -> i32 { 90 }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
// Mask user messages — replace sensitive entities with tokens
|
||||
for msg in &mut ctx.messages {
|
||||
if let Message::User { ref mut content } = msg {
|
||||
let masked = self.masker.mask(content)?;
|
||||
*content = masked;
|
||||
}
|
||||
}
|
||||
|
||||
// Also mask user_input field
|
||||
if !ctx.user_input.is_empty() {
|
||||
ctx.user_input = self.masker.mask(&ctx.user_input)?;
|
||||
}
|
||||
|
||||
Ok(MiddlewareDecision::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mask_company_name() {
|
||||
let masker = DataMasker::new();
|
||||
let input = "A公司的订单被退了";
|
||||
let masked = masker.mask(input).unwrap();
|
||||
assert!(!masked.contains("A公司"), "Company name should be masked: {}", masked);
|
||||
assert!(masked.contains("__ENTITY_"), "Should contain token: {}", masked);
|
||||
|
||||
let unmasked = masker.unmask(&masked).unwrap();
|
||||
assert_eq!(unmasked, input, "Unmask should restore original");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_consistency() {
|
||||
let masker = DataMasker::new();
|
||||
let masked1 = masker.mask("A公司").unwrap();
|
||||
let masked2 = masker.mask("A公司").unwrap();
|
||||
assert_eq!(masked1, masked2, "Same entity should always get same token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_money() {
|
||||
let masker = DataMasker::new();
|
||||
let input = "成本是¥50万";
|
||||
let masked = masker.mask(input).unwrap();
|
||||
assert!(!masked.contains("¥50万"), "Money should be masked: {}", masked);
|
||||
|
||||
let unmasked = masker.unmask(&masked).unwrap();
|
||||
assert_eq!(unmasked, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_phone() {
|
||||
let masker = DataMasker::new();
|
||||
let input = "联系13812345678";
|
||||
let masked = masker.mask(input).unwrap();
|
||||
assert!(!masked.contains("13812345678"), "Phone should be masked: {}", masked);
|
||||
|
||||
let unmasked = masker.unmask(&masked).unwrap();
|
||||
assert_eq!(unmasked, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_email() {
|
||||
let masker = DataMasker::new();
|
||||
let input = "发到 test@example.com 吧";
|
||||
let masked = masker.mask(input).unwrap();
|
||||
assert!(!masked.contains("test@example.com"), "Email should be masked: {}", masked);
|
||||
|
||||
let unmasked = masker.unmask(&masked).unwrap();
|
||||
assert_eq!(unmasked, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_no_entities() {
|
||||
let masker = DataMasker::new();
|
||||
let input = "今天天气不错";
|
||||
let masked = masker.mask(input).unwrap();
|
||||
assert_eq!(masked, input, "Text without entities should pass through unchanged");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_multiple_entities() {
|
||||
let masker = DataMasker::new();
|
||||
let input = "A公司的订单花了¥50万,联系13812345678";
|
||||
let masked = masker.mask(input).unwrap();
|
||||
assert!(!masked.contains("A公司"));
|
||||
assert!(!masked.contains("¥50万"));
|
||||
assert!(!masked.contains("13812345678"));
|
||||
|
||||
let unmasked = masker.unmask(&masked).unwrap();
|
||||
assert_eq!(unmasked, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unmask_empty() {
|
||||
let masker = DataMasker::new();
|
||||
let result = masker.unmask("hello world").unwrap();
|
||||
assert_eq!(result, "hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mask_id_card() {
|
||||
let masker = DataMasker::new();
|
||||
let input = "身份证号 110101199001011234";
|
||||
let masked = masker.mask(input).unwrap();
|
||||
assert!(!masked.contains("110101199001011234"), "ID card should be masked: {}", masked);
|
||||
|
||||
let unmasked = masker.unmask(&masked).unwrap();
|
||||
assert_eq!(unmasked, input);
|
||||
}
|
||||
}
|
||||
@@ -19,21 +19,45 @@ pub struct PendingEvolution {
|
||||
}
|
||||
|
||||
/// 进化引擎中间件
|
||||
/// 检查是否有待确认的进化事件,注入确认提示到 system prompt
|
||||
/// 检查是否有待确认的进化事件,根据模式:
|
||||
/// - suggest 模式(默认): 注入确认提示到 system prompt
|
||||
/// - auto 模式: 不注入,仅排队等待 kernel 自动处理
|
||||
pub struct EvolutionMiddleware {
|
||||
pending: Arc<RwLock<Vec<PendingEvolution>>>,
|
||||
auto_mode: bool,
|
||||
}
|
||||
|
||||
impl EvolutionMiddleware {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
pending: Arc::new(RwLock::new(Vec::new())),
|
||||
auto_mode: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with auto mode enabled
|
||||
pub fn new_auto() -> Self {
|
||||
Self {
|
||||
pending: Arc::new(RwLock::new(Vec::new())),
|
||||
auto_mode: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if auto mode is enabled
|
||||
pub fn is_auto_mode(&self) -> bool {
|
||||
self.auto_mode
|
||||
}
|
||||
|
||||
/// 添加一个待确认的进化事件
|
||||
pub async fn add_pending(&self, evolution: PendingEvolution) {
|
||||
self.pending.write().await.push(evolution);
|
||||
let mut pending = self.pending.write().await;
|
||||
if pending.len() >= 100 {
|
||||
tracing::warn!(
|
||||
"[EvolutionMiddleware] Pending queue full (100), dropping oldest event"
|
||||
);
|
||||
pending.remove(0);
|
||||
}
|
||||
pending.push(evolution);
|
||||
}
|
||||
|
||||
/// 获取并清除所有待确认事件
|
||||
@@ -64,6 +88,8 @@ impl AgentMiddleware for EvolutionMiddleware {
|
||||
78 // 在 ButlerRouter(80) 之前
|
||||
}
|
||||
|
||||
fn parallel_safe(&self) -> bool { true }
|
||||
|
||||
async fn before_completion(
|
||||
&self,
|
||||
ctx: &mut MiddlewareContext,
|
||||
@@ -73,7 +99,12 @@ impl AgentMiddleware for EvolutionMiddleware {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
// 只移除第一个事件,保留后续事件留待下次注入
|
||||
// Auto mode: don't inject into prompt, leave for kernel to process
|
||||
if self.auto_mode {
|
||||
return Ok(MiddlewareDecision::Continue);
|
||||
}
|
||||
|
||||
// Suggest mode: 只移除第一个事件,保留后续事件留待下次注入
|
||||
let to_inject = {
|
||||
let mut pending = self.pending.write().await;
|
||||
if pending.is_empty() {
|
||||
|
||||
@@ -19,7 +19,7 @@ use crate::middleware::evolution::EvolutionMiddleware;
|
||||
/// - `before_completion` → `enhance_prompt()` for memory injection
|
||||
/// - `after_completion` → `extract_combined()` for memory extraction + evolution check
|
||||
pub struct MemoryMiddleware {
|
||||
growth: GrowthIntegration,
|
||||
growth: std::sync::Arc<GrowthIntegration>,
|
||||
/// Shared EvolutionMiddleware for pushing evolution suggestions
|
||||
evolution_mw: Option<std::sync::Arc<EvolutionMiddleware>>,
|
||||
/// Minimum seconds between extractions for the same agent (debounce).
|
||||
@@ -29,7 +29,7 @@ pub struct MemoryMiddleware {
|
||||
}
|
||||
|
||||
impl MemoryMiddleware {
|
||||
pub fn new(growth: GrowthIntegration) -> Self {
|
||||
pub fn new(growth: std::sync::Arc<GrowthIntegration>) -> Self {
|
||||
Self {
|
||||
growth,
|
||||
evolution_mw: None,
|
||||
@@ -111,6 +111,7 @@ impl MemoryMiddleware {
|
||||
impl AgentMiddleware for MemoryMiddleware {
|
||||
fn name(&self) -> &str { "memory" }
|
||||
fn priority(&self) -> i32 { 150 }
|
||||
fn parallel_safe(&self) -> bool { true }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
tracing::debug!(
|
||||
|
||||
@@ -40,6 +40,7 @@ impl SkillIndexMiddleware {
|
||||
impl AgentMiddleware for SkillIndexMiddleware {
|
||||
fn name(&self) -> &str { "skill_index" }
|
||||
fn priority(&self) -> i32 { 200 }
|
||||
fn parallel_safe(&self) -> bool { true }
|
||||
|
||||
async fn before_completion(&self, ctx: &mut MiddlewareContext) -> Result<MiddlewareDecision> {
|
||||
if self.entries.is_empty() {
|
||||
|
||||
@@ -41,6 +41,7 @@ impl Default for TitleMiddleware {
|
||||
impl AgentMiddleware for TitleMiddleware {
|
||||
fn name(&self) -> &str { "title" }
|
||||
fn priority(&self) -> i32 { 180 }
|
||||
fn parallel_safe(&self) -> bool { true }
|
||||
|
||||
// All hooks default to Continue — placeholder until LLM driver is wired in.
|
||||
async fn before_completion(&self, _ctx: &mut crate::middleware::MiddlewareContext) -> zclaw_types::Result<MiddlewareDecision> {
|
||||
|
||||
@@ -4,12 +4,16 @@
|
||||
//! Inspired by DeerFlow's ToolErrorMiddleware: instead of propagating raw errors
|
||||
//! that crash the agent loop, this middleware wraps tool errors into a structured
|
||||
//! format that the LLM can use to self-correct.
|
||||
//!
|
||||
//! Also tracks consecutive tool failures across different tools — if N consecutive
|
||||
//! tool calls all fail, the loop is aborted to prevent infinite retry cycles.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::Value;
|
||||
use zclaw_types::Result;
|
||||
use crate::driver::ContentBlock;
|
||||
use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
|
||||
/// Middleware that intercepts tool call errors and formats recovery messages.
|
||||
///
|
||||
@@ -17,12 +21,18 @@ use crate::middleware::{AgentMiddleware, MiddlewareContext, ToolCallDecision};
|
||||
pub struct ToolErrorMiddleware {
|
||||
/// Maximum error message length before truncation.
|
||||
max_error_length: usize,
|
||||
/// Maximum consecutive failures before aborting the loop.
|
||||
max_consecutive_failures: u32,
|
||||
/// Tracks consecutive tool failures.
|
||||
consecutive_failures: AtomicU32,
|
||||
}
|
||||
|
||||
impl ToolErrorMiddleware {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_error_length: 500,
|
||||
max_consecutive_failures: 3,
|
||||
consecutive_failures: AtomicU32::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +71,6 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
||||
tool_input: &Value,
|
||||
) -> Result<ToolCallDecision> {
|
||||
// Pre-validate tool input structure for common issues.
|
||||
// This catches malformed JSON inputs before they reach the tool executor.
|
||||
if tool_input.is_null() {
|
||||
tracing::warn!(
|
||||
"[ToolErrorMiddleware] Tool '{}' received null input — replacing with empty object",
|
||||
@@ -69,6 +78,19 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
||||
);
|
||||
return Ok(ToolCallDecision::ReplaceInput(serde_json::json!({})));
|
||||
}
|
||||
|
||||
// Check consecutive failure count — abort if too many failures
|
||||
let failures = self.consecutive_failures.load(Ordering::SeqCst);
|
||||
if failures >= self.max_consecutive_failures {
|
||||
tracing::warn!(
|
||||
"[ToolErrorMiddleware] Aborting loop: {} consecutive tool failures",
|
||||
failures
|
||||
);
|
||||
return Ok(ToolCallDecision::AbortLoop(
|
||||
format!("连续 {} 次工具调用失败,已自动终止以避免无限重试", failures)
|
||||
));
|
||||
}
|
||||
|
||||
Ok(ToolCallDecision::Allow)
|
||||
}
|
||||
|
||||
@@ -80,12 +102,12 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
||||
) -> Result<()> {
|
||||
// Check if the tool result indicates an error.
|
||||
if let Some(error) = result.get("error") {
|
||||
let failures = self.consecutive_failures.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
let error_msg = match error {
|
||||
Value::String(s) => s.clone(),
|
||||
other => other.to_string(),
|
||||
};
|
||||
let truncated = if error_msg.len() > self.max_error_length {
|
||||
// Use char-boundary-safe truncation to avoid panic on UTF-8 strings (e.g. Chinese)
|
||||
let end = error_msg.floor_char_boundary(self.max_error_length);
|
||||
format!("{}...(truncated)", &error_msg[..end])
|
||||
} else {
|
||||
@@ -93,19 +115,19 @@ impl AgentMiddleware for ToolErrorMiddleware {
|
||||
};
|
||||
|
||||
tracing::warn!(
|
||||
"[ToolErrorMiddleware] Tool '{}' failed: {}",
|
||||
tool_name, truncated
|
||||
"[ToolErrorMiddleware] Tool '{}' failed ({}/{} consecutive): {}",
|
||||
tool_name, failures, self.max_consecutive_failures, truncated
|
||||
);
|
||||
|
||||
// Build a guided recovery message so the LLM can self-correct.
|
||||
let guided_message = self.format_tool_error(tool_name, &truncated);
|
||||
|
||||
// Inject into response_content so the agent loop feeds this back
|
||||
// to the LLM alongside the raw tool result.
|
||||
ctx.response_content.push(ContentBlock::Text {
|
||||
text: guided_message,
|
||||
});
|
||||
} else {
|
||||
// Success — reset consecutive failure counter
|
||||
self.consecutive_failures.store(0, Ordering::SeqCst);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,14 +68,14 @@ const PERIOD: &str = "(凌晨|早上|早晨|上午|中午|下午|午后|傍晚|
|
||||
// extract_task_description
|
||||
static RE_TIME_STRIP: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(
|
||||
r"^(?:凌晨|早上|早晨|上午|中午|下午|午后|傍晚|黄昏|晚上|晚间|夜里|夜晚|半夜|午夜)?\d{1,2}[点时::]\d{0,2}分?"
|
||||
r"^(?:凌晨|早上|早晨|上午|中午|下午|午后|傍晚|黄昏|晚上|晚间|夜里|夜晚|半夜|午夜)?\d{1,2}[点时::](?:\d{1,2}分?|半)?"
|
||||
).expect("static regex pattern is valid")
|
||||
});
|
||||
|
||||
// try_every_day
|
||||
static RE_EVERY_DAY_EXACT: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(&format!(
|
||||
r"(?:每天|每日)(?:的)?{}(\d{{1,2}})[点时::](\d{{1,2}})?",
|
||||
r"(?:每天|每日)(?:的)?{}(\d{{1,2}})[点时::](?:(\d{{1,2}})|(半))?",
|
||||
PERIOD
|
||||
)).expect("static regex pattern is valid")
|
||||
});
|
||||
@@ -89,15 +89,15 @@ static RE_EVERY_DAY_PERIOD: LazyLock<Regex> = LazyLock::new(|| {
|
||||
// try_every_week
|
||||
static RE_EVERY_WEEK: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(&format!(
|
||||
r"(?:每周|每个?星期|每个?礼拜)(一|二|三|四|五|六|日|天|周一|周二|周三|周四|周五|周六|周日|周天|星期一|星期二|星期三|星期四|星期五|星期六|星期日|星期天|礼拜一|礼拜二|礼拜三|礼拜四|礼拜五|礼拜六|礼拜日|礼拜天)(?:的)?{}(\d{{1,2}})[点时::](\d{{1,2}})?",
|
||||
r"(?:每周|每个?星期|每个?礼拜)(一|二|三|四|五|六|日|天|周一|周二|周三|周四|周五|周六|周日|周天|星期一|星期二|星期三|星期四|星期五|星期六|星期日|星期天|礼拜一|礼拜二|礼拜三|礼拜四|礼拜五|礼拜六|礼拜日|礼拜天)(?:的)?{}(\d{{1,2}})[点时::](?:(\d{{1,2}})|(半))?",
|
||||
PERIOD
|
||||
)).expect("static regex pattern is valid")
|
||||
});
|
||||
|
||||
// try_workday
|
||||
// try_workday — also matches "工作日每天..." and "工作日每日..."
|
||||
static RE_WORKDAY_EXACT: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(&format!(
|
||||
r"(?:工作日|每个?工作日|工作日(?:的)?){}(\d{{1,2}})[点时::](\d{{1,2}})?",
|
||||
r"(?:工作日|每个?工作日)(?:每天|每日)?(?:的)?{}(\d{{1,2}})[点时::](?:(\d{{1,2}})|(半))?",
|
||||
PERIOD
|
||||
)).expect("static regex pattern is valid")
|
||||
});
|
||||
@@ -113,10 +113,15 @@ static RE_INTERVAL: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"每(\d{1,2})(小时|分钟|分|钟|个小时)").expect("static regex pattern is valid")
|
||||
});
|
||||
|
||||
// try_relative_delay — "X秒后", "X分钟后", "X小时后"
|
||||
static RE_RELATIVE_DELAY: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r"(\d{1,3})\s*(秒|秒钟|分钟|分|小时|个?小时)后").expect("static regex pattern is valid")
|
||||
});
|
||||
|
||||
// try_monthly
|
||||
static RE_MONTHLY: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(&format!(
|
||||
r"(?:每月|每个月)(?:的)?(\d{{1,2}})[号日](?:的)?{}(\d{{1,2}})?[点时::]?(\d{{1,2}})?",
|
||||
r"(?:每月|每个月)(?:的)?(\d{{1,2}})[号日](?:的)?{}(\d{{1,2}})?[点时::]?(?:(\d{{1,2}})|(半))?",
|
||||
PERIOD
|
||||
)).expect("static regex pattern is valid")
|
||||
});
|
||||
@@ -124,7 +129,16 @@ static RE_MONTHLY: LazyLock<Regex> = LazyLock::new(|| {
|
||||
// try_one_shot
|
||||
static RE_ONE_SHOT: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(&format!(
|
||||
r"(明天|后天|大后天)(?:的)?{}(\d{{1,2}})[点时::](\d{{1,2}})?",
|
||||
r"(明天|后天|大后天)(?:的)?{}(\d{{1,2}})[点时::](?:(\d{{1,2}})|(半))?",
|
||||
PERIOD
|
||||
)).expect("static regex pattern is valid")
|
||||
});
|
||||
|
||||
/// Matches same-day one-shot triggers: "下午3点半提醒我..." or "上午10点提醒我..."
|
||||
/// Pattern: period + time + "提醒我" (no date prefix — implied today)
|
||||
static RE_ONE_SHOT_TODAY: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(&format!(
|
||||
r"^{}(\d{{1,2}})[点时::](?:(\d{{1,2}})|(半))?.*提醒我",
|
||||
PERIOD
|
||||
)).expect("static regex pattern is valid")
|
||||
});
|
||||
@@ -194,15 +208,16 @@ pub fn parse_nl_schedule(input: &str, default_agent_id: &AgentId) -> SchedulePar
|
||||
|
||||
let task_description = extract_task_description(input);
|
||||
|
||||
// Try workday BEFORE every_day, so "工作日每天..." matches workday first
|
||||
if let Some(result) = try_workday(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
if let Some(result) = try_every_day(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
if let Some(result) = try_every_week(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
if let Some(result) = try_workday(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
if let Some(result) = try_interval(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
@@ -212,6 +227,9 @@ pub fn parse_nl_schedule(input: &str, default_agent_id: &AgentId) -> SchedulePar
|
||||
if let Some(result) = try_one_shot(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
if let Some(result) = try_relative_delay(input, &task_description, default_agent_id) {
|
||||
return result;
|
||||
}
|
||||
|
||||
ScheduleParseResult::Unclear
|
||||
}
|
||||
@@ -248,11 +266,21 @@ fn extract_task_description(input: &str) -> String {
|
||||
|
||||
// -- Pattern matchers (all use pre-compiled statics) --
|
||||
|
||||
/// Extract minute value from a regex capture group that may be a digit string or "半".
|
||||
/// Group 3 is the digit capture, group 4 is absent (used when "半" matches instead).
|
||||
fn extract_minute(caps: ®ex::Captures, digit_group: usize, han_group: usize) -> u32 {
|
||||
// Check if the "半" (half) group matched
|
||||
if caps.get(han_group).is_some() {
|
||||
return 30;
|
||||
}
|
||||
caps.get(digit_group).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0)
|
||||
}
|
||||
|
||||
fn try_every_day(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<ScheduleParseResult> {
|
||||
if let Some(caps) = RE_EVERY_DAY_EXACT.captures(input) {
|
||||
let period = caps.get(1).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(2)?.as_str().parse().ok()?;
|
||||
let minute: u32 = caps.get(3).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0);
|
||||
let minute: u32 = extract_minute(&caps, 3, 4);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if hour > 23 || minute > 59 {
|
||||
return None;
|
||||
@@ -288,7 +316,7 @@ fn try_every_week(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<Sc
|
||||
let dow = weekday_to_cron(day_str)?;
|
||||
let period = caps.get(2).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(3)?.as_str().parse().ok()?;
|
||||
let minute: u32 = caps.get(4).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0);
|
||||
let minute: u32 = extract_minute(&caps, 4, 5);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if hour > 23 || minute > 59 {
|
||||
return None;
|
||||
@@ -307,7 +335,7 @@ fn try_workday(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<Sched
|
||||
if let Some(caps) = RE_WORKDAY_EXACT.captures(input) {
|
||||
let period = caps.get(1).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(2)?.as_str().parse().ok()?;
|
||||
let minute: u32 = caps.get(3).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0);
|
||||
let minute: u32 = extract_minute(&caps, 3, 4);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if hour > 23 || minute > 59 {
|
||||
return None;
|
||||
@@ -366,7 +394,7 @@ fn try_monthly(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<Sched
|
||||
let day: u32 = caps.get(1)?.as_str().parse().ok()?;
|
||||
let period = caps.get(2).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(3).map(|m| m.as_str().parse().unwrap_or(9)).unwrap_or(9);
|
||||
let minute: u32 = caps.get(4).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0);
|
||||
let minute: u32 = extract_minute(&caps, 4, 5);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if day > 31 || hour > 23 || minute > 59 {
|
||||
return None;
|
||||
@@ -384,35 +412,95 @@ fn try_monthly(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<Sched
|
||||
}
|
||||
|
||||
fn try_one_shot(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<ScheduleParseResult> {
|
||||
let caps = RE_ONE_SHOT.captures(input)?;
|
||||
let day_offset = match caps.get(1)?.as_str() {
|
||||
"明天" => 1,
|
||||
"后天" => 2,
|
||||
"大后天" => 3,
|
||||
_ => return None,
|
||||
};
|
||||
let period = caps.get(2).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(3)?.as_str().parse().ok()?;
|
||||
let minute: u32 = caps.get(4).map(|m| m.as_str().parse().unwrap_or(0)).unwrap_or(0);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if hour > 23 || minute > 59 {
|
||||
// First try explicit date prefix: 明天/后天/大后天 + time
|
||||
if let Some(caps) = RE_ONE_SHOT.captures(input) {
|
||||
let day_offset = match caps.get(1)?.as_str() {
|
||||
"明天" => 1,
|
||||
"后天" => 2,
|
||||
"大后天" => 3,
|
||||
_ => return None,
|
||||
};
|
||||
let period = caps.get(2).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(3)?.as_str().parse().ok()?;
|
||||
let minute: u32 = extract_minute(&caps, 4, 5);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if hour > 23 || minute > 59 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let target = chrono::Utc::now()
|
||||
.checked_add_signed(chrono::Duration::days(day_offset))
|
||||
.unwrap_or_else(chrono::Utc::now)
|
||||
.with_hour(hour)
|
||||
.unwrap_or_else(|| chrono::Utc::now())
|
||||
.with_minute(minute)
|
||||
.unwrap_or_else(|| chrono::Utc::now())
|
||||
.with_second(0)
|
||||
.unwrap_or_else(|| chrono::Utc::now());
|
||||
|
||||
return Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: target.to_rfc3339(),
|
||||
natural_description: format!("{} {:02}:{:02}", caps.get(1)?.as_str(), hour, minute),
|
||||
confidence: 0.88,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
// Then try same-day implicit: "下午3点半提醒我..." (no date prefix)
|
||||
if let Some(caps) = RE_ONE_SHOT_TODAY.captures(input) {
|
||||
let period = caps.get(1).map(|m| m.as_str());
|
||||
let raw_hour: u32 = caps.get(2)?.as_str().parse().ok()?;
|
||||
let minute: u32 = extract_minute(&caps, 3, 4);
|
||||
let hour = adjust_hour_for_period(raw_hour, period);
|
||||
if hour > 23 || minute > 59 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let target = chrono::Utc::now()
|
||||
.with_hour(hour)
|
||||
.unwrap_or_else(|| chrono::Utc::now())
|
||||
.with_minute(minute)
|
||||
.unwrap_or_else(|| chrono::Utc::now())
|
||||
.with_second(0)
|
||||
.unwrap_or_else(|| chrono::Utc::now());
|
||||
|
||||
let period_desc = period.unwrap_or("");
|
||||
return Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: target.to_rfc3339(),
|
||||
natural_description: format!("今天{} {:02}:{:02}", period_desc, hour, minute),
|
||||
confidence: 0.82,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Parse relative delay expressions like "10秒后", "5分钟后", "2小时后".
|
||||
/// Converts to ISO-8601 timestamp from now.
|
||||
fn try_relative_delay(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<ScheduleParseResult> {
|
||||
let caps = RE_RELATIVE_DELAY.captures(input)?;
|
||||
let amount: i64 = caps.get(1)?.as_str().parse().ok()?;
|
||||
if amount <= 0 || amount > 999 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let target = chrono::Utc::now()
|
||||
.checked_add_signed(chrono::Duration::days(day_offset))
|
||||
.unwrap_or_else(chrono::Utc::now)
|
||||
.with_hour(hour)
|
||||
.unwrap_or_else(|| chrono::Utc::now())
|
||||
.with_minute(minute)
|
||||
.unwrap_or_else(|| chrono::Utc::now())
|
||||
.with_second(0)
|
||||
.unwrap_or_else(|| chrono::Utc::now());
|
||||
let unit = caps.get(2)?.as_str();
|
||||
let (seconds, desc_unit) = match unit {
|
||||
"秒" | "秒钟" => (amount, "秒"),
|
||||
"分钟" | "分" => (amount * 60, "分钟"),
|
||||
"小时" | "个小时" => (amount * 3600, "小时"),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let target = chrono::Utc::now() + chrono::Duration::seconds(seconds);
|
||||
|
||||
Some(ScheduleParseResult::Exact(ParsedSchedule {
|
||||
cron_expression: target.to_rfc3339(),
|
||||
natural_description: format!("{} {:02}:{:02}", caps.get(1)?.as_str(), hour, minute),
|
||||
confidence: 0.88,
|
||||
natural_description: format!("{}{}后", amount, desc_unit),
|
||||
confidence: 0.92,
|
||||
task_description: task_desc.to_string(),
|
||||
task_target: TaskTarget::Agent(agent_id.to_string()),
|
||||
}))
|
||||
@@ -426,7 +514,7 @@ fn try_one_shot(input: &str, task_desc: &str, agent_id: &AgentId) -> Option<Sche
|
||||
const SCHEDULE_INTENT_KEYWORDS: &[&str] = &[
|
||||
"提醒我", "提醒", "定时", "每天", "每日", "每周", "每月",
|
||||
"工作日", "每隔", "每", "定期", "到时候", "准时",
|
||||
"闹钟", "闹铃", "日程", "日历",
|
||||
"闹钟", "闹铃", "日程", "日历", "秒后", "分钟后", "小时后",
|
||||
];
|
||||
|
||||
/// Check if user input contains schedule intent.
|
||||
@@ -604,4 +692,115 @@ mod tests {
|
||||
fn test_task_description_extraction() {
|
||||
assert_eq!(extract_task_description("每天早上9点提醒我查房"), "查房");
|
||||
}
|
||||
|
||||
// --- New tests for BUG-3 (半) and BUG-4 (工作日每天) ---
|
||||
|
||||
#[test]
|
||||
fn test_every_day_half_hour() {
|
||||
// "8点半" should parse as 08:30
|
||||
let result = parse_nl_schedule("每天早上8点半提醒我打卡", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "30 8 * * *");
|
||||
}
|
||||
_ => panic!("Expected Exact, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_every_day_afternoon_half() {
|
||||
// "下午3点半" should parse as 15:30
|
||||
let result = parse_nl_schedule("每天下午3点半提醒我", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "30 15 * * *");
|
||||
}
|
||||
_ => panic!("Expected Exact, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workday_with_every_day_prefix() {
|
||||
// "工作日每天早上8点半" should parse as weekday 08:30 with 1-5
|
||||
let result = parse_nl_schedule("工作日每天早上8点半提醒我打卡", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "30 8 * * 1-5");
|
||||
}
|
||||
_ => panic!("Expected Exact, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_workday_half_hour() {
|
||||
// "工作日下午5点半" should parse as weekday 17:30
|
||||
let result = parse_nl_schedule("工作日下午5点半提醒我写周报", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "30 17 * * 1-5");
|
||||
}
|
||||
_ => panic!("Expected Exact, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_every_week_half_hour() {
|
||||
// "每周一下午3点半" should parse as 15:30 on Monday
|
||||
let result = parse_nl_schedule("每周一下午3点半提醒我开会", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert_eq!(s.cron_expression, "30 15 * * 1");
|
||||
}
|
||||
_ => panic!("Expected Exact, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_one_shot_half_hour() {
|
||||
// "明天早上9点半" should parse as tomorrow 09:30
|
||||
let result = parse_nl_schedule("明天早上9点半提醒我开会", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
// Should contain the time in ISO format
|
||||
assert!(s.cron_expression.contains("T09:30:"));
|
||||
}
|
||||
_ => panic!("Expected Exact, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relative_delay_seconds() {
|
||||
let result = parse_nl_schedule("30秒后提醒我", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert!(s.natural_description.contains("30秒"));
|
||||
assert!(s.confidence >= 0.9);
|
||||
}
|
||||
_ => panic!("Expected Exact, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relative_delay_minutes() {
|
||||
let result = parse_nl_schedule("5分钟后提醒我喝水", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert!(s.natural_description.contains("5分钟"));
|
||||
// task_description preserves the original text minus schedule keywords
|
||||
assert!(s.task_description.contains("喝水"));
|
||||
}
|
||||
_ => panic!("Expected Exact, got {:?}", result),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_relative_delay_hours() {
|
||||
let result = parse_nl_schedule("2小时后提醒我开会", &default_agent());
|
||||
match result {
|
||||
ScheduleParseResult::Exact(s) => {
|
||||
assert!(s.natural_description.contains("2小时"));
|
||||
}
|
||||
_ => panic!("Expected Exact, got {:?}", result),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,10 @@ pub enum StreamChunk {
|
||||
input_tokens: u32,
|
||||
output_tokens: u32,
|
||||
stop_reason: String,
|
||||
#[serde(default)]
|
||||
cache_creation_input_tokens: Option<u32>,
|
||||
#[serde(default)]
|
||||
cache_read_input_tokens: Option<u32>,
|
||||
},
|
||||
/// Error occurred
|
||||
Error { message: String },
|
||||
|
||||
216
crates/zclaw-runtime/src/test_util.rs
Normal file
216
crates/zclaw-runtime/src/test_util.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
//! Shared test utilities for zclaw-runtime and dependent crates.
|
||||
//!
|
||||
//! Provides `MockLlmDriver` — a controllable LLM driver for offline testing.
|
||||
|
||||
use crate::driver::{
|
||||
CompletionRequest, CompletionResponse, ContentBlock, LlmDriver, StopReason,
|
||||
};
|
||||
use crate::stream::StreamChunk;
|
||||
use async_trait::async_trait;
|
||||
use futures::{Stream, StreamExt};
|
||||
use serde_json::Value;
|
||||
use std::collections::VecDeque;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use zclaw_types::Result;
|
||||
use zclaw_types::ZclawError;
|
||||
|
||||
/// Thread-safe mock LLM driver for testing.
|
||||
///
|
||||
/// # Usage
|
||||
/// ```ignore
|
||||
/// let mock = MockLlmDriver::new()
|
||||
/// .with_text_response("Hello!")
|
||||
/// .with_text_response("How can I help?");
|
||||
///
|
||||
/// let resp = mock.complete(request).await?;
|
||||
/// assert_eq!(resp.content_text(), "Hello!");
|
||||
/// ```
|
||||
pub struct MockLlmDriver {
|
||||
responses: Arc<Mutex<VecDeque<CompletionResponse>>>,
|
||||
stream_chunks: Arc<Mutex<VecDeque<Vec<StreamChunk>>>>,
|
||||
call_count: AtomicUsize,
|
||||
last_request: Arc<Mutex<Option<CompletionRequest>>>,
|
||||
/// If true, `complete()` returns an error instead of a response.
|
||||
fail_mode: Arc<Mutex<bool>>,
|
||||
}
|
||||
|
||||
impl MockLlmDriver {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
responses: Arc::new(Mutex::new(VecDeque::new())),
|
||||
stream_chunks: Arc::new(Mutex::new(VecDeque::new())),
|
||||
call_count: AtomicUsize::new(0),
|
||||
last_request: Arc::new(Mutex::new(None)),
|
||||
fail_mode: Arc::new(Mutex::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Queue a text response.
|
||||
pub fn with_text_response(mut self, text: &str) -> Self {
|
||||
self.push_response(CompletionResponse {
|
||||
content: vec![ContentBlock::Text { text: text.to_string() }],
|
||||
model: "mock-model".to_string(),
|
||||
input_tokens: 10,
|
||||
output_tokens: text.len() as u32 / 4,
|
||||
stop_reason: StopReason::EndTurn,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Queue a response with tool calls.
|
||||
pub fn with_tool_call(mut self, tool_name: &str, args: Value) -> Self {
|
||||
self.push_response(CompletionResponse {
|
||||
content: vec![
|
||||
ContentBlock::Text { text: format!("Calling {}", tool_name) },
|
||||
ContentBlock::ToolUse {
|
||||
id: format!("call_{}", self.call_count()),
|
||||
name: tool_name.to_string(),
|
||||
input: args,
|
||||
},
|
||||
],
|
||||
model: "mock-model".to_string(),
|
||||
input_tokens: 10,
|
||||
output_tokens: 20,
|
||||
stop_reason: StopReason::ToolUse,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Queue an error response.
|
||||
pub fn with_error(mut self, _error: &str) -> Self {
|
||||
self.push_response(CompletionResponse {
|
||||
content: vec![],
|
||||
model: "mock-model".to_string(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
stop_reason: StopReason::Error,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Queue a raw response.
|
||||
pub fn with_response(mut self, response: CompletionResponse) -> Self {
|
||||
self.push_response(response);
|
||||
self
|
||||
}
|
||||
|
||||
/// Queue stream chunks for a streaming call.
|
||||
pub fn with_stream_chunks(self, chunks: Vec<StreamChunk>) -> Self {
|
||||
self.stream_chunks
|
||||
.lock()
|
||||
.expect("stream_chunks lock")
|
||||
.push_back(chunks);
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable fail mode — all `complete()` calls return an error.
|
||||
pub fn set_fail_mode(&self, fail: bool) {
|
||||
*self.fail_mode.lock().expect("fail_mode lock") = fail;
|
||||
}
|
||||
|
||||
/// Number of times `complete()` was called.
|
||||
pub fn call_count(&self) -> usize {
|
||||
self.call_count.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/// Inspect the last request sent to the driver.
|
||||
pub fn last_request(&self) -> Option<CompletionRequest> {
|
||||
self.last_request
|
||||
.lock()
|
||||
.expect("last_request lock")
|
||||
.clone()
|
||||
}
|
||||
|
||||
fn push_response(&mut self, resp: CompletionResponse) {
|
||||
self.responses
|
||||
.lock()
|
||||
.expect("responses lock")
|
||||
.push_back(resp);
|
||||
}
|
||||
|
||||
fn next_response(&self) -> CompletionResponse {
|
||||
let mut queue = self.responses.lock().expect("responses lock");
|
||||
queue
|
||||
.pop_front()
|
||||
.unwrap_or_else(|| CompletionResponse {
|
||||
content: vec![ContentBlock::Text {
|
||||
text: "mock default response".to_string(),
|
||||
}],
|
||||
model: "mock-model".to_string(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
stop_reason: StopReason::EndTurn,
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MockLlmDriver {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmDriver for MockLlmDriver {
|
||||
fn provider(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse> {
|
||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
*self.last_request.lock().expect("last_request lock") = Some(request);
|
||||
|
||||
if *self.fail_mode.lock().expect("fail_mode lock") {
|
||||
return Err(ZclawError::LlmError("mock driver fail mode".to_string()));
|
||||
}
|
||||
|
||||
Ok(self.next_response())
|
||||
}
|
||||
|
||||
fn stream(
|
||||
&self,
|
||||
request: CompletionRequest,
|
||||
) -> Pin<Box<dyn Stream<Item = Result<StreamChunk>> + Send + '_>> {
|
||||
self.call_count.fetch_add(1, Ordering::SeqCst);
|
||||
*self.last_request.lock().expect("last_request lock") = Some(request);
|
||||
|
||||
let chunks: Vec<Result<StreamChunk>> = self
|
||||
.stream_chunks
|
||||
.lock()
|
||||
.expect("stream_chunks lock")
|
||||
.pop_front()
|
||||
.unwrap_or_else(|| {
|
||||
vec![
|
||||
StreamChunk::TextDelta {
|
||||
delta: "mock stream".to_string(),
|
||||
},
|
||||
StreamChunk::Complete {
|
||||
input_tokens: 10,
|
||||
output_tokens: 2,
|
||||
stop_reason: "end_turn".to_string(),
|
||||
cache_creation_input_tokens: None,
|
||||
cache_read_input_tokens: None,
|
||||
},
|
||||
]
|
||||
})
|
||||
.into_iter()
|
||||
.map(Ok)
|
||||
.collect();
|
||||
|
||||
futures::stream::iter(chunks).boxed()
|
||||
}
|
||||
|
||||
fn is_configured(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
@@ -11,6 +11,17 @@ use crate::driver::ToolDefinition;
|
||||
use crate::loop_runner::LoopEvent;
|
||||
use crate::tool::builtin::PathValidator;
|
||||
|
||||
/// Tool concurrency safety level
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ToolConcurrency {
|
||||
/// Read-only operations, always safe to parallelize (file_read, web_fetch, etc.)
|
||||
ReadOnly,
|
||||
/// Exclusive operations, must be serial (file_write, shell_exec, etc.)
|
||||
Exclusive,
|
||||
/// Interactive operations, never parallelize (ask_clarification, etc.)
|
||||
Interactive,
|
||||
}
|
||||
|
||||
/// Tool trait for implementing agent tools
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync {
|
||||
@@ -25,6 +36,11 @@ pub trait Tool: Send + Sync {
|
||||
|
||||
/// Execute the tool
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value>;
|
||||
|
||||
/// Tool concurrency safety level. Default: ReadOnly.
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::ReadOnly
|
||||
}
|
||||
}
|
||||
|
||||
/// Skill executor trait for runtime skill execution
|
||||
@@ -74,12 +90,27 @@ pub struct SkillDetail {
|
||||
pub capabilities: Vec<String>,
|
||||
}
|
||||
|
||||
/// Hand executor trait for runtime hand execution
|
||||
/// This allows tools (HandTool) to execute hands without direct dependency on zclaw-hands
|
||||
#[async_trait]
|
||||
pub trait HandExecutor: Send + Sync {
|
||||
/// Execute a hand by ID, returning the output as JSON
|
||||
async fn execute_hand(
|
||||
&self,
|
||||
hand_id: &str,
|
||||
agent_id: &AgentId,
|
||||
input: Value,
|
||||
) -> Result<Value>;
|
||||
}
|
||||
|
||||
/// Context provided to tool execution
|
||||
pub struct ToolContext {
|
||||
pub agent_id: AgentId,
|
||||
pub working_directory: Option<String>,
|
||||
pub session_id: Option<String>,
|
||||
pub skill_executor: Option<Arc<dyn SkillExecutor>>,
|
||||
/// Hand executor for dispatching Hand tool calls to the HandRegistry
|
||||
pub hand_executor: Option<Arc<dyn HandExecutor>>,
|
||||
/// Path validator for file system operations
|
||||
pub path_validator: Option<PathValidator>,
|
||||
/// Optional event sender for streaming tool progress to the frontend.
|
||||
@@ -94,6 +125,7 @@ impl std::fmt::Debug for ToolContext {
|
||||
.field("working_directory", &self.working_directory)
|
||||
.field("session_id", &self.session_id)
|
||||
.field("skill_executor", &self.skill_executor.as_ref().map(|_| "SkillExecutor"))
|
||||
.field("hand_executor", &self.hand_executor.as_ref().map(|_| "HandExecutor"))
|
||||
.field("path_validator", &self.path_validator.as_ref().map(|_| "PathValidator"))
|
||||
.field("event_sender", &self.event_sender.as_ref().map(|_| "Sender<LoopEvent>"))
|
||||
.finish()
|
||||
@@ -107,6 +139,7 @@ impl Clone for ToolContext {
|
||||
working_directory: self.working_directory.clone(),
|
||||
session_id: self.session_id.clone(),
|
||||
skill_executor: self.skill_executor.clone(),
|
||||
hand_executor: self.hand_executor.clone(),
|
||||
path_validator: self.path_validator.clone(),
|
||||
event_sender: self.event_sender.clone(),
|
||||
}
|
||||
@@ -191,3 +224,4 @@ impl Default for ToolRegistry {
|
||||
|
||||
// Built-in tools module
|
||||
pub mod builtin;
|
||||
pub mod hand_tool;
|
||||
|
||||
@@ -9,7 +9,7 @@ use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
|
||||
/// Clarification type — categorizes the reason for asking.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@@ -96,6 +96,10 @@ impl Tool for AskClarificationTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Interactive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
|
||||
let question = input["question"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'question' parameter".into()))?;
|
||||
|
||||
@@ -4,7 +4,7 @@ use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
|
||||
pub struct ExecuteSkillTool;
|
||||
|
||||
@@ -42,6 +42,10 @@ impl Tool for ExecuteSkillTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
let skill_id = input["skill_id"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'skill_id' parameter".into()))?;
|
||||
|
||||
@@ -139,6 +139,7 @@ mod tests {
|
||||
working_directory: None,
|
||||
session_id: None,
|
||||
skill_executor: None,
|
||||
hand_executor: None,
|
||||
path_validator,
|
||||
event_sender: None,
|
||||
};
|
||||
|
||||
@@ -6,7 +6,7 @@ use zclaw_types::{Result, ZclawError};
|
||||
use std::fs;
|
||||
use std::io::Write;
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
use super::path_validator::PathValidator;
|
||||
|
||||
pub struct FileWriteTool;
|
||||
@@ -55,6 +55,10 @@ impl Tool for FileWriteTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
let path = input["path"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'path' parameter".into()))?;
|
||||
@@ -162,6 +166,7 @@ mod tests {
|
||||
working_directory: None,
|
||||
session_id: None,
|
||||
skill_executor: None,
|
||||
hand_executor: None,
|
||||
path_validator,
|
||||
event_sender: None,
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use serde_json::Value;
|
||||
use std::sync::Arc;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
|
||||
/// Wraps an MCP tool adapter into the `Tool` trait.
|
||||
///
|
||||
@@ -42,6 +42,10 @@ impl Tool for McpToolWrapper {
|
||||
self.adapter.input_schema().clone()
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
|
||||
self.adapter.execute(input).await
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use std::process::{Command, Stdio};
|
||||
use std::time::{Duration, Instant};
|
||||
use zclaw_types::{Result, ZclawError};
|
||||
|
||||
use crate::tool::{Tool, ToolContext};
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
|
||||
/// Parse a command string into program and arguments using proper shell quoting
|
||||
fn parse_command(command: &str) -> Result<(String, Vec<String>)> {
|
||||
@@ -175,6 +175,10 @@ impl Tool for ShellExecTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, _context: &ToolContext) -> Result<Value> {
|
||||
let command = input["command"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'command' parameter".into()))?;
|
||||
|
||||
@@ -11,7 +11,7 @@ use zclaw_memory::MemoryStore;
|
||||
|
||||
use crate::driver::LlmDriver;
|
||||
use crate::loop_runner::{AgentLoop, LoopEvent};
|
||||
use crate::tool::{Tool, ToolContext, ToolRegistry};
|
||||
use crate::tool::{Tool, ToolContext, ToolRegistry, ToolConcurrency};
|
||||
use crate::tool::builtin::register_builtin_tools;
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -91,6 +91,10 @@ impl Tool for TaskTool {
|
||||
})
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
let description = input["description"].as_str()
|
||||
.ok_or_else(|| ZclawError::InvalidInput("Missing 'description' parameter".into()))?;
|
||||
|
||||
159
crates/zclaw-runtime/src/tool/hand_tool.rs
Normal file
159
crates/zclaw-runtime/src/tool/hand_tool.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
//! Hand Tool Wrapper
|
||||
//!
|
||||
//! Bridges the Hand trait (zclaw-hands) to the Tool trait (zclaw-runtime),
|
||||
//! allowing Hands to be registered in the ToolRegistry and callable by the LLM.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use zclaw_types::Result;
|
||||
|
||||
use crate::tool::{Tool, ToolContext, ToolConcurrency};
|
||||
|
||||
/// Wrapper that exposes a Hand as a Tool in the agent's tool registry.
|
||||
///
|
||||
/// When the LLM calls `hand_quiz`, `hand_researcher`, etc., the call is
|
||||
/// routed through this wrapper to the actual Hand implementation.
|
||||
pub struct HandTool {
|
||||
/// Hand identifier (e.g., "hand_quiz", "hand_researcher")
|
||||
name: String,
|
||||
/// Human-readable description
|
||||
description: String,
|
||||
/// Input JSON schema
|
||||
input_schema: Value,
|
||||
/// Hand ID for registry lookup
|
||||
hand_id: String,
|
||||
}
|
||||
|
||||
impl HandTool {
|
||||
/// Create a new HandTool wrapper from hand metadata.
|
||||
pub fn new(
|
||||
tool_name: &str,
|
||||
description: &str,
|
||||
input_schema: Value,
|
||||
hand_id: &str,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: tool_name.to_string(),
|
||||
description: description.to_string(),
|
||||
input_schema,
|
||||
hand_id: hand_id.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a HandTool from HandConfig fields.
|
||||
pub fn from_config(hand_id: &str, description: &str, input_schema: Option<Value>) -> Self {
|
||||
let tool_name = format!("hand_{}", hand_id);
|
||||
let schema = input_schema.unwrap_or_else(|| {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input": {
|
||||
"type": "string",
|
||||
"description": format!("Input for the {} hand", hand_id)
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
})
|
||||
});
|
||||
Self::new(&tool_name, description, schema, hand_id)
|
||||
}
|
||||
|
||||
/// Get the hand ID for registry lookup
|
||||
pub fn hand_id(&self) -> &str {
|
||||
&self.hand_id
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for HandTool {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
&self.description
|
||||
}
|
||||
|
||||
fn input_schema(&self) -> Value {
|
||||
self.input_schema.clone()
|
||||
}
|
||||
|
||||
fn concurrency(&self) -> ToolConcurrency {
|
||||
ToolConcurrency::Exclusive
|
||||
}
|
||||
|
||||
async fn execute(&self, input: Value, context: &ToolContext) -> Result<Value> {
|
||||
// Delegate to the HandExecutor (bridged from HandRegistry via kernel).
|
||||
// If no hand_executor is available (e.g., standalone runtime without kernel),
|
||||
// return a descriptive error so the LLM knows the hand is unavailable.
|
||||
match &context.hand_executor {
|
||||
Some(executor) => {
|
||||
executor.execute_hand(&self.hand_id, &context.agent_id, input).await
|
||||
}
|
||||
None => {
|
||||
Ok(json!({
|
||||
"hand_id": self.hand_id,
|
||||
"status": "unavailable",
|
||||
"error": format!(
|
||||
"Hand '{}' cannot execute: no hand executor configured. \
|
||||
This usually means the kernel is not running or hands are not registered.",
|
||||
self.hand_id
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hand_tool_creation() {
|
||||
let tool = HandTool::from_config(
|
||||
"quiz",
|
||||
"Generate quizzes on various topics",
|
||||
None,
|
||||
);
|
||||
assert_eq!(tool.name(), "hand_quiz");
|
||||
assert_eq!(tool.hand_id(), "quiz");
|
||||
assert!(tool.description().contains("quiz"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hand_tool_custom_schema() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"topic": { "type": "string" },
|
||||
"difficulty": { "type": "string" }
|
||||
}
|
||||
});
|
||||
let tool = HandTool::from_config(
|
||||
"quiz",
|
||||
"Generate quizzes",
|
||||
Some(schema.clone()),
|
||||
);
|
||||
assert_eq!(tool.input_schema(), schema);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_hand_tool_execute_no_executor() {
|
||||
let tool = HandTool::from_config("quiz", "Generate quizzes", None);
|
||||
let ctx = ToolContext {
|
||||
agent_id: zclaw_types::AgentId::new(),
|
||||
working_directory: None,
|
||||
session_id: None,
|
||||
skill_executor: None,
|
||||
hand_executor: None,
|
||||
path_validator: None,
|
||||
event_sender: None,
|
||||
};
|
||||
let result = tool.execute(json!({"topic": "Python"}), &ctx).await;
|
||||
assert!(result.is_ok());
|
||||
let val = result.unwrap();
|
||||
assert_eq!(val["hand_id"], "quiz");
|
||||
assert_eq!(val["status"], "unavailable");
|
||||
}
|
||||
}
|
||||
@@ -186,5 +186,8 @@ pub async fn create_agent_from_template(
|
||||
Path(id): Path<String>,
|
||||
) -> SaasResult<Json<AgentConfigFromTemplate>> {
|
||||
check_permission(&ctx, "model:read")?;
|
||||
Ok(Json(service::create_agent_from_template(&state.db, &id).await?))
|
||||
tracing::info!("[AgentTemplate] create_agent_from_template: id={}, account={}", id, ctx.account_id);
|
||||
let result = service::create_agent_from_template(&state.db, &id).await?;
|
||||
tracing::info!("[AgentTemplate] create_agent_from_template OK: name={}", result.name);
|
||||
Ok(Json(result))
|
||||
}
|
||||
|
||||
@@ -299,3 +299,68 @@ pub async fn seed_builtin_industries(pool: &PgPool) -> SaasResult<()> {
|
||||
tracing::info!("Seeded {} builtin industries", builtin_industries().len());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Auto-optimize industry config based on actual usage data.
|
||||
///
|
||||
/// Analyzes experience data for all agents under an account and updates
|
||||
/// `skill_priorities` and `pain_seed_categories` to reflect actual usage
|
||||
/// patterns rather than static configuration.
|
||||
pub async fn auto_optimize_config(
|
||||
pool: &sqlx::PgPool,
|
||||
account_id: i64,
|
||||
usage_signals: &std::collections::HashMap<String, u32>,
|
||||
) -> crate::Result<()> {
|
||||
// Find active industries for this account
|
||||
let industries: Vec<(String, serde_json::Value)> = sqlx::query_as(
|
||||
"SELECT i.id, i.skill_priorities FROM industries i
|
||||
JOIN account_industries ai ON ai.industry_id = i.id
|
||||
WHERE ai.account_id = $1 AND i.status = 'active'",
|
||||
)
|
||||
.bind(account_id)
|
||||
.fetch_all(pool)
|
||||
.await
|
||||
.map_err(crate::SaasError::from)?;
|
||||
|
||||
if industries.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Build updated skill_priorities based on actual usage
|
||||
let mut new_priorities: Vec<(String, i32)> = Vec::new();
|
||||
for (skill, count) in usage_signals {
|
||||
let priority = (*count as i32).min(10);
|
||||
if priority > 0 {
|
||||
new_priorities.push((skill.clone(), priority));
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by priority descending
|
||||
new_priorities.sort_by(|a, b| b.1.cmp(&a.1));
|
||||
|
||||
if new_priorities.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Update each linked industry's skill_priorities
|
||||
let priorities_json = serde_json::to_string(&new_priorities)
|
||||
.unwrap_or_else(|_| "[]".to_string());
|
||||
|
||||
for (industry_id, _old_priorities) in &industries {
|
||||
sqlx::query(
|
||||
"UPDATE industries SET skill_priorities = $1, updated_at = NOW() WHERE id = $2",
|
||||
)
|
||||
.bind(&priorities_json)
|
||||
.bind(industry_id)
|
||||
.execute(pool)
|
||||
.await
|
||||
.map_err(crate::SaasError::from)?;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
"[auto_optimize] Updated skill_priorities for {} industries under account {}",
|
||||
industries.len(),
|
||||
account_id,
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -28,3 +28,5 @@ pub mod telemetry;
|
||||
pub mod billing;
|
||||
pub mod industry;
|
||||
pub mod knowledge;
|
||||
|
||||
pub use error::{SaasError, SaasError as Error, SaasResult as Result};
|
||||
|
||||
@@ -142,13 +142,13 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
||||
return Ok(selection);
|
||||
}
|
||||
|
||||
// 所有 Key 都超限或无 Key — 先检查是否存在活跃 Key
|
||||
let has_any_key: Option<(bool,)> = sqlx::query_as(
|
||||
// 所有活跃 Key 都超限 — 先检查是否存在活跃 Key
|
||||
let has_any_active: Option<(bool,)> = sqlx::query_as(
|
||||
"SELECT COUNT(*) > 0 FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE"
|
||||
).bind(provider_id).fetch_optional(db).await?;
|
||||
|
||||
if has_any_key.is_some_and(|(b,)| b) {
|
||||
// 有 key 但全部 cooldown 或超限 — 检查最快恢复时间
|
||||
if has_any_active.is_some_and(|(b,)| b) {
|
||||
// 有活跃 key 但全部 cooldown 或超限 — 检查最快恢复时间
|
||||
let cooldown_row: Option<(String,)> = sqlx::query_as(
|
||||
"SELECT cooldown_until::TEXT FROM provider_keys
|
||||
WHERE provider_id = $1 AND is_active = TRUE AND cooldown_until IS NOT NULL AND cooldown_until::timestamptz > $2
|
||||
@@ -169,7 +169,79 @@ pub async fn select_best_key(db: &PgPool, provider_id: &str, enc_key: &[u8; 32])
|
||||
));
|
||||
}
|
||||
|
||||
Err(SaasError::NotFound(format!("Provider {} 没有可用的 API Key", provider_id)))
|
||||
// 没有活跃 Key — 自动恢复 cooldown 已过期但 is_active=false 的 Key
|
||||
let reactivated: Option<(i64,)> = sqlx::query_as(
|
||||
"UPDATE provider_keys SET is_active = TRUE, cooldown_until = NULL, updated_at = NOW()
|
||||
WHERE provider_id = $1 AND is_active = FALSE
|
||||
AND (cooldown_until IS NOT NULL AND cooldown_until::timestamptz <= $2)
|
||||
RETURNING (SELECT COUNT(*) FROM provider_keys WHERE provider_id = $1 AND is_active = TRUE)"
|
||||
).bind(provider_id).bind(&now).fetch_optional(db).await?;
|
||||
|
||||
if let Some((active_count,)) = &reactivated {
|
||||
if *active_count > 0 {
|
||||
tracing::info!(
|
||||
"Provider {} 自动恢复了 {} 个 cooldown 过期的 Key,重试选择",
|
||||
provider_id, active_count
|
||||
);
|
||||
invalidate_cache(provider_id);
|
||||
// 重试查询(不用递归,直接再走一次查询逻辑)
|
||||
let retry_rows: Vec<(String, String, i32, Option<i64>, Option<i64>, Option<i64>, Option<i64>)> =
|
||||
sqlx::query_as(
|
||||
"SELECT pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm,
|
||||
COALESCE(SUM(uw.request_count), 0)::bigint,
|
||||
COALESCE(SUM(uw.token_count), 0)::bigint
|
||||
FROM provider_keys pk
|
||||
LEFT JOIN key_usage_window uw ON pk.id = uw.key_id
|
||||
AND uw.window_minute >= to_char(NOW() - INTERVAL '1 minute', 'YYYY-MM-DDTHH24:MI')
|
||||
WHERE pk.provider_id = $1 AND pk.is_active = TRUE
|
||||
AND (pk.cooldown_until IS NULL OR pk.cooldown_until::timestamptz <= $2)
|
||||
GROUP BY pk.id, pk.key_value, pk.priority, pk.max_rpm, pk.max_tpm
|
||||
ORDER BY pk.priority ASC, pk.last_used_at ASC NULLS FIRST"
|
||||
).bind(provider_id).bind(&now).fetch_all(db).await?;
|
||||
|
||||
for (id, key_value, _priority, max_rpm, max_tpm, req_count, token_count) in &retry_rows {
|
||||
if let Some(rpm_limit) = max_rpm {
|
||||
if *rpm_limit > 0 && req_count.unwrap_or(0) >= *rpm_limit {
|
||||
tracing::debug!("[retry] Reactivated key {} hit RPM limit ({}/{})", id, req_count.unwrap_or(0), rpm_limit);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if let Some(tpm_limit) = max_tpm {
|
||||
if *tpm_limit > 0 && token_count.unwrap_or(0) >= *tpm_limit {
|
||||
tracing::debug!("[retry] Reactivated key {} hit TPM limit ({}/{})", id, token_count.unwrap_or(0), tpm_limit);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
let decrypted_kv = match decrypt_key_value(key_value, enc_key) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
tracing::warn!("[retry] Reactivated key {} decryption failed: {}", id, e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let selection = KeySelection {
|
||||
key: PoolKey { id: id.clone(), key_value: decrypted_kv, priority: *_priority, max_rpm: *max_rpm, max_tpm: *max_tpm },
|
||||
key_id: id.clone(),
|
||||
};
|
||||
get_cache().insert(provider_id.to_string(), CachedSelection {
|
||||
selection: selection.clone(),
|
||||
cached_at: Instant::now(),
|
||||
});
|
||||
return Ok(selection);
|
||||
}
|
||||
|
||||
// 所有恢复的 Key 仍被 RPM/TPM 限制或解密失败
|
||||
tracing::warn!("Provider {} 恢复的 Key 全部不可用(RPM/TPM 超限或解密失败)", provider_id);
|
||||
return Err(SaasError::RateLimited(
|
||||
format!("Provider {} 恢复的 Key 仍在限流中,请稍后重试", provider_id)
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
Err(SaasError::NotFound(format!(
|
||||
"Provider {} 没有可用的 API Key(所有 Key 已停用,请在管理后台激活)",
|
||||
provider_id
|
||||
)))
|
||||
}
|
||||
|
||||
/// 记录 Key 使用量(滑动窗口)
|
||||
@@ -229,14 +301,14 @@ pub async fn mark_key_429(
|
||||
let now = chrono::Utc::now();
|
||||
|
||||
sqlx::query(
|
||||
"UPDATE provider_keys SET last_429_at = $1, cooldown_until = $2, updated_at = $3
|
||||
"UPDATE provider_keys SET last_429_at = $1, cooldown_until = $2, is_active = FALSE, updated_at = $3
|
||||
WHERE id = $4"
|
||||
)
|
||||
.bind(&now).bind(&cooldown).bind(&now).bind(key_id)
|
||||
.execute(db).await?;
|
||||
|
||||
tracing::warn!(
|
||||
"Key {} 收到 429,冷却至 {}",
|
||||
"Key {} 收到 429,标记 is_active=FALSE,冷却至 {}",
|
||||
key_id,
|
||||
cooldown
|
||||
);
|
||||
@@ -315,9 +387,16 @@ pub async fn toggle_key_active(
|
||||
active: bool,
|
||||
) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now();
|
||||
sqlx::query(
|
||||
"UPDATE provider_keys SET is_active = $1, updated_at = $2 WHERE id = $3"
|
||||
).bind(active).bind(&now).bind(key_id).execute(db).await?;
|
||||
// When activating, clear cooldown so the key is immediately selectable
|
||||
if active {
|
||||
sqlx::query(
|
||||
"UPDATE provider_keys SET is_active = $1, cooldown_until = NULL, updated_at = $2 WHERE id = $3"
|
||||
).bind(active).bind(&now).bind(key_id).execute(db).await?;
|
||||
} else {
|
||||
sqlx::query(
|
||||
"UPDATE provider_keys SET is_active = $1, updated_at = $2 WHERE id = $3"
|
||||
).bind(active).bind(&now).bind(key_id).execute(db).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -191,6 +191,7 @@ pub fn parse_skill_md(content: &str) -> Result<SkillManifest> {
|
||||
triggers,
|
||||
tools,
|
||||
enabled: true,
|
||||
body: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -292,6 +293,7 @@ pub fn parse_skill_toml(content: &str) -> Result<SkillManifest> {
|
||||
triggers,
|
||||
tools,
|
||||
enabled: true,
|
||||
body: None,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -241,6 +241,7 @@ impl SkillRegistry {
|
||||
// P2-19: Preserve tools field during update (was silently dropped)
|
||||
tools: if updates.tools.is_empty() { existing.tools } else { updates.tools },
|
||||
enabled: updates.enabled,
|
||||
body: existing.body,
|
||||
};
|
||||
|
||||
// Rewrite SKILL.md
|
||||
@@ -318,10 +319,14 @@ fn serialize_skill_md(manifest: &SkillManifest) -> String {
|
||||
parts.push("---".to_string());
|
||||
parts.push(String::new());
|
||||
|
||||
// Body: use description as the skill content
|
||||
parts.push(format!("# {}", manifest.name));
|
||||
parts.push(String::new());
|
||||
parts.push(manifest.description.clone());
|
||||
// Body: use custom body if provided, otherwise default to "# {name}\n\n{description}"
|
||||
if let Some(ref body) = manifest.body {
|
||||
parts.push(body.clone());
|
||||
} else {
|
||||
parts.push(format!("# {}", manifest.name));
|
||||
parts.push(String::new());
|
||||
parts.push(manifest.description.clone());
|
||||
}
|
||||
|
||||
parts.join("\n")
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::time::Instant;
|
||||
use tracing::warn;
|
||||
use zclaw_types::Result;
|
||||
|
||||
use super::{Skill, SkillContext, SkillManifest, SkillResult};
|
||||
use super::{Skill, SkillCompletion, SkillContext, SkillManifest, SkillResult};
|
||||
|
||||
/// Returns the platform-appropriate Python binary name.
|
||||
/// On Windows, the standard installer provides `python.exe`, not `python3.exe`.
|
||||
@@ -39,6 +39,17 @@ impl PromptOnlySkill {
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
fn completion_to_result(&self, completion: SkillCompletion) -> SkillResult {
|
||||
if completion.tool_calls.is_empty() {
|
||||
return SkillResult::success(Value::String(completion.text));
|
||||
}
|
||||
// Include both text and tool calls so the caller can relay them.
|
||||
SkillResult::success(serde_json::json!({
|
||||
"text": completion.text,
|
||||
"tool_calls": completion.tool_calls,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -50,13 +61,25 @@ impl Skill for PromptOnlySkill {
|
||||
async fn execute(&self, context: &SkillContext, input: Value) -> Result<SkillResult> {
|
||||
let prompt = self.format_prompt(&input);
|
||||
|
||||
// If an LLM completer is available, generate an AI response
|
||||
if let Some(completer) = &context.llm {
|
||||
// If tool definitions are available and the manifest declares tools,
|
||||
// use tool-augmented completion so the LLM can invoke tools.
|
||||
if !context.tool_definitions.is_empty() && !self.manifest.tools.is_empty() {
|
||||
match completer.complete_with_tools(&prompt, None, context.tool_definitions.clone()).await {
|
||||
Ok(completion) => {
|
||||
return Ok(self.completion_to_result(completion));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("[PromptOnlySkill] Tool completion failed: {}, falling back", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Plain completion (no tools or fallback)
|
||||
match completer.complete(&prompt).await {
|
||||
Ok(response) => return Ok(SkillResult::success(Value::String(response))),
|
||||
Err(e) => {
|
||||
warn!("[PromptOnlySkill] LLM completion failed: {}, falling back to raw prompt", e);
|
||||
// Fall through to return raw prompt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -93,6 +93,8 @@ pub struct SemanticSkillRouter {
|
||||
confidence_threshold: f32,
|
||||
/// LLM fallback for ambiguous queries (confidence below threshold)
|
||||
llm_fallback: Option<Arc<dyn RuntimeLlmIntent>>,
|
||||
/// Experience-based boost factors: tool_name → boost weight (0.0 - 0.15)
|
||||
experience_boosts: HashMap<String, f32>,
|
||||
}
|
||||
|
||||
impl SemanticSkillRouter {
|
||||
@@ -104,6 +106,7 @@ impl SemanticSkillRouter {
|
||||
tfidf_index: SkillTfidfIndex::new(),
|
||||
skill_embeddings: HashMap::new(),
|
||||
confidence_threshold: 0.85,
|
||||
experience_boosts: HashMap::new(),
|
||||
llm_fallback: None,
|
||||
};
|
||||
router.rebuild_index_sync();
|
||||
@@ -194,7 +197,7 @@ impl SemanticSkillRouter {
|
||||
for (skill_id, manifest) in &manifests {
|
||||
let tfidf_score = self.tfidf_index.score(query, &skill_id.to_string());
|
||||
|
||||
let final_score = if let Some(ref q_emb) = query_embedding {
|
||||
let base_score = if let Some(ref q_emb) = query_embedding {
|
||||
// Hybrid: embedding (70%) + TF-IDF (30%)
|
||||
if let Some(s_emb) = self.skill_embeddings.get(&skill_id.to_string()) {
|
||||
let emb_sim = cosine_similarity(q_emb, s_emb);
|
||||
@@ -206,6 +209,10 @@ impl SemanticSkillRouter {
|
||||
tfidf_score
|
||||
};
|
||||
|
||||
// Apply experience-based boost for frequently used tools
|
||||
let boost = self.experience_boosts.get(&skill_id.to_string()).copied().unwrap_or(0.0);
|
||||
let final_score = base_score + boost;
|
||||
|
||||
scored.push(ScoredCandidate {
|
||||
manifest: manifest.clone(),
|
||||
score: final_score,
|
||||
@@ -281,6 +288,22 @@ impl SemanticSkillRouter {
|
||||
confidence_threshold: self.confidence_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update experience-based boost factors.
|
||||
///
|
||||
/// `experiences` maps tool/skill names to reuse counts.
|
||||
/// Higher reuse count → higher boost (capped at 0.15).
|
||||
/// This lets the router prefer skills the user frequently uses.
|
||||
pub fn update_experience_boosts(&mut self, experiences: &HashMap<String, u32>) {
|
||||
self.experience_boosts.clear();
|
||||
for (tool, count) in experiences {
|
||||
// Boost = min(0.05 * ln(count + 1), 0.15) — logarithmic scaling
|
||||
let boost = (0.05 * (*count as f32 + 1.0).ln()).min(0.15);
|
||||
if boost > 0.01 {
|
||||
self.experience_boosts.insert(tool.clone(), boost);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Router statistics
|
||||
@@ -534,6 +557,7 @@ mod tests {
|
||||
triggers: triggers.into_iter().map(|s| s.to_string()).collect(),
|
||||
tools: vec![],
|
||||
enabled: true,
|
||||
body: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -719,4 +743,40 @@ mod tests {
|
||||
// Should still return best TF-IDF match even below threshold
|
||||
assert_eq!(result.unwrap().skill_id, "skill-x");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_experience_boost_applied() {
|
||||
let registry = Arc::new(SkillRegistry::new());
|
||||
let embedder = Arc::new(NoOpEmbedder);
|
||||
let mut router = SemanticSkillRouter::new(registry.clone(), embedder);
|
||||
|
||||
let skill_a = make_manifest("researcher", "研究员", "深度研究分析报告", vec!["研究", "分析"]);
|
||||
let skill_b = make_manifest("collector", "收集器", "数据采集整理汇总", vec!["收集", "采集"]);
|
||||
registry.register(
|
||||
Arc::new(crate::runner::PromptOnlySkill::new(skill_a.clone(), String::new())),
|
||||
skill_a,
|
||||
).await;
|
||||
registry.register(
|
||||
Arc::new(crate::runner::PromptOnlySkill::new(skill_b.clone(), String::new())),
|
||||
skill_b,
|
||||
).await;
|
||||
|
||||
router.rebuild_index().await;
|
||||
|
||||
let mut exp = HashMap::new();
|
||||
exp.insert("researcher".to_string(), 10);
|
||||
router.update_experience_boosts(&exp);
|
||||
|
||||
let candidates = router.retrieve_candidates("帮我研究一下", 5).await;
|
||||
assert!(!candidates.is_empty());
|
||||
|
||||
let rid = SkillId::new("researcher");
|
||||
let cid = SkillId::new("collector");
|
||||
let researcher_score = candidates.iter().find(|c| c.manifest.id == rid).map(|c| c.score);
|
||||
let collector_score = candidates.iter().find(|c| c.manifest.id == cid).map(|c| c.score);
|
||||
|
||||
if let (Some(r), Some(c)) = (researcher_score, collector_score) {
|
||||
assert!(r >= c, "Experience-boosted researcher should score >= collector");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::pin::Pin;
|
||||
use zclaw_types::{SkillId, Result};
|
||||
use zclaw_types::{SkillId, ToolDefinition, Result};
|
||||
|
||||
/// Type-erased LLM completion interface.
|
||||
///
|
||||
@@ -15,6 +15,43 @@ pub trait LlmCompleter: Send + Sync {
|
||||
&self,
|
||||
prompt: &str,
|
||||
) -> Pin<Box<dyn std::future::Future<Output = std::result::Result<String, String>> + Send + '_>>;
|
||||
|
||||
/// Complete a prompt with tool definitions available to the LLM.
|
||||
///
|
||||
/// The LLM may return text, tool calls, or both. Tool calls are returned
|
||||
/// in the `tool_calls` field for the caller to execute or relay.
|
||||
/// Default implementation falls back to plain `complete()`.
|
||||
fn complete_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_system_prompt: Option<&str>,
|
||||
_tools: Vec<ToolDefinition>,
|
||||
) -> Pin<Box<dyn std::future::Future<Output = std::result::Result<SkillCompletion, String>> + Send + '_>> {
|
||||
let prompt = prompt.to_string();
|
||||
Box::pin(async move {
|
||||
self.complete(&prompt).await.map(|text| SkillCompletion { text, tool_calls: vec![] })
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of an LLM completion that may include tool calls.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillCompletion {
|
||||
/// The text portion of the LLM response.
|
||||
pub text: String,
|
||||
/// Tool calls the LLM requested, if any.
|
||||
pub tool_calls: Vec<SkillToolCall>,
|
||||
}
|
||||
|
||||
/// A single tool call returned by the LLM during skill execution.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SkillToolCall {
|
||||
/// Unique call ID.
|
||||
pub id: String,
|
||||
/// Name of the tool to invoke.
|
||||
pub name: String,
|
||||
/// Input arguments for the tool.
|
||||
pub input: Value,
|
||||
}
|
||||
|
||||
/// Skill manifest definition
|
||||
@@ -58,6 +95,9 @@ pub struct SkillManifest {
|
||||
/// Whether the skill is enabled
|
||||
#[serde(default = "default_enabled")]
|
||||
pub enabled: bool,
|
||||
/// Custom body content for SKILL.md (overrides default "# {name}\n\n{description}")
|
||||
#[serde(default, skip)]
|
||||
pub body: Option<String>,
|
||||
}
|
||||
|
||||
fn default_enabled() -> bool { true }
|
||||
@@ -97,6 +137,9 @@ pub struct SkillContext {
|
||||
pub file_access_allowed: bool,
|
||||
/// Optional LLM completer for skills that need AI generation (e.g. PromptOnly)
|
||||
pub llm: Option<std::sync::Arc<dyn LlmCompleter>>,
|
||||
/// Tool definitions resolved from the skill manifest's `tools` field.
|
||||
/// Populated by the kernel when creating the context.
|
||||
pub tool_definitions: Vec<ToolDefinition>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for SkillContext {
|
||||
@@ -109,6 +152,7 @@ impl std::fmt::Debug for SkillContext {
|
||||
.field("network_allowed", &self.network_allowed)
|
||||
.field("file_access_allowed", &self.file_access_allowed)
|
||||
.field("llm", &self.llm.as_ref().map(|_| "Arc<dyn LlmCompleter>"))
|
||||
.field("tool_definitions", &self.tool_definitions.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -124,6 +168,7 @@ impl Default for SkillContext {
|
||||
network_allowed: false,
|
||||
file_access_allowed: false,
|
||||
llm: None,
|
||||
tool_definitions: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -468,6 +468,7 @@ mod tests {
|
||||
triggers: vec![],
|
||||
tools: vec![],
|
||||
enabled: true,
|
||||
body: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
271
crates/zclaw-skills/tests/embedding_router_test.rs
Normal file
271
crates/zclaw-skills/tests/embedding_router_test.rs
Normal file
@@ -0,0 +1,271 @@
|
||||
//! Embedding router tests (EM-01 ~ EM-06)
|
||||
//!
|
||||
//! Validates SemanticSkillRouter with embedding, TF-IDF fallback,
|
||||
//! dimension mismatch handling, empty queries, CJK queries, and LLM fallback.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use async_trait::async_trait;
|
||||
use zclaw_skills::semantic_router::{
|
||||
Embedder, NoOpEmbedder, SemanticSkillRouter, RuntimeLlmIntent,
|
||||
RoutingResult, ScoredCandidate, cosine_similarity,
|
||||
};
|
||||
use zclaw_skills::{SkillRegistry, PromptOnlySkill, SkillManifest, SkillMode};
|
||||
use zclaw_types::id::SkillId;
|
||||
|
||||
fn make_manifest(id: &str, name: &str, triggers: Vec<&str>) -> SkillManifest {
|
||||
SkillManifest {
|
||||
id: SkillId::new(id),
|
||||
name: name.to_string(),
|
||||
description: format!("{} description", name),
|
||||
version: "1.0.0".to_string(),
|
||||
mode: SkillMode::PromptOnly,
|
||||
triggers: triggers.into_iter().map(String::from).collect(),
|
||||
enabled: true,
|
||||
author: None,
|
||||
capabilities: Vec::new(),
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
tags: Vec::new(),
|
||||
category: None,
|
||||
tools: Vec::new(),
|
||||
body: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock embedder that returns fixed 768-dim vectors with variation by text hash.
|
||||
struct MockEmbedder {
|
||||
dim: usize,
|
||||
should_fail: bool,
|
||||
}
|
||||
|
||||
impl MockEmbedder {
|
||||
fn new(dim: usize) -> Self {
|
||||
Self { dim, should_fail: false }
|
||||
}
|
||||
fn failing() -> Self {
|
||||
Self { dim: 768, should_fail: true }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Embedder for MockEmbedder {
|
||||
async fn embed(&self, text: &str) -> Option<Vec<f32>> {
|
||||
if self.should_fail {
|
||||
return None;
|
||||
}
|
||||
// Deterministic vector based on text content
|
||||
let mut vec = vec![0.0f32; self.dim];
|
||||
for (i, b) in text.as_bytes().iter().enumerate() {
|
||||
vec[i % self.dim] += (*b as f32) / 255.0;
|
||||
}
|
||||
// Normalize
|
||||
let norm: f32 = vec.iter().map(|v| v * v).sum::<f32>().sqrt().max(1e-8);
|
||||
for v in vec.iter_mut() {
|
||||
*v /= norm;
|
||||
}
|
||||
Some(vec)
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper: register skills and build router with embedding.
|
||||
async fn build_router_with_skills(
|
||||
embedder: Arc<dyn Embedder>,
|
||||
skills: Vec<(&str, &str, Vec<&str>)>,
|
||||
) -> SemanticSkillRouter {
|
||||
let registry = Arc::new(SkillRegistry::new());
|
||||
for (id, name, triggers) in skills {
|
||||
let manifest = make_manifest(id, name, triggers);
|
||||
registry
|
||||
.register(
|
||||
Arc::new(zclaw_skills::PromptOnlySkill::new(
|
||||
manifest.clone(),
|
||||
format!("Execute {}", name),
|
||||
)),
|
||||
manifest,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
let mut router = SemanticSkillRouter::new(registry, embedder);
|
||||
router.rebuild_index().await;
|
||||
router
|
||||
}
|
||||
|
||||
/// EM-01: Embedding API normal routing with 70/30 hybrid scoring.
|
||||
#[tokio::test]
|
||||
async fn em01_embedding_normal_routing() {
|
||||
let router = build_router_with_skills(
|
||||
Arc::new(MockEmbedder::new(768)),
|
||||
vec![
|
||||
("finance", "财务追踪", vec!["财务", "花销", "支出", "账单"]),
|
||||
("scheduling", "排班管理", vec!["排班", "班表", "值班"]),
|
||||
("news", "新闻搜索", vec!["新闻", "资讯", "头条"]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = router.route("帮我查一下上个月的花销").await;
|
||||
assert!(result.is_some(), "should match a skill");
|
||||
let r = result.unwrap();
|
||||
assert_eq!(r.skill_id, "finance", "should match finance skill");
|
||||
assert!(
|
||||
r.confidence > 0.1,
|
||||
"confidence should be positive: {}",
|
||||
r.confidence
|
||||
);
|
||||
}
|
||||
|
||||
/// EM-02: Embedding API failure degrades to TF-IDF.
|
||||
#[tokio::test]
|
||||
async fn em02_embedding_failure_fallback_to_tfidf() {
|
||||
let router = build_router_with_skills(
|
||||
Arc::new(MockEmbedder::failing()),
|
||||
vec![
|
||||
("finance", "财务追踪", vec!["财务", "花销"]),
|
||||
("scheduling", "排班管理", vec!["排班", "班表"]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
// Should still return results via TF-IDF fallback
|
||||
let result = router.route("帮我查花销").await;
|
||||
assert!(
|
||||
result.is_some(),
|
||||
"TF-IDF fallback should still produce results"
|
||||
);
|
||||
}
|
||||
|
||||
/// EM-03: Embedding dimension mismatch — no panic.
|
||||
#[tokio::test]
|
||||
async fn em03_embedding_dimension_mismatch() {
|
||||
// Use a mismatched embedder that returns different dimensions
|
||||
struct MismatchedEmbedder;
|
||||
#[async_trait]
|
||||
impl Embedder for MismatchedEmbedder {
|
||||
async fn embed(&self, _text: &str) -> Option<Vec<f32>> {
|
||||
// Return a small vector — won't match index embeddings
|
||||
Some(vec![0.5; 64])
|
||||
}
|
||||
}
|
||||
|
||||
let router = build_router_with_skills(
|
||||
Arc::new(MismatchedEmbedder),
|
||||
vec![("finance", "财务追踪", vec!["财务", "花销"])],
|
||||
)
|
||||
.await;
|
||||
|
||||
// Should not panic
|
||||
let result = router.route("查花销").await;
|
||||
// May return None or a result via TF-IDF — key assertion: no panic
|
||||
let _ = result;
|
||||
}
|
||||
|
||||
/// EM-04: Empty query handling.
|
||||
#[tokio::test]
|
||||
async fn em04_empty_query_handling() {
|
||||
let router = build_router_with_skills(
|
||||
Arc::new(MockEmbedder::new(768)),
|
||||
vec![("finance", "财务追踪", vec!["财务"])],
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = router.route("").await;
|
||||
// Empty query may return None or a low-confidence result
|
||||
// Key: no panic
|
||||
let _ = result;
|
||||
}
|
||||
|
||||
/// EM-05: Pure Chinese CJK query with bigram matching.
|
||||
#[tokio::test]
|
||||
async fn em05_cjk_query_matching() {
|
||||
let router = build_router_with_skills(
|
||||
Arc::new(NoOpEmbedder), // TF-IDF only
|
||||
vec![
|
||||
("scheduling", "排班管理", vec!["排班", "班表", "值班"]),
|
||||
("news", "新闻搜索", vec!["新闻"]),
|
||||
],
|
||||
)
|
||||
.await;
|
||||
|
||||
let result = router.route("我这个月值班表怎么排").await;
|
||||
assert!(result.is_some(), "CJK query should match");
|
||||
assert_eq!(
|
||||
result.unwrap().skill_id,
|
||||
"scheduling",
|
||||
"should match scheduling skill"
|
||||
);
|
||||
}
|
||||
|
||||
/// EM-06: LLM fallback triggered for ambiguous queries.
|
||||
#[tokio::test]
|
||||
async fn em06_llm_fallback_triggered() {
|
||||
struct MockLlmFallback {
|
||||
target: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RuntimeLlmIntent for MockLlmFallback {
|
||||
async fn resolve_skill(
|
||||
&self,
|
||||
_query: &str,
|
||||
candidates: &[ScoredCandidate],
|
||||
) -> Option<RoutingResult> {
|
||||
let c = candidates
|
||||
.iter()
|
||||
.find(|c| c.manifest.id.as_str() == self.target)?;
|
||||
Some(RoutingResult {
|
||||
skill_id: c.manifest.id.to_string(),
|
||||
confidence: 0.75,
|
||||
parameters: serde_json::json!({}),
|
||||
reasoning: "LLM selected".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
let registry = Arc::new(SkillRegistry::new());
|
||||
let manifest = make_manifest("helper", "通用助手", vec!["帮助", "处理"]);
|
||||
registry
|
||||
.register(
|
||||
Arc::new(zclaw_skills::PromptOnlySkill::new(
|
||||
manifest.clone(),
|
||||
"Help".to_string(),
|
||||
)),
|
||||
manifest,
|
||||
)
|
||||
.await;
|
||||
|
||||
let mut router = SemanticSkillRouter::new_tf_idf_only(registry)
|
||||
.with_confidence_threshold(100.0) // Force all to be below threshold
|
||||
.with_llm_fallback(Arc::new(MockLlmFallback {
|
||||
target: "helper".to_string(),
|
||||
}));
|
||||
router.rebuild_index().await;
|
||||
|
||||
let result = router.route("帮我处理一下那个东西").await;
|
||||
assert!(result.is_some(), "LLM fallback should resolve");
|
||||
assert_eq!(result.unwrap().skill_id, "helper");
|
||||
}
|
||||
|
||||
/// Bonus: cosine_similarity utility correctness.
|
||||
#[test]
|
||||
fn cosine_similarity_identical_vectors() {
|
||||
let v = vec![1.0, 0.0, 1.0, 0.0];
|
||||
let sim = cosine_similarity(&v, &v);
|
||||
assert!((sim - 1.0).abs() < 1e-6, "identical vectors => cosine=1.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_similarity_orthogonal_vectors() {
|
||||
let a = vec![1.0, 0.0];
|
||||
let b = vec![0.0, 1.0];
|
||||
let sim = cosine_similarity(&a, &b);
|
||||
assert!(sim.abs() < 1e-6, "orthogonal => cosine≈0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cosine_similarity_mismatched_dimensions() {
|
||||
let a = vec![1.0, 0.0, 1.0];
|
||||
let b = vec![1.0, 0.0];
|
||||
let sim = cosine_similarity(&a, &b);
|
||||
assert_eq!(sim, 0.0, "mismatched dimensions => 0.0");
|
||||
}
|
||||
247
crates/zclaw-skills/tests/loader_tests.rs
Normal file
247
crates/zclaw-skills/tests/loader_tests.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
//! Tests for skill loader — SKILL.md and TOML parsing
|
||||
|
||||
use zclaw_skills::*;
|
||||
|
||||
// === parse_skill_md ===
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_basic_frontmatter() {
|
||||
let content = r#"---
|
||||
name: "Code Reviewer"
|
||||
description: "Reviews code"
|
||||
version: "1.0.0"
|
||||
mode: prompt-only
|
||||
tags: coding, review
|
||||
---
|
||||
# Code Reviewer
|
||||
Reviews code for quality.
|
||||
"#;
|
||||
let manifest = parse_skill_md(content).unwrap();
|
||||
assert_eq!(manifest.name, "Code Reviewer");
|
||||
assert_eq!(manifest.description, "Reviews code");
|
||||
assert_eq!(manifest.version, "1.0.0");
|
||||
assert_eq!(manifest.mode, zclaw_skills::SkillMode::PromptOnly);
|
||||
assert_eq!(manifest.tags, vec!["coding", "review"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_with_triggers_list() {
|
||||
let content = r#"---
|
||||
name: "Translator"
|
||||
description: "Translates text"
|
||||
version: "1.0.0"
|
||||
mode: prompt-only
|
||||
triggers:
|
||||
- "翻译"
|
||||
- "translate"
|
||||
- "中译英"
|
||||
---
|
||||
# Translator
|
||||
"#;
|
||||
let manifest = parse_skill_md(content).unwrap();
|
||||
assert_eq!(manifest.triggers, vec!["翻译", "translate", "中译英"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_with_tools_list() {
|
||||
let content = r#"---
|
||||
name: "Builder"
|
||||
description: "Builds projects"
|
||||
version: "1.0.0"
|
||||
mode: shell
|
||||
tools:
|
||||
- "bash"
|
||||
- "cargo"
|
||||
---
|
||||
# Builder
|
||||
"#;
|
||||
let manifest = parse_skill_md(content).unwrap();
|
||||
assert_eq!(manifest.tools, vec!["bash", "cargo"]);
|
||||
assert_eq!(manifest.mode, zclaw_skills::SkillMode::Shell);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_with_category() {
|
||||
let content = r#"---
|
||||
name: "Math Solver"
|
||||
description: "Solves math problems"
|
||||
version: "1.0.0"
|
||||
mode: prompt-only
|
||||
category: "math"
|
||||
---
|
||||
# Math Solver
|
||||
"#;
|
||||
let manifest = parse_skill_md(content).unwrap();
|
||||
assert_eq!(manifest.category.unwrap(), "math");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_auto_classify_coding() {
|
||||
let content = r#"---
|
||||
name: "Code Helper"
|
||||
description: "Helps with programming and debugging"
|
||||
version: "1.0.0"
|
||||
mode: prompt-only
|
||||
---
|
||||
# Code Helper
|
||||
"#;
|
||||
let manifest = parse_skill_md(content).unwrap();
|
||||
// Should auto-classify as "coding" based on description
|
||||
assert_eq!(manifest.category.unwrap(), "coding");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_auto_classify_translation() {
|
||||
let content = r#"---
|
||||
name: "Translator"
|
||||
description: "Helps with translation between languages"
|
||||
version: "1.0.0"
|
||||
mode: prompt-only
|
||||
---
|
||||
# Translator
|
||||
"#;
|
||||
let manifest = parse_skill_md(content).unwrap();
|
||||
// Should auto-classify based on "translat" keyword
|
||||
assert!(manifest.category.is_some(), "Should auto-classify translation skill");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_no_frontmatter_extracts_name() {
|
||||
let content = "# My Skill\n\nThis is a cool skill.";
|
||||
let manifest = parse_skill_md(content).unwrap();
|
||||
assert_eq!(manifest.name, "My Skill");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_fallback_name() {
|
||||
let content = "Just some text without structure.";
|
||||
let manifest = parse_skill_md(content).unwrap();
|
||||
assert_eq!(manifest.name, "unnamed-skill");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_id_generation() {
|
||||
let content = "---\nname: \"Hello World\"\n---\n";
|
||||
let manifest = parse_skill_md(content).unwrap();
|
||||
assert_eq!(manifest.id.as_str(), "hello-world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_all_modes() {
|
||||
for (mode_str, expected) in &[
|
||||
("prompt-only", zclaw_skills::SkillMode::PromptOnly),
|
||||
("python", zclaw_skills::SkillMode::Python),
|
||||
("shell", zclaw_skills::SkillMode::Shell),
|
||||
("wasm", zclaw_skills::SkillMode::Wasm),
|
||||
("native", zclaw_skills::SkillMode::Native),
|
||||
] {
|
||||
let content = format!("---\nname: \"Test\"\nmode: {}\n---\n", mode_str);
|
||||
let manifest = parse_skill_md(&content).unwrap();
|
||||
assert_eq!(&manifest.mode, expected, "Failed for mode: {}", mode_str);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_md_capabilities_csv() {
|
||||
let content = "---\nname: \"Multi\"\ncapabilities: llm, web, file\n---\n";
|
||||
let manifest = parse_skill_md(content).unwrap();
|
||||
assert_eq!(manifest.capabilities, vec!["llm", "web", "file"]);
|
||||
}
|
||||
|
||||
// === parse_skill_toml ===
|
||||
|
||||
#[test]
|
||||
fn parse_skill_toml_basic() {
|
||||
let content = r#"
|
||||
name = "Calculator"
|
||||
description = "Performs calculations"
|
||||
version = "2.0.0"
|
||||
mode = "prompt_only"
|
||||
"#;
|
||||
let manifest = parse_skill_toml(content).unwrap();
|
||||
assert_eq!(manifest.name, "Calculator");
|
||||
assert_eq!(manifest.description, "Performs calculations");
|
||||
assert_eq!(manifest.version, "2.0.0");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_toml_with_id() {
|
||||
let content = r#"
|
||||
id = "my-calc"
|
||||
name = "Calculator"
|
||||
description = "Calc"
|
||||
"#;
|
||||
let manifest = parse_skill_toml(content).unwrap();
|
||||
assert_eq!(manifest.id.as_str(), "my-calc");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_toml_generates_id_from_name() {
|
||||
let content = "name = \"Hello World\"\ndescription = \"x\"";
|
||||
let manifest = parse_skill_toml(content).unwrap();
|
||||
assert_eq!(manifest.id.as_str(), "hello-world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_toml_requires_name() {
|
||||
let content = r#"description = "no name""#;
|
||||
let result = parse_skill_toml(content);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_toml_arrays() {
|
||||
let content = r#"
|
||||
name = "X"
|
||||
description = "x"
|
||||
tags = ["a", "b", "c"]
|
||||
capabilities = ["llm"]
|
||||
triggers = ["go", "run"]
|
||||
"#;
|
||||
let manifest = parse_skill_toml(content).unwrap();
|
||||
assert_eq!(manifest.tags, vec!["a", "b", "c"]);
|
||||
assert_eq!(manifest.capabilities, vec!["llm"]);
|
||||
assert_eq!(manifest.triggers, vec!["go", "run"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_toml_category() {
|
||||
let content = r#"
|
||||
name = "X"
|
||||
description = "x"
|
||||
category = "data"
|
||||
"#;
|
||||
let manifest = parse_skill_toml(content).unwrap();
|
||||
assert_eq!(manifest.category.unwrap(), "data");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_toml_tools() {
|
||||
let content = r#"
|
||||
name = "X"
|
||||
description = "x"
|
||||
tools = ["bash", "cargo"]
|
||||
"#;
|
||||
let manifest = parse_skill_toml(content).unwrap();
|
||||
assert_eq!(manifest.tools, vec!["bash", "cargo"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_skill_toml_ignores_comments_and_sections() {
|
||||
let content = r#"
|
||||
# This is a comment
|
||||
[section]
|
||||
name = "X"
|
||||
description = "x"
|
||||
"#;
|
||||
let manifest = parse_skill_toml(content).unwrap();
|
||||
assert_eq!(manifest.name, "X");
|
||||
}
|
||||
|
||||
// === discover_skills ===
|
||||
|
||||
#[test]
|
||||
fn discover_skills_nonexistent_dir() {
|
||||
let result = discover_skills(std::path::Path::new("/nonexistent/path")).unwrap();
|
||||
assert!(result.is_empty());
|
||||
}
|
||||
79
crates/zclaw-skills/tests/runner_tests.rs
Normal file
79
crates/zclaw-skills/tests/runner_tests.rs
Normal file
@@ -0,0 +1,79 @@
|
||||
//! Tests for PromptOnlySkill runner
|
||||
|
||||
use zclaw_skills::*;
|
||||
use zclaw_types::SkillId;
|
||||
|
||||
/// Helper to create a minimal manifest
|
||||
fn test_manifest(mode: SkillMode) -> SkillManifest {
|
||||
SkillManifest {
|
||||
id: SkillId::new("test-prompt-skill"),
|
||||
name: "Test Prompt Skill".to_string(),
|
||||
description: "A test prompt skill".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
author: None,
|
||||
mode,
|
||||
capabilities: vec![],
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
tags: vec![],
|
||||
category: None,
|
||||
triggers: vec![],
|
||||
tools: vec![],
|
||||
enabled: true,
|
||||
body: None,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_only_skill_returns_formatted_prompt() {
|
||||
let manifest = test_manifest(SkillMode::PromptOnly);
|
||||
let template = "Hello {{input}}, welcome!".to_string();
|
||||
let skill = PromptOnlySkill::new(manifest, template);
|
||||
|
||||
let ctx = SkillContext::default();
|
||||
let skill_ref: &dyn Skill = &skill;
|
||||
let result = skill_ref.execute(&ctx, serde_json::json!("World")).await.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output = result.output.as_str().unwrap();
|
||||
assert_eq!(output, "Hello World, welcome!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_only_skill_json_input() {
|
||||
let manifest = test_manifest(SkillMode::PromptOnly);
|
||||
let template = "Input: {{input}}".to_string();
|
||||
let skill = PromptOnlySkill::new(manifest, template);
|
||||
|
||||
let ctx = SkillContext::default();
|
||||
let input = serde_json::json!({"key": "value"});
|
||||
let skill_ref: &dyn Skill = &skill;
|
||||
let result = skill_ref.execute(&ctx, input).await.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
let output = result.output.as_str().unwrap();
|
||||
assert!(output.contains("key"));
|
||||
assert!(output.contains("value"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_only_skill_no_placeholder() {
|
||||
let manifest = test_manifest(SkillMode::PromptOnly);
|
||||
let template = "Static prompt content".to_string();
|
||||
let skill = PromptOnlySkill::new(manifest, template);
|
||||
|
||||
let ctx = SkillContext::default();
|
||||
let skill_ref: &dyn Skill = &skill;
|
||||
let result = skill_ref.execute(&ctx, serde_json::json!("ignored")).await.unwrap();
|
||||
|
||||
assert!(result.success);
|
||||
assert_eq!(result.output.as_str().unwrap(), "Static prompt content");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prompt_only_skill_manifest() {
|
||||
let manifest = test_manifest(SkillMode::PromptOnly);
|
||||
let skill = PromptOnlySkill::new(manifest.clone(), "prompt".to_string());
|
||||
assert_eq!(skill.manifest().id.as_str(), "test-prompt-skill");
|
||||
assert_eq!(skill.manifest().name, "Test Prompt Skill");
|
||||
}
|
||||
150
crates/zclaw-skills/tests/skill_types_tests.rs
Normal file
150
crates/zclaw-skills/tests/skill_types_tests.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
//! Tests for zclaw-skills types: SkillManifest, SkillMode, SkillResult, SkillContext
|
||||
|
||||
use serde_json;
|
||||
use zclaw_skills::*;
|
||||
use zclaw_types::SkillId;
|
||||
|
||||
// === SkillMode ===
|
||||
|
||||
#[test]
|
||||
fn skill_mode_serialization_roundtrip() {
|
||||
let modes = vec![
|
||||
SkillMode::PromptOnly,
|
||||
SkillMode::Python,
|
||||
SkillMode::Shell,
|
||||
SkillMode::Wasm,
|
||||
SkillMode::Native,
|
||||
];
|
||||
for mode in modes {
|
||||
let json = serde_json::to_string(&mode).unwrap();
|
||||
let parsed: SkillMode = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(mode, parsed);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_mode_snake_case_serialization() {
|
||||
let json = serde_json::to_string(&SkillMode::PromptOnly).unwrap();
|
||||
assert!(json.contains("prompt_only"));
|
||||
}
|
||||
|
||||
// === SkillResult ===
|
||||
|
||||
#[test]
|
||||
fn skill_result_success() {
|
||||
let result = SkillResult::success(serde_json::json!({"answer": 42}));
|
||||
assert!(result.success);
|
||||
assert!(result.error.is_none());
|
||||
assert_eq!(result.output["answer"], 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_result_error() {
|
||||
let result = SkillResult::error("execution failed");
|
||||
assert!(!result.success);
|
||||
assert_eq!(result.error.unwrap(), "execution failed");
|
||||
assert!(result.output.is_null());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_result_roundtrip() {
|
||||
let result = SkillResult {
|
||||
success: true,
|
||||
output: serde_json::json!("hello"),
|
||||
error: None,
|
||||
duration_ms: Some(150),
|
||||
tokens_used: Some(42),
|
||||
};
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
let parsed: SkillResult = serde_json::from_str(&json).unwrap();
|
||||
assert!(parsed.success);
|
||||
assert_eq!(parsed.duration_ms.unwrap(), 150);
|
||||
assert_eq!(parsed.tokens_used.unwrap(), 42);
|
||||
}
|
||||
|
||||
// === SkillManifest ===
|
||||
|
||||
#[test]
|
||||
fn skill_manifest_full_roundtrip() {
|
||||
let manifest = SkillManifest {
|
||||
id: SkillId::new("test-skill"),
|
||||
name: "Test Skill".to_string(),
|
||||
description: "A test skill".to_string(),
|
||||
version: "2.0.0".to_string(),
|
||||
author: Some("tester".to_string()),
|
||||
mode: SkillMode::PromptOnly,
|
||||
capabilities: vec!["llm".to_string()],
|
||||
input_schema: Some(serde_json::json!({"type": "object"})),
|
||||
output_schema: None,
|
||||
tags: vec!["test".to_string()],
|
||||
category: Some("coding".to_string()),
|
||||
triggers: vec!["test trigger".to_string()],
|
||||
tools: vec!["bash".to_string()],
|
||||
enabled: true,
|
||||
body: None,
|
||||
};
|
||||
let json = serde_json::to_string(&manifest).unwrap();
|
||||
let parsed: SkillManifest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.id.as_str(), "test-skill");
|
||||
assert_eq!(parsed.name, "Test Skill");
|
||||
assert_eq!(parsed.mode, SkillMode::PromptOnly);
|
||||
assert_eq!(parsed.capabilities.len(), 1);
|
||||
assert_eq!(parsed.triggers.len(), 1);
|
||||
assert_eq!(parsed.tools.len(), 1);
|
||||
assert_eq!(parsed.category.unwrap(), "coding");
|
||||
assert!(parsed.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_manifest_default_enabled() {
|
||||
let json = r#"{"id":"x","name":"X","description":"x","version":"1.0","mode":"prompt_only"}"#;
|
||||
let manifest: SkillManifest = serde_json::from_str(json).unwrap();
|
||||
assert!(manifest.enabled, "enabled should default to true");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_manifest_disabled() {
|
||||
let json = r#"{"id":"x","name":"X","description":"x","version":"1.0","mode":"prompt_only","enabled":false}"#;
|
||||
let manifest: SkillManifest = serde_json::from_str(json).unwrap();
|
||||
assert!(!manifest.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skill_manifest_all_modes_roundtrip() {
|
||||
for mode in &[SkillMode::PromptOnly, SkillMode::Python, SkillMode::Shell, SkillMode::Wasm] {
|
||||
let manifest = SkillManifest {
|
||||
id: SkillId::new("m"),
|
||||
name: "M".into(),
|
||||
description: "d".into(),
|
||||
version: "1.0".into(),
|
||||
author: None,
|
||||
mode: mode.clone(),
|
||||
capabilities: vec![],
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
tags: vec![],
|
||||
category: None,
|
||||
triggers: vec![],
|
||||
tools: vec![],
|
||||
enabled: true,
|
||||
body: None,
|
||||
};
|
||||
let json = serde_json::to_string(&manifest).unwrap();
|
||||
let parsed: SkillManifest = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(*mode, parsed.mode);
|
||||
}
|
||||
}
|
||||
|
||||
// === SkillContext ===
|
||||
|
||||
#[test]
|
||||
fn skill_context_default() {
|
||||
let ctx = SkillContext::default();
|
||||
assert!(ctx.agent_id.is_empty());
|
||||
assert!(ctx.session_id.is_empty());
|
||||
assert!(ctx.working_dir.is_none());
|
||||
assert_eq!(ctx.timeout_secs, 60);
|
||||
assert!(!ctx.network_allowed);
|
||||
assert!(!ctx.file_access_allowed);
|
||||
assert!(ctx.llm.is_none());
|
||||
}
|
||||
222
crates/zclaw-skills/tests/tool_enabled_skill_test.rs
Normal file
222
crates/zclaw-skills/tests/tool_enabled_skill_test.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
//! Tool-enabled skill execution tests (SK-01 ~ SK-03)
|
||||
//!
|
||||
//! Validates that skills with tool declarations actually pass tools to the LLM,
|
||||
//! skills without tools use pure prompt mode, and lock poisoning is handled gracefully.
|
||||
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use serde_json::{json, Value};
|
||||
use zclaw_skills::{
|
||||
PromptOnlySkill, LlmCompleter, Skill, SkillCompletion, SkillContext,
|
||||
SkillManifest, SkillMode, SkillToolCall, SkillRegistry,
|
||||
};
|
||||
use zclaw_types::id::SkillId;
|
||||
use zclaw_types::tool::ToolDefinition;
|
||||
|
||||
fn make_tool_manifest(id: &str, tools: Vec<&str>) -> SkillManifest {
|
||||
SkillManifest {
|
||||
id: SkillId::new(id),
|
||||
name: id.to_string(),
|
||||
description: format!("{} test skill", id),
|
||||
version: "1.0.0".to_string(),
|
||||
mode: SkillMode::PromptOnly,
|
||||
tools: tools.into_iter().map(String::from).collect(),
|
||||
enabled: true,
|
||||
author: None,
|
||||
capabilities: Vec::new(),
|
||||
input_schema: None,
|
||||
output_schema: None,
|
||||
tags: Vec::new(),
|
||||
category: None,
|
||||
triggers: Vec::new(),
|
||||
body: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock LLM completer that records calls and returns preset responses.
|
||||
struct MockCompleter {
|
||||
response_text: String,
|
||||
tool_calls: Vec<SkillToolCall>,
|
||||
calls: std::sync::Mutex<Vec<String>>,
|
||||
tools_received: std::sync::Mutex<Vec<Vec<ToolDefinition>>>,
|
||||
}
|
||||
|
||||
impl MockCompleter {
|
||||
fn new(text: &str) -> Self {
|
||||
Self {
|
||||
response_text: text.to_string(),
|
||||
tool_calls: Vec::new(),
|
||||
calls: std::sync::Mutex::new(Vec::new()),
|
||||
tools_received: std::sync::Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_tool_call(mut self, name: &str, input: Value) -> Self {
|
||||
self.tool_calls.push(SkillToolCall {
|
||||
id: format!("call_{}", name),
|
||||
name: name.to_string(),
|
||||
input,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
fn call_count(&self) -> usize {
|
||||
self.calls.lock().unwrap().len()
|
||||
}
|
||||
|
||||
fn last_tools(&self) -> Vec<ToolDefinition> {
|
||||
self.tools_received
|
||||
.lock()
|
||||
.unwrap()
|
||||
.last()
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmCompleter for MockCompleter {
|
||||
fn complete(
|
||||
&self,
|
||||
prompt: &str,
|
||||
) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send + '_>> {
|
||||
self.calls.lock().unwrap().push(prompt.to_string());
|
||||
let text = self.response_text.clone();
|
||||
Box::pin(async move { Ok(text) })
|
||||
}
|
||||
|
||||
fn complete_with_tools(
|
||||
&self,
|
||||
prompt: &str,
|
||||
_system_prompt: Option<&str>,
|
||||
tools: Vec<ToolDefinition>,
|
||||
) -> Pin<Box<dyn Future<Output = Result<SkillCompletion, String>> + Send + '_>> {
|
||||
self.calls.lock().unwrap().push(prompt.to_string());
|
||||
self.tools_received.lock().unwrap().push(tools);
|
||||
let text = self.response_text.clone();
|
||||
let tool_calls = self.tool_calls.clone();
|
||||
Box::pin(async move {
|
||||
Ok(SkillCompletion { text, tool_calls })
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// SK-01: Skill with tool declarations passes tools to LLM via complete_with_tools.
|
||||
#[tokio::test]
|
||||
async fn sk01_skill_with_tools_calls_complete_with_tools() {
|
||||
let completer = Arc::new(MockCompleter::new("Research completed").with_tool_call(
|
||||
"web_fetch",
|
||||
json!({"url": "https://example.com"}),
|
||||
));
|
||||
|
||||
let manifest = make_tool_manifest("web-researcher", vec!["web_fetch", "execute_skill"]);
|
||||
|
||||
let tool_defs = vec![
|
||||
ToolDefinition::new("web_fetch", "Fetch a URL", json!({"type": "object"})),
|
||||
ToolDefinition::new("execute_skill", "Execute another skill", json!({"type": "object"})),
|
||||
];
|
||||
|
||||
let ctx = SkillContext {
|
||||
agent_id: "agent-1".into(),
|
||||
session_id: "sess-1".into(),
|
||||
llm: Some(completer.clone()),
|
||||
tool_definitions: tool_defs.clone(),
|
||||
..SkillContext::default()
|
||||
};
|
||||
|
||||
let skill = PromptOnlySkill::new(
|
||||
manifest.clone(),
|
||||
"Research: {{input}}".to_string(),
|
||||
);
|
||||
let result = skill.execute(&ctx, json!("rust programming")).await;
|
||||
|
||||
assert!(result.is_ok(), "skill execution should succeed");
|
||||
let skill_result = result.unwrap();
|
||||
assert!(skill_result.success, "skill result should be successful");
|
||||
|
||||
// Verify LLM was called
|
||||
assert_eq!(completer.call_count(), 1, "LLM should be called once");
|
||||
|
||||
// Verify tools were passed
|
||||
let tools = completer.last_tools();
|
||||
assert_eq!(tools.len(), 2, "both tools should be passed to LLM");
|
||||
assert_eq!(tools[0].name, "web_fetch");
|
||||
assert_eq!(tools[1].name, "execute_skill");
|
||||
}
|
||||
|
||||
/// SK-02: Skill without tool declarations uses pure complete() call.
|
||||
#[tokio::test]
|
||||
async fn sk02_skill_without_tools_uses_pure_prompt() {
|
||||
let completer = Arc::new(MockCompleter::new("Writing helper response"));
|
||||
|
||||
let manifest = make_tool_manifest("writing-helper", vec![]);
|
||||
|
||||
let ctx = SkillContext {
|
||||
agent_id: "agent-1".into(),
|
||||
session_id: "sess-1".into(),
|
||||
llm: Some(completer.clone()),
|
||||
tool_definitions: vec![],
|
||||
..SkillContext::default()
|
||||
};
|
||||
|
||||
let skill = PromptOnlySkill::new(
|
||||
manifest,
|
||||
"Help with: {{input}}".to_string(),
|
||||
);
|
||||
let result = skill.execute(&ctx, json!("write a summary")).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let skill_result = result.unwrap();
|
||||
assert!(skill_result.success);
|
||||
|
||||
// Verify LLM was called (via complete(), not complete_with_tools)
|
||||
assert_eq!(completer.call_count(), 1);
|
||||
// No tools should have been received (complete path, not complete_with_tools)
|
||||
assert!(
|
||||
completer.last_tools().is_empty(),
|
||||
"pure prompt should not pass tools"
|
||||
);
|
||||
}
|
||||
|
||||
/// SK-03: Skill execution degrades gracefully on lock poisoning.
|
||||
/// Note: SkillRegistry uses std::sync::RwLock which can be poisoned.
|
||||
/// This test verifies that registry operations handle the poisoned state.
|
||||
#[tokio::test]
|
||||
async fn sk03_registry_handles_lock_contention() {
|
||||
let registry = Arc::new(SkillRegistry::new());
|
||||
|
||||
let manifest = make_tool_manifest("test-skill", vec![]);
|
||||
|
||||
// Register skill
|
||||
registry
|
||||
.register(
|
||||
Arc::new(PromptOnlySkill::new(
|
||||
manifest.clone(),
|
||||
"Test: {{input}}".to_string(),
|
||||
)),
|
||||
manifest,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Concurrent read and write should not panic
|
||||
let r1 = registry.clone();
|
||||
let r2 = registry.clone();
|
||||
|
||||
let h1 = tokio::spawn(async move {
|
||||
for _ in 0..10 {
|
||||
let _ = r1.list().await;
|
||||
}
|
||||
});
|
||||
let h2 = tokio::spawn(async move {
|
||||
for _ in 0..10 {
|
||||
let _ = r2.list().await;
|
||||
}
|
||||
});
|
||||
|
||||
h1.await.unwrap();
|
||||
h2.await.unwrap();
|
||||
|
||||
// Verify skill is still accessible
|
||||
let skill = registry.get(&SkillId::new("test-skill")).await;
|
||||
assert!(skill.is_some(), "skill should still be registered");
|
||||
}
|
||||
@@ -223,6 +223,33 @@ impl Serialize for ZclawError {
|
||||
/// Result type alias for ZCLAW operations
|
||||
pub type Result<T> = std::result::Result<T, ZclawError>;
|
||||
|
||||
/// LLM 调用错误的细粒度分类,指导重试和恢复策略
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum LlmErrorKind {
|
||||
Auth,
|
||||
AuthPermanent,
|
||||
BillingExhausted,
|
||||
RateLimited,
|
||||
Overloaded,
|
||||
ServerError,
|
||||
Timeout,
|
||||
ContextOverflow,
|
||||
ModelNotFound,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// 分类后的 LLM 错误,附带恢复提示
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClassifiedLlmError {
|
||||
pub kind: LlmErrorKind,
|
||||
pub retryable: bool,
|
||||
pub should_compress: bool,
|
||||
pub should_rotate_credential: bool,
|
||||
pub retry_after: Option<std::time::Duration>,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
210
desktop/src-tauri/src/intelligence/cold_start_prompt.rs
Normal file
210
desktop/src-tauri/src/intelligence/cold_start_prompt.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
//! Cold start prompt generation for conversation-driven onboarding.
|
||||
//!
|
||||
//! Generates stage-specific system prompts that guide the agent through
|
||||
//! the 6-phase cold start flow without requiring form-filling.
|
||||
|
||||
/// Cold start phases matching the frontend state machine.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ColdStartPhase {
|
||||
Idle,
|
||||
AgentGreeting,
|
||||
IndustryDiscovery,
|
||||
IdentitySetup,
|
||||
FirstTask,
|
||||
Completed,
|
||||
}
|
||||
|
||||
impl ColdStartPhase {
|
||||
pub fn from_str(s: &str) -> Self {
|
||||
match s {
|
||||
"idle" => Self::Idle,
|
||||
"agent_greeting" => Self::AgentGreeting,
|
||||
"industry_discovery" => Self::IndustryDiscovery,
|
||||
"identity_setup" => Self::IdentitySetup,
|
||||
"first_task" => Self::FirstTask,
|
||||
"completed" => Self::Completed,
|
||||
_ => Self::Idle,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Industry-specific task suggestions for first_task phase.
|
||||
struct IndustryTasks {
|
||||
tasks: &'static [(&'static str, &'static str)],
|
||||
}
|
||||
|
||||
const HEALTHCARE_TASKS: IndustryTasks = IndustryTasks {
|
||||
tasks: &[
|
||||
("排班查询", "今天有需要处理的排班问题吗?"),
|
||||
("数据报表", "需要我帮你整理上周的数据报表吗?"),
|
||||
("政策查询", "最近有医保政策变化需要了解吗?"),
|
||||
],
|
||||
};
|
||||
|
||||
const EDUCATION_TASKS: IndustryTasks = IndustryTasks {
|
||||
tasks: &[
|
||||
("课程安排", "需要帮你安排下周的课程吗?"),
|
||||
("成绩分析", "有学生成绩需要分析吗?"),
|
||||
("测验生成", "需要帮学生出一份测验吗?告诉我科目和年级就行。"),
|
||||
],
|
||||
};
|
||||
|
||||
const GARMENT_TASKS: IndustryTasks = IndustryTasks {
|
||||
tasks: &[
|
||||
("订单跟踪", "有需要跟踪的订单吗?"),
|
||||
("生产排期", "需要安排生产计划吗?"),
|
||||
("成本核算", "有需要核算的成本数据吗?"),
|
||||
],
|
||||
};
|
||||
|
||||
const ECOMMERCE_TASKS: IndustryTasks = IndustryTasks {
|
||||
tasks: &[
|
||||
("库存检查", "需要检查库存情况吗?"),
|
||||
("销售分析", "想看看最近的销售数据吗?"),
|
||||
("商品文案", "有新商品需要写详情页吗?"),
|
||||
],
|
||||
}
|
||||
|
||||
;
|
||||
|
||||
/// Generate the cold start system prompt for a given phase and optional industry.
|
||||
pub fn generate_cold_start_prompt(phase: ColdStartPhase, industry: Option<&str>) -> String {
|
||||
match phase {
|
||||
ColdStartPhase::Idle | ColdStartPhase::AgentGreeting => format!(
|
||||
"你是一个正在认识新用户的 AI 管家。\n\n\
|
||||
## 当前任务\n\
|
||||
向用户打招呼并了解他们的工作。用简短自然的方式询问。\n\n\
|
||||
## 规则\n\
|
||||
- 每条消息不超过 3 句话\n\
|
||||
- 不要问\"你的行业是什么\",而是问\"你每天最常处理什么事?\"\n\
|
||||
- 保持热情友好,像一个新同事在打招呼\n\
|
||||
- 用中文交流"
|
||||
),
|
||||
|
||||
ColdStartPhase::IndustryDiscovery => {
|
||||
let industry_hint = match industry {
|
||||
Some("healthcare") => "用户可能从事医疗行政工作。",
|
||||
Some("education") => "用户可能从事教育培训工作。",
|
||||
Some("garment") => "用户可能从事制衣制造工作。",
|
||||
Some("ecommerce") => "用户可能从事电商零售工作。",
|
||||
_ => "继续了解用户的工作场景。",
|
||||
};
|
||||
format!(
|
||||
"你是一个正在了解用户工作场景的 AI 管家。\n\n\
|
||||
## 当前阶段:行业发现\n\
|
||||
{industry_hint}\n\n\
|
||||
## 规则\n\
|
||||
- 根据用户的回答确认行业\n\
|
||||
- 如果检测到行业,主动说出你的理解,让用户确认\n\
|
||||
- 每条消息不超过 3 句话\n\
|
||||
- 用中文交流"
|
||||
)
|
||||
}
|
||||
|
||||
ColdStartPhase::IdentitySetup => {
|
||||
let name_suggestion = match industry {
|
||||
Some("healthcare") => "小医",
|
||||
Some("education") => "小教",
|
||||
Some("garment") => "小织",
|
||||
Some("ecommerce") => "小商",
|
||||
_ => "小助手",
|
||||
};
|
||||
format!(
|
||||
"你是一个正在为自己起名字的 AI 管家。\n\n\
|
||||
## 当前阶段:身份设定\n\
|
||||
根据你了解的行业信息,向用户提议一个合适的名字和沟通风格。\n\n\
|
||||
## 建议\n\
|
||||
- 可以提议叫\"{name_suggestion}\"或其他合适的名字\n\
|
||||
- 说明你选择的沟通风格(专业/亲切/简洁)\n\
|
||||
- 让用户确认或提出自己的想法\n\
|
||||
- 每条消息不超过 3 句话\n\
|
||||
- 用中文交流"
|
||||
)
|
||||
}
|
||||
|
||||
ColdStartPhase::FirstTask => {
|
||||
let task_prompt = match industry {
|
||||
Some("healthcare") => HEALTHCARE_TASKS.tasks[2].1,
|
||||
Some("education") => EDUCATION_TASKS.tasks[2].1,
|
||||
Some("garment") => GARMENT_TASKS.tasks[2].1,
|
||||
Some("ecommerce") => ECOMMERCE_TASKS.tasks[2].1,
|
||||
_ => "有什么我可以帮你的吗?",
|
||||
};
|
||||
format!(
|
||||
"你是一个 AI 管家,用户已经完成了初始设置。\n\n\
|
||||
## 当前阶段:首次任务引导\n\
|
||||
引导用户完成第一个实际任务,让他们体验你的能力。\n\n\
|
||||
## 建议\n\
|
||||
- {task_prompt}\n\
|
||||
- 根据用户需求灵活调整\n\
|
||||
- 保持简短,1-2 句话\n\
|
||||
- 用中文交流"
|
||||
)
|
||||
}
|
||||
|
||||
ColdStartPhase::Completed => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a cold start prompt should be injected for the given phase.
|
||||
pub fn should_inject_prompt(phase: ColdStartPhase) -> bool {
|
||||
!matches!(phase, ColdStartPhase::Idle | ColdStartPhase::Completed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_phase_from_str() {
|
||||
assert_eq!(ColdStartPhase::from_str("idle"), ColdStartPhase::Idle);
|
||||
assert_eq!(ColdStartPhase::from_str("agent_greeting"), ColdStartPhase::AgentGreeting);
|
||||
assert_eq!(ColdStartPhase::from_str("industry_discovery"), ColdStartPhase::IndustryDiscovery);
|
||||
assert_eq!(ColdStartPhase::from_str("identity_setup"), ColdStartPhase::IdentitySetup);
|
||||
assert_eq!(ColdStartPhase::from_str("first_task"), ColdStartPhase::FirstTask);
|
||||
assert_eq!(ColdStartPhase::from_str("completed"), ColdStartPhase::Completed);
|
||||
assert_eq!(ColdStartPhase::from_str("unknown"), ColdStartPhase::Idle);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_greeting_prompt_not_empty() {
|
||||
let prompt = generate_cold_start_prompt(ColdStartPhase::AgentGreeting, None);
|
||||
assert!(!prompt.is_empty());
|
||||
assert!(prompt.contains("AI 管家"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_industry_discovery_with_industry() {
|
||||
let prompt = generate_cold_start_prompt(ColdStartPhase::IndustryDiscovery, Some("healthcare"));
|
||||
assert!(prompt.contains("医疗行政"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_identity_setup_suggests_name() {
|
||||
let prompt = generate_cold_start_prompt(ColdStartPhase::IdentitySetup, Some("education"));
|
||||
assert!(prompt.contains("小教"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_first_task_has_suggestion() {
|
||||
let prompt = generate_cold_start_prompt(ColdStartPhase::FirstTask, Some("ecommerce"));
|
||||
assert!(!prompt.is_empty());
|
||||
assert!(prompt.contains("库存") || prompt.contains("销售") || prompt.contains("商品"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_completed_returns_empty() {
|
||||
let prompt = generate_cold_start_prompt(ColdStartPhase::Completed, None);
|
||||
assert!(prompt.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_inject() {
|
||||
assert!(!should_inject_prompt(ColdStartPhase::Idle));
|
||||
assert!(should_inject_prompt(ColdStartPhase::AgentGreeting));
|
||||
assert!(should_inject_prompt(ColdStartPhase::IndustryDiscovery));
|
||||
assert!(should_inject_prompt(ColdStartPhase::IdentitySetup));
|
||||
assert!(should_inject_prompt(ColdStartPhase::FirstTask));
|
||||
assert!(!should_inject_prompt(ColdStartPhase::Completed));
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,21 @@ use zclaw_types::Result;
|
||||
use super::pain_aggregator::PainPoint;
|
||||
use super::solution_generator::Proposal;
|
||||
|
||||
/// Brief summary of a stored experience, for suggestion context enrichment.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExperienceBrief {
|
||||
pub pain_pattern: String,
|
||||
pub solution_summary: String,
|
||||
pub reuse_count: u32,
|
||||
}
|
||||
|
||||
static EXPERIENCE_EXTRACTOR: std::sync::OnceLock<std::sync::Arc<ExperienceExtractor>> = std::sync::OnceLock::new();
|
||||
|
||||
/// Get the global ExperienceExtractor singleton (if initialized).
|
||||
pub(crate) fn get_experience_extractor() -> Option<std::sync::Arc<ExperienceExtractor>> {
|
||||
EXPERIENCE_EXTRACTOR.get().cloned()
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared completion status
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -263,6 +278,36 @@ fn xml_escape(s: &str) -> String {
|
||||
.replace('>', ">")
|
||||
}
|
||||
|
||||
/// Initialize the global ExperienceExtractor singleton.
|
||||
/// Called once during app startup, after viking storage is ready.
|
||||
pub async fn init_experience_extractor() -> Result<()> {
|
||||
let sqlite_storage = crate::viking_commands::get_storage().await
|
||||
.map_err(|e| zclaw_types::ZclawError::StorageError(e))?;
|
||||
let viking = std::sync::Arc::new(zclaw_growth::VikingAdapter::new(sqlite_storage));
|
||||
let store = std::sync::Arc::new(ExperienceStore::new(viking));
|
||||
let extractor = std::sync::Arc::new(ExperienceExtractor::new(store));
|
||||
EXPERIENCE_EXTRACTOR.set(extractor)
|
||||
.map_err(|_| zclaw_types::ZclawError::StorageError("ExperienceExtractor already initialized".into()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find experiences relevant to the current conversation for suggestion enrichment.
|
||||
#[tauri::command]
|
||||
pub async fn experience_find_relevant(
|
||||
agent_id: String,
|
||||
query: String,
|
||||
) -> std::result::Result<Vec<ExperienceBrief>, String> {
|
||||
let extractor = get_experience_extractor()
|
||||
.ok_or("ExperienceExtractor not initialized".to_string())?;
|
||||
let experiences = extractor.find_relevant_experiences(&agent_id, &query).await;
|
||||
Ok(experiences.into_iter().take(3).map(|e| ExperienceBrief {
|
||||
pain_pattern: e.pain_pattern,
|
||||
solution_summary: e.solution_steps.join(";")
|
||||
.chars().take(100).collect(),
|
||||
reuse_count: e.reuse_count,
|
||||
}).collect())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -407,4 +452,17 @@ mod tests {
|
||||
assert_eq!(truncate("hello", 10), "hello");
|
||||
assert_eq!(truncate("这是一个很长的字符串用于测试截断", 10).chars().count(), 11); // 10 + …
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_experience_brief_serialization() {
|
||||
let brief = super::ExperienceBrief {
|
||||
pain_pattern: "报表生成慢".to_string(),
|
||||
solution_summary: "使用 researcher 技能自动收集".to_string(),
|
||||
reuse_count: 3,
|
||||
};
|
||||
let json = serde_json::to_string(&brief).unwrap();
|
||||
let parsed: super::ExperienceBrief = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.pain_pattern, "报表生成慢");
|
||||
assert_eq!(parsed.reuse_count, 3);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,9 +47,30 @@ pub async fn health_snapshot(
|
||||
) -> Result<HealthSnapshot, String> {
|
||||
let engines = heartbeat_state.lock().await;
|
||||
|
||||
let engine = engines
|
||||
.get(&agent_id)
|
||||
.ok_or_else(|| format!("Heartbeat engine not initialized for agent: {}", agent_id))?;
|
||||
// If heartbeat engine not yet initialized, return a graceful "pending" snapshot
|
||||
// instead of erroring — this avoids race conditions when HealthPanel mounts
|
||||
// before the heartbeat bootstrap sequence completes.
|
||||
let engine = match engines.get(&agent_id) {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
tracing::debug!("[health_snapshot] Engine not initialized for {}, returning pending snapshot", agent_id);
|
||||
return Ok(HealthSnapshot {
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
intelligence: IntelligenceHealth {
|
||||
engine_running: false,
|
||||
config: HeartbeatConfig::default(),
|
||||
last_tick: None,
|
||||
alert_count_24h: 0,
|
||||
total_checks: 5,
|
||||
},
|
||||
memory: MemoryHealth {
|
||||
total_entries: 0,
|
||||
storage_size_bytes: 0,
|
||||
last_extraction: None,
|
||||
},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let engine_running = engine.is_running().await;
|
||||
let config = engine.get_config().await;
|
||||
|
||||
@@ -357,6 +357,7 @@ async fn execute_tick(
|
||||
let checks: Vec<(&str, fn(&str) -> Option<HeartbeatAlert>)> = vec![
|
||||
("pending-tasks", check_pending_tasks),
|
||||
("memory-health", check_memory_health),
|
||||
("unresolved-pains", check_unresolved_pains),
|
||||
("idle-greeting", check_idle_greeting),
|
||||
("personality-improvement", check_personality_improvement),
|
||||
("learning-opportunities", check_learning_opportunities),
|
||||
@@ -447,7 +448,48 @@ static MEMORY_STATS_CACHE: OnceLock<RwLock<StdHashMap<String, MemoryStatsCache>>
|
||||
/// Key: agent_id, Value: last interaction timestamp (RFC3339)
|
||||
static LAST_INTERACTION: OnceLock<RwLock<StdHashMap<String, String>>> = OnceLock::new();
|
||||
|
||||
/// Cached memory stats for an agent
|
||||
/// Global pain points cache (updated by frontend via Tauri command)
|
||||
/// Key: agent_id, Value: list of unresolved pain point descriptions
|
||||
static PAIN_POINTS_CACHE: OnceLock<RwLock<StdHashMap<String, Vec<String>>>> = OnceLock::new();
|
||||
|
||||
fn get_pain_points_cache() -> &'static RwLock<StdHashMap<String, Vec<String>>> {
|
||||
PAIN_POINTS_CACHE.get_or_init(|| RwLock::new(StdHashMap::new()))
|
||||
}
|
||||
|
||||
/// Update pain points cache (called from frontend or growth middleware)
|
||||
pub fn update_pain_points_cache(agent_id: &str, pain_points: Vec<String>) {
|
||||
let cache = get_pain_points_cache();
|
||||
if let Ok(mut cache) = cache.write() {
|
||||
cache.insert(agent_id.to_string(), pain_points);
|
||||
}
|
||||
}
|
||||
|
||||
/// Global experience cache: high-reuse experiences per agent.
|
||||
/// Key: agent_id, Value: list of (tool_used, reuse_count) tuples.
|
||||
static EXPERIENCE_CACHE: OnceLock<RwLock<StdHashMap<String, Vec<(String, u32)>>>> = OnceLock::new();
|
||||
|
||||
fn get_experience_cache() -> &'static RwLock<StdHashMap<String, Vec<(String, u32)>>> {
|
||||
EXPERIENCE_CACHE.get_or_init(|| RwLock::new(StdHashMap::new()))
|
||||
}
|
||||
|
||||
/// Update experience cache (called from frontend or growth middleware)
|
||||
pub fn update_experience_cache(agent_id: &str, experiences: Vec<(String, u32)>) {
|
||||
let cache = get_experience_cache();
|
||||
if let Ok(mut cache) = cache.write() {
|
||||
cache.insert(agent_id.to_string(), experiences);
|
||||
}
|
||||
}
|
||||
|
||||
fn get_cached_experiences(agent_id: &str) -> Option<Vec<(String, u32)>> {
|
||||
let cache = get_experience_cache();
|
||||
cache.read().ok()?.get(agent_id).cloned()
|
||||
}
|
||||
|
||||
/// Get cached pain points for an agent
|
||||
fn get_cached_pain_points(agent_id: &str) -> Option<Vec<String>> {
|
||||
let cache = get_pain_points_cache();
|
||||
cache.read().ok().and_then(|c| c.get(agent_id).cloned())
|
||||
}
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct MemoryStatsCache {
|
||||
pub task_count: usize,
|
||||
@@ -755,6 +797,47 @@ fn check_learning_opportunities(agent_id: &str) -> Option<HeartbeatAlert> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Check for unresolved user pain points accumulated by the butler system.
|
||||
/// When pain points persist across multiple conversations, surface them as
|
||||
/// proactive suggestions. Also considers high-reuse experiences to generate
|
||||
/// contextual skill suggestions.
|
||||
fn check_unresolved_pains(agent_id: &str) -> Option<HeartbeatAlert> {
|
||||
let pains = get_cached_pain_points(agent_id)?;
|
||||
if pains.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let count = pains.len();
|
||||
let summary = if count <= 3 {
|
||||
pains.join("、")
|
||||
} else {
|
||||
format!("{}等 {} 项", pains[..3].join("、"), count)
|
||||
};
|
||||
|
||||
// Enhance with experience-based suggestions
|
||||
let experience_hint = if let Some(experiences) = get_cached_experiences(agent_id) {
|
||||
let high_use: Vec<&(String, u32)> = experiences.iter().filter(|(_, c)| *c >= 3).collect();
|
||||
if !high_use.is_empty() {
|
||||
let tools: Vec<&str> = high_use.iter().map(|(t, _)| t.as_str()).collect();
|
||||
format!(" 用户频繁使用{},可主动提供相关技能建议。", tools.join("、"))
|
||||
} else {
|
||||
String::new()
|
||||
}
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
Some(HeartbeatAlert {
|
||||
title: "未解决的用户痛点".to_string(),
|
||||
content: format!(
|
||||
"检测到 {} 个持续痛点:{}。建议主动提供解决方案或相关建议。{}",
|
||||
count, summary, experience_hint
|
||||
),
|
||||
urgency: if count >= 3 { Urgency::High } else { Urgency::Medium },
|
||||
source: "unresolved-pains".to_string(),
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
})
|
||||
}
|
||||
|
||||
// === Tauri Commands ===
|
||||
|
||||
/// Heartbeat engine state for Tauri
|
||||
@@ -800,6 +883,9 @@ pub async fn heartbeat_init(
|
||||
// Restore heartbeat history from VikingStorage metadata
|
||||
engine.restore_history().await;
|
||||
|
||||
// Restore pain points cache from VikingStorage metadata
|
||||
restore_pain_points(&agent_id).await;
|
||||
|
||||
let mut engines = state.lock().await;
|
||||
engines.insert(agent_id, engine);
|
||||
Ok(())
|
||||
@@ -865,6 +951,33 @@ pub async fn restore_last_interaction(agent_id: &str) {
|
||||
}
|
||||
}
|
||||
|
||||
/// Restore pain points cache from VikingStorage metadata.
|
||||
async fn restore_pain_points(agent_id: &str) {
|
||||
let key = format!("heartbeat:pain_points:{}", agent_id);
|
||||
match crate::viking_commands::get_storage().await {
|
||||
Ok(storage) => {
|
||||
match zclaw_growth::VikingStorage::get_metadata_json(&*storage, &key).await {
|
||||
Ok(Some(json)) => {
|
||||
if let Ok(pains) = serde_json::from_str::<Vec<String>>(&json) {
|
||||
let count = pains.len();
|
||||
update_pain_points_cache(agent_id, pains);
|
||||
tracing::info!("[heartbeat] Restored {} pain points for {}", count, agent_id);
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::debug!("[heartbeat] No persisted pain points for {}", agent_id);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[heartbeat] Failed to restore pain points: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[heartbeat] Storage unavailable for pain points restore: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Start heartbeat engine for an agent
|
||||
// @connected
|
||||
#[tauri::command]
|
||||
@@ -998,6 +1111,51 @@ pub async fn heartbeat_record_interaction(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update pain points cache for heartbeat pain-awareness checks.
|
||||
/// Called by frontend when pain points are extracted from conversations.
|
||||
// @connected
|
||||
#[tauri::command]
|
||||
pub async fn heartbeat_update_pain_points(
|
||||
agent_id: String,
|
||||
pain_points: Vec<String>,
|
||||
) -> Result<(), String> {
|
||||
update_pain_points_cache(&agent_id, pain_points.clone());
|
||||
// Persist to VikingStorage for survival across restarts
|
||||
let key = format!("heartbeat:pain_points:{}", agent_id);
|
||||
tokio::spawn(async move {
|
||||
if let Ok(storage) = crate::viking_commands::get_storage().await {
|
||||
if let Ok(json) = serde_json::to_string(&pain_points) {
|
||||
if let Err(e) = zclaw_growth::VikingStorage::store_metadata_json(&*storage, &key, &json).await {
|
||||
tracing::warn!("[heartbeat] Failed to persist pain points: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update experience cache for heartbeat proactive suggestions.
|
||||
/// Called by frontend when high-reuse experiences are detected.
|
||||
// @reserved
|
||||
#[tauri::command]
|
||||
pub async fn heartbeat_update_experiences(
|
||||
agent_id: String,
|
||||
experiences: Vec<(String, u32)>,
|
||||
) -> Result<(), String> {
|
||||
update_experience_cache(&agent_id, experiences.clone());
|
||||
let key = format!("heartbeat:experiences:{}", agent_id);
|
||||
tokio::spawn(async move {
|
||||
if let Ok(storage) = crate::viking_commands::get_storage().await {
|
||||
if let Ok(json) = serde_json::to_string(&experiences) {
|
||||
if let Err(e) = zclaw_growth::VikingStorage::store_metadata_json(&*storage, &key, &json).await {
|
||||
tracing::warn!("[heartbeat] Failed to persist experiences: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -45,6 +45,7 @@ pub mod triggers;
|
||||
pub mod user_profiler;
|
||||
pub mod trajectory_compressor;
|
||||
pub mod health_snapshot;
|
||||
pub mod cold_start_prompt;
|
||||
|
||||
// Re-export main types for convenience
|
||||
pub use heartbeat::HeartbeatEngineState;
|
||||
|
||||
@@ -7,13 +7,47 @@
|
||||
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tauri::Emitter;
|
||||
use tokio::sync::RwLock;
|
||||
use zclaw_growth::VikingStorage;
|
||||
|
||||
use crate::intelligence::identity::IdentityManagerState;
|
||||
use crate::intelligence::heartbeat::HeartbeatEngineState;
|
||||
use crate::intelligence::reflection::{MemoryEntryForAnalysis, ReflectionEngineState};
|
||||
use zclaw_runtime::driver::LlmDriver;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Identity prompt cache — avoids mutex + disk I/O on every request
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct CachedIdentity {
|
||||
prompt: String,
|
||||
#[allow(dead_code)] // Reserved for future TTL-based cache validation
|
||||
soul_hash: u64,
|
||||
}
|
||||
|
||||
static IDENTITY_CACHE: std::sync::LazyLock<RwLock<HashMap<String, CachedIdentity>>> =
|
||||
std::sync::LazyLock::new(|| RwLock::new(HashMap::new()));
|
||||
|
||||
/// Invalidate cached identity prompt for a given agent (call when soul.md changes).
|
||||
pub fn invalidate_identity_cache(agent_id: &str) {
|
||||
let cache = &*IDENTITY_CACHE;
|
||||
// Non-blocking: spawn a task to remove the entry
|
||||
if let Ok(mut guard) = cache.try_write() {
|
||||
guard.remove(agent_id);
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple hash for cache invalidation — uses string content hash.
|
||||
fn content_hash(s: &str) -> u64 {
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = std::collections::hash_map::DefaultHasher::new();
|
||||
s.hash(&mut hasher);
|
||||
hasher.finish()
|
||||
}
|
||||
|
||||
/// Run pre-conversation intelligence hooks
|
||||
///
|
||||
/// Builds identity-enhanced system prompt (SOUL.md + instructions) and
|
||||
@@ -27,10 +61,29 @@ pub async fn pre_conversation_hook(
|
||||
_user_message: &str,
|
||||
identity_state: &IdentityManagerState,
|
||||
) -> Result<String, String> {
|
||||
// Build identity-enhanced system prompt (SOUL.md + instructions)
|
||||
// Memory context is injected by MemoryMiddleware in the kernel middleware chain,
|
||||
// not here, to avoid duplicate injection.
|
||||
let enhanced_prompt = match build_identity_prompt(agent_id, "", identity_state).await {
|
||||
// Check identity prompt cache first (avoids mutex + disk I/O)
|
||||
let cache = &*IDENTITY_CACHE;
|
||||
{
|
||||
let guard = cache.read().await;
|
||||
if let Some(cached) = guard.get(agent_id) {
|
||||
// Cache hit — still need continuity context, but skip identity build
|
||||
let continuity_context = build_continuity_context(agent_id, _user_message).await;
|
||||
let mut result = cached.prompt.clone();
|
||||
if !continuity_context.is_empty() {
|
||||
result.push_str(&continuity_context);
|
||||
}
|
||||
debug!("[intelligence_hooks] Identity cache HIT for agent {}", agent_id);
|
||||
return Ok(result);
|
||||
}
|
||||
}
|
||||
|
||||
// Cache miss — build identity prompt and continuity context in parallel
|
||||
let (identity_result, continuity_context) = tokio::join!(
|
||||
build_identity_prompt_cached(agent_id, "", identity_state, cache),
|
||||
build_continuity_context(agent_id, _user_message)
|
||||
);
|
||||
|
||||
let enhanced_prompt = match identity_result {
|
||||
Ok(prompt) => prompt,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
@@ -41,9 +94,6 @@ pub async fn pre_conversation_hook(
|
||||
}
|
||||
};
|
||||
|
||||
// Cross-session continuity: check for unresolved pain points and recent experiences
|
||||
let continuity_context = build_continuity_context(agent_id, _user_message).await;
|
||||
|
||||
let mut result = enhanced_prompt;
|
||||
if !continuity_context.is_empty() {
|
||||
result.push_str(&continuity_context);
|
||||
@@ -56,12 +106,15 @@ pub async fn pre_conversation_hook(
|
||||
///
|
||||
/// 1. Record interaction for heartbeat engine
|
||||
/// 2. Record conversation for reflection engine, trigger reflection if needed
|
||||
/// 3. Detect identity signals and write back to identity files
|
||||
pub async fn post_conversation_hook(
|
||||
agent_id: &str,
|
||||
_user_message: &str,
|
||||
_heartbeat_state: &HeartbeatEngineState,
|
||||
reflection_state: &ReflectionEngineState,
|
||||
llm_driver: Option<Arc<dyn LlmDriver>>,
|
||||
identity_state: &IdentityManagerState,
|
||||
app: &tauri::AppHandle,
|
||||
) {
|
||||
// Step 1: Record interaction for heartbeat
|
||||
crate::intelligence::heartbeat::record_interaction(agent_id);
|
||||
@@ -200,6 +253,73 @@ pub async fn post_conversation_hook(
|
||||
reflection_result.improvements.len()
|
||||
);
|
||||
}
|
||||
|
||||
// Step 3: Detect identity signals from recent memory extraction and write back
|
||||
if let Ok(storage) = crate::viking_commands::get_storage().await {
|
||||
let identity_prefix = format!("agent://{}/identity/", agent_id);
|
||||
|
||||
// Check for agent_name identity signal
|
||||
let agent_name_uri = format!("{}agent-name", identity_prefix);
|
||||
if let Ok(Some(entry)) = VikingStorage::get(storage.as_ref(), &agent_name_uri).await {
|
||||
// Extract name from content like "助手的名字是小马"
|
||||
let name = entry.content.strip_prefix("助手的名字是")
|
||||
.map(|n| n.trim().to_string())
|
||||
.unwrap_or_else(|| entry.content.clone());
|
||||
|
||||
if !name.is_empty() {
|
||||
// Update IdentityFiles.soul to include the agent name
|
||||
let mut manager = identity_state.lock().await;
|
||||
let current_soul = manager.get_file(agent_id, crate::intelligence::identity::IdentityFile::Soul);
|
||||
|
||||
// Only update if the name isn't already in the soul
|
||||
if !current_soul.contains(&name) {
|
||||
let updated_soul = if current_soul.is_empty() {
|
||||
format!("# ZCLAW 人格\n\n你的名字是{}。\n\n你是一个成长性的中文 AI 助手。", name)
|
||||
} else if current_soul.contains("你的名字是") || current_soul.contains("你的名字:") {
|
||||
// Replace existing name line
|
||||
let re = regex::Regex::new(r"你的名字是[^\n]+").unwrap();
|
||||
re.replace(¤t_soul, format!("你的名字是{}", name)).to_string()
|
||||
} else {
|
||||
// Prepend name to existing soul
|
||||
format!("你的名字是{}。\n\n{}", name, current_soul)
|
||||
};
|
||||
|
||||
if let Err(e) = manager.update_file(agent_id, "soul", &updated_soul) {
|
||||
warn!("[intelligence_hooks] Failed to update soul with agent name: {}", e);
|
||||
} else {
|
||||
debug!("[intelligence_hooks] Updated agent name to '{}' in soul", name);
|
||||
// Invalidate cache since soul.md changed
|
||||
invalidate_identity_cache(agent_id);
|
||||
}
|
||||
}
|
||||
drop(manager);
|
||||
|
||||
// Emit event for frontend to update AgentConfig.name
|
||||
let _ = app.emit("zclaw:agent-identity-updated", serde_json::json!({
|
||||
"agentId": agent_id,
|
||||
"agentName": name,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
// Check for user_name identity signal
|
||||
let user_name_uri = format!("{}user-name", identity_prefix);
|
||||
if let Ok(Some(entry)) = VikingStorage::get(storage.as_ref(), &user_name_uri).await {
|
||||
let name = entry.content.strip_prefix("用户的名字是")
|
||||
.map(|n| n.trim().to_string())
|
||||
.unwrap_or_else(|| entry.content.clone());
|
||||
|
||||
if !name.is_empty() {
|
||||
let mut manager = identity_state.lock().await;
|
||||
let profile = manager.get_file(agent_id, crate::intelligence::identity::IdentityFile::UserProfile);
|
||||
|
||||
if !profile.contains(&name) {
|
||||
manager.append_to_user_profile(agent_id, &format!("- 用户名字: {}", name));
|
||||
debug!("[intelligence_hooks] Appended user name '{}' to profile", name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build memory context by searching VikingStorage for relevant memories
|
||||
@@ -270,21 +390,34 @@ async fn build_memory_context(
|
||||
Ok(context)
|
||||
}
|
||||
|
||||
/// Build identity-enhanced system prompt
|
||||
async fn build_identity_prompt(
|
||||
/// Build identity-enhanced system prompt and cache the result.
|
||||
async fn build_identity_prompt_cached(
|
||||
agent_id: &str,
|
||||
memory_context: &str,
|
||||
identity_state: &IdentityManagerState,
|
||||
cache: &RwLock<HashMap<String, CachedIdentity>>,
|
||||
) -> Result<String, String> {
|
||||
// IdentityManagerState is Arc<tokio::sync::Mutex<AgentIdentityManager>>
|
||||
// tokio::sync::Mutex::lock() returns MutexGuard directly
|
||||
let mut manager = identity_state.lock().await;
|
||||
|
||||
// Read current soul content for hashing
|
||||
let soul_content = manager.get_file(agent_id, crate::intelligence::identity::IdentityFile::Soul);
|
||||
let soul_hash = content_hash(&soul_content);
|
||||
|
||||
let prompt = manager.build_system_prompt(
|
||||
agent_id,
|
||||
if memory_context.is_empty() { None } else { Some(memory_context) },
|
||||
).await;
|
||||
|
||||
// Cache the result
|
||||
drop(manager); // Release lock before acquiring write guard
|
||||
{
|
||||
let mut guard = cache.write().await;
|
||||
guard.insert(agent_id.to_string(), CachedIdentity {
|
||||
prompt: prompt.clone(),
|
||||
soul_hash,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ use zclaw_types::{AgentConfig, AgentId, AgentInfo};
|
||||
|
||||
use super::{validate_agent_id, KernelState};
|
||||
use crate::intelligence::validation::validate_string_length;
|
||||
use crate::intelligence::identity::IdentityManagerState;
|
||||
use crate::intelligence::identity::{IdentityFile, IdentityManagerState};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Request / Response types
|
||||
@@ -185,16 +185,23 @@ pub async fn agent_get(
|
||||
|
||||
let mut info = kernel.get_agent(&id);
|
||||
|
||||
// Extend with UserProfile if available
|
||||
// Extend with UserProfile if available (reads from same MemoryStore pool as middleware writes to)
|
||||
if let Some(ref mut agent_info) = info {
|
||||
if let Ok(storage) = crate::viking_commands::get_storage().await {
|
||||
let profile_store = zclaw_memory::UserProfileStore::new(storage.pool().clone());
|
||||
if let Ok(Some(profile)) = profile_store.get(&agent_id).await {
|
||||
let memory_store = kernel.memory();
|
||||
let profile_store = zclaw_memory::UserProfileStore::new(memory_store.pool());
|
||||
match profile_store.get(&agent_id).await {
|
||||
Ok(Some(profile)) => {
|
||||
match serde_json::to_value(&profile) {
|
||||
Ok(val) => agent_info.user_profile = Some(val),
|
||||
Err(e) => tracing::warn!("[agent_get] Failed to serialize UserProfile: {}", e),
|
||||
}
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::debug!("[agent_get] No UserProfile found for agent {}", agent_id);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("[agent_get] Failed to read UserProfile: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,6 +235,7 @@ pub async fn agent_delete(
|
||||
#[tauri::command]
|
||||
pub async fn agent_update(
|
||||
state: State<'_, KernelState>,
|
||||
identity_state: State<'_, IdentityManagerState>,
|
||||
agent_id: String,
|
||||
updates: AgentUpdateRequest,
|
||||
) -> Result<AgentInfo, String> {
|
||||
@@ -246,6 +254,20 @@ pub async fn agent_update(
|
||||
|
||||
// Apply updates
|
||||
if let Some(name) = updates.name {
|
||||
// Sync name to identity soul so next session's system prompt includes it
|
||||
let mut identity_mgr = identity_state.lock().await;
|
||||
let current_soul = identity_mgr.get_file(&agent_id, IdentityFile::Soul);
|
||||
let updated_soul = if current_soul.is_empty() {
|
||||
format!("# ZCLAW 人格\n\n你的名字是{}。\n\n你是一个成长性的中文 AI 助手。", name)
|
||||
} else if current_soul.contains("你的名字是") {
|
||||
let re = regex::Regex::new(r"你的名字是[^\n]+").unwrap();
|
||||
re.replace(¤t_soul, format!("你的名字是{}", name)).to_string()
|
||||
} else {
|
||||
format!("你的名字是{}。\n\n{}", name, current_soul)
|
||||
};
|
||||
let _ = identity_mgr.update_file(&agent_id, "soul", &updated_soul);
|
||||
drop(identity_mgr);
|
||||
|
||||
config.name = name;
|
||||
}
|
||||
if let Some(description) = updates.description {
|
||||
|
||||
@@ -7,6 +7,7 @@ use zclaw_types::AgentId;
|
||||
|
||||
use super::{validate_agent_id, KernelState, SessionStreamGuard, StreamCancelFlags};
|
||||
use crate::intelligence::validation::validate_string_length;
|
||||
use zclaw_runtime::LoopEvent;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Request / Response types
|
||||
@@ -60,6 +61,47 @@ pub enum StreamChatEvent {
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
/// Translate a runtime LoopEvent into a Tauri StreamChatEvent.
|
||||
///
|
||||
/// Hand tools (name starts with "hand_") are mapped to HandStart/HandEnd
|
||||
/// variants; all other tool events use ToolStart/ToolEnd.
|
||||
fn translate_event(event: &zclaw_runtime::LoopEvent) -> StreamChatEvent {
|
||||
match event {
|
||||
LoopEvent::Delta(delta) => StreamChatEvent::Delta { delta: delta.clone() },
|
||||
LoopEvent::ThinkingDelta(delta) => StreamChatEvent::ThinkingDelta { delta: delta.clone() },
|
||||
LoopEvent::ToolStart { name, input } => {
|
||||
if name.starts_with("hand_") {
|
||||
StreamChatEvent::HandStart { name: name.clone(), params: input.clone() }
|
||||
} else {
|
||||
StreamChatEvent::ToolStart { name: name.clone(), input: input.clone() }
|
||||
}
|
||||
}
|
||||
LoopEvent::ToolEnd { name, output } => {
|
||||
if name.starts_with("hand_") {
|
||||
StreamChatEvent::HandEnd { name: name.clone(), result: output.clone() }
|
||||
} else {
|
||||
StreamChatEvent::ToolEnd { name: name.clone(), output: output.clone() }
|
||||
}
|
||||
}
|
||||
LoopEvent::SubtaskStatus { task_id, description, status, detail } => {
|
||||
StreamChatEvent::SubtaskStatus {
|
||||
task_id: task_id.clone(),
|
||||
description: description.clone(),
|
||||
status: status.clone(),
|
||||
detail: detail.clone(),
|
||||
}
|
||||
}
|
||||
LoopEvent::IterationStart { iteration, max_iterations } => {
|
||||
StreamChatEvent::IterationStart { iteration: *iteration, max_iterations: *max_iterations }
|
||||
}
|
||||
LoopEvent::Complete(result) => StreamChatEvent::Complete {
|
||||
input_tokens: result.input_tokens,
|
||||
output_tokens: result.output_tokens,
|
||||
},
|
||||
LoopEvent::Error(message) => StreamChatEvent::Error { message: message.clone() },
|
||||
}
|
||||
}
|
||||
|
||||
/// Streaming chat request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -218,156 +260,71 @@ pub async fn agent_chat_stream(
|
||||
).await.unwrap_or_default();
|
||||
|
||||
// --- Schedule intent interception ---
|
||||
// If the user's message contains a schedule intent (e.g. "每天早上9点提醒我查房"),
|
||||
// parse it with NlScheduleParser, create a trigger, and return confirmation
|
||||
// directly without calling the LLM.
|
||||
let mut captured_parsed: Option<zclaw_runtime::nl_schedule::ParsedSchedule> = None;
|
||||
|
||||
if zclaw_runtime::nl_schedule::has_schedule_intent(&message) {
|
||||
let parse_result = zclaw_runtime::nl_schedule::parse_nl_schedule(&message, &id);
|
||||
|
||||
match parse_result {
|
||||
zclaw_runtime::nl_schedule::ScheduleParseResult::Exact(ref parsed)
|
||||
if parsed.confidence >= 0.8 =>
|
||||
{
|
||||
// Try to create a schedule trigger
|
||||
let kernel_lock = state.lock().await;
|
||||
if let Some(kernel) = kernel_lock.as_ref() {
|
||||
// Use UUID fragment to avoid collision under high concurrency
|
||||
let trigger_id = format!(
|
||||
"sched_{}_{}",
|
||||
chrono::Utc::now().timestamp_millis(),
|
||||
&uuid::Uuid::new_v4().to_string()[..8]
|
||||
);
|
||||
let trigger_config = zclaw_hands::TriggerConfig {
|
||||
id: trigger_id.clone(),
|
||||
name: parsed.task_description.clone(),
|
||||
hand_id: "_reminder".to_string(),
|
||||
trigger_type: zclaw_hands::TriggerType::Schedule {
|
||||
cron: parsed.cron_expression.clone(),
|
||||
},
|
||||
enabled: true,
|
||||
// 60/hour = once per minute max, reasonable for scheduled tasks
|
||||
max_executions_per_hour: 60,
|
||||
};
|
||||
|
||||
match kernel.create_trigger(trigger_config).await {
|
||||
Ok(_entry) => {
|
||||
tracing::info!(
|
||||
"[agent_chat_stream] Schedule trigger created: {} (cron: {})",
|
||||
trigger_id, parsed.cron_expression
|
||||
);
|
||||
captured_parsed = Some(parsed.clone());
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"[agent_chat_stream] Failed to create schedule trigger, falling through to LLM: {}",
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Ambiguous, Unclear, or low confidence — let LLM handle it naturally
|
||||
tracing::debug!(
|
||||
"[agent_chat_stream] Schedule intent detected but not confident enough, falling through to LLM"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the streaming receiver while holding the lock, then release it
|
||||
// NOTE: When schedule_intercepted, llm_driver is None so post_conversation_hook
|
||||
// (memory extraction, heartbeat, reflection) is intentionally skipped —
|
||||
// schedule confirmations are system messages, not user conversations.
|
||||
let (mut rx, llm_driver) = if let Some(parsed) = captured_parsed {
|
||||
// Schedule was intercepted — build confirmation message directly
|
||||
let confirm_msg = format!(
|
||||
"已为您设置定时任务:\n\n- **任务**:{}\n- **时间**:{}\n- **Cron**:`{}`\n\n任务已激活,将在设定时间自动执行。",
|
||||
parsed.task_description,
|
||||
parsed.natural_description,
|
||||
parsed.cron_expression,
|
||||
);
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(32);
|
||||
if tx.send(zclaw_runtime::LoopEvent::Delta(confirm_msg)).await.is_err() {
|
||||
tracing::warn!("[agent_chat_stream] Failed to send confirm msg to new channel");
|
||||
}
|
||||
if tx.send(zclaw_runtime::LoopEvent::Complete(
|
||||
zclaw_runtime::AgentLoopResult {
|
||||
response: String::new(),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
iterations: 1,
|
||||
}
|
||||
)).await.is_err() {
|
||||
tracing::warn!("[agent_chat_stream] Failed to send complete to new channel");
|
||||
}
|
||||
drop(tx);
|
||||
(rx, None)
|
||||
} else {
|
||||
// Normal LLM chat path
|
||||
// Try to intercept schedule intents (e.g. "每天早上9点提醒我查房") at the kernel level.
|
||||
// If intercepted, returns a pre-built confirmation stream — no LLM call needed.
|
||||
let (mut rx, llm_driver) = {
|
||||
let kernel_lock = state.lock().await;
|
||||
let kernel = kernel_lock.as_ref()
|
||||
.ok_or_else(|| {
|
||||
// Cleanup on error: release guard + cancel flag
|
||||
err_cleanup_flag.store(false, std::sync::atomic::Ordering::SeqCst);
|
||||
err_cleanup_guard.remove(&err_cleanup_session_id);
|
||||
err_cleanup_cancel.remove(&err_cleanup_session_id);
|
||||
"Kernel not initialized. Call kernel_init first.".to_string()
|
||||
})?;
|
||||
.ok_or_else(|| "Kernel not initialized. Call kernel_init first.".to_string())?;
|
||||
|
||||
let driver = Some(kernel.driver());
|
||||
|
||||
let prompt_arg = if enhanced_prompt.is_empty() { None } else { Some(enhanced_prompt) };
|
||||
|
||||
let session_id_parsed = if session_id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
match uuid::Uuid::parse_str(&session_id) {
|
||||
Ok(uuid) => Some(zclaw_types::SessionId::from_uuid(uuid)),
|
||||
Err(e) => {
|
||||
// Cleanup on error
|
||||
err_cleanup_flag.store(false, std::sync::atomic::Ordering::SeqCst);
|
||||
err_cleanup_guard.remove(&err_cleanup_session_id);
|
||||
err_cleanup_cancel.remove(&err_cleanup_session_id);
|
||||
return Err(format!(
|
||||
"Invalid session_id '{}': {}. Cannot reuse conversation context.",
|
||||
session_id, e
|
||||
));
|
||||
}
|
||||
match kernel.try_intercept_schedule(&message, &id).await {
|
||||
Ok(Some(intercept)) => {
|
||||
tracing::info!("[agent_chat_stream] Schedule intercepted: {}", intercept.task_description);
|
||||
(intercept.rx, None)
|
||||
}
|
||||
};
|
||||
// Build chat mode config from request parameters
|
||||
let chat_mode_config = zclaw_kernel::ChatModeConfig {
|
||||
thinking_enabled: request.thinking_enabled,
|
||||
reasoning_effort: request.reasoning_effort.clone(),
|
||||
plan_mode: request.plan_mode,
|
||||
subagent_enabled: request.subagent_enabled,
|
||||
};
|
||||
_ => {
|
||||
// No interception or error — normal LLM chat path
|
||||
let driver = Some(kernel.driver());
|
||||
|
||||
let rx = kernel.send_message_stream_with_prompt(
|
||||
&id,
|
||||
message.clone(),
|
||||
prompt_arg,
|
||||
session_id_parsed,
|
||||
Some(chat_mode_config),
|
||||
request.model.clone(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
// Cleanup on error
|
||||
err_cleanup_flag.store(false, std::sync::atomic::Ordering::SeqCst);
|
||||
err_cleanup_guard.remove(&err_cleanup_session_id);
|
||||
err_cleanup_cancel.remove(&err_cleanup_session_id);
|
||||
format!("Failed to start streaming: {}", e)
|
||||
})?;
|
||||
(rx, driver)
|
||||
let prompt_arg = if enhanced_prompt.is_empty() { None } else { Some(enhanced_prompt) };
|
||||
|
||||
let session_id_parsed = if session_id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
match uuid::Uuid::parse_str(&session_id) {
|
||||
Ok(uuid) => Some(zclaw_types::SessionId::from_uuid(uuid)),
|
||||
Err(e) => {
|
||||
err_cleanup_flag.store(false, std::sync::atomic::Ordering::SeqCst);
|
||||
err_cleanup_guard.remove(&err_cleanup_session_id);
|
||||
err_cleanup_cancel.remove(&err_cleanup_session_id);
|
||||
return Err(format!(
|
||||
"Invalid session_id '{}': {}. Cannot reuse conversation context.",
|
||||
session_id, e
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let chat_mode_config = zclaw_kernel::ChatModeConfig {
|
||||
thinking_enabled: request.thinking_enabled,
|
||||
reasoning_effort: request.reasoning_effort.clone(),
|
||||
plan_mode: request.plan_mode,
|
||||
subagent_enabled: request.subagent_enabled,
|
||||
};
|
||||
|
||||
let rx = kernel.send_message_stream_with_prompt(
|
||||
&id,
|
||||
message.clone(),
|
||||
prompt_arg,
|
||||
session_id_parsed,
|
||||
Some(chat_mode_config),
|
||||
request.model.clone(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
err_cleanup_flag.store(false, std::sync::atomic::Ordering::SeqCst);
|
||||
err_cleanup_guard.remove(&err_cleanup_session_id);
|
||||
err_cleanup_cancel.remove(&err_cleanup_session_id);
|
||||
format!("Failed to start streaming: {}", e)
|
||||
})?;
|
||||
(rx, driver)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let hb_state = heartbeat_state.inner().clone();
|
||||
let rf_state = reflection_state.inner().clone();
|
||||
let id_state_hook = identity_state.inner().clone();
|
||||
|
||||
// Clone the guard map for cleanup in the spawned task
|
||||
let guard_map: SessionStreamGuard = stream_guard.inner().clone();
|
||||
@@ -415,69 +372,28 @@ pub async fn agent_chat_stream(
|
||||
|
||||
match tokio::time::timeout(stream_timeout, rx.recv()).await {
|
||||
Ok(Some(event)) => {
|
||||
let stream_event = match &event {
|
||||
LoopEvent::Delta(delta) => {
|
||||
tracing::trace!("[agent_chat_stream] Delta: {} bytes", delta.len());
|
||||
StreamChatEvent::Delta { delta: delta.clone() }
|
||||
// Fire post-conversation hooks before translating (memory extraction, heartbeat, reflection)
|
||||
if let LoopEvent::Complete(result) = &event {
|
||||
tracing::info!("[agent_chat_stream] Complete: input_tokens={}, output_tokens={}",
|
||||
result.input_tokens, result.output_tokens);
|
||||
let agent_id_hook = agent_id_str.clone();
|
||||
let message_hook = message.clone();
|
||||
let hb = hb_state.clone();
|
||||
let rf = rf_state.clone();
|
||||
let driver = llm_driver.clone();
|
||||
let id_state = id_state_hook.clone();
|
||||
let app_hook = app.clone();
|
||||
if driver.is_none() {
|
||||
tracing::debug!("[agent_chat_stream] Post-hook firing without LLM driver (schedule intercept path)");
|
||||
}
|
||||
LoopEvent::ThinkingDelta(delta) => {
|
||||
tracing::trace!("[agent_chat_stream] ThinkingDelta: {} bytes", delta.len());
|
||||
StreamChatEvent::ThinkingDelta { delta: delta.clone() }
|
||||
}
|
||||
LoopEvent::ToolStart { name, input } => {
|
||||
tracing::debug!("[agent_chat_stream] ToolStart: {}", name);
|
||||
if name.starts_with("hand_") {
|
||||
StreamChatEvent::HandStart { name: name.clone(), params: input.clone() }
|
||||
} else {
|
||||
StreamChatEvent::ToolStart { name: name.clone(), input: input.clone() }
|
||||
}
|
||||
}
|
||||
LoopEvent::ToolEnd { name, output } => {
|
||||
tracing::debug!("[agent_chat_stream] ToolEnd: {}", name);
|
||||
if name.starts_with("hand_") {
|
||||
StreamChatEvent::HandEnd { name: name.clone(), result: output.clone() }
|
||||
} else {
|
||||
StreamChatEvent::ToolEnd { name: name.clone(), output: output.clone() }
|
||||
}
|
||||
}
|
||||
LoopEvent::SubtaskStatus { task_id, description, status, detail } => {
|
||||
tracing::debug!("[agent_chat_stream] SubtaskStatus: {} - {} (id={})", description, status, task_id);
|
||||
StreamChatEvent::SubtaskStatus {
|
||||
task_id: task_id.clone(),
|
||||
description: description.clone(),
|
||||
status: status.clone(),
|
||||
detail: detail.clone(),
|
||||
}
|
||||
}
|
||||
LoopEvent::IterationStart { iteration, max_iterations } => {
|
||||
tracing::debug!("[agent_chat_stream] IterationStart: {}/{}", iteration, max_iterations);
|
||||
StreamChatEvent::IterationStart { iteration: *iteration, max_iterations: *max_iterations }
|
||||
}
|
||||
LoopEvent::Complete(result) => {
|
||||
tracing::info!("[agent_chat_stream] Complete: input_tokens={}, output_tokens={}",
|
||||
result.input_tokens, result.output_tokens);
|
||||
tokio::spawn(async move {
|
||||
crate::intelligence_hooks::post_conversation_hook(
|
||||
&agent_id_hook, &message_hook, &hb, &rf, driver, &id_state, &app_hook,
|
||||
).await;
|
||||
});
|
||||
}
|
||||
|
||||
let agent_id_hook = agent_id_str.clone();
|
||||
let message_hook = message.clone();
|
||||
let hb = hb_state.clone();
|
||||
let rf = rf_state.clone();
|
||||
let driver = llm_driver.clone();
|
||||
tokio::spawn(async move {
|
||||
crate::intelligence_hooks::post_conversation_hook(
|
||||
&agent_id_hook, &message_hook, &hb, &rf, driver,
|
||||
).await;
|
||||
});
|
||||
|
||||
StreamChatEvent::Complete {
|
||||
input_tokens: result.input_tokens,
|
||||
output_tokens: result.output_tokens,
|
||||
}
|
||||
}
|
||||
LoopEvent::Error(message) => {
|
||||
tracing::warn!("[agent_chat_stream] Error: {}", message);
|
||||
StreamChatEvent::Error { message: message.clone() }
|
||||
}
|
||||
};
|
||||
let stream_event = translate_event(&event);
|
||||
|
||||
if let Err(e) = app.emit("stream:chunk", serde_json::json!({
|
||||
"sessionId": session_id,
|
||||
|
||||
@@ -241,6 +241,7 @@ pub async fn orchestration_execute(
|
||||
network_allowed: true,
|
||||
file_access_allowed: true,
|
||||
llm: None,
|
||||
tool_definitions: Vec::new(),
|
||||
};
|
||||
|
||||
// Execute orchestration
|
||||
|
||||
@@ -174,8 +174,9 @@ pub async fn skill_create(
|
||||
tags: vec![],
|
||||
category: None,
|
||||
triggers: request.triggers,
|
||||
tools: vec![], // P2-19: Include tools field
|
||||
tools: vec![],
|
||||
enabled: request.enabled.unwrap_or(true),
|
||||
body: None,
|
||||
};
|
||||
|
||||
kernel.create_skill(manifest.clone())
|
||||
@@ -221,8 +222,9 @@ pub async fn skill_update(
|
||||
tags: existing.tags.clone(),
|
||||
category: existing.category.clone(),
|
||||
triggers: request.triggers.unwrap_or(existing.triggers),
|
||||
tools: existing.tools.clone(), // P2-19: Preserve tools on update
|
||||
tools: existing.tools.clone(),
|
||||
enabled: request.enabled.unwrap_or(existing.enabled),
|
||||
body: existing.body.clone(),
|
||||
};
|
||||
|
||||
let result = kernel.update_skill(&SkillId::new(&id), updated)
|
||||
@@ -277,6 +279,7 @@ impl From<SkillContext> for zclaw_skills::SkillContext {
|
||||
network_allowed: true,
|
||||
file_access_allowed: true,
|
||||
llm: None, // Injected by Kernel.execute_skill()
|
||||
tool_definitions: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -212,6 +212,12 @@ pub fn run() {
|
||||
if let Err(e) = rt.block_on(intelligence::pain_aggregator::init_pain_storage(pool)) {
|
||||
tracing::error!("[PainStorage] Init failed: {}, pain points will not persist", e);
|
||||
}
|
||||
|
||||
// Initialize experience extractor for suggestion enrichment.
|
||||
// Graceful degradation: failure does not block app startup.
|
||||
if let Err(e) = rt.block_on(intelligence::experience::init_experience_extractor()) {
|
||||
tracing::warn!("[ExperienceExtractor] Init failed: {}, suggestion context will be empty", e);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -381,6 +387,8 @@ pub fn run() {
|
||||
intelligence::heartbeat::heartbeat_update_memory_stats,
|
||||
intelligence::heartbeat::heartbeat_record_correction,
|
||||
intelligence::heartbeat::heartbeat_record_interaction,
|
||||
intelligence::heartbeat::heartbeat_update_pain_points,
|
||||
intelligence::heartbeat::heartbeat_update_experiences,
|
||||
// Health Snapshot (on-demand query)
|
||||
intelligence::health_snapshot::health_snapshot,
|
||||
// Context Compactor
|
||||
@@ -433,6 +441,8 @@ pub fn run() {
|
||||
intelligence::pain_aggregator::butler_update_proposal_status,
|
||||
// Industry config loader
|
||||
viking_commands::viking_load_industry_keywords,
|
||||
// Experience finder for suggestion enrichment
|
||||
intelligence::experience::experience_find_relevant,
|
||||
])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
|
||||
@@ -602,9 +602,11 @@ fn parse_uri(uri: &str) -> Result<(String, MemoryType, String), String> {
|
||||
|
||||
/// Configure embedding for semantic memory search
|
||||
/// Configures SqliteStorage (VikingStorage) embedding for FTS5 + semantic search.
|
||||
/// Also propagates to Kernel's skill router and memory retriever.
|
||||
// @connected
|
||||
#[tauri::command]
|
||||
pub async fn viking_configure_embedding(
|
||||
kernel_state: tauri::State<'_, crate::kernel_commands::KernelState>,
|
||||
provider: String,
|
||||
api_key: String,
|
||||
model: Option<String>,
|
||||
@@ -621,12 +623,28 @@ pub async fn viking_configure_embedding(
|
||||
|
||||
let client_viking = crate::llm::EmbeddingClient::new(config_viking);
|
||||
let adapter = crate::embedding_adapter::TauriEmbeddingAdapter::new(client_viking);
|
||||
let arc_adapter = std::sync::Arc::new(adapter);
|
||||
|
||||
// 1. Configure SqliteStorage (existing behavior)
|
||||
storage
|
||||
.configure_embedding(std::sync::Arc::new(adapter))
|
||||
.configure_embedding(arc_adapter.clone())
|
||||
.await
|
||||
.map_err(|e| format!("Failed to configure embedding: {}", e))?;
|
||||
|
||||
// 2. Propagate to Kernel for skill router + memory retriever
|
||||
{
|
||||
let mut kernel_lock = kernel_state.lock().await;
|
||||
if let Some(ref mut k) = *kernel_lock {
|
||||
k.set_embedding_client(arc_adapter);
|
||||
tracing::info!("[VikingCommands] Embedding propagated to Kernel skill router + memory retriever");
|
||||
} else {
|
||||
tracing::warn!(
|
||||
"[VikingCommands] Kernel not initialized, embedding only applied to SqliteStorage. \
|
||||
It will be applied when Kernel boots."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
tracing::info!("[VikingCommands] Embedding configured with provider: {}", provider);
|
||||
|
||||
Ok(EmbeddingConfigResult {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { useState, useEffect } from 'react';
|
||||
import { Brain, Loader2 } from 'lucide-react';
|
||||
import { listVikingResources } from '../../lib/viking-client';
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import { Brain, Loader2, ChevronDown, ChevronRight, User } from 'lucide-react';
|
||||
import { listVikingResources, readVikingResource } from '../../lib/viking-client';
|
||||
import { invoke } from '@tauri-apps/api/core';
|
||||
|
||||
interface MemorySectionProps {
|
||||
agentId: string;
|
||||
@@ -11,29 +12,140 @@ interface MemoryEntry {
|
||||
uri: string;
|
||||
name: string;
|
||||
resourceType: string;
|
||||
size?: number;
|
||||
modifiedAt?: string;
|
||||
summary?: string;
|
||||
loading?: boolean;
|
||||
}
|
||||
|
||||
type MemoryGroup = 'preferences' | 'knowledge' | 'experience' | 'sessions' | 'other';
|
||||
|
||||
interface UserProfile {
|
||||
industry?: string;
|
||||
role?: string;
|
||||
expertise_level?: string;
|
||||
communication_style?: string;
|
||||
preferred_language?: string;
|
||||
recent_topics?: string[];
|
||||
active_pain_points?: string[];
|
||||
preferred_tools?: string[];
|
||||
confidence?: number;
|
||||
}
|
||||
|
||||
const GROUP_LABELS: Record<MemoryGroup, string> = {
|
||||
preferences: '偏好',
|
||||
knowledge: '知识',
|
||||
experience: '经验',
|
||||
sessions: '会话',
|
||||
other: '其他',
|
||||
};
|
||||
|
||||
const GROUP_ORDER: MemoryGroup[] = ['preferences', 'knowledge', 'experience', 'sessions', 'other'];
|
||||
|
||||
function classifyGroup(resourceType: string): MemoryGroup {
|
||||
if (resourceType in GROUP_LABELS) return resourceType as MemoryGroup;
|
||||
return 'other';
|
||||
}
|
||||
|
||||
function formatDate(iso?: string): string {
|
||||
if (!iso) return '';
|
||||
try {
|
||||
return new Date(iso).toLocaleDateString('zh-CN', { month: 'short', day: 'numeric' });
|
||||
} catch {
|
||||
return '';
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch user profile from agent_get Tauri command
|
||||
async function fetchUserProfile(agentId: string): Promise<UserProfile | null> {
|
||||
try {
|
||||
const result = await invoke<{ userProfile?: UserProfile } | null>('agent_get', { agentId });
|
||||
return result?.userProfile ?? null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function MemorySection({ agentId, refreshKey }: MemorySectionProps) {
|
||||
const [memories, setMemories] = useState<MemoryEntry[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [expandedGroups, setExpandedGroups] = useState<Set<MemoryGroup>>(new Set(['preferences', 'knowledge']));
|
||||
const [profile, setProfile] = useState<UserProfile | null>(null);
|
||||
const [_profileLoading, setProfileLoading] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const loadMemories = useCallback(async () => {
|
||||
if (!agentId) return;
|
||||
|
||||
setLoading(true);
|
||||
// 查询 agent:// 下的所有记忆资源 (preferences/knowledge/experience/sessions)
|
||||
listVikingResources(`agent://${agentId}/`)
|
||||
.then((entries) => {
|
||||
setMemories(entries as MemoryEntry[]);
|
||||
})
|
||||
.catch(() => {
|
||||
// Memory path may not exist yet — show empty state
|
||||
setMemories([]);
|
||||
})
|
||||
.finally(() => setLoading(false));
|
||||
}, [agentId, refreshKey]);
|
||||
try {
|
||||
const entries = await listVikingResources(`agent://${agentId}/`);
|
||||
const typed = entries as MemoryEntry[];
|
||||
|
||||
if (loading) {
|
||||
// Load L1 summaries in parallel (batched to avoid overwhelming)
|
||||
const enriched = await Promise.all(
|
||||
typed.map(async (entry) => {
|
||||
try {
|
||||
const summary = await readVikingResource(entry.uri, 'L1');
|
||||
return { ...entry, summary: summary || entry.name };
|
||||
} catch {
|
||||
return { ...entry, summary: entry.name };
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
setMemories(enriched);
|
||||
} catch {
|
||||
setMemories([]);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [agentId]);
|
||||
|
||||
const loadProfile = useCallback(async () => {
|
||||
if (!agentId) return;
|
||||
setProfileLoading(true);
|
||||
try {
|
||||
const p = await fetchUserProfile(agentId);
|
||||
setProfile(p);
|
||||
} catch {
|
||||
setProfile(null);
|
||||
} finally {
|
||||
setProfileLoading(false);
|
||||
}
|
||||
}, [agentId]);
|
||||
|
||||
useEffect(() => {
|
||||
loadMemories();
|
||||
loadProfile();
|
||||
}, [loadMemories, loadProfile, refreshKey]);
|
||||
|
||||
// Group memories by type
|
||||
const grouped = memories.reduce<Record<MemoryGroup, MemoryEntry[]>>((acc, m) => {
|
||||
const group = classifyGroup(m.resourceType);
|
||||
if (!acc[group]) acc[group] = [];
|
||||
acc[group].push(m);
|
||||
return acc;
|
||||
}, {} as Record<MemoryGroup, MemoryEntry[]>);
|
||||
|
||||
const nonEmptyGroups = GROUP_ORDER.filter((g) => (grouped[g]?.length ?? 0) > 0);
|
||||
const totalMemories = memories.length;
|
||||
|
||||
const toggleGroup = (group: MemoryGroup) => {
|
||||
setExpandedGroups((prev) => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(group)) next.delete(group);
|
||||
else next.add(group);
|
||||
return next;
|
||||
});
|
||||
};
|
||||
|
||||
const hasProfile = profile && (
|
||||
profile.industry || profile.role || profile.communication_style ||
|
||||
(profile.recent_topics && profile.recent_topics.length > 0) ||
|
||||
(profile.preferred_tools && profile.preferred_tools.length > 0)
|
||||
);
|
||||
|
||||
if (loading && memories.length === 0) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-8">
|
||||
<Loader2 className="w-5 h-5 text-gray-400 animate-spin" />
|
||||
@@ -41,7 +153,7 @@ export function MemorySection({ agentId, refreshKey }: MemorySectionProps) {
|
||||
);
|
||||
}
|
||||
|
||||
if (memories.length === 0) {
|
||||
if (totalMemories === 0 && !hasProfile) {
|
||||
return (
|
||||
<div className="text-center py-8">
|
||||
<Brain className="w-8 h-8 mx-auto mb-2 text-gray-300 dark:text-gray-600" />
|
||||
@@ -54,20 +166,114 @@ export function MemorySection({ agentId, refreshKey }: MemorySectionProps) {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
{memories.map((memory) => (
|
||||
<div
|
||||
key={memory.uri}
|
||||
className="rounded-lg border border-gray-200 dark:border-gray-700 bg-white dark:bg-gray-800 px-3 py-2"
|
||||
>
|
||||
<div className="text-sm text-gray-900 dark:text-gray-100 truncate">
|
||||
{memory.name}
|
||||
<div className="space-y-3">
|
||||
{/* User Profile Card */}
|
||||
{hasProfile && (
|
||||
<div className="rounded-lg border border-blue-100 dark:border-blue-900/30 bg-blue-50/50 dark:bg-blue-900/10 px-3 py-2.5">
|
||||
<div className="flex items-center gap-1.5 mb-2">
|
||||
<User className="w-3.5 h-3.5 text-blue-500" />
|
||||
<span className="text-xs font-medium text-blue-700 dark:text-blue-300">用户画像</span>
|
||||
{profile.confidence !== undefined && profile.confidence > 0 && (
|
||||
<span className="text-[10px] text-blue-400 dark:text-blue-500 ml-auto">
|
||||
置信度 {Math.round(profile.confidence * 100)}%
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="text-xs text-gray-400 dark:text-gray-500 truncate mt-0.5">
|
||||
{memory.uri}
|
||||
<div className="space-y-1.5">
|
||||
{profile.industry && (
|
||||
<ProfileField label="行业" value={profile.industry} />
|
||||
)}
|
||||
{profile.role && (
|
||||
<ProfileField label="角色" value={profile.role} />
|
||||
)}
|
||||
{profile.expertise_level && (
|
||||
<ProfileField label="专业水平" value={profile.expertise_level} />
|
||||
)}
|
||||
{profile.communication_style && (
|
||||
<ProfileField label="沟通风格" value={profile.communication_style} />
|
||||
)}
|
||||
{profile.recent_topics && profile.recent_topics.length > 0 && (
|
||||
<div className="flex flex-wrap gap-1 items-center">
|
||||
<span className="text-[10px] text-gray-500 dark:text-gray-400 shrink-0">近期话题</span>
|
||||
{profile.recent_topics.slice(0, 8).map((topic) => (
|
||||
<span key={topic} className="inline-block text-[10px] px-1.5 py-0.5 rounded bg-white dark:bg-gray-800 text-gray-600 dark:text-gray-300 border border-gray-200 dark:border-gray-700">
|
||||
{topic}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{profile.preferred_tools && profile.preferred_tools.length > 0 && (
|
||||
<div className="flex flex-wrap gap-1 items-center">
|
||||
<span className="text-[10px] text-gray-500 dark:text-gray-400 shrink-0">常用工具</span>
|
||||
{profile.preferred_tools.map((tool) => (
|
||||
<span key={tool} className="inline-block text-[10px] px-1.5 py-0.5 rounded bg-purple-50 dark:bg-purple-900/20 text-purple-600 dark:text-purple-400">
|
||||
{tool}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
)}
|
||||
|
||||
{/* Memory Groups */}
|
||||
{nonEmptyGroups.map((group) => {
|
||||
const isExpanded = expandedGroups.has(group);
|
||||
const items = grouped[group] ?? [];
|
||||
return (
|
||||
<div key={group}>
|
||||
<button
|
||||
onClick={() => toggleGroup(group)}
|
||||
className="flex items-center gap-1.5 w-full text-left hover:bg-gray-50 dark:hover:bg-gray-800/50 rounded px-1 py-1 transition-colors"
|
||||
>
|
||||
{isExpanded ? (
|
||||
<ChevronDown className="w-3.5 h-3.5 text-gray-400" />
|
||||
) : (
|
||||
<ChevronRight className="w-3.5 h-3.5 text-gray-400" />
|
||||
)}
|
||||
<span className="text-xs font-medium text-gray-700 dark:text-gray-300">
|
||||
{GROUP_LABELS[group]}
|
||||
</span>
|
||||
<span className="text-[10px] text-gray-400 dark:text-gray-500">
|
||||
{items.length}
|
||||
</span>
|
||||
</button>
|
||||
{isExpanded && (
|
||||
<div className="mt-1 space-y-1.5 pl-1">
|
||||
{items.map((memory) => (
|
||||
<div
|
||||
key={memory.uri}
|
||||
className="rounded-lg border border-gray-200 dark:border-gray-700 bg-white dark:bg-gray-800/50 px-3 py-2"
|
||||
>
|
||||
<div className="text-xs text-gray-800 dark:text-gray-200 leading-relaxed">
|
||||
{memory.summary || memory.name}
|
||||
</div>
|
||||
<div className="flex items-center gap-2 mt-1">
|
||||
<span className="text-[10px] text-gray-400 dark:text-gray-500">
|
||||
{memory.name}
|
||||
</span>
|
||||
{memory.modifiedAt && (
|
||||
<span className="text-[10px] text-gray-400 dark:text-gray-500 ml-auto">
|
||||
{formatDate(memory.modifiedAt)}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function ProfileField({ label, value }: { label: string; value: string }) {
|
||||
return (
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-[10px] text-gray-500 dark:text-gray-400 shrink-0 w-14">{label}</span>
|
||||
<span className="text-xs text-gray-700 dark:text-gray-300">{value}</span>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -34,7 +34,6 @@ import { ModelSelector } from './ai/ModelSelector';
|
||||
import { isTauriRuntime } from '../lib/tauri-gateway';
|
||||
import { SuggestionChips } from './ai/SuggestionChips';
|
||||
import { PipelineResultPreview } from './pipeline/PipelineResultPreview';
|
||||
import { PresentationContainer } from './presentation/PresentationContainer';
|
||||
// TokenMeter temporarily unused — using inline text counter instead
|
||||
|
||||
// Default heights for virtualized messages
|
||||
@@ -54,7 +53,7 @@ export function ChatArea({ compact, onOpenDetail }: { compact?: boolean; onOpenD
|
||||
const {
|
||||
messages, isStreaming, isLoading,
|
||||
sendMessage: sendToGateway, initStreamListener,
|
||||
chatMode, setChatMode, suggestions,
|
||||
chatMode, setChatMode, suggestions, suggestionsLoading,
|
||||
totalInputTokens, totalOutputTokens,
|
||||
cancelStream,
|
||||
} = useChatStore();
|
||||
@@ -88,12 +87,17 @@ export function ChatArea({ compact, onOpenDetail }: { compact?: boolean; onOpenD
|
||||
const models = useMemo(() => {
|
||||
const failed = failedModelIds.current;
|
||||
if (isLoggedIn && saasModels.length > 0) {
|
||||
return saasModels.map(m => ({
|
||||
id: m.alias || m.id,
|
||||
name: m.alias || m.id,
|
||||
provider: m.provider_id,
|
||||
available: !failed.has(m.alias || m.id),
|
||||
}));
|
||||
return saasModels
|
||||
.filter(m => {
|
||||
const name = (m.alias || m.id).toLowerCase();
|
||||
return !name.includes('embedding');
|
||||
})
|
||||
.map(m => ({
|
||||
id: m.alias || m.id,
|
||||
name: m.alias || m.id,
|
||||
provider: undefined,
|
||||
available: !failed.has(m.alias || m.id),
|
||||
}));
|
||||
}
|
||||
if (configModels.length > 0) {
|
||||
return configModels;
|
||||
@@ -210,6 +214,8 @@ export function ChatArea({ compact, onOpenDetail }: { compact?: boolean; onOpenD
|
||||
'hand-execution-complete',
|
||||
(event) => {
|
||||
const { handId, success, error } = event.payload;
|
||||
const streaming = useChatStore.getState().isStreaming;
|
||||
if (!streaming) return;
|
||||
useChatStore.getState().addMessage({
|
||||
id: crypto.randomUUID(),
|
||||
role: 'hand',
|
||||
@@ -499,10 +505,11 @@ export function ChatArea({ compact, onOpenDetail }: { compact?: boolean; onOpenD
|
||||
<div className="flex-shrink-0 p-4 bg-white dark:bg-gray-900">
|
||||
<div className="max-w-4xl mx-auto">
|
||||
{/* Suggestion chips */}
|
||||
{!isStreaming && suggestions.length > 0 && !messages.some(m => m.error) && (
|
||||
{!isStreaming && !messages.some(m => m.error) && (suggestions.length > 0 || suggestionsLoading) && (
|
||||
<SuggestionChips
|
||||
suggestions={suggestions}
|
||||
onSelect={(text) => { setInput(text); textareaRef.current?.focus(); }}
|
||||
loading={suggestionsLoading}
|
||||
onSelect={(text) => { setInput(text); textareaRef.current?.focus(); setTimeout(() => handleSend(), 0); }}
|
||||
className="mb-3"
|
||||
/>
|
||||
)}
|
||||
@@ -630,10 +637,64 @@ export function ChatArea({ compact, onOpenDetail }: { compact?: boolean; onOpenD
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Strip LLM tool-usage narration from response content.
|
||||
* When the LLM calls tools (search, fetch, etc.), it often narrates its reasoning
|
||||
* in English ("Now let me execute...", "I need to provide...", "I keep getting errors...")
|
||||
* and Chinese ("让我执行...", "让我尝试..."). These are internal thoughts, not user-facing content.
|
||||
*/
|
||||
function stripToolNarration(content: string): string {
|
||||
// Process line-by-line to preserve markdown structure (headings, lists, paragraphs)
|
||||
const lines = content.split('\n');
|
||||
const filtered = lines.filter(line => {
|
||||
const t = line.trim();
|
||||
// Keep empty lines (paragraph breaks in markdown)
|
||||
if (!t) return true;
|
||||
// Keep markdown structural lines (headings, list items, horizontal rules, blockquotes, code)
|
||||
if (/^(#{1,6}\s|[-*+]\s|\d+\.\s|>|\s*```|---|\|)/.test(t)) return true;
|
||||
// English narration patterns
|
||||
if (/^(?:Now )?[Ll]et me\s/i.test(t)) return false;
|
||||
if (/^I\s+(?:need to|keep getting|should|will try|have to|can try|must)\s/i.test(t)) return false;
|
||||
if (/^The hand_researcher\s/i.test(t)) return false;
|
||||
// Chinese narration patterns
|
||||
if (/^让我(?:执行|尝试|使用|进一步|调用|运行)/.test(t)) return false;
|
||||
if (/^好的,让我为您/.test(t)) return false;
|
||||
return true;
|
||||
});
|
||||
const result = filtered.join('\n');
|
||||
return result || content;
|
||||
}
|
||||
|
||||
/**
|
||||
* Strip dangling clarification references from text when ask_clarification tool was called.
|
||||
* When the LLM calls ask_clarification, it often ends its text with phrases like
|
||||
* "比如:" / "以下信息" / "以下选项" that reference the tool output — but the tool output
|
||||
* is rendered in a separate ClarificationCard, so these become confusing dead-end sentences.
|
||||
*/
|
||||
function stripDanglingClarificationRef(text: string, hasClarificationTool: boolean): string {
|
||||
if (!hasClarificationTool || !text) return text;
|
||||
// Match trailing dangling references in Chinese and English
|
||||
const patterns = [
|
||||
/[,,]\s*可以(?:提供以下|告诉我更多细节,)?(?:信息|选项|方向|细节|分类|类型)[::]\s*$/,
|
||||
/[,,]\s*比如[::]\s*$/,
|
||||
/[,,]\s*(?:例如|譬如|如以下)[::]\s*$/,
|
||||
/,\s*(?:for example|such as|like|the following)[::]?\s*$/i,
|
||||
];
|
||||
for (const pat of patterns) {
|
||||
const stripped = text.replace(pat, '');
|
||||
if (stripped !== text) return stripped;
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
function MessageBubble({ message, onRetry }: { message: Message; setInput?: (text: string) => void; onRetry?: () => void }) {
|
||||
if (message.role === 'tool') {
|
||||
return null;
|
||||
}
|
||||
// Hand status/result messages are internal — search results are already in the LLM reply
|
||||
if (message.role === 'hand') {
|
||||
return null;
|
||||
}
|
||||
|
||||
const isUser = message.role === 'user';
|
||||
const isThinking = message.streaming && !message.content;
|
||||
@@ -710,15 +771,18 @@ function MessageBubble({ message, onRetry }: { message: Message; setInput?: (tex
|
||||
? (isUser
|
||||
? message.content
|
||||
: <StreamingText
|
||||
content={message.content}
|
||||
content={stripDanglingClarificationRef(
|
||||
stripToolNarration(message.content),
|
||||
toolCallSteps?.some(s => s.toolName === 'ask_clarification') ?? false,
|
||||
)}
|
||||
isStreaming={!!message.streaming}
|
||||
className="text-gray-700 dark:text-gray-200"
|
||||
/>
|
||||
)
|
||||
: '...'}
|
||||
</div>
|
||||
{/* Pipeline / Hand result presentation */}
|
||||
{!isUser && (message.role === 'workflow' || message.role === 'hand') && message.workflowResult && typeof message.workflowResult === 'object' && message.workflowResult !== null && (
|
||||
{/* Pipeline result presentation */}
|
||||
{!isUser && message.role === 'workflow' && message.workflowResult && typeof message.workflowResult === 'object' && message.workflowResult !== null && (
|
||||
<div className="mt-3">
|
||||
<PipelineResultPreview
|
||||
outputs={message.workflowResult as Record<string, unknown>}
|
||||
@@ -726,11 +790,6 @@ function MessageBubble({ message, onRetry }: { message: Message; setInput?: (tex
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{!isUser && message.role === 'hand' && message.handResult && typeof message.handResult === 'object' && message.handResult !== null && !message.workflowResult && (
|
||||
<div className="mt-3">
|
||||
<PresentationContainer data={message.handResult} />
|
||||
</div>
|
||||
)}
|
||||
{message.error && (
|
||||
<div className="flex items-center gap-2 mt-2">
|
||||
<p className="text-xs text-red-500">{message.error}</p>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user