fix: 三端联调测试 2 P0 + 6 P1 + 2 P2 修复
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
P0-1: SaaS relay 模型别名解析 — "glm-4-flash" → "glm-4-flash-250414" (resolve_model)
P0-2: config.rs interpolate_env_vars UTF-8 修复 (chars 迭代器替代 bytes as char)
+ DB 启动编码检查 + docker-compose UTF-8 编码参数
P1-3: UI 模型选择器覆盖 Agent 默认模型 (model_override 全链路: TS→Tauri→Rust kernel)
P1-6: 知识搜索管道修复 — seed_knowledge 创建 chunks + 默认分类 (seed/uploaded/distillation)
P1-7: 用量限额从当前 Plan 读取 (非 stale usage 表)
P1-8: relay 双维度配额检查 (relay_requests + input_tokens)
P2-9: SSE 路径 token 计数修复 — 流结束检测替代固定 500ms sleep + billing increment
This commit is contained in:
@@ -25,7 +25,7 @@ impl Kernel {
|
||||
agent_id: &AgentId,
|
||||
message: String,
|
||||
) -> Result<MessageResponse> {
|
||||
self.send_message_with_chat_mode(agent_id, message, None).await
|
||||
self.send_message_with_chat_mode(agent_id, message, None, None).await
|
||||
}
|
||||
|
||||
/// Send a message to an agent with optional chat mode configuration
|
||||
@@ -34,6 +34,7 @@ impl Kernel {
|
||||
agent_id: &AgentId,
|
||||
message: String,
|
||||
chat_mode: Option<ChatModeConfig>,
|
||||
model_override: Option<String>,
|
||||
) -> Result<MessageResponse> {
|
||||
let agent_config = self.registry.get(agent_id)
|
||||
.ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Agent not found: {}", agent_id)))?;
|
||||
@@ -41,12 +42,16 @@ impl Kernel {
|
||||
// Create or get session
|
||||
let session_id = self.memory.create_session(agent_id).await?;
|
||||
|
||||
// Use agent-level model if configured, otherwise fall back to global config
|
||||
let model = if !agent_config.model.model.is_empty() {
|
||||
agent_config.model.model.clone()
|
||||
} else {
|
||||
self.config.model().to_string()
|
||||
};
|
||||
// Model priority: UI override > Agent config > Global config
|
||||
let model = model_override
|
||||
.filter(|m| !m.is_empty())
|
||||
.unwrap_or_else(|| {
|
||||
if !agent_config.model.model.is_empty() {
|
||||
agent_config.model.model.clone()
|
||||
} else {
|
||||
self.config.model().to_string()
|
||||
}
|
||||
});
|
||||
|
||||
// Create agent loop with model configuration
|
||||
let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false);
|
||||
@@ -122,7 +127,7 @@ impl Kernel {
|
||||
agent_id: &AgentId,
|
||||
message: String,
|
||||
) -> Result<mpsc::Receiver<zclaw_runtime::LoopEvent>> {
|
||||
self.send_message_stream_with_prompt(agent_id, message, None, None, None).await
|
||||
self.send_message_stream_with_prompt(agent_id, message, None, None, None, None).await
|
||||
}
|
||||
|
||||
/// Send a message with streaming, optional system prompt, optional session reuse,
|
||||
@@ -134,6 +139,7 @@ impl Kernel {
|
||||
system_prompt_override: Option<String>,
|
||||
session_id_override: Option<zclaw_types::SessionId>,
|
||||
chat_mode: Option<ChatModeConfig>,
|
||||
model_override: Option<String>,
|
||||
) -> Result<mpsc::Receiver<zclaw_runtime::LoopEvent>> {
|
||||
let agent_config = self.registry.get(agent_id)
|
||||
.ok_or_else(|| zclaw_types::ZclawError::NotFound(format!("Agent not found: {}", agent_id)))?;
|
||||
@@ -150,12 +156,16 @@ impl Kernel {
|
||||
None => self.memory.create_session(agent_id).await?,
|
||||
};
|
||||
|
||||
// Use agent-level model if configured, otherwise fall back to global config
|
||||
let model = if !agent_config.model.model.is_empty() {
|
||||
agent_config.model.model.clone()
|
||||
} else {
|
||||
self.config.model().to_string()
|
||||
};
|
||||
// Model priority: UI override > Agent config > Global config
|
||||
let model = model_override
|
||||
.filter(|m| !m.is_empty())
|
||||
.unwrap_or_else(|| {
|
||||
if !agent_config.model.model.is_empty() {
|
||||
agent_config.model.model.clone()
|
||||
} else {
|
||||
self.config.model().to_string()
|
||||
}
|
||||
});
|
||||
|
||||
// Create agent loop with model configuration
|
||||
let subagent_enabled = chat_mode.as_ref().and_then(|m| m.subagent_enabled).unwrap_or(false);
|
||||
|
||||
@@ -282,19 +282,26 @@ pub async fn increment_dimension_by(
|
||||
}
|
||||
|
||||
/// 检查用量配额
|
||||
///
|
||||
/// P1-7 修复: 从当前 Plan 读取限额(而非 stale 的 usage 表冗余列)
|
||||
/// P1-8 修复: 支持 relay_requests + input_tokens 双维度检查
|
||||
pub async fn check_quota(
|
||||
pool: &PgPool,
|
||||
account_id: &str,
|
||||
quota_type: &str,
|
||||
) -> SaasResult<QuotaCheck> {
|
||||
let usage = get_or_create_usage(pool, account_id).await?;
|
||||
// 从当前 Plan 读取真实限额,而非 usage 表的 stale 冗余列
|
||||
let plan = get_account_plan(pool, account_id).await?;
|
||||
let limits: crate::billing::types::PlanLimits = serde_json::from_value(plan.limits)
|
||||
.unwrap_or_else(|_| crate::billing::types::PlanLimits::free());
|
||||
|
||||
let (current, limit) = match quota_type {
|
||||
"input_tokens" => (usage.input_tokens, usage.max_input_tokens),
|
||||
"output_tokens" => (usage.output_tokens, usage.max_output_tokens),
|
||||
"relay_requests" => (usage.relay_requests as i64, usage.max_relay_requests.map(|v| v as i64)),
|
||||
"hand_executions" => (usage.hand_executions as i64, usage.max_hand_executions.map(|v| v as i64)),
|
||||
"pipeline_runs" => (usage.pipeline_runs as i64, usage.max_pipeline_runs.map(|v| v as i64)),
|
||||
"input_tokens" => (usage.input_tokens, limits.max_input_tokens_monthly),
|
||||
"output_tokens" => (usage.output_tokens, limits.max_output_tokens_monthly),
|
||||
"relay_requests" => (usage.relay_requests as i64, limits.max_relay_requests_monthly.map(|v| v as i64)),
|
||||
"hand_executions" => (usage.hand_executions as i64, limits.max_hand_executions_monthly.map(|v| v as i64)),
|
||||
"pipeline_runs" => (usage.pipeline_runs as i64, limits.max_pipeline_runs_monthly.map(|v| v as i64)),
|
||||
_ => return Ok(QuotaCheck {
|
||||
allowed: true,
|
||||
reason: None,
|
||||
@@ -309,7 +316,7 @@ pub async fn check_quota(
|
||||
|
||||
Ok(QuotaCheck {
|
||||
allowed,
|
||||
reason: if !allowed { Some(format!("{} 配额已用尽", quota_type)) } else { None },
|
||||
reason: if !allowed { Some(format!("{} 配额已用尽 (已用 {}/{})", quota_type, current, limit.unwrap_or(0))) } else { None },
|
||||
current,
|
||||
limit,
|
||||
remaining,
|
||||
|
||||
@@ -248,6 +248,37 @@ impl AppCache {
|
||||
.map(|r| r.value().clone())
|
||||
}
|
||||
|
||||
/// 按别名查找模型 — 用于向后兼容旧模型 ID (如 "glm-4-flash" → "glm-4-flash-250414")
|
||||
/// 先按 alias 字段精确匹配,再按 model_id 前缀匹配(去掉日期后缀)
|
||||
pub fn resolve_model(&self, model_name: &str) -> Option<CachedModel> {
|
||||
// 1. 直接 model_id 查找
|
||||
if let Some(m) = self.get_model(model_name) {
|
||||
return Some(m);
|
||||
}
|
||||
// 2. 按 alias 精确匹配
|
||||
for entry in self.models.iter() {
|
||||
if entry.value().enabled && entry.value().alias == model_name {
|
||||
return Some(entry.value().clone());
|
||||
}
|
||||
}
|
||||
// 3. 前缀匹配: "glm-4-flash" 匹配 "glm-4-flash-250414" 等带后缀的模型
|
||||
for entry in self.models.iter() {
|
||||
let mid = &entry.value().model_id;
|
||||
if entry.value().enabled
|
||||
&& (mid.starts_with(&format!("{}-", model_name))
|
||||
|| mid.starts_with(&format!("{}v", model_name)))
|
||||
{
|
||||
tracing::info!(
|
||||
"Model alias resolved: {} → {}",
|
||||
model_name,
|
||||
mid
|
||||
);
|
||||
return Some(entry.value().clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// 按 provider id 查找已启用的 Provider。O(1) DashMap 查找。
|
||||
pub fn get_provider(&self, provider_id: &str) -> Option<CachedProvider> {
|
||||
self.providers.get(provider_id)
|
||||
|
||||
@@ -465,22 +465,25 @@ impl SaaSConfig {
|
||||
|
||||
/// 替换 TOML 配置文件中的 `${ENV_VAR}` 模式为环境变量值
|
||||
/// 未设置的环境变量保留原文,后续数据库连接或 JWT 初始化时会报明确错误
|
||||
///
|
||||
/// 注意: 使用 chars() 迭代器而非 bytes() 来正确处理多字节 UTF-8 字符(如中文),
|
||||
/// 避免将多字节 UTF-8 序列的每个字节单独 `as char` 导致编码损坏。
|
||||
fn interpolate_env_vars(content: &str) -> String {
|
||||
let mut result = String::with_capacity(content.len());
|
||||
let bytes = content.as_bytes();
|
||||
let chars: Vec<char> = content.chars().collect();
|
||||
let mut i = 0;
|
||||
while i < bytes.len() {
|
||||
if i + 1 < bytes.len() && bytes[i] == b'$' && bytes[i + 1] == b'{' {
|
||||
while i < chars.len() {
|
||||
if i + 1 < chars.len() && chars[i] == '$' && chars[i + 1] == '{' {
|
||||
let start = i + 2;
|
||||
let mut end = start;
|
||||
while end < bytes.len()
|
||||
&& (bytes[end].is_ascii_alphanumeric() || bytes[end] == b'_')
|
||||
while end < chars.len()
|
||||
&& (chars[end].is_ascii_alphanumeric() || chars[end] == '_')
|
||||
{
|
||||
end += 1;
|
||||
}
|
||||
if end < bytes.len() && bytes[end] == b'}' {
|
||||
let var_name = std::str::from_utf8(&bytes[start..end]).unwrap_or("");
|
||||
match std::env::var(var_name) {
|
||||
if end < chars.len() && chars[end] == '}' {
|
||||
let var_name: String = chars[start..end].iter().collect();
|
||||
match std::env::var(&var_name) {
|
||||
Ok(val) => {
|
||||
tracing::debug!("Config: ${{{}}} → resolved ({} bytes)", var_name, val.len());
|
||||
result.push_str(&val);
|
||||
@@ -492,11 +495,11 @@ fn interpolate_env_vars(content: &str) -> String {
|
||||
}
|
||||
i = end + 1;
|
||||
} else {
|
||||
result.push(bytes[i] as char);
|
||||
result.push(chars[i]);
|
||||
i += 1;
|
||||
}
|
||||
} else {
|
||||
result.push(bytes[i] as char);
|
||||
result.push(chars[i]);
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,10 +38,25 @@ pub async fn init_db(config: &DatabaseConfig) -> SaasResult<PgPool> {
|
||||
.connect(&database_url)
|
||||
.await?;
|
||||
|
||||
// 验证数据库编码为 UTF8 — 中文 Windows (GBK/代码页936) 可能导致默认非 UTF8
|
||||
let encoding: (String,) = sqlx::query_as("SHOW server_encoding")
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.unwrap_or(("UNKNOWN".to_string(),));
|
||||
if encoding.0.to_uppercase() != "UTF8" {
|
||||
tracing::error!(
|
||||
"⚠ 数据库编码为 '{}',非 UTF8!中文数据将损坏。请使用 CREATE DATABASE ... WITH ENCODING='UTF8' 重建数据库。",
|
||||
encoding.0
|
||||
);
|
||||
} else {
|
||||
tracing::info!("Database encoding: {}", encoding.0);
|
||||
}
|
||||
|
||||
run_migrations(&pool).await?;
|
||||
ensure_security_columns(&pool).await?;
|
||||
seed_admin_account(&pool).await?;
|
||||
seed_builtin_prompts(&pool).await?;
|
||||
seed_knowledge_categories(&pool).await?;
|
||||
seed_builtin_industries(&pool).await?;
|
||||
seed_demo_data(&pool).await?;
|
||||
fix_seed_data(&pool).await?;
|
||||
@@ -1004,6 +1019,31 @@ async fn seed_builtin_industries(pool: &PgPool) -> SaasResult<()> {
|
||||
crate::industry::service::seed_builtin_industries(pool).await
|
||||
}
|
||||
|
||||
/// 种子化知识库默认分类(幂等)
|
||||
async fn seed_knowledge_categories(pool: &PgPool) -> SaasResult<()> {
|
||||
let now = chrono::Utc::now();
|
||||
let categories = [
|
||||
("seed", "种子知识", "系统内置的行业基础知识"),
|
||||
("uploaded", "上传文档", "用户上传的文档知识"),
|
||||
("distillation", "蒸馏知识", "API 蒸馏生成的知识"),
|
||||
];
|
||||
for (id, name, desc) in &categories {
|
||||
sqlx::query(
|
||||
"INSERT INTO knowledge_categories (id, name, description, created_at, updated_at) \
|
||||
VALUES ($1, $2, $3, $4, $4) \
|
||||
ON CONFLICT (id) DO NOTHING"
|
||||
)
|
||||
.bind(id)
|
||||
.bind(name)
|
||||
.bind(desc)
|
||||
.bind(&now)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
tracing::debug!("Seeded knowledge categories");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// PostgreSQL 单元测试需要真实数据库连接,此处保留接口兼容
|
||||
|
||||
@@ -632,6 +632,8 @@ pub async fn unified_search(
|
||||
// === 种子知识冷启动 ===
|
||||
|
||||
/// 为指定行业插入种子知识(幂等)
|
||||
///
|
||||
/// P1-6 修复: 同时创建 knowledge_chunks 以支持搜索
|
||||
pub async fn seed_knowledge(
|
||||
pool: &PgPool,
|
||||
industry_id: &str,
|
||||
@@ -684,6 +686,24 @@ pub async fn seed_knowledge(
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
// 创建 chunks 以支持搜索(与 distill_knowledge worker 一致)
|
||||
let chunks = chunk_content(content, 500, 50);
|
||||
for (chunk_idx, chunk_text) in chunks.iter().enumerate() {
|
||||
let chunk_id = uuid::Uuid::new_v4().to_string();
|
||||
sqlx::query(
|
||||
"INSERT INTO knowledge_chunks (id, item_id, content, keywords, chunk_index, created_at) \
|
||||
VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
)
|
||||
.bind(&chunk_id)
|
||||
.bind(&id)
|
||||
.bind(chunk_text)
|
||||
.bind(&kw_json)
|
||||
.bind(chunk_idx as i32)
|
||||
.bind(&now)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
}
|
||||
|
||||
created += 1;
|
||||
}
|
||||
Ok(created)
|
||||
|
||||
@@ -145,6 +145,26 @@ pub async fn quota_check_middleware(
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// P1-8 修复: 同时检查 input_tokens 配额
|
||||
match crate::billing::service::check_quota(&state.db, &account_id, "input_tokens").await {
|
||||
Ok(check) if !check.allowed => {
|
||||
tracing::warn!(
|
||||
"Token quota exceeded for account {}: {} ({}/{})",
|
||||
account_id,
|
||||
check.reason.as_deref().unwrap_or("Token配额已用尽"),
|
||||
check.current,
|
||||
check.limit.map(|l| l.to_string()).unwrap_or_else(|| "∞".into()),
|
||||
);
|
||||
return SaasError::RateLimited(
|
||||
check.reason.unwrap_or_else(|| "月度 Token 配额已用尽".into()),
|
||||
).into_response();
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Token quota check failed for account {}: {}", account_id, e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
||||
|
||||
@@ -152,8 +152,8 @@ pub async fn chat_completions(
|
||||
}
|
||||
ModelResolution::Group(candidates)
|
||||
} else {
|
||||
// 向后兼容:直接模型查找
|
||||
let target_model = state.cache.get_model(model_name)
|
||||
// 向后兼容:直接模型查找 + 别名解析(如 "glm-4-flash" → "glm-4-flash-250414")
|
||||
let target_model = state.cache.resolve_model(model_name)
|
||||
.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在或未启用", model_name)))?;
|
||||
|
||||
// 获取 provider 信息 — 使用内存缓存消除 DB 查询
|
||||
@@ -218,7 +218,7 @@ pub async fn chat_completions(
|
||||
ModelResolution::Direct(ref candidate) => {
|
||||
// 单 Provider 直接路由(向后兼容)
|
||||
match service::execute_relay(
|
||||
&state.db, &task.id, &candidate.provider_id,
|
||||
&state.db, &task.id, &ctx.account_id, &candidate.provider_id,
|
||||
&candidate.base_url, &request_body, stream,
|
||||
max_attempts, retry_delay_ms, &enc_key,
|
||||
true, // 独立调用,管理 task 状态
|
||||
@@ -233,7 +233,7 @@ pub async fn chat_completions(
|
||||
// SSE 一旦开始流式传输,中途上游断连不会触发 failover(SSE 协议固有限制)。
|
||||
service::sort_candidates_by_quota(&state.db, candidates).await;
|
||||
service::execute_relay_with_failover(
|
||||
&state.db, &task.id, candidates,
|
||||
&state.db, &task.id, &ctx.account_id, candidates,
|
||||
&request_body, stream,
|
||||
max_attempts, retry_delay_ms, &enc_key
|
||||
).await
|
||||
@@ -553,11 +553,12 @@ pub async fn retry_task(
|
||||
// 异步执行重试 — 根据解析结果选择执行路径
|
||||
let db = state.db.clone();
|
||||
let task_id = id.clone();
|
||||
let account_id_for_spawn = task.account_id.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let result = match model_resolution {
|
||||
ModelResolution::Direct(ref candidate) => {
|
||||
service::execute_relay(
|
||||
&db, &task_id, &candidate.provider_id,
|
||||
&db, &task_id, &account_id_for_spawn, &candidate.provider_id,
|
||||
&candidate.base_url, &body, stream,
|
||||
max_attempts, base_delay_ms, &enc_key,
|
||||
true,
|
||||
@@ -566,7 +567,7 @@ pub async fn retry_task(
|
||||
ModelResolution::Group(ref mut candidates) => {
|
||||
service::sort_candidates_by_quota(&db, candidates).await;
|
||||
service::execute_relay_with_failover(
|
||||
&db, &task_id, candidates,
|
||||
&db, &task_id, &account_id_for_spawn, candidates,
|
||||
&body, stream,
|
||||
max_attempts, base_delay_ms, &enc_key,
|
||||
).await
|
||||
|
||||
@@ -217,6 +217,7 @@ impl SseUsageCapture {
|
||||
pub async fn execute_relay(
|
||||
db: &PgPool,
|
||||
task_id: &str,
|
||||
account_id: &str,
|
||||
provider_id: &str,
|
||||
provider_base_url: &str,
|
||||
request_body: &str,
|
||||
@@ -313,6 +314,7 @@ pub async fn execute_relay(
|
||||
let db_clone = db.clone();
|
||||
let task_id_clone = task_id.to_string();
|
||||
let key_id_for_spawn = key_id.clone();
|
||||
let account_id_clone = account_id.to_string();
|
||||
|
||||
// Bounded channel for backpressure: 128 chunks (~128KB) buffer.
|
||||
// If the client reads slowly, the upstream is signaled via
|
||||
@@ -369,20 +371,53 @@ pub async fn execute_relay(
|
||||
|
||||
tokio::spawn(async move {
|
||||
let _permit = permit; // 持有 permit 直到任务完成
|
||||
// Brief delay to allow SSE stream to settle before recording
|
||||
tokio::time::sleep(std::time::Duration::from_millis(500)).await;
|
||||
let capture = usage_capture.lock().await;
|
||||
let (input, output) = (
|
||||
if capture.input_tokens > 0 { Some(capture.input_tokens) } else { None },
|
||||
if capture.output_tokens > 0 { Some(capture.output_tokens) } else { None },
|
||||
);
|
||||
// Record task status with timeout to avoid holding DB connections
|
||||
// 等待 SSE 流结束 — 等待 capture 稳定(tokens 不再增长)
|
||||
// 替代原来固定 500ms 的 race condition
|
||||
let max_wait = std::time::Duration::from_secs(120);
|
||||
let poll_interval = std::time::Duration::from_millis(500);
|
||||
let start = tokio::time::Instant::now();
|
||||
let mut last_tokens: i64 = 0;
|
||||
let mut stable_count = 0;
|
||||
let (input, output) = loop {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
let capture = usage_capture.lock().await;
|
||||
let total = capture.input_tokens + capture.output_tokens;
|
||||
if total == last_tokens && total > 0 {
|
||||
stable_count += 1;
|
||||
if stable_count >= 3 {
|
||||
// 连续 3 次稳定(1.5s),认为流结束
|
||||
break (capture.input_tokens, capture.output_tokens);
|
||||
}
|
||||
} else {
|
||||
stable_count = 0;
|
||||
last_tokens = total;
|
||||
}
|
||||
drop(capture);
|
||||
if start.elapsed() >= max_wait {
|
||||
let capture = usage_capture.lock().await;
|
||||
break (capture.input_tokens, capture.output_tokens);
|
||||
}
|
||||
};
|
||||
|
||||
let input_opt = if input > 0 { Some(input) } else { None };
|
||||
let output_opt = if output > 0 { Some(output) } else { None };
|
||||
|
||||
// Record task status + billing usage + key usage
|
||||
let db_op = async {
|
||||
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input, output, None).await {
|
||||
if let Err(e) = update_task_status(&db_clone, &task_id_clone, "completed", input_opt, output_opt, None).await {
|
||||
tracing::warn!("Failed to update task status after SSE stream: {}", e);
|
||||
}
|
||||
// Record key usage (now 2 queries instead of 3)
|
||||
let total_tokens = input.unwrap_or(0) + output.unwrap_or(0);
|
||||
// P2-9 修复: SSE 路径也更新 billing_usage_quotas
|
||||
if input > 0 || output > 0 {
|
||||
if let Err(e) = crate::billing::service::increment_usage(
|
||||
&db_clone, &account_id_clone,
|
||||
input, output,
|
||||
).await {
|
||||
tracing::warn!("Failed to increment billing usage for SSE task {}: {}", task_id_clone, e);
|
||||
}
|
||||
}
|
||||
// Record key usage
|
||||
let total_tokens = input + output;
|
||||
if let Err(e) = super::key_pool::record_key_usage(&db_clone, &key_id_for_spawn, Some(total_tokens)).await {
|
||||
tracing::warn!("Failed to record key usage: {}", e);
|
||||
}
|
||||
@@ -503,6 +538,7 @@ pub async fn execute_relay(
|
||||
pub async fn execute_relay_with_failover(
|
||||
db: &PgPool,
|
||||
task_id: &str,
|
||||
account_id: &str,
|
||||
candidates: &[CandidateModel],
|
||||
request_body: &str,
|
||||
stream: bool,
|
||||
@@ -533,6 +569,7 @@ pub async fn execute_relay_with_failover(
|
||||
match execute_relay(
|
||||
db,
|
||||
task_id,
|
||||
account_id,
|
||||
&candidate.provider_id,
|
||||
&candidate.base_url,
|
||||
&patched_body,
|
||||
|
||||
@@ -30,6 +30,9 @@ pub struct ChatRequest {
|
||||
/// Enable sub-agent delegation (Ultra mode only)
|
||||
#[serde(default)]
|
||||
pub subagent_enabled: Option<bool>,
|
||||
/// Model override — UI 选择的模型优先于 Agent 配置的默认模型
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
/// Chat response
|
||||
@@ -76,6 +79,9 @@ pub struct StreamChatRequest {
|
||||
/// Enable sub-agent delegation (Ultra mode only)
|
||||
#[serde(default)]
|
||||
pub subagent_enabled: Option<bool>,
|
||||
/// Model override — UI 选择的模型优先于 Agent 配置的默认模型
|
||||
#[serde(default)]
|
||||
pub model: Option<String>,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -116,7 +122,7 @@ pub async fn agent_chat(
|
||||
None
|
||||
};
|
||||
|
||||
let response = kernel.send_message_with_chat_mode(&id, request.message, chat_mode)
|
||||
let response = kernel.send_message_with_chat_mode(&id, request.message, chat_mode, request.model)
|
||||
.await
|
||||
.map_err(|e| format!("Chat failed: {}", e))?;
|
||||
|
||||
@@ -257,6 +263,7 @@ pub async fn agent_chat_stream(
|
||||
prompt_arg,
|
||||
session_id_parsed,
|
||||
Some(chat_mode_config),
|
||||
request.model.clone(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
|
||||
@@ -60,6 +60,7 @@ export function installChatMethods(ClientClass: { prototype: KernelClient }): vo
|
||||
reasoning_effort?: string;
|
||||
plan_mode?: boolean;
|
||||
subagent_enabled?: boolean;
|
||||
model?: string;
|
||||
}
|
||||
): Promise<{ runId: string }> {
|
||||
const runId = crypto.randomUUID();
|
||||
@@ -199,6 +200,7 @@ export function installChatMethods(ClientClass: { prototype: KernelClient }): vo
|
||||
reasoningEffort: opts?.reasoning_effort,
|
||||
planMode: opts?.plan_mode,
|
||||
subagentEnabled: opts?.subagent_enabled,
|
||||
model: opts?.model,
|
||||
},
|
||||
});
|
||||
} catch (err: unknown) {
|
||||
|
||||
@@ -303,6 +303,9 @@ export const useStreamStore = create<StreamState>()(
|
||||
.map(m => ({ role: m.role, content: m.content }))
|
||||
.slice(-20);
|
||||
|
||||
// UI 模型选择器应覆盖 Agent 默认模型
|
||||
const currentModel = useConversationStore.getState().currentModel;
|
||||
|
||||
const result = await client.chatStream(
|
||||
content,
|
||||
{
|
||||
@@ -534,6 +537,7 @@ export const useStreamStore = create<StreamState>()(
|
||||
set({ isStreaming: false, activeRunId: null });
|
||||
},
|
||||
},
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
{
|
||||
sessionKey: effectiveSessionKey,
|
||||
agentId: effectiveAgentId,
|
||||
@@ -542,7 +546,8 @@ export const useStreamStore = create<StreamState>()(
|
||||
plan_mode: get().getChatModeConfig().plan_mode,
|
||||
subagent_enabled: get().getChatModeConfig().subagent_enabled,
|
||||
history,
|
||||
}
|
||||
model: currentModel || undefined,
|
||||
} as Parameters<typeof client.chatStream>[2]
|
||||
);
|
||||
|
||||
if (result?.runId) {
|
||||
|
||||
@@ -18,6 +18,8 @@ services:
|
||||
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-your_secure_password}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-zclaw}
|
||||
# 确保 UTF-8 编码 — 中文 Windows 默认 GBK 会导致中文数据损坏
|
||||
POSTGRES_INITDB_ARGS: "--encoding=UTF8 --locale=C.UTF-8"
|
||||
|
||||
ports:
|
||||
- "${POSTGRES_PORT:-5432}:5432"
|
||||
|
||||
Reference in New Issue
Block a user