182 lines
5.6 KiB
Rust
182 lines
5.6 KiB
Rust
use futures::Stream;
|
|
use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, Set};
|
|
use sha2::{Digest, Sha256};
|
|
use std::pin::Pin;
|
|
use uuid::Uuid;
|
|
|
|
use crate::dto::{AnalysisType, GenerateRequest};
|
|
use crate::entity::ai_analysis;
|
|
use crate::error::{AiError, AiResult};
|
|
use crate::prompt::PromptRenderer;
|
|
use crate::provider::AiProvider;
|
|
use crate::sanitization::SanitizationService;
|
|
|
|
pub struct AnalysisService {
|
|
pub provider: Box<dyn AiProvider>,
|
|
pub sanitizer: SanitizationService,
|
|
pub renderer: PromptRenderer,
|
|
pub db: sea_orm::DatabaseConnection,
|
|
}
|
|
|
|
impl AnalysisService {
|
|
pub fn new(provider: Box<dyn AiProvider>, db: sea_orm::DatabaseConnection) -> Self {
|
|
Self {
|
|
provider,
|
|
sanitizer: SanitizationService::new(),
|
|
renderer: PromptRenderer::new(),
|
|
db,
|
|
}
|
|
}
|
|
|
|
/// 执行流式分析 — 返回 SSE 事件流
|
|
pub async fn stream_analyze(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
_user_id: Uuid,
|
|
patient_id: Uuid,
|
|
analysis_type: AnalysisType,
|
|
source_ref: String,
|
|
system_prompt: String,
|
|
user_template: String,
|
|
sanitized_data: serde_json::Value,
|
|
model: String,
|
|
temperature: f32,
|
|
max_tokens: u32,
|
|
) -> AiResult<(
|
|
Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>,
|
|
uuid::Uuid,
|
|
String,
|
|
)> {
|
|
let analysis_id = Uuid::now_v7();
|
|
let input_hash = self.compute_hash(&sanitized_data);
|
|
let provider_name = self.provider.name().to_string();
|
|
|
|
// 1. 渲染 Prompt
|
|
let user_prompt = self.renderer.render(&user_template, &sanitized_data)?;
|
|
|
|
// 2. 创建分析记录
|
|
self.create_analysis_record(
|
|
analysis_id,
|
|
tenant_id,
|
|
patient_id,
|
|
analysis_type.as_str(),
|
|
&source_ref,
|
|
&input_hash,
|
|
&provider_name,
|
|
&model,
|
|
)
|
|
.await?;
|
|
|
|
// 3. 调用 AI 流式生成
|
|
let req = GenerateRequest {
|
|
system_prompt,
|
|
user_prompt,
|
|
model,
|
|
temperature,
|
|
max_tokens,
|
|
};
|
|
let stream = self.provider.stream_generate(req).await?;
|
|
|
|
Ok((stream, analysis_id, provider_name))
|
|
}
|
|
|
|
/// 更新分析记录为完成
|
|
pub async fn complete_analysis(
|
|
&self,
|
|
analysis_id: Uuid,
|
|
content: String,
|
|
metadata: serde_json::Value,
|
|
) -> AiResult<()> {
|
|
let entity = ai_analysis::Entity::find_by_id(analysis_id)
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?;
|
|
|
|
let mut active: ai_analysis::ActiveModel = entity.into();
|
|
active.status = Set("completed".into());
|
|
active.result_content = Set(Some(content));
|
|
active.result_metadata = Set(Some(metadata));
|
|
active.updated_at = Set(chrono::Utc::now());
|
|
active.update(&self.db).await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// 标记分析失败
|
|
pub async fn fail_analysis(&self, analysis_id: Uuid, error: String) -> AiResult<()> {
|
|
let entity = ai_analysis::Entity::find_by_id(analysis_id)
|
|
.one(&self.db)
|
|
.await?
|
|
.ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?;
|
|
|
|
let mut active: ai_analysis::ActiveModel = entity.into();
|
|
active.status = Set("failed".into());
|
|
active.error_message = Set(Some(error));
|
|
active.updated_at = Set(chrono::Utc::now());
|
|
active.update(&self.db).await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// 查找缓存
|
|
pub async fn find_cached(
|
|
&self,
|
|
tenant_id: Uuid,
|
|
input_hash: &str,
|
|
prompt_version: i32,
|
|
) -> AiResult<Option<ai_analysis::Model>> {
|
|
let result = ai_analysis::Entity::find()
|
|
.filter(ai_analysis::Column::TenantId.eq(tenant_id))
|
|
.filter(ai_analysis::Column::InputDataHash.eq(input_hash))
|
|
.filter(ai_analysis::Column::PromptVersion.eq(prompt_version))
|
|
.filter(ai_analysis::Column::Status.eq("completed"))
|
|
.filter(ai_analysis::Column::DeletedAt.is_null())
|
|
.one(&self.db)
|
|
.await?;
|
|
Ok(result)
|
|
}
|
|
|
|
fn compute_hash(&self, data: &serde_json::Value) -> String {
|
|
let canonical = serde_json::to_string(data).unwrap_or_default();
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(canonical.as_bytes());
|
|
hex::encode(hasher.finalize())
|
|
}
|
|
|
|
async fn create_analysis_record(
|
|
&self,
|
|
id: Uuid,
|
|
tenant_id: Uuid,
|
|
patient_id: Uuid,
|
|
analysis_type: &str,
|
|
source_ref: &str,
|
|
input_hash: &str,
|
|
_provider: &str,
|
|
model: &str,
|
|
) -> AiResult<()> {
|
|
let now = chrono::Utc::now();
|
|
let active = ai_analysis::ActiveModel {
|
|
id: Set(id),
|
|
tenant_id: Set(tenant_id),
|
|
patient_id: Set(patient_id),
|
|
analysis_type: Set(analysis_type.into()),
|
|
source_ref: Set(source_ref.into()),
|
|
prompt_id: Set(Uuid::nil()),
|
|
prompt_version: Set(1),
|
|
model_used: Set(model.into()),
|
|
input_data_hash: Set(input_hash.into()),
|
|
sanitized_input: Set(None),
|
|
result_content: Set(None),
|
|
result_metadata: Set(None),
|
|
status: Set("streaming".into()),
|
|
error_message: Set(None),
|
|
created_at: Set(now),
|
|
updated_at: Set(now),
|
|
created_by: Set(None),
|
|
updated_by: Set(None),
|
|
deleted_at: Set(None),
|
|
version_lock: Set(1),
|
|
};
|
|
active.insert(&self.db).await?;
|
|
Ok(())
|
|
}
|
|
}
|