Compare commits

...

10 Commits

19 changed files with 1316 additions and 66 deletions

View File

@@ -0,0 +1,8 @@
pub mod orchestrator;
pub mod registry;
pub mod tool;
pub mod tools;
pub use orchestrator::AgentOrchestrator;
pub use registry::ToolRegistry;
pub use tool::{AgentTool, DisplayHint, ToolContext, ToolResult};

View File

@@ -0,0 +1,117 @@
use std::sync::Arc;
use super::registry::ToolRegistry;
use super::tool::ToolContext;
use crate::dto::{ChatMessage, ChatMessageRole};
use crate::error::AiResult;
use crate::provider::AiProvider;
/// Agent Orchestrator — 执行 ReAct 循环
pub struct AgentOrchestrator {
provider: Arc<dyn AiProvider>,
tool_registry: Arc<ToolRegistry>,
max_iterations: usize,
}
/// Agent 运行结果
pub struct AgentRunResult {
pub reply: String,
pub total_input_tokens: u32,
pub total_output_tokens: u32,
pub iterations: usize,
}
impl AgentOrchestrator {
pub fn new(provider: Arc<dyn AiProvider>, tool_registry: Arc<ToolRegistry>) -> Self {
Self {
provider,
tool_registry,
max_iterations: 5,
}
}
/// 执行 Agent ReAct 循环
pub async fn run(
&self,
system_prompt: &str,
messages: &mut Vec<ChatMessage>,
ctx: &ToolContext,
) -> AiResult<AgentRunResult> {
let tools = self.tool_registry.tool_definitions();
let mut iterations = 0;
let mut total_input_tokens = 0u32;
let mut total_output_tokens = 0u32;
loop {
iterations += 1;
let response = self
.provider
.generate_with_tools(
messages.clone(),
tools.clone(),
system_prompt,
"auto",
0.7,
2048,
)
.await?;
if let Some(ref usage) = response.usage {
total_input_tokens += usage.input;
total_output_tokens += usage.output;
}
// 如果没有 tool_callsAgent 给出最终回复
let tool_calls = match response.tool_calls {
Some(tc) if !tc.is_empty() => tc,
_ => {
return Ok(AgentRunResult {
reply: response.content.unwrap_or_default(),
total_input_tokens,
total_output_tokens,
iterations,
});
}
};
// 达到上限:强制结束
if iterations >= self.max_iterations {
messages.push(ChatMessage {
role: ChatMessageRole::User,
content: "(系统提示:已收集足够信息,请直接总结回复用户,不要再调用工具)"
.to_string(),
tool_calls: None,
tool_call_id: None,
});
continue;
}
// 将 assistant 的 tool_calls 加入消息历史
messages.push(ChatMessage {
role: ChatMessageRole::Assistant,
content: response.content.unwrap_or_default(),
tool_calls: Some(tool_calls.clone()),
tool_call_id: None,
});
// 执行每个 Tool Call
for tc in &tool_calls {
let tool_result = match self.tool_registry.get(&tc.name) {
Some(tool) => {
let result = tool.execute(ctx, tc.arguments.clone()).await;
result.output
}
None => format!("未知 Tool: {}", tc.name),
};
messages.push(ChatMessage {
role: ChatMessageRole::Tool,
content: tool_result,
tool_calls: None,
tool_call_id: Some(tc.id.clone()),
});
}
}
}
}

View File

@@ -0,0 +1,47 @@
use std::collections::HashMap;
use std::sync::Arc;
use super::tool::AgentTool;
/// Tool 注册表 — 管理所有可用的 Agent Tool
pub struct ToolRegistry {
tools: HashMap<String, Arc<dyn AgentTool>>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register(&mut self, tool: Arc<dyn AgentTool>) {
self.tools.insert(tool.name().to_string(), tool);
}
pub fn get(&self, name: &str) -> Option<&Arc<dyn AgentTool>> {
self.tools.get(name)
}
pub fn all_tools(&self) -> Vec<&Arc<dyn AgentTool>> {
self.tools.values().collect()
}
/// 生成传给 LLM 的 ToolDefinition 列表
pub fn tool_definitions(&self) -> Vec<crate::dto::ToolDefinition> {
self.tools
.values()
.map(|t| crate::dto::ToolDefinition {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.parameters_schema(),
})
.collect()
}
}

View File

@@ -0,0 +1,55 @@
use async_trait::async_trait;
use erp_core::health_provider::HealthDataProvider;
use sea_orm::DatabaseConnection;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
/// Agent Tool trait — 所有 Agent 可调用的工具都实现此 trait
#[async_trait]
pub trait AgentTool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(&self, ctx: &ToolContext, params: serde_json::Value) -> ToolResult;
}
/// Tool 执行上下文 — 包含安全过滤后的租户/用户信息
pub struct ToolContext {
pub tenant_id: Uuid,
pub user_id: Uuid,
pub patient_id: Option<Uuid>,
pub db: DatabaseConnection,
pub health_provider: Arc<dyn HealthDataProvider>,
}
/// Tool 执行结果
pub struct ToolResult {
pub output: String,
pub display_hint: Option<DisplayHint>,
}
/// 前端渲染提示 — 告诉前端如何富化展示 Tool 返回的数据
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DisplayHint {
VitalCard {
indicator_type: String,
values: Vec<(String, f64)>,
unit: String,
},
LabReportCard {
report_date: String,
abnormal_count: usize,
},
ActionConfirm {
action_type: String,
summary: String,
confirm_payload: serde_json::Value,
},
RiskAlert {
level: String,
message: String,
},
Text,
}

View File

@@ -0,0 +1,5 @@
// Agent Tool 实现 — Phase 0 添加 query_patient_vitals
pub mod query_vitals;
pub use query_vitals::QueryPatientVitalsTool;

View File

@@ -0,0 +1,96 @@
use async_trait::async_trait;
use erp_core::health_provider::TimeRange;
use crate::agent::tool::{AgentTool, DisplayHint, ToolContext, ToolResult};
/// 查询患者最近体征数据(血压/血糖/心率等)
pub struct QueryPatientVitalsTool;
#[async_trait]
impl AgentTool for QueryPatientVitalsTool {
fn name(&self) -> &str {
"query_patient_vitals"
}
fn description(&self) -> &str {
"查询患者最近的体征数据(血压、血糖、心率、体重、血氧等)。需要提供天数范围(默认 7 天)。"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"days": {
"type": "integer",
"description": "查询最近多少天的数据,默认 7 天"
}
}
})
}
async fn execute(&self, ctx: &ToolContext, params: serde_json::Value) -> ToolResult {
let patient_id = match ctx.patient_id {
Some(id) => id,
None => {
return ToolResult {
output: "未关联患者档案,无法查询体征数据".to_string(),
display_hint: None,
};
}
};
let days = params["days"].as_i64().unwrap_or(7);
let now = chrono::Utc::now();
let start = now - chrono::Duration::days(days);
let range = TimeRange { start, end: now };
let metrics = vec![
"systolic_bp_morning".into(),
"diastolic_bp_morning".into(),
"heart_rate".into(),
"blood_sugar".into(),
];
match ctx
.health_provider
.get_vital_signs(ctx.tenant_id, patient_id, &metrics, &range)
.await
{
Ok(vitals) => {
if vitals.is_empty() {
return ToolResult {
output: "该时间段内无体征数据".to_string(),
display_hint: None,
};
}
let mut output = String::from("最近体征数据:\n");
for v in &vitals {
output.push_str(&format!("- {}:", v.metric));
let values_str: Vec<String> = v
.values
.iter()
.take(10)
.map(|(date, val)| format!("{}={}", date, val))
.collect();
output.push_str(&format!(" ({})\n", values_str.join(", ")));
}
let display_hint = vitals.first().map(|v| DisplayHint::VitalCard {
indicator_type: v.metric.clone(),
values: v.values.iter().take(10).cloned().collect(),
unit: v.unit.clone(),
});
ToolResult {
output,
display_hint,
}
}
Err(e) => ToolResult {
output: format!("查询体征数据失败: {}", e),
display_hint: None,
},
}
}
}

View File

@@ -104,6 +104,51 @@ pub enum AnalysisSseEvent {
Error { message: String },
}
// === Agent Function Calling DTO ===
/// Agent 对话消息
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: ChatMessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChatMessageRole {
User,
Assistant,
Tool,
}
/// Tool 定义(传给 LLM 的 Function Schema
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
/// LLM 返回的 Tool Call
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: serde_json::Value,
}
/// Agent 专用生成响应
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentGenerateResponse {
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Option<TokenUsage>,
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -49,6 +49,9 @@ pub enum AiError {
#[error("AI 配置错误: {0}")]
ConfigError(String),
#[error("不支持的操作: {0}")]
UnsupportedOperation(String),
}
impl From<AiError> for AppError {

View File

@@ -4,7 +4,10 @@ use erp_core::rbac::require_permission;
use erp_core::types::{ApiResponse, TenantContext};
use serde::{Deserialize, Serialize};
use crate::dto::GenerateRequest;
use crate::agent::tool::ToolContext;
use crate::agent::tools::QueryPatientVitalsTool;
use crate::agent::{AgentOrchestrator, ToolRegistry};
use crate::dto::{ChatMessage, ChatMessageRole};
use crate::state::AiState;
// === 请求 / 响应 ===
@@ -13,6 +16,8 @@ use crate::state::AiState;
pub struct ChatRequest {
pub message: String,
pub history: Option<Vec<ChatHistoryItem>>,
/// 可选:关联患者 ID从用户档案中获取
pub patient_id: Option<uuid::Uuid>,
}
#[derive(Debug, Deserialize, Serialize, utoipa::ToSchema)]
@@ -25,26 +30,46 @@ pub struct ChatHistoryItem {
pub struct ChatResponse {
pub reply: String,
pub message_id: String,
pub iterations: usize,
}
const SYSTEM_PROMPT: &str = r#"你是 HMS 健康管理平台的 AI 客服助手"小华"。你的职责是:
1. 回答用户的健康咨询问题
2. 帮助用户了解体检报告指标
3. 提供预约挂号、用药提醒等服务指导
4. 推荐健康生活方式
const SYSTEM_PROMPT: &str = r#"你是 HMS 健康管理平台的 AI 健康顾问"小华"。
注意:
- 你不能替代医生的诊断,遇到需要诊断的问题请建议用户就医
- 不能推荐具体药物,只能提供一般性健康建议
- 语气要亲切、专业、耐心
- 回复要简洁明了,避免过长
- 如果用户问的问题超出健康范围,礼貌引导回到健康话题"#;
## 核心策略
根据用户表达的内容和情绪,自然地采用以下策略方向:
1. 【情绪安抚】当用户表达焦虑、恐惧、沮丧时:
- 先共情认可感受,不急于给建议
- 用通俗语言解释,避免医学术语
- 分享积极案例,降低恐惧感
2. 【医疗科普】当用户询问指标含义、疾病知识时:
- 调用 search_medical_knowledge 获取准确信息(如可用)
- 用比喻和类比让老年患者也能理解
- 强调"具体请以医生诊断为准"
3. 【服务推荐】当用户表达就医需求或身体不适时:
- 调用 query_appointments 查看已有预约(如可用)
- 主动提出帮用户预约
4. 【风险预警】当用户描述的症状或数据异常时:
- 调用 query_patient_vitals 查看体征数据
- 明确告知风险等级和需要注意的事项
- 高风险时建议尽快就医
5. 【引导到院】当用户有明确就诊意向或高风险预警时:
- 提供科室位置、出诊医生信息
- 建议用户联系前台预约
## 策略不是互斥的,你可以在一轮对话中自然切换。
## 永远不要:推荐具体药物、给出明确诊断、替代医生建议。
## 如果没有可用的工具数据,就基于常识回答,并建议用户咨询医生。"#;
#[utoipa::path(
post,
path = "/ai/chat",
request_body = ChatRequest,
responses((status = 200, description = "AI 客服回复")),
responses((status = 200, description = "AI Agent 回复")),
tag = "AI 客服",
security(("bearer_auth" = [])),
)]
@@ -69,30 +94,41 @@ where
));
}
let user_prompt = match body.history {
Some(ref hist) if !hist.is_empty() => {
let filtered: Vec<&ChatHistoryItem> = hist
.iter()
.filter(|h| h.role == "user" || h.role == "assistant")
.collect();
let start = filtered.len().saturating_sub(10);
let ctx: String = filtered[start..]
.iter()
.map(|h| {
format!(
"{}: {}",
if h.role == "user" { "用户" } else { "助手" },
h.content
)
})
.collect::<Vec<_>>()
.join("\n");
format!("历史对话:\n{}\n\n用户最新消息: {}", ctx, message)
}
_ => message.to_string(),
};
let ai_state = AiState::from_ref(&state);
// 构建 Agent 消息历史
let mut messages = vec![];
// 将前端传来的历史转换为 Agent ChatMessage
if let Some(ref hist) = body.history {
let filtered: Vec<&ChatHistoryItem> = hist
.iter()
.filter(|h| h.role == "user" || h.role == "assistant")
.collect();
let start = filtered.len().saturating_sub(10);
for h in &filtered[start..] {
messages.push(ChatMessage {
role: if h.role == "user" {
ChatMessageRole::User
} else {
ChatMessageRole::Assistant
},
content: h.content.clone(),
tool_calls: None,
tool_call_id: None,
});
}
}
// 添加当前用户消息
messages.push(ChatMessage {
role: ChatMessageRole::User,
content: message.to_string(),
tool_calls: None,
tool_call_id: None,
});
// 解析 Provider
let resolved = ai_state
.provider_registry
.resolve("auto")
@@ -102,37 +138,51 @@ where
erp_core::error::AppError::Internal("AI 服务暂时不可用,请稍后再试".into())
})?;
let req = GenerateRequest {
system_prompt: SYSTEM_PROMPT.to_string(),
user_prompt,
model: String::new(),
temperature: 0.7,
max_tokens: 1024,
// 构建 ToolRegistry — Phase 0 只有 query_patient_vitals
let mut registry = ToolRegistry::new();
registry.register(std::sync::Arc::new(QueryPatientVitalsTool));
let tool_ctx = ToolContext {
tenant_id: ctx.tenant_id,
user_id: ctx.user_id,
patient_id: body.patient_id,
db: ai_state.db.clone(),
health_provider: ai_state.health_provider.clone(),
};
tracing::info!(
tenant_id = %ctx.tenant_id,
user_id = %ctx.user_id,
patient_id = ?body.patient_id,
msg_len = message.len(),
"AI chat request"
"AI Agent chat request"
);
let resp = resolved.provider().generate(req).await.map_err(|e| {
tracing::error!(error = %e, "AI chat generate failed");
erp_core::error::AppError::Internal("AI 服务暂时不可用,请稍后再试".into())
})?;
// 执行 Agent ReAct 循环
let provider_arc = resolved.into_arc();
let orchestrator = AgentOrchestrator::new(provider_arc, std::sync::Arc::new(registry));
let result = orchestrator
.run(SYSTEM_PROMPT, &mut messages, &tool_ctx)
.await
.map_err(|e| {
tracing::error!(error = %e, "AI Agent run failed");
erp_core::error::AppError::Internal("AI 服务暂时不可用,请稍后再试".into())
})?;
let message_id = uuid::Uuid::now_v7().to_string();
tracing::info!(
tenant_id = %ctx.tenant_id,
message_id = %message_id,
tokens = resp.output_tokens,
"AI chat response sent"
iterations = result.iterations,
input_tokens = result.total_input_tokens,
output_tokens = result.total_output_tokens,
"AI Agent response sent"
);
Ok(Json(ApiResponse::ok(ChatResponse {
reply: resp.content,
reply: result.reply,
message_id,
iterations: result.iterations,
})))
}

View File

@@ -1,3 +1,4 @@
pub mod agent;
pub mod config;
pub mod copilot;
pub mod dto;

View File

@@ -100,6 +100,31 @@ impl ErpModule for AiModule {
description: "创建/编辑/删除 Copilot 规则".into(),
module: "ai".into(),
},
// AI 客服会话权限
PermissionDescriptor {
code: "ai.chat.send".into(),
name: "AI 客服对话".into(),
description: "向 AI 客服发送消息".into(),
module: "ai".into(),
},
PermissionDescriptor {
code: "ai.chat.session.list".into(),
name: "查看 AI 会话列表".into(),
description: "查看用户的 AI 客服会话列表".into(),
module: "ai".into(),
},
PermissionDescriptor {
code: "ai.chat.session.manage".into(),
name: "管理 AI 会话".into(),
description: "创建/关闭 AI 客服会话".into(),
module: "ai".into(),
},
PermissionDescriptor {
code: "ai.chat.session.history".into(),
name: "查看 AI 会话历史".into(),
description: "查看 AI 客服会话消息历史".into(),
module: "ai".into(),
},
]
}

View File

@@ -31,6 +31,13 @@ impl ClaudeProvider {
}
}
#[derive(Serialize)]
#[serde(untagged)]
enum ClaudeContent {
Text(String),
Blocks(Vec<serde_json::Value>),
}
#[derive(Serialize)]
struct ClaudeRequest {
model: String,
@@ -38,13 +45,22 @@ struct ClaudeRequest {
temperature: f32,
system: String,
messages: Vec<ClaudeMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ClaudeToolDef>>,
stream: bool,
}
#[derive(Serialize)]
struct ClaudeMessage {
role: String,
content: String,
content: ClaudeContent,
}
#[derive(Serialize)]
struct ClaudeToolDef {
name: String,
description: String,
input_schema: serde_json::Value,
}
#[derive(Deserialize)]
@@ -88,8 +104,9 @@ impl AiProvider for ClaudeProvider {
system: req.system_prompt,
messages: vec![ClaudeMessage {
role: "user".into(),
content: req.user_prompt,
content: ClaudeContent::Text(req.user_prompt),
}],
tools: None,
stream: true,
};
@@ -153,8 +170,9 @@ impl AiProvider for ClaudeProvider {
system: req.system_prompt,
messages: vec![ClaudeMessage {
role: "user".into(),
content: req.user_prompt,
content: ClaudeContent::Text(req.user_prompt),
}],
tools: None,
stream: false,
};
@@ -223,4 +241,138 @@ impl AiProvider for ClaudeProvider {
Err(_) => Ok(false),
}
}
async fn generate_with_tools(
&self,
messages: Vec<crate::dto::ChatMessage>,
tools: Vec<crate::dto::ToolDefinition>,
system_prompt: &str,
model: &str,
temperature: f32,
max_tokens: u32,
) -> AiResult<crate::dto::AgentGenerateResponse> {
use crate::dto::ChatMessageRole;
let claude_messages: Vec<ClaudeMessage> = messages
.iter()
.map(|m| {
let role = match m.role {
ChatMessageRole::User => "user",
ChatMessageRole::Assistant => "assistant",
ChatMessageRole::Tool => "user",
};
let content = match m.role {
ChatMessageRole::Tool => {
let blocks = vec![serde_json::json!({
"type": "tool_result",
"tool_use_id": m.tool_call_id.as_deref().unwrap_or(""),
"content": m.content,
})];
ClaudeContent::Blocks(blocks)
}
ChatMessageRole::Assistant if m.tool_calls.is_some() => {
let mut blocks = vec![];
if !m.content.is_empty() {
blocks.push(serde_json::json!({
"type": "text",
"text": m.content,
}));
}
for tc in m.tool_calls.as_ref().unwrap() {
blocks.push(serde_json::json!({
"type": "tool_use",
"id": tc.id,
"name": tc.name,
"input": tc.arguments,
}));
}
ClaudeContent::Blocks(blocks)
}
_ => ClaudeContent::Text(m.content.clone()),
};
ClaudeMessage {
role: role.to_string(),
content,
}
})
.collect();
let claude_tools: Vec<ClaudeToolDef> = tools
.into_iter()
.map(|t| ClaudeToolDef {
name: t.name,
description: t.description,
input_schema: t.parameters,
})
.collect();
let req = ClaudeRequest {
model: model.to_string(),
max_tokens,
temperature,
system: system_prompt.to_string(),
messages: claude_messages,
tools: Some(claude_tools),
stream: false,
};
let resp = self
.client
.post(format!("{}/v1/messages", self.base_url))
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&req)
.send()
.await
.map_err(|e| AiError::ProviderError(e.to_string()))?;
let status = resp.status();
let body = resp
.text()
.await
.map_err(|e| AiError::ProviderError(e.to_string()))?;
if !status.is_success() {
return Err(AiError::ProviderError(format!("Claude {status}: {body}")));
}
let parsed: serde_json::Value = serde_json::from_str(&body)
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
let mut content_text = None;
let mut tool_calls = None;
if let Some(blocks) = parsed["content"].as_array() {
for block in blocks {
match block["type"].as_str() {
Some("text") => {
content_text = block["text"].as_str().map(|s| s.to_string());
}
Some("tool_use") => {
let tc = crate::dto::ToolCall {
id: block["id"].as_str().unwrap_or_default().to_string(),
name: block["name"].as_str().unwrap_or_default().to_string(),
arguments: block["input"].clone(),
};
tool_calls.get_or_insert_with(Vec::new).push(tc);
}
_ => {}
}
}
}
let usage = parsed["usage"].as_object().map(|u| crate::dto::TokenUsage {
input: u["input_tokens"].as_u64().unwrap_or(0) as u32,
output: u["output_tokens"].as_u64().unwrap_or(0) as u32,
});
Ok(crate::dto::AgentGenerateResponse {
content: content_text,
tool_calls,
usage,
})
}
}

View File

@@ -27,4 +27,20 @@ pub trait AiProvider: Send + Sync {
/// 健康检查
async fn health_check(&self) -> AiResult<bool>;
/// Agent 专用生成方法 — 支持 Function Calling
/// 不支持 FC 的 Provider 使用默认实现(返回错误)
async fn generate_with_tools(
&self,
_messages: Vec<crate::dto::ChatMessage>,
_tools: Vec<crate::dto::ToolDefinition>,
_system_prompt: &str,
_model: &str,
_temperature: f32,
_max_tokens: u32,
) -> AiResult<crate::dto::AgentGenerateResponse> {
Err(crate::error::AiError::UnsupportedOperation(
"Function Calling not supported by this provider".into(),
))
}
}

View File

@@ -300,6 +300,20 @@ impl AiProvider for OllamaProvider {
Err(_) => Ok(false),
}
}
async fn generate_with_tools(
&self,
_messages: Vec<crate::dto::ChatMessage>,
_tools: Vec<crate::dto::ToolDefinition>,
_system_prompt: &str,
_model: &str,
_temperature: f32,
_max_tokens: u32,
) -> AiResult<crate::dto::AgentGenerateResponse> {
Err(AiError::UnsupportedOperation(
"Ollama does not support Function Calling. Use Claude or OpenAI provider for Agent features.".into(),
))
}
}
#[cfg(test)]

View File

@@ -34,13 +34,32 @@ struct ChatRequest {
max_tokens: u32,
temperature: f32,
messages: Vec<ChatMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ChatToolDef>>,
stream: bool,
}
#[derive(Serialize)]
struct ChatToolDef {
r#type: String,
function: ChatFunctionDef,
}
#[derive(Serialize)]
struct ChatFunctionDef {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Serialize)]
struct ChatMessage {
role: String,
content: String,
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ChatToolCallResp>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
}
#[derive(Deserialize)]
@@ -54,9 +73,24 @@ struct ChatChoice {
message: ChatMessageResp,
}
#[derive(Deserialize)]
#[derive(Deserialize, Serialize)]
struct ChatMessageResp {
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<ChatToolCallResp>>,
}
#[derive(Debug, Deserialize, Serialize)]
struct ChatToolCallResp {
id: String,
r#type: String,
function: ChatFunctionResp,
}
#[derive(Debug, Deserialize, Serialize)]
struct ChatFunctionResp {
name: String,
arguments: String,
}
#[derive(Deserialize)]
@@ -99,13 +133,18 @@ impl AiProvider for OpenAIProvider {
messages: vec![
ChatMessage {
role: "system".into(),
content: req.system_prompt,
content: Some(req.system_prompt),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "user".into(),
content: req.user_prompt,
content: Some(req.user_prompt),
tool_calls: None,
tool_call_id: None,
},
],
tools: None,
stream: true,
};
@@ -175,13 +214,18 @@ impl AiProvider for OpenAIProvider {
messages: vec![
ChatMessage {
role: "system".into(),
content: req.system_prompt,
content: Some(req.system_prompt),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "user".into(),
content: req.user_prompt,
content: Some(req.user_prompt),
tool_calls: None,
tool_call_id: None,
},
],
tools: None,
stream: false,
};
@@ -245,6 +289,138 @@ impl AiProvider for OpenAIProvider {
Err(_) => Ok(false),
}
}
async fn generate_with_tools(
&self,
messages: Vec<crate::dto::ChatMessage>,
tools: Vec<crate::dto::ToolDefinition>,
system_prompt: &str,
model: &str,
temperature: f32,
max_tokens: u32,
) -> AiResult<crate::dto::AgentGenerateResponse> {
use crate::dto::ChatMessageRole;
let model = if model == "auto" || model.is_empty() {
self.default_model.clone()
} else {
model.to_string()
};
let mut chat_messages = vec![ChatMessage {
role: "system".into(),
content: Some(system_prompt.to_string()),
tool_calls: None,
tool_call_id: None,
}];
for m in &messages {
let (role, content) = match m.role {
ChatMessageRole::User => ("user", Some(m.content.clone())),
ChatMessageRole::Assistant => (
"assistant",
if m.content.is_empty() {
None
} else {
Some(m.content.clone())
},
),
ChatMessageRole::Tool => ("tool", Some(m.content.clone())),
};
let tool_calls = m.tool_calls.as_ref().map(|tcs| {
tcs.iter()
.map(|tc| ChatToolCallResp {
id: tc.id.clone(),
r#type: "function".into(),
function: ChatFunctionResp {
name: tc.name.clone(),
arguments: tc.arguments.to_string(),
},
})
.collect::<Vec<_>>()
});
chat_messages.push(ChatMessage {
role: role.into(),
content,
tool_calls,
tool_call_id: m.tool_call_id.clone(),
});
}
let chat_tools: Vec<ChatToolDef> = tools
.into_iter()
.map(|t| ChatToolDef {
r#type: "function".into(),
function: ChatFunctionDef {
name: t.name,
description: t.description,
parameters: t.parameters,
},
})
.collect();
let req = ChatRequest {
model: model.clone(),
max_tokens,
temperature,
messages: chat_messages,
tools: Some(chat_tools),
stream: false,
};
let resp = self
.client
.post(format!("{}/v1/chat/completions", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("content-type", "application/json")
.json(&req)
.send()
.await
.map_err(|e| AiError::ProviderError(e.to_string()))?;
let status = resp.status();
let body = resp
.text()
.await
.map_err(|e| AiError::ProviderError(e.to_string()))?;
if !status.is_success() {
return Err(AiError::ProviderError(format!("OpenAI {status}: {body}")));
}
let parsed: ChatResponse = serde_json::from_str(&body)
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
let msg = parsed
.choices
.first()
.map(|c| &c.message)
.ok_or_else(|| AiError::ProviderError("无响应选项".into()))?;
let tool_calls = msg.tool_calls.as_ref().map(|tcs| {
tcs.iter()
.map(|tc| crate::dto::ToolCall {
id: tc.id.clone(),
name: tc.function.name.clone(),
arguments: serde_json::from_str(&tc.function.arguments)
.unwrap_or(serde_json::Value::Null),
})
.collect::<Vec<_>>()
});
let usage = parsed.usage.map(|u| crate::dto::TokenUsage {
input: u.prompt_tokens,
output: u.completion_tokens,
});
Ok(crate::dto::AgentGenerateResponse {
content: msg.content.clone(),
tool_calls,
usage,
})
}
}
#[cfg(test)]
@@ -271,13 +447,18 @@ mod tests {
messages: vec![
ChatMessage {
role: "system".into(),
content: "你是助手".into(),
content: Some("你是助手".into()),
tool_calls: None,
tool_call_id: None,
},
ChatMessage {
role: "user".into(),
content: "你好".into(),
content: Some("你好".into()),
tool_calls: None,
tool_call_id: None,
},
],
tools: None,
stream: false,
};
let json = serde_json::to_value(&req).unwrap();

View File

@@ -39,6 +39,20 @@ pub trait HealthDataProvider: Send + Sync {
metrics: &[String],
range: &TimeRange,
) -> AppResult<TrendAnalysisDto>;
/// 获取患者即将到来的预约
async fn get_upcoming_appointments(
&self,
tenant_id: Uuid,
patient_id: Uuid,
) -> AppResult<Vec<AppointmentSummaryDto>>;
/// 获取患者当前用药列表
async fn get_medication_list(
&self,
tenant_id: Uuid,
patient_id: Uuid,
) -> AppResult<Vec<MedicationSummaryDto>>;
}
// === DTO 定义 ===
@@ -152,3 +166,21 @@ pub struct AnomalyInfo {
pub std_dev: f64,
pub deviation: f64,
}
// === Agent 新增 DTO ===
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppointmentSummaryDto {
pub id: Uuid,
pub department: String,
pub doctor_name: String,
pub scheduled_at: String,
pub status: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MedicationSummaryDto {
pub name: String,
pub dosage: String,
pub frequency: String,
}

View File

@@ -3,16 +3,16 @@ use chrono::Datelike;
use erp_core::crypto::{self as pii, PiiCrypto};
use erp_core::error::{AppError, AppResult};
use erp_core::health_provider::{
AnomalyInfo, HealthDataProvider, HealthReportDto, LabItemDto, LabReportDto,
MetricTrendAnalysis, PatientSummaryDto, RegressionStats, ReportSectionDto, TimeRange,
TrendAnalysisDto, TrendDirection, VitalSignDto,
AnomalyInfo, AppointmentSummaryDto, HealthDataProvider, HealthReportDto, LabItemDto,
LabReportDto, MedicationSummaryDto, MetricTrendAnalysis, PatientSummaryDto, RegressionStats,
ReportSectionDto, TimeRange, TrendAnalysisDto, TrendDirection, VitalSignDto,
};
use num_traits::ToPrimitive;
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder};
use sea_orm::{ColumnTrait, EntityTrait, QueryFilter, QueryOrder, QuerySelect};
use uuid::Uuid;
use crate::entity::{
diagnosis, health_record, lab_report, medication_record, patient, vital_signs,
appointment, diagnosis, health_record, lab_report, medication_record, patient, vital_signs,
};
pub struct HealthDataProviderImpl {
@@ -557,4 +557,70 @@ impl HealthDataProvider for HealthDataProviderImpl {
metrics: metric_results,
})
}
async fn get_upcoming_appointments(
&self,
tenant_id: Uuid,
patient_id: Uuid,
) -> AppResult<Vec<AppointmentSummaryDto>> {
let _ = find_patient(&self.db, tenant_id, patient_id).await?;
let today = chrono::Utc::now().date_naive();
let records = appointment::Entity::find()
.filter(appointment::Column::TenantId.eq(tenant_id))
.filter(appointment::Column::PatientId.eq(patient_id))
.filter(appointment::Column::DeletedAt.is_null())
.filter(appointment::Column::AppointmentDate.gte(today))
.filter(
appointment::Column::Status
.is_in(vec!["scheduled".to_string(), "confirmed".to_string()]),
)
.order_by_asc(appointment::Column::AppointmentDate)
.order_by_asc(appointment::Column::StartTime)
.limit(10)
.all(&self.db)
.await?;
let result = records
.into_iter()
.map(|r| AppointmentSummaryDto {
id: r.id,
department: r.appointment_type,
doctor_name: r
.doctor_id
.map_or("待定".to_string(), |_| "医生".to_string()),
scheduled_at: format!("{} {}", r.appointment_date, r.start_time),
status: r.status,
})
.collect();
Ok(result)
}
async fn get_medication_list(
&self,
tenant_id: Uuid,
patient_id: Uuid,
) -> AppResult<Vec<MedicationSummaryDto>> {
let _ = find_patient(&self.db, tenant_id, patient_id).await?;
let records = medication_record::Entity::find()
.filter(medication_record::Column::TenantId.eq(tenant_id))
.filter(medication_record::Column::PatientId.eq(patient_id))
.filter(medication_record::Column::DeletedAt.is_null())
.filter(medication_record::Column::IsCurrent.eq(true))
.all(&self.db)
.await?;
let result = records
.into_iter()
.map(|m| MedicationSummaryDto {
name: m.medication_name,
dosage: m.dosage.unwrap_or_default(),
frequency: m.frequency.unwrap_or_default(),
})
.collect();
Ok(result)
}
}

View File

@@ -149,6 +149,7 @@ mod m20260513_000144_enforce_version_optimistic_lock;
mod m20260513_000145_seed_missing_permissions;
mod m20260515_000146_seed_menu_permissions_phase2;
mod m20260516_000147_seed_ai_chat_permission;
mod m20260518_000148_create_ai_chat_tables;
pub struct Migrator;
@@ -305,6 +306,7 @@ impl MigratorTrait for Migrator {
Box::new(m20260513_000145_seed_missing_permissions::Migration),
Box::new(m20260515_000146_seed_menu_permissions_phase2::Migration),
Box::new(m20260516_000147_seed_ai_chat_permission::Migration),
Box::new(m20260518_000148_create_ai_chat_tables::Migration),
]
}
}

View File

@@ -0,0 +1,335 @@
use sea_orm_migration::prelude::*;
#[derive(DeriveMigrationName)]
pub struct Migration;
#[async_trait::async_trait]
impl MigrationTrait for Migration {
async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> {
// ai_chat_sessions — AI 会话表
manager
.create_table(
Table::create()
.table(AiChatSessions::Table)
.col(
ColumnDef::new(AiChatSessions::Id)
.uuid()
.not_null()
.primary_key(),
)
.col(ColumnDef::new(AiChatSessions::TenantId).uuid().not_null())
.col(ColumnDef::new(AiChatSessions::UserId).uuid().not_null())
.col(ColumnDef::new(AiChatSessions::PatientId).uuid().null())
.col(ColumnDef::new(AiChatSessions::Title).string_len(255).null())
.col(
ColumnDef::new(AiChatSessions::Status)
.string_len(20)
.not_null()
.default("active"),
)
.col(ColumnDef::new(AiChatSessions::Metadata).json().null())
.col(
ColumnDef::new(AiChatSessions::CreatedAt)
.timestamp_with_time_zone()
.not_null()
.default(Expr::current_timestamp()),
)
.col(
ColumnDef::new(AiChatSessions::UpdatedAt)
.timestamp_with_time_zone()
.not_null()
.default(Expr::current_timestamp()),
)
.col(ColumnDef::new(AiChatSessions::CreatedBy).uuid().null())
.col(ColumnDef::new(AiChatSessions::UpdatedBy).uuid().null())
.col(
ColumnDef::new(AiChatSessions::DeletedAt)
.timestamp_with_time_zone()
.null(),
)
.col(
ColumnDef::new(AiChatSessions::VersionLock)
.integer()
.not_null()
.default(1),
)
.to_owned(),
)
.await?;
manager
.create_index(
Index::create()
.name("idx_ai_chat_sessions_tenant_user")
.table(AiChatSessions::Table)
.col(AiChatSessions::TenantId)
.col(AiChatSessions::UserId)
.to_owned(),
)
.await?;
// ai_chat_messages — AI 聊天消息表
manager
.create_table(
Table::create()
.table(AiChatMessages::Table)
.col(
ColumnDef::new(AiChatMessages::Id)
.uuid()
.not_null()
.primary_key(),
)
.col(ColumnDef::new(AiChatMessages::TenantId).uuid().not_null())
.col(ColumnDef::new(AiChatMessages::SessionId).uuid().not_null())
.col(
ColumnDef::new(AiChatMessages::Role)
.string_len(20)
.not_null(),
)
.col(ColumnDef::new(AiChatMessages::Content).text().null())
.col(ColumnDef::new(AiChatMessages::ToolCalls).json().null())
.col(
ColumnDef::new(AiChatMessages::ToolCallId)
.string_len(100)
.null(),
)
.col(ColumnDef::new(AiChatMessages::TokenCount).integer().null())
.col(
ColumnDef::new(AiChatMessages::CreatedAt)
.timestamp_with_time_zone()
.not_null()
.default(Expr::current_timestamp()),
)
.col(
ColumnDef::new(AiChatMessages::UpdatedAt)
.timestamp_with_time_zone()
.not_null()
.default(Expr::current_timestamp()),
)
.col(ColumnDef::new(AiChatMessages::CreatedBy).uuid().null())
.col(ColumnDef::new(AiChatMessages::UpdatedBy).uuid().null())
.col(
ColumnDef::new(AiChatMessages::DeletedAt)
.timestamp_with_time_zone()
.null(),
)
.col(
ColumnDef::new(AiChatMessages::VersionLock)
.integer()
.not_null()
.default(1),
)
.foreign_key(
ForeignKey::create()
.name("fk_ai_chat_messages_session")
.from(AiChatMessages::Table, AiChatMessages::SessionId)
.to(AiChatSessions::Table, AiChatSessions::Id)
.on_delete(ForeignKeyAction::Cascade),
)
.to_owned(),
)
.await?;
manager
.create_index(
Index::create()
.name("idx_ai_chat_messages_session")
.table(AiChatMessages::Table)
.col(AiChatMessages::TenantId)
.col(AiChatMessages::SessionId)
.col(AiChatMessages::CreatedAt)
.to_owned(),
)
.await?;
// ai_tool_call_logs — AI 工具调用日志append-only
manager
.create_table(
Table::create()
.table(AiToolCallLogs::Table)
.col(
ColumnDef::new(AiToolCallLogs::Id)
.uuid()
.not_null()
.primary_key(),
)
.col(ColumnDef::new(AiToolCallLogs::TenantId).uuid().not_null())
.col(ColumnDef::new(AiToolCallLogs::SessionId).uuid().not_null())
.col(ColumnDef::new(AiToolCallLogs::MessageId).uuid().not_null())
.col(
ColumnDef::new(AiToolCallLogs::ToolName)
.string_len(100)
.not_null(),
)
.col(ColumnDef::new(AiToolCallLogs::Parameters).json().null())
.col(ColumnDef::new(AiToolCallLogs::ResultSummary).text().null())
.col(ColumnDef::new(AiToolCallLogs::ExecutionMs).integer().null())
.col(ColumnDef::new(AiToolCallLogs::Success).boolean().not_null())
.col(
ColumnDef::new(AiToolCallLogs::CreatedAt)
.timestamp_with_time_zone()
.not_null()
.default(Expr::current_timestamp()),
)
.col(ColumnDef::new(AiToolCallLogs::CreatedBy).uuid().null())
.to_owned(),
)
.await?;
manager
.create_index(
Index::create()
.name("idx_ai_tool_call_logs_session")
.table(AiToolCallLogs::Table)
.col(AiToolCallLogs::TenantId)
.col(AiToolCallLogs::SessionId)
.to_owned(),
)
.await?;
// ai_user_profiles — 用户长期画像
manager
.create_table(
Table::create()
.table(AiUserProfiles::Table)
.col(
ColumnDef::new(AiUserProfiles::Id)
.uuid()
.not_null()
.primary_key(),
)
.col(ColumnDef::new(AiUserProfiles::TenantId).uuid().not_null())
.col(ColumnDef::new(AiUserProfiles::UserId).uuid().not_null())
.col(ColumnDef::new(AiUserProfiles::Preferences).json().null())
.col(ColumnDef::new(AiUserProfiles::HealthInterests).array(ColumnType::Text))
.col(ColumnDef::new(AiUserProfiles::FrequentTopics).array(ColumnType::Text))
.col(
ColumnDef::new(AiUserProfiles::PersonalitySummary)
.text()
.null(),
)
.col(
ColumnDef::new(AiUserProfiles::LastUpdatedAt)
.timestamp_with_time_zone()
.null(),
)
.col(
ColumnDef::new(AiUserProfiles::CreatedAt)
.timestamp_with_time_zone()
.not_null()
.default(Expr::current_timestamp()),
)
.col(
ColumnDef::new(AiUserProfiles::UpdatedAt)
.timestamp_with_time_zone()
.not_null()
.default(Expr::current_timestamp()),
)
.col(
ColumnDef::new(AiUserProfiles::DeletedAt)
.timestamp_with_time_zone()
.null(),
)
.col(
ColumnDef::new(AiUserProfiles::VersionLock)
.integer()
.not_null()
.default(1),
)
.index(
Index::create()
.name("uq_ai_user_profiles_tenant_user")
.col(AiUserProfiles::TenantId)
.col(AiUserProfiles::UserId)
.unique(),
)
.to_owned(),
)
.await
}
async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> {
manager
.drop_table(Table::drop().table(AiUserProfiles::Table).to_owned())
.await?;
manager
.drop_table(Table::drop().table(AiToolCallLogs::Table).to_owned())
.await?;
manager
.drop_table(Table::drop().table(AiChatMessages::Table).to_owned())
.await?;
manager
.drop_table(Table::drop().table(AiChatSessions::Table).to_owned())
.await
}
}
#[derive(DeriveIden)]
enum AiChatSessions {
Table,
Id,
TenantId,
UserId,
PatientId,
Title,
Status,
Metadata,
CreatedAt,
UpdatedAt,
CreatedBy,
UpdatedBy,
DeletedAt,
VersionLock,
}
#[derive(DeriveIden)]
enum AiChatMessages {
Table,
Id,
TenantId,
SessionId,
Role,
Content,
ToolCalls,
ToolCallId,
TokenCount,
CreatedAt,
UpdatedAt,
CreatedBy,
UpdatedBy,
DeletedAt,
VersionLock,
}
#[derive(DeriveIden)]
enum AiToolCallLogs {
Table,
Id,
TenantId,
SessionId,
MessageId,
ToolName,
Parameters,
ResultSummary,
ExecutionMs,
Success,
CreatedAt,
CreatedBy,
}
#[derive(DeriveIden)]
enum AiUserProfiles {
Table,
Id,
TenantId,
UserId,
Preferences,
HealthInterests,
FrequentTopics,
PersonalitySummary,
LastUpdatedAt,
CreatedAt,
UpdatedAt,
DeletedAt,
VersionLock,
}