use axum::Json; use axum::extract::{Extension, FromRef, State}; use erp_core::rbac::require_permission; use erp_core::types::{ApiResponse, TenantContext}; use serde::{Deserialize, Serialize}; use crate::agent::orchestrator::AgentRunParams; use crate::agent::tool::ToolContext; use crate::agent::tools::QueryPatientVitalsTool; use crate::agent::{AgentOrchestrator, ToolRegistry}; use crate::config_resolver; use crate::dto::{ChatMessage, ChatMessageRole}; use crate::state::AiState; // === 请求 / 响应 === #[derive(Debug, Deserialize, utoipa::ToSchema)] pub struct ChatRequest { pub message: String, pub history: Option>, /// 可选:关联患者 ID(从用户档案中获取) pub patient_id: Option, } #[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)] pub struct ChatHistoryItem { pub role: String, pub content: String, } #[derive(Debug, Serialize, utoipa::ToSchema)] pub struct ChatResponse { pub reply: String, pub message_id: String, pub iterations: usize, } #[utoipa::path( post, path = "/ai/chat", request_body = ChatRequest, responses((status = 200, description = "AI Agent 回复")), tag = "AI 客服", security(("bearer_auth" = [])), )] pub async fn chat( Extension(ctx): Extension, State(state): State, Json(body): Json, ) -> Result>, erp_core::error::AppError> where AiState: FromRef, S: Clone + Send + Sync + 'static, { require_permission(&ctx, "ai.chat.send")?; let message = body.message.trim(); if message.is_empty() { return Err(erp_core::error::AppError::Validation("消息不能为空".into())); } if message.len() > 2000 { return Err(erp_core::error::AppError::Validation( "消息长度不能超过 2000 字".into(), )); } let ai_state = AiState::from_ref(&state); // 从 settings 表加载 AI 配置(替代硬编码) let config = config_resolver::load_ai_config(ctx.tenant_id, &ai_state.db).await; // 构建 Agent 消息历史 let mut messages = vec![]; // 将前端传来的历史转换为 Agent ChatMessage if let Some(ref hist) = body.history { let filtered: Vec<&ChatHistoryItem> = hist .iter() .filter(|h| h.role == "user" || h.role == "assistant") .collect(); let start = filtered.len().saturating_sub(10); for h in &filtered[start..] { messages.push(ChatMessage { role: if h.role == "user" { ChatMessageRole::User } else { ChatMessageRole::Assistant }, content: h.content.clone(), tool_calls: None, tool_call_id: None, }); } } // 添加当前用户消息 messages.push(ChatMessage { role: ChatMessageRole::User, content: message.to_string(), tool_calls: None, tool_call_id: None, }); // 解析 Provider — Agent 需要 Function Calling,精确获取 Claude/OpenAI let provider_arc = ai_state .provider_registry .get_provider("claude") .or_else(|| ai_state.provider_registry.get_provider("openai")) .ok_or_else(|| { tracing::error!("No FC-capable provider found (need claude or openai)"); erp_core::error::AppError::Internal( "AI Agent 暂时不可用,需要 Claude 或 OpenAI 提供商".into(), ) })?; // 构建 ToolRegistry — Phase 0 只有 query_patient_vitals let mut registry = ToolRegistry::new(); registry.register(std::sync::Arc::new(QueryPatientVitalsTool)); let tool_ctx = ToolContext { tenant_id: ctx.tenant_id, user_id: ctx.user_id, patient_id: body.patient_id, db: ai_state.db.clone(), health_provider: ai_state.health_provider.clone(), }; let run_params = AgentRunParams { model: config.agent.model, temperature: config.agent.temperature, max_tokens: config.agent.max_tokens, max_iterations: config.agent.max_iterations, }; tracing::info!( tenant_id = %ctx.tenant_id, user_id = %ctx.user_id, patient_id = ?body.patient_id, msg_len = message.len(), model = %run_params.model, temperature = run_params.temperature, max_tokens = run_params.max_tokens, max_iterations = run_params.max_iterations, "AI Agent chat request" ); let provider_name = provider_arc.name().to_string(); // 执行 Agent ReAct 循环 let orchestrator = AgentOrchestrator::new(provider_arc, std::sync::Arc::new(registry)); let result = orchestrator .run( &config.agent.system_prompt, &mut messages, &tool_ctx, &run_params, ) .await .map_err(|e| { tracing::error!(error = %e, "AI Agent run failed"); erp_core::error::AppError::Internal("AI 服务暂时不可用,请稍后再试".into()) })?; let message_id = uuid::Uuid::now_v7().to_string(); tracing::info!( tenant_id = %ctx.tenant_id, message_id = %message_id, iterations = result.iterations, input_tokens = result.total_input_tokens, output_tokens = result.total_output_tokens, "AI Agent response sent" ); // 记录用量的 token 消耗 if let Err(e) = ai_state .usage .log_usage( ctx.tenant_id, &provider_name, &run_params.model, "chat", result.total_input_tokens as u32, result.total_output_tokens as u32, 0, 0, false, ) .await { tracing::warn!(error = %e, "Failed to log chat usage"); } Ok(Json(ApiResponse::ok(ChatResponse { reply: result.reply, message_id, iterations: result.iterations, }))) }