diff --git a/crates/erp-ai/src/dto.rs b/crates/erp-ai/src/dto.rs new file mode 100644 index 0000000..ac69185 --- /dev/null +++ b/crates/erp-ai/src/dto.rs @@ -0,0 +1,102 @@ +use serde::{Deserialize, Serialize}; + +// === 分析请求 === + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AnalyzeRequest { + pub analysis_type: AnalysisType, + pub source_ref: String, + pub options: AnalyzeOptions, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum AnalysisType { + LabReport, + Trends, + CheckupPlan, + ReportSummary, +} + +impl AnalysisType { + pub fn as_str(&self) -> &str { + match self { + Self::LabReport => "lab_report", + Self::Trends => "trend", + Self::CheckupPlan => "checkup_plan", + Self::ReportSummary => "report_summary", + } + } + + pub fn prompt_name(&self) -> &str { + match self { + Self::LabReport => "lab_report_interpretation", + Self::Trends => "health_trend_analysis", + Self::CheckupPlan => "personalized_checkup_plan", + Self::ReportSummary => "report_summary_generation", + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AnalyzeOptions { + pub detail_level: Option, + pub language: Option, +} + +impl Default for AnalyzeOptions { + fn default() -> Self { + Self { + detail_level: Some("patient_friendly".into()), + language: Some("zh-CN".into()), + } + } +} + +// === AI Provider 请求/响应 === + +#[derive(Debug, Clone)] +pub struct GenerateRequest { + pub system_prompt: String, + pub user_prompt: String, + pub model: String, + pub temperature: f32, + pub max_tokens: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GenerateResponse { + pub content: String, + pub model: String, + pub input_tokens: u32, + pub output_tokens: u32, + pub duration_ms: u64, +} + +// === SSE 事件 === + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenUsage { + pub input: u32, + pub output: u32, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum AnalysisSseEvent { + #[serde(rename = "chunk")] + Chunk { content: String, index: u32 }, + #[serde(rename = "metadata")] + Metadata { + model: String, + tokens: TokenUsage, + duration_ms: u64, + }, + #[serde(rename = "done")] + Done { + analysis_id: uuid::Uuid, + status: String, + }, + #[serde(rename = "error")] + Error { message: String }, +} diff --git a/crates/erp-ai/src/lib.rs b/crates/erp-ai/src/lib.rs index e47f7c3..c4b3d26 100644 --- a/crates/erp-ai/src/lib.rs +++ b/crates/erp-ai/src/lib.rs @@ -1,4 +1,6 @@ +pub mod dto; pub mod entity; pub mod error; +pub mod provider; pub use error::{AiError, AiResult}; diff --git a/crates/erp-ai/src/provider/claude.rs b/crates/erp-ai/src/provider/claude.rs new file mode 100644 index 0000000..8b95707 --- /dev/null +++ b/crates/erp-ai/src/provider/claude.rs @@ -0,0 +1,227 @@ +use async_stream::stream; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::pin::Pin; + +use super::AiProvider; +use crate::dto::GenerateRequest; +use crate::error::{AiError, AiResult}; + +#[derive(Debug, Clone)] +pub struct ClaudeProvider { + client: Client, + api_key: String, + base_url: String, +} + +impl ClaudeProvider { + pub fn new(api_key: String) -> Self { + Self { + client: Client::new(), + api_key, + base_url: "https://api.anthropic.com".into(), + } + } + + pub fn with_base_url(mut self, url: String) -> Self { + self.base_url = url; + self + } +} + +#[derive(Serialize)] +struct ClaudeRequest { + model: String, + max_tokens: u32, + temperature: f32, + system: String, + messages: Vec, + stream: bool, +} + +#[derive(Serialize)] +struct ClaudeMessage { + role: String, + content: String, +} + +#[derive(Deserialize)] +struct ClaudeStreamEvent { + #[serde(rename = "type")] + event_type: String, + delta: Option, + message: Option, +} + +#[derive(Deserialize)] +struct ClaudeDelta { + text: Option, +} + +#[derive(Deserialize)] +struct ClaudeMessageResp { + usage: Option, +} + +#[derive(Deserialize)] +struct ClaudeUsage { + input_tokens: u32, + output_tokens: u32, +} + +#[async_trait] +impl AiProvider for ClaudeProvider { + async fn stream_generate( + &self, + req: GenerateRequest, + ) -> AiResult> + Send>>> { + let claude_req = ClaudeRequest { + model: req.model, + max_tokens: req.max_tokens, + temperature: req.temperature, + system: req.system_prompt, + messages: vec![ClaudeMessage { + role: "user".into(), + content: req.user_prompt, + }], + stream: true, + }; + + let response = self + .client + .post(format!("{}/v1/messages", self.base_url)) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") + .header("content-type", "application/json") + .json(&claude_req) + .send() + .await + .map_err(|e| AiError::ProviderError(format!("Claude API 请求失败: {e}")))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(AiError::ProviderError(format!( + "Claude API 错误 {status}: {body}" + ))); + } + + let stream = Box::pin(stream! { + let mut stream = response.bytes_stream(); + while let Some(chunk_result) = stream.next().await { + let bytes = match chunk_result { + Ok(b) => b, + Err(e) => { + yield Err(AiError::ProviderError(format!("流读取错误: {e}"))); + break; + } + }; + + let text = String::from_utf8_lossy(&bytes); + for line in text.lines() { + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + return; + } + if let Ok(event) = serde_json::from_str::(data) { + if event.event_type == "content_block_delta" { + if let Some(delta) = event.delta { + if let Some(text) = delta.text { + yield Ok(text); + } + } + } + } + } + } + } + }); + + Ok(stream) + } + + async fn generate(&self, req: GenerateRequest) -> AiResult { + let start = std::time::Instant::now(); + + let claude_req = ClaudeRequest { + model: req.model.clone(), + max_tokens: req.max_tokens, + temperature: req.temperature, + system: req.system_prompt, + messages: vec![ClaudeMessage { + role: "user".into(), + content: req.user_prompt, + }], + stream: false, + }; + + let resp = self + .client + .post(format!("{}/v1/messages", self.base_url)) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") + .header("content-type", "application/json") + .json(&claude_req) + .send() + .await + .map_err(|e| AiError::ProviderError(e.to_string()))?; + + let status = resp.status(); + let body = resp + .text() + .await + .map_err(|e| AiError::ProviderError(e.to_string()))?; + + if !status.is_success() { + return Err(AiError::ProviderError(format!( + "Claude {status}: {body}" + ))); + } + + let parsed: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?; + + let content = parsed["content"][0]["text"] + .as_str() + .unwrap_or("") + .to_string(); + + let input_tokens = parsed["usage"]["input_tokens"].as_u64().unwrap_or(0) as u32; + let output_tokens = parsed["usage"]["output_tokens"].as_u64().unwrap_or(0) as u32; + + Ok(crate::dto::GenerateResponse { + content, + model: req.model, + input_tokens, + output_tokens, + duration_ms: start.elapsed().as_millis() as u64, + }) + } + + fn name(&self) -> &str { + "claude" + } + + async fn health_check(&self) -> AiResult { + let resp = self + .client + .post(format!("{}/v1/messages", self.base_url)) + .header("x-api-key", &self.api_key) + .header("anthropic-version", "2023-06-01") + .header("content-type", "application/json") + .json(&serde_json::json!({ + "model": "claude-sonnet-4-6", + "max_tokens": 1, + "messages": [{"role": "user", "content": "hi"}] + })) + .send() + .await; + + match resp { + Ok(r) => Ok(r.status().is_success() || r.status().as_u16() == 400), + Err(_) => Ok(false), + } + } +} diff --git a/crates/erp-ai/src/provider/mod.rs b/crates/erp-ai/src/provider/mod.rs new file mode 100644 index 0000000..1f35e19 --- /dev/null +++ b/crates/erp-ai/src/provider/mod.rs @@ -0,0 +1,27 @@ +pub mod claude; + +use async_trait::async_trait; +use futures::Stream; +use std::pin::Pin; + +use crate::dto::GenerateRequest; +use crate::error::AiResult; + +/// AI 提供商 trait +#[async_trait] +pub trait AiProvider: Send + Sync { + /// 流式生成 + async fn stream_generate( + &self, + req: GenerateRequest, + ) -> AiResult> + Send>>>; + + /// 非流式生成 + async fn generate(&self, req: GenerateRequest) -> AiResult; + + /// 提供商名称 + fn name(&self) -> &str; + + /// 健康检查 + async fn health_check(&self) -> AiResult; +}