Files
hms/crates/erp-ai/src/handler/chat_handler.rs
iven 7e3d27ecf3 feat(ai): Phase 1A 收尾 — 用量记录 + 健康摘要端点 + 小程序组件
- chat_handler 添加 log_usage 精确记录 token 消耗(provider + model)
- SSE build_sse_stream 添加估算 token 用量记录(4 字符 ≈ 1 token)
- 新增 GET /ai/health-summary 端点聚合患者洞察+分析记录
- 小程序 AiHealthSummaryCard 组件(风险等级+洞察统计+摘要列表)
- 小程序 services/ai-analysis 新增 getHealthSummary API
2026-05-18 23:20:06 +08:00

200 lines
6.0 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::Json;
use axum::extract::{Extension, FromRef, State};
use erp_core::rbac::require_permission;
use erp_core::types::{ApiResponse, TenantContext};
use serde::{Deserialize, Serialize};
use crate::agent::orchestrator::AgentRunParams;
use crate::agent::tool::ToolContext;
use crate::agent::tools::QueryPatientVitalsTool;
use crate::agent::{AgentOrchestrator, ToolRegistry};
use crate::config_resolver;
use crate::dto::{ChatMessage, ChatMessageRole};
use crate::state::AiState;
// === 请求 / 响应 ===
#[derive(Debug, Deserialize, utoipa::ToSchema)]
pub struct ChatRequest {
pub message: String,
pub history: Option<Vec<ChatHistoryItem>>,
/// 可选:关联患者 ID从用户档案中获取
pub patient_id: Option<uuid::Uuid>,
}
#[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)]
pub struct ChatHistoryItem {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize, utoipa::ToSchema)]
pub struct ChatResponse {
pub reply: String,
pub message_id: String,
pub iterations: usize,
}
#[utoipa::path(
post,
path = "/ai/chat",
request_body = ChatRequest,
responses((status = 200, description = "AI Agent 回复")),
tag = "AI 客服",
security(("bearer_auth" = [])),
)]
pub async fn chat<S>(
Extension(ctx): Extension<TenantContext>,
State(state): State<S>,
Json(body): Json<ChatRequest>,
) -> Result<Json<ApiResponse<ChatResponse>>, erp_core::error::AppError>
where
AiState: FromRef<S>,
S: Clone + Send + Sync + 'static,
{
require_permission(&ctx, "ai.chat.send")?;
let message = body.message.trim();
if message.is_empty() {
return Err(erp_core::error::AppError::Validation("消息不能为空".into()));
}
if message.len() > 2000 {
return Err(erp_core::error::AppError::Validation(
"消息长度不能超过 2000 字".into(),
));
}
let ai_state = AiState::from_ref(&state);
// 从 settings 表加载 AI 配置(替代硬编码)
let config = config_resolver::load_ai_config(ctx.tenant_id, &ai_state.db).await;
// 构建 Agent 消息历史
let mut messages = vec![];
// 将前端传来的历史转换为 Agent ChatMessage
if let Some(ref hist) = body.history {
let filtered: Vec<&ChatHistoryItem> = hist
.iter()
.filter(|h| h.role == "user" || h.role == "assistant")
.collect();
let start = filtered.len().saturating_sub(10);
for h in &filtered[start..] {
messages.push(ChatMessage {
role: if h.role == "user" {
ChatMessageRole::User
} else {
ChatMessageRole::Assistant
},
content: h.content.clone(),
tool_calls: None,
tool_call_id: None,
});
}
}
// 添加当前用户消息
messages.push(ChatMessage {
role: ChatMessageRole::User,
content: message.to_string(),
tool_calls: None,
tool_call_id: None,
});
// 解析 Provider — Agent 需要 Function Calling精确获取 Claude/OpenAI
let provider_arc = ai_state
.provider_registry
.get_provider("claude")
.or_else(|| ai_state.provider_registry.get_provider("openai"))
.ok_or_else(|| {
tracing::error!("No FC-capable provider found (need claude or openai)");
erp_core::error::AppError::Internal(
"AI Agent 暂时不可用,需要 Claude 或 OpenAI 提供商".into(),
)
})?;
// 构建 ToolRegistry — Phase 0 只有 query_patient_vitals
let mut registry = ToolRegistry::new();
registry.register(std::sync::Arc::new(QueryPatientVitalsTool));
let tool_ctx = ToolContext {
tenant_id: ctx.tenant_id,
user_id: ctx.user_id,
patient_id: body.patient_id,
db: ai_state.db.clone(),
health_provider: ai_state.health_provider.clone(),
};
let run_params = AgentRunParams {
model: config.agent.model,
temperature: config.agent.temperature,
max_tokens: config.agent.max_tokens,
max_iterations: config.agent.max_iterations,
};
tracing::info!(
tenant_id = %ctx.tenant_id,
user_id = %ctx.user_id,
patient_id = ?body.patient_id,
msg_len = message.len(),
model = %run_params.model,
temperature = run_params.temperature,
max_tokens = run_params.max_tokens,
max_iterations = run_params.max_iterations,
"AI Agent chat request"
);
let provider_name = provider_arc.name().to_string();
// 执行 Agent ReAct 循环
let orchestrator = AgentOrchestrator::new(provider_arc, std::sync::Arc::new(registry));
let result = orchestrator
.run(
&config.agent.system_prompt,
&mut messages,
&tool_ctx,
&run_params,
)
.await
.map_err(|e| {
tracing::error!(error = %e, "AI Agent run failed");
erp_core::error::AppError::Internal("AI 服务暂时不可用,请稍后再试".into())
})?;
let message_id = uuid::Uuid::now_v7().to_string();
tracing::info!(
tenant_id = %ctx.tenant_id,
message_id = %message_id,
iterations = result.iterations,
input_tokens = result.total_input_tokens,
output_tokens = result.total_output_tokens,
"AI Agent response sent"
);
// 记录用量的 token 消耗
if let Err(e) = ai_state
.usage
.log_usage(
ctx.tenant_id,
&provider_name,
&run_params.model,
"chat",
result.total_input_tokens as u32,
result.total_output_tokens as u32,
0,
0,
false,
)
.await
{
tracing::warn!(error = %e, "Failed to log chat usage");
}
Ok(Json(ApiResponse::ok(ChatResponse {
reply: result.reply,
message_id,
iterations: result.iterations,
})))
}