fix: 三端联调测试 2 P0 + 6 P1 + 2 P2 修复
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

P0-1: SaaS relay 模型别名解析 — "glm-4-flash" → "glm-4-flash-250414" (resolve_model)
P0-2: config.rs interpolate_env_vars UTF-8 修复 (chars 迭代器替代 bytes as char)
      + DB 启动编码检查 + docker-compose UTF-8 编码参数

P1-3: UI 模型选择器覆盖 Agent 默认模型 (model_override 全链路: TS→Tauri→Rust kernel)
P1-6: 知识搜索管道修复 — seed_knowledge 创建 chunks + 默认分类 (seed/uploaded/distillation)
P1-7: 用量限额从当前 Plan 读取 (非 stale usage 表)
P1-8: relay 双维度配额检查 (relay_requests + input_tokens)

P2-9: SSE 路径 token 计数修复 — 流结束检测替代固定 500ms sleep + billing increment
This commit is contained in:
iven
2026-04-14 00:17:08 +08:00
parent 0903a0d652
commit 4c3136890b
13 changed files with 234 additions and 49 deletions

View File

@@ -25,7 +25,7 @@ impl Kernel {
agent_id: &AgentId,
message: String,
) -> Result<MessageResponse> {
self.send_message_with_chat_mode(agent_id, message, None).await
self.send_message_with_chat_mode(agent_id, message, None, None).await
}
/// Send a message to an agent with optional chat mode configuration
@@ -34,6 +34,7 @@ impl Kernel {
agent_id: &AgentId,
message: String,
chat_mode: Option<ChatModeConfig>,
model_override: Option<String>,
) -> Result<MessageResponse> {
let agent_config = self.registry.get(agent_id)
.ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Agent not found: {}", agent_id)))?;
@@ -41,12 +42,16 @@ impl Kernel {
// Create or get session
let session_id = self.memory.create_session(agent_id).await?;
// Use agent-level model if configured, otherwise fall back to global config
let model = if !agent_config.model.model.is_empty() {
agent_config.model.model.clone()
} else {
self.config.model().to_string()
};
// Model priority: UI override > Agent config > Global config
let model = model_override
.filter(|m| !m.is_empty())
.unwrap_or_else(|| {
if !agent_config.model.model.is_empty() {
agent_config.model.model.clone()
} else {
self.config.model().to_string()
}
});
// Create agent loop with model configuration
let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false);
@@ -122,7 +127,7 @@ impl Kernel {
agent_id: &AgentId,
message: String,
) -> Result<mpsc::Receiver<zclaw_runtime::LoopEvent>> {
self.send_message_stream_with_prompt(agent_id, message, None, None, None).await
self.send_message_stream_with_prompt(agent_id, message, None, None, None, None).await
}
/// Send a message with streaming, optional system prompt, optional session reuse,
@@ -134,6 +139,7 @@ impl Kernel {
system_prompt_override: Option<String>,
session_id_override: Option<zclaw_types::SessionId>,
chat_mode: Option<ChatModeConfig>,
model_override: Option<String>,
) -> Result<mpsc::Receiver<zclaw_runtime::LoopEvent>> {
let agent_config = self.registry.get(agent_id)
.ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Agent not found: {}", agent_id)))?;
@@ -150,12 +156,16 @@ impl Kernel {
None => self.memory.create_session(agent_id).await?,
};
// Use agent-level model if configured, otherwise fall back to global config
let model = if !agent_config.model.model.is_empty() {
agent_config.model.model.clone()
} else {
self.config.model().to_string()
};
// Model priority: UI override > Agent config > Global config
let model = model_override
.filter(|m| !m.is_empty())
.unwrap_or_else(|| {
if !agent_config.model.model.is_empty() {
agent_config.model.model.clone()
} else {
self.config.model().to_string()
}
});
// Create agent loop with model configuration
let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false);

View File

@@ -282,19 +282,26 @@ pub async fn increment_dimension_by(
}
/// 检查用量配额
///
/// P1-7 修复: 从当前 Plan 读取限额(而非 stale 的 usage 表冗余列)
/// P1-8 修复: 支持 relay_requests + input_tokens 双维度检查
pub async fn check_quota(
pool: &PgPool,
account_id: &str,
quota_type: &str,
) -> SaasResult<QuotaCheck> {
let usage = get_or_create_usage(pool, account_id).await?;
// 从当前 Plan 读取真实限额,而非 usage 表的 stale 冗余列
let plan = get_account_plan(pool, account_id).await?;
let limits: crate::billing::types::PlanLimits = serde_json::from_value(plan.limits)
.unwrap_or_else(|_| crate::billing::types::PlanLimits::free());
let (current, limit) = match quota_type {
"input_tokens" => (usage.input_tokens, usage.max_input_tokens),
"output_tokens" => (usage.output_tokens, usage.max_output_tokens),
"relay_requests" => (usage.relay_requests as i64, usage.max_relay_requests.map(|v| v as i64)),
"hand_executions" => (usage.hand_executions as i64, usage.max_hand_executions.map(|v| v as i64)),
"pipeline_runs" => (usage.pipeline_runs as i64, usage.max_pipeline_runs.map(|v| v as i64)),
"input_tokens" => (usage.input_tokens, limits.max_input_tokens_monthly),
"output_tokens" => (usage.output_tokens, limits.max_output_tokens_monthly),
"relay_requests" => (usage.relay_requests as i64, limits.max_relay_requests_monthly.map(|v| v as i64)),
"hand_executions" => (usage.hand_executions as i64, limits.max_hand_executions_monthly.map(|v| v as i64)),
"pipeline_runs" => (usage.pipeline_runs as i64, limits.max_pipeline_runs_monthly.map(|v| v as i64)),
_ => return Ok(QuotaCheck {
allowed: true,
reason: None,
@@ -309,7 +316,7 @@ pub async fn check_quota(
Ok(QuotaCheck {
allowed,
reason: if !allowed { Some(format!("{} 配额已用尽", quota_type)) } else { None },
reason: if !allowed { Some(format!("{} 配额已用尽 (已用 {}/{})", quota_type, current, limit.unwrap_or(0))) } else { None },
current,
limit,
remaining,

View File

@@ -248,6 +248,37 @@ impl AppCache {
.map(|r| r.value().clone())
}
/// 按别名查找模型 — 用于向后兼容旧模型 ID (如 "glm-4-flash" → "glm-4-flash-250414")
/// 先按 alias 字段精确匹配,再按 model_id 前缀匹配(去掉日期后缀)
pub fn resolve_model(&self, model_name: &str) -> Option<CachedModel> {
// 1. 直接 model_id 查找
if let Some(m) = self.get_model(model_name) {
return Some(m);
}
// 2. 按 alias 精确匹配
for entry in self.models.iter() {
if entry.value().enabled && entry.value().alias == model_name {
return Some(entry.value().clone());
}
}
// 3. 前缀匹配: "glm-4-flash" 匹配 "glm-4-flash-250414" 等带后缀的模型
for entry in self.models.iter() {
let mid = &entry.value().model_id;
if entry.value().enabled
&& (mid.starts_with(&format!("{}-", model_name))
|| mid.starts_with(&format!("{}v", model_name)))
{
tracing::info!(
"Model alias resolved: {} → {}",
model_name,
mid
);
return Some(entry.value().clone());
}
}
None
}
/// 按 provider id 查找已启用的 Provider。O(1) DashMap 查找。
pub fn get_provider(&self, provider_id: &str) -> Option<CachedProvider> {
self.providers.get(provider_id)

View File

@@ -465,22 +465,25 @@ impl SaaSConfig {
/// 替换 TOML 配置文件中的 `${ENV_VAR}` 模式为环境变量值
/// 未设置的环境变量保留原文,后续数据库连接或 JWT 初始化时会报明确错误
///
/// 注意: 使用 chars() 迭代器而非 bytes() 来正确处理多字节 UTF-8 字符(如中文),
/// 避免将多字节 UTF-8 序列的每个字节单独 `as char` 导致编码损坏。
fn interpolate_env_vars(content: &str) -> String {
let mut result = String::with_capacity(content.len());
let bytes = content.as_bytes();
let chars: Vec<char> = content.chars().collect();
let mut i = 0;
while i < bytes.len() {
if i + 1 < bytes.len() && bytes[i] == b'$' && bytes[i + 1] == b'{' {
while i < chars.len() {
if i + 1 < chars.len() && chars[i] == '$' && chars[i + 1] == '{' {
let start = i + 2;
let mut end = start;
while end < bytes.len()
&& (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_')
while end < chars.len()
&& (chars[end].is_ascii_alphanumeric() || chars[end] == '_')
{
end += 1;
}
if end < bytes.len() && bytes[end] == b'}' {
let var_name = std::str::from_utf8(&bytes[start..end]).unwrap_or("");
match std::env::var(var_name) {
if end < chars.len() && chars[end] == '}' {
let var_name: String = chars[start..end].iter().collect();
match std::env::var(&var_name) {
Ok(val) => {
tracing::debug!("Config: ${{{}}} → resolved ({} bytes)", var_name, val.len());
result.push_str(&val);
@@ -492,11 +495,11 @@ fn interpolate_env_vars(content: &str) -> String {
}
i = end + 1;
} else {
result.push(bytes[i] as char);
result.push(chars[i]);
i += 1;
}
} else {
result.push(bytes[i] as char);
result.push(chars[i]);
i += 1;
}
}

View File

@@ -38,10 +38,25 @@ pub async fn init_db(config: &DatabaseConfig) -> SaasResult<PgPool> {
.connect(&database_url)
.await?;
// 验证数据库编码为 UTF8 — 中文 Windows (GBK/代码页936) 可能导致默认非 UTF8
let encoding: (String,) = sqlx::query_as("SHOW server_encoding")
.fetch_one(&pool)
.await
.unwrap_or(("UNKNOWN".to_string(),));
if encoding.0.to_uppercase() != "UTF8" {
tracing::error!(
"⚠ 数据库编码为 '{}',非 UTF8中文数据将损坏。请使用 CREATE DATABASE ... WITH ENCODING='UTF8' 重建数据库。",
encoding.0
);
} else {
tracing::info!("Database encoding: {}", encoding.0);
}
run_migrations(&pool).await?;
ensure_security_columns(&pool).await?;
seed_admin_account(&pool).await?;
seed_builtin_prompts(&pool).await?;
seed_knowledge_categories(&pool).await?;
seed_builtin_industries(&pool).await?;
seed_demo_data(&pool).await?;
fix_seed_data(&pool).await?;
@@ -1004,6 +1019,31 @@ async fn seed_builtin_industries(pool: &PgPool) -> SaasResult<()> {
crate::industry::service::seed_builtin_industries(pool).await
}
/// 种子化知识库默认分类(幂等)
async fn seed_knowledge_categories(pool: &PgPool) -> SaasResult<()> {
let now = chrono::Utc::now();
let categories = [
("seed", "种子知识", "系统内置的行业基础知识"),
("uploaded", "上传文档", "用户上传的文档知识"),
("distillation", "蒸馏知识", "API 蒸馏生成的知识"),
];
for (id, name, desc) in &categories {
sqlx::query(
"INSERT INTO knowledge_categories (id, name, description, created_at, updated_at) \
VALUES ($1, $2, $3, $4, $4) \
ON CONFLICT (id) DO NOTHING"
)
.bind(id)
.bind(name)
.bind(desc)
.bind(&now)
.execute(pool)
.await?;
}
tracing::debug!("Seeded knowledge categories");
Ok(())
}
#[cfg(test)]
mod tests {
// PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容

View File

@@ -632,6 +632,8 @@ pub async fn unified_search(
// === 种子知识冷启动 ===
/// 为指定行业插入种子知识(幂等)
///
/// P1-6 修复: 同时创建 knowledge_chunks 以支持搜索
pub async fn seed_knowledge(
pool: &PgPool,
industry_id: &str,
@@ -684,6 +686,24 @@ pub async fn seed_knowledge(
.execute(pool)
.await?;
// 创建 chunks 以支持搜索(与 distill_knowledge worker 一致)
let chunks = chunk_content(content, 500, 50);
for (chunk_idx, chunk_text) in chunks.iter().enumerate() {
let chunk_id = uuid::Uuid::new_v4().to_string();
sqlx::query(
"INSERT INTO knowledge_chunks (id, item_id, content, keywords, chunk_index, created_at) \
VALUES ($1, $2, $3, $4, $5, $6)"
)
.bind(&chunk_id)
.bind(&id)
.bind(chunk_text)
.bind(&kw_json)
.bind(chunk_idx as i32)
.bind(&now)
.execute(pool)
.await?;
}
created += 1;
}
Ok(created)

View File

@@ -145,6 +145,26 @@ pub async fn quota_check_middleware(
_ => {}
}
// P1-8 修复: 同时检查 input_tokens 配额
match crate::billing::service::check_quota(&state.db, &account_id, "input_tokens").await {
Ok(check) if !check.allowed => {
tracing::warn!(
"Token quota exceeded for account {}: {} ({}/{})",
account_id,
check.reason.as_deref().unwrap_or("Token配额已用尽"),
check.current,
check.limit.map(|l| l.to_string()).unwrap_or_else(|| "".into()),
);
return SaasError::RateLimited(
check.reason.unwrap_or_else(|| "月度 Token 配额已用尽".into()),
).into_response();
}
Err(e) => {
tracing::warn!("Token quota check failed for account {}: {}", account_id, e);
}
_ => {}
}
next.run(req).await
}

View File

@@ -152,8 +152,8 @@ pub async fn chat_completions(
}
ModelResolution::Group(candidates)
} else {
// 向后兼容:直接模型查找
let target_model = state.cache.get_model(model_name)
// 向后兼容:直接模型查找 + 别名解析(如 "glm-4-flash" → "glm-4-flash-250414"
let target_model = state.cache.resolve_model(model_name)
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
@@ -218,7 +218,7 @@ pub async fn chat_completions(
ModelResolution::Direct(ref candidate) => {
// 单 Provider 直接路由(向后兼容)
match service::execute_relay(
&state.db, &task.id, &candidate.provider_id,
&state.db, &task.id, &ctx.account_id, &candidate.provider_id,
&candidate.base_url, &request_body, stream,
max_attempts, retry_delay_ms, &enc_key,
true, // 独立调用,管理 task 状态
@@ -233,7 +233,7 @@ pub async fn chat_completions(
// SSE 一旦开始流式传输,中途上游断连不会触发 failoverSSE 协议固有限制)。
service::sort_candidates_by_quota(&state.db, candidates).await;
service::execute_relay_with_failover(
&state.db, &task.id, candidates,
&state.db, &task.id, &ctx.account_id, candidates,
&request_body, stream,
max_attempts, retry_delay_ms, &enc_key
).await
@@ -553,11 +553,12 @@ pub async fn retry_task(
// 异步执行重试 — 根据解析结果选择执行路径
let db = state.db.clone();
let task_id = id.clone();
let account_id_for_spawn = task.account_id.clone();
let handle = tokio::spawn(async move {
let result = match model_resolution {
ModelResolution::Direct(ref candidate) => {
service::execute_relay(
&db, &task_id, &candidate.provider_id,
&db, &task_id, &account_id_for_spawn, &candidate.provider_id,
&candidate.base_url, &body, stream,
max_attempts, base_delay_ms, &enc_key,
true,
@@ -566,7 +567,7 @@ pub async fn retry_task(
ModelResolution::Group(ref mut candidates) => {
service::sort_candidates_by_quota(&db, candidates).await;
service::execute_relay_with_failover(
&db, &task_id, candidates,
&db, &task_id, &account_id_for_spawn, candidates,
&body, stream,
max_attempts, base_delay_ms, &enc_key,
).await

View File

@@ -217,6 +217,7 @@ impl SseUsageCapture {
pub async fn execute_relay(
db: &PgPool,
task_id: &str,
account_id: &str,
provider_id: &str,
provider_base_url: &str,
request_body: &str,
@@ -313,6 +314,7 @@ pub async fn execute_relay(
let db_clone = db.clone();
let task_id_clone = task_id.to_string();
let key_id_for_spawn = key_id.clone();
let account_id_clone = account_id.to_string();
// Bounded channel for backpressure: 128 chunks (~128KB) buffer.
// If the client reads slowly, the upstream is signaled via
@@ -369,20 +371,53 @@ pub async fn execute_relay(
tokio::spawn(async move {
let _permit = permit; // 持有 permit 直到任务完成
// Brief delay to allow SSE stream to settle before recording
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
let capture = usage_capture.lock().await;
let (input, output) = (
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
);
// Record task status with timeout to avoid holding DB connections
// 等待 SSE 流结束 — 等待 capture 稳定tokens 不再增长)
// 替代原来固定 500ms 的 race condition
let max_wait = std::time::Duration::from_secs(120);
let poll_interval = std::time::Duration::from_millis(500);
let start = tokio::time::Instant::now();
let mut last_tokens: i64 = 0;
let mut stable_count = 0;
let (input, output) = loop {
tokio::time::sleep(poll_interval).await;
let capture = usage_capture.lock().await;
let total = capture.input_tokens + capture.output_tokens;
if total == last_tokens && total > 0 {
stable_count += 1;
if stable_count >= 3 {
// 连续 3 次稳定1.5s),认为流结束
break (capture.input_tokens, capture.output_tokens);
}
} else {
stable_count = 0;
last_tokens = total;
}
drop(capture);
if start.elapsed() >= max_wait {
let capture = usage_capture.lock().await;
break (capture.input_tokens, capture.output_tokens);
}
};
let input_opt = if input > 0 { Some(input) } else { None };
let output_opt = if output > 0 { Some(output) } else { None };
// Record task status + billing usage + key usage
let db_op = async {
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input_opt, output_opt, None).await {
tracing::warn!("Failed to update task status after SSE stream: {}", e);
}
// Record key usage (now 2 queries instead of 3)
let total_tokens = input.unwrap_or(0) + output.unwrap_or(0);
// P2-9 修复: SSE 路径也更新 billing_usage_quotas
if input > 0 || output > 0 {
if let Err(e) = crate::billing::service::increment_usage(
&db_clone, &account_id_clone,
input, output,
).await {
tracing::warn!("Failed to increment billing usage for SSE task {}: {}", task_id_clone, e);
}
}
// Record key usage
let total_tokens = input + output;
if let Err(e) = super::key_pool::record_key_usage(&db_clone, &key_id_for_spawn, Some(total_tokens)).await {
tracing::warn!("Failed to record key usage: {}", e);
}
@@ -503,6 +538,7 @@ pub async fn execute_relay(
pub async fn execute_relay_with_failover(
db: &PgPool,
task_id: &str,
account_id: &str,
candidates: &[CandidateModel],
request_body: &str,
stream: bool,
@@ -533,6 +569,7 @@ pub async fn execute_relay_with_failover(
match execute_relay(
db,
task_id,
account_id,
&candidate.provider_id,
&candidate.base_url,
&patched_body,

View File

@@ -30,6 +30,9 @@ pub struct ChatRequest {
/// Enable sub-agent delegation (Ultra mode only)
#[serde(default)]
pub subagent_enabled: Option<bool>,
/// Model override — UI 选择的模型优先于 Agent 配置的默认模型
#[serde(default)]
pub model: Option<String>,
}
/// Chat response
@@ -76,6 +79,9 @@ pub struct StreamChatRequest {
/// Enable sub-agent delegation (Ultra mode only)
#[serde(default)]
pub subagent_enabled: Option<bool>,
/// Model override — UI 选择的模型优先于 Agent 配置的默认模型
#[serde(default)]
pub model: Option<String>,
}
// ---------------------------------------------------------------------------
@@ -116,7 +122,7 @@ pub async fn agent_chat(
None
};
let response = kernel.send_message_with_chat_mode(&id, request.message, chat_mode)
let response = kernel.send_message_with_chat_mode(&id, request.message, chat_mode, request.model)
.await
.map_err(|e| format!("Chat failed: {}", e))?;
@@ -257,6 +263,7 @@ pub async fn agent_chat_stream(
prompt_arg,
session_id_parsed,
Some(chat_mode_config),
request.model.clone(),
)
.await
.map_err(|e| {

View File

@@ -60,6 +60,7 @@ export function installChatMethods(ClientClass: { prototype: KernelClient }): vo
reasoning_effort?: string;
plan_mode?: boolean;
subagent_enabled?: boolean;
model?: string;
}
): Promise<{ runId: string }> {
const runId = crypto.randomUUID();
@@ -199,6 +200,7 @@ export function installChatMethods(ClientClass: { prototype: KernelClient }): vo
reasoningEffort: opts?.reasoning_effort,
planMode: opts?.plan_mode,
subagentEnabled: opts?.subagent_enabled,
model: opts?.model,
},
});
} catch (err: unknown) {

View File

@@ -303,6 +303,9 @@ export const useStreamStore = create<StreamState>()(
.map(m => ({ role: m.role, content: m.content }))
.slice(-20);
// UI 模型选择器应覆盖 Agent 默认模型
const currentModel = useConversationStore.getState().currentModel;
const result = await client.chatStream(
content,
{
@@ -534,6 +537,7 @@ export const useStreamStore = create<StreamState>()(
set({ isStreaming: false, activeRunId: null });
},
},
// eslint-disable-next-line @typescript-eslint/no-explicit-any
{
sessionKey: effectiveSessionKey,
agentId: effectiveAgentId,
@@ -542,7 +546,8 @@ export const useStreamStore = create<StreamState>()(
plan_mode: get().getChatModeConfig().plan_mode,
subagent_enabled: get().getChatModeConfig().subagent_enabled,
history,
}
model: currentModel || undefined,
} as Parameters<typeof client.chatStream>[2]
);
if (result?.runId) {

View File

@@ -18,6 +18,8 @@ services:
POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-your_secure_password}
POSTGRES_DB: ${POSTGRES_DB:-zclaw}
# 确保 UTF-8 编码 — 中文 Windows 默认 GBK 会导致中文数据损坏
POSTGRES_INITDB_ARGS: "--encoding=UTF8 --locale=C.UTF-8"
ports:
- "${POSTGRES_PORT:-5432}:5432"