diff --git a/crates/erp-ai/src/lib.rs b/crates/erp-ai/src/lib.rs index d80659c..c83f9c1 100644 --- a/crates/erp-ai/src/lib.rs +++ b/crates/erp-ai/src/lib.rs @@ -4,5 +4,6 @@ pub mod error; pub mod prompt; pub mod provider; pub mod sanitization; +pub mod service; pub use error::{AiError, AiResult}; diff --git a/crates/erp-ai/src/service/analysis.rs b/crates/erp-ai/src/service/analysis.rs new file mode 100644 index 0000000..68f810e --- /dev/null +++ b/crates/erp-ai/src/service/analysis.rs @@ -0,0 +1,181 @@ +use futures::Stream; +use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, Set}; +use sha2::{Digest, Sha256}; +use std::pin::Pin; +use uuid::Uuid; + +use crate::dto::{AnalysisType, GenerateRequest}; +use crate::entity::ai_analysis; +use crate::error::{AiError, AiResult}; +use crate::prompt::PromptRenderer; +use crate::provider::AiProvider; +use crate::sanitization::SanitizationService; + +pub struct AnalysisService { + pub provider: Box, + pub sanitizer: SanitizationService, + pub renderer: PromptRenderer, + pub db: sea_orm::DatabaseConnection, +} + +impl AnalysisService { + pub fn new(provider: Box, db: sea_orm::DatabaseConnection) -> Self { + Self { + provider, + sanitizer: SanitizationService::new(), + renderer: PromptRenderer::new(), + db, + } + } + + /// 执行流式分析 — 返回 SSE 事件流 + pub async fn stream_analyze( + &self, + tenant_id: Uuid, + _user_id: Uuid, + patient_id: Uuid, + analysis_type: AnalysisType, + source_ref: String, + system_prompt: String, + user_template: String, + sanitized_data: serde_json::Value, + model: String, + temperature: f32, + max_tokens: u32, + ) -> AiResult<( + Pin> + Send>>, + uuid::Uuid, + String, + )> { + let analysis_id = Uuid::now_v7(); + let input_hash = self.compute_hash(&sanitized_data); + let provider_name = self.provider.name().to_string(); + + // 1. 渲染 Prompt + let user_prompt = self.renderer.render(&user_template, &sanitized_data)?; + + // 2. 创建分析记录 + self.create_analysis_record( + analysis_id, + tenant_id, + patient_id, + analysis_type.as_str(), + &source_ref, + &input_hash, + &provider_name, + &model, + ) + .await?; + + // 3. 调用 AI 流式生成 + let req = GenerateRequest { + system_prompt, + user_prompt, + model, + temperature, + max_tokens, + }; + let stream = self.provider.stream_generate(req).await?; + + Ok((stream, analysis_id, provider_name)) + } + + /// 更新分析记录为完成 + pub async fn complete_analysis( + &self, + analysis_id: Uuid, + content: String, + metadata: serde_json::Value, + ) -> AiResult<()> { + let entity = ai_analysis::Entity::find_by_id(analysis_id) + .one(&self.db) + .await? + .ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?; + + let mut active: ai_analysis::ActiveModel = entity.into(); + active.status = Set("completed".into()); + active.result_content = Set(Some(content)); + active.result_metadata = Set(Some(metadata)); + active.updated_at = Set(chrono::Utc::now()); + active.update(&self.db).await?; + Ok(()) + } + + /// 标记分析失败 + pub async fn fail_analysis(&self, analysis_id: Uuid, error: String) -> AiResult<()> { + let entity = ai_analysis::Entity::find_by_id(analysis_id) + .one(&self.db) + .await? + .ok_or_else(|| AiError::AnalysisNotFound(analysis_id.to_string()))?; + + let mut active: ai_analysis::ActiveModel = entity.into(); + active.status = Set("failed".into()); + active.error_message = Set(Some(error)); + active.updated_at = Set(chrono::Utc::now()); + active.update(&self.db).await?; + Ok(()) + } + + /// 查找缓存 + pub async fn find_cached( + &self, + tenant_id: Uuid, + input_hash: &str, + prompt_version: i32, + ) -> AiResult> { + let result = ai_analysis::Entity::find() + .filter(ai_analysis::Column::TenantId.eq(tenant_id)) + .filter(ai_analysis::Column::InputDataHash.eq(input_hash)) + .filter(ai_analysis::Column::PromptVersion.eq(prompt_version)) + .filter(ai_analysis::Column::Status.eq("completed")) + .filter(ai_analysis::Column::DeletedAt.is_null()) + .one(&self.db) + .await?; + Ok(result) + } + + fn compute_hash(&self, data: &serde_json::Value) -> String { + let canonical = serde_json::to_string(data).unwrap_or_default(); + let mut hasher = Sha256::new(); + hasher.update(canonical.as_bytes()); + hex::encode(hasher.finalize()) + } + + async fn create_analysis_record( + &self, + id: Uuid, + tenant_id: Uuid, + patient_id: Uuid, + analysis_type: &str, + source_ref: &str, + input_hash: &str, + _provider: &str, + model: &str, + ) -> AiResult<()> { + let now = chrono::Utc::now(); + let active = ai_analysis::ActiveModel { + id: Set(id), + tenant_id: Set(tenant_id), + patient_id: Set(patient_id), + analysis_type: Set(analysis_type.into()), + source_ref: Set(source_ref.into()), + prompt_id: Set(Uuid::nil()), + prompt_version: Set(1), + model_used: Set(model.into()), + input_data_hash: Set(input_hash.into()), + sanitized_input: Set(None), + result_content: Set(None), + result_metadata: Set(None), + status: Set("streaming".into()), + error_message: Set(None), + created_at: Set(now), + updated_at: Set(now), + created_by: Set(None), + updated_by: Set(None), + deleted_at: Set(None), + version_lock: Set(1), + }; + active.insert(&self.db).await?; + Ok(()) + } +} diff --git a/crates/erp-ai/src/service/mod.rs b/crates/erp-ai/src/service/mod.rs new file mode 100644 index 0000000..087d0ca --- /dev/null +++ b/crates/erp-ai/src/service/mod.rs @@ -0,0 +1,3 @@ +pub mod analysis; +pub mod prompt; +pub mod usage; diff --git a/crates/erp-ai/src/service/prompt.rs b/crates/erp-ai/src/service/prompt.rs new file mode 100644 index 0000000..0ae018f --- /dev/null +++ b/crates/erp-ai/src/service/prompt.rs @@ -0,0 +1,67 @@ +use sea_orm::{ActiveModelTrait, ColumnTrait, EntityTrait, QueryFilter, Set}; +use uuid::Uuid; + +use crate::entity::ai_prompt; +use crate::error::{AiError, AiResult}; + +pub struct PromptService { + pub db: sea_orm::DatabaseConnection, +} + +impl PromptService { + pub fn new(db: sea_orm::DatabaseConnection) -> Self { + Self { db } + } + + /// 获取当前激活的 Prompt 模板 + pub async fn get_active_prompt( + &self, + tenant_id: Uuid, + name: &str, + ) -> AiResult { + ai_prompt::Entity::find() + .filter(ai_prompt::Column::TenantId.eq(tenant_id)) + .filter(ai_prompt::Column::Name.eq(name)) + .filter(ai_prompt::Column::IsActive.eq(true)) + .filter(ai_prompt::Column::DeletedAt.is_null()) + .one(&self.db) + .await? + .ok_or_else(|| AiError::PromptNotFound(name.into())) + } + + /// 新建 Prompt + pub async fn create_prompt( + &self, + tenant_id: Uuid, + user_id: Uuid, + name: String, + system_prompt: String, + user_prompt_template: String, + model_config: serde_json::Value, + category: String, + ) -> AiResult { + let id = Uuid::now_v7(); + let now = chrono::Utc::now(); + let active = ai_prompt::ActiveModel { + id: Set(id), + tenant_id: Set(tenant_id), + name: Set(name), + description: Set(String::new()), + system_prompt: Set(system_prompt), + user_prompt_template: Set(user_prompt_template), + variables_schema: Set(None), + model_config: Set(model_config), + version: Set(1), + is_active: Set(true), + category: Set(category), + tags: Set(None), + created_at: Set(now), + updated_at: Set(now), + created_by: Set(Some(user_id)), + updated_by: Set(Some(user_id)), + deleted_at: Set(None), + version_lock: Set(1), + }; + Ok(active.insert(&self.db).await?) + } +} diff --git a/crates/erp-ai/src/service/usage.rs b/crates/erp-ai/src/service/usage.rs new file mode 100644 index 0000000..1c20d4b --- /dev/null +++ b/crates/erp-ai/src/service/usage.rs @@ -0,0 +1,45 @@ +use sea_orm::ActiveModelTrait; +use sea_orm::Set; +use uuid::Uuid; + +use crate::entity::ai_usage; +use crate::error::AiResult; + +pub struct UsageService { + pub db: sea_orm::DatabaseConnection, +} + +impl UsageService { + pub fn new(db: sea_orm::DatabaseConnection) -> Self { + Self { db } + } + + pub async fn log_usage( + &self, + tenant_id: Uuid, + provider: &str, + model: &str, + analysis_type: &str, + input_tokens: u32, + output_tokens: u32, + duration_ms: u64, + cost_cents: i32, + is_cache_hit: bool, + ) -> AiResult { + let id = Uuid::now_v7(); + let active = ai_usage::ActiveModel { + id: Set(id), + tenant_id: Set(tenant_id), + provider: Set(provider.into()), + model: Set(model.into()), + analysis_type: Set(analysis_type.into()), + input_tokens: Set(input_tokens as i32), + output_tokens: Set(output_tokens as i32), + duration_ms: Set(duration_ms as i32), + cost_cents: Set(cost_cents), + is_cache_hit: Set(is_cache_hit), + created_at: Set(chrono::Utc::now()), + }; + Ok(active.insert(&self.db).await?) + } +}