feat(ai): AiProvider trait + Claude SSE 流式实现 + DTO 定义
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
102
crates/erp-ai/src/dto.rs
Normal file
102
crates/erp-ai/src/dto.rs
Normal file
@@ -0,0 +1,102 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// === 分析请求 ===
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnalyzeRequest {
|
||||
pub analysis_type: AnalysisType,
|
||||
pub source_ref: String,
|
||||
pub options: AnalyzeOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AnalysisType {
|
||||
LabReport,
|
||||
Trends,
|
||||
CheckupPlan,
|
||||
ReportSummary,
|
||||
}
|
||||
|
||||
impl AnalysisType {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
Self::LabReport => "lab_report",
|
||||
Self::Trends => "trend",
|
||||
Self::CheckupPlan => "checkup_plan",
|
||||
Self::ReportSummary => "report_summary",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prompt_name(&self) -> &str {
|
||||
match self {
|
||||
Self::LabReport => "lab_report_interpretation",
|
||||
Self::Trends => "health_trend_analysis",
|
||||
Self::CheckupPlan => "personalized_checkup_plan",
|
||||
Self::ReportSummary => "report_summary_generation",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnalyzeOptions {
|
||||
pub detail_level: Option<String>,
|
||||
pub language: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for AnalyzeOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
detail_level: Some("patient_friendly".into()),
|
||||
language: Some("zh-CN".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// === AI Provider 请求/响应 ===
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GenerateRequest {
|
||||
pub system_prompt: String,
|
||||
pub user_prompt: String,
|
||||
pub model: String,
|
||||
pub temperature: f32,
|
||||
pub max_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GenerateResponse {
|
||||
pub content: String,
|
||||
pub model: String,
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
pub duration_ms: u64,
|
||||
}
|
||||
|
||||
// === SSE 事件 ===
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TokenUsage {
|
||||
pub input: u32,
|
||||
pub output: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum AnalysisSseEvent {
|
||||
#[serde(rename = "chunk")]
|
||||
Chunk { content: String, index: u32 },
|
||||
#[serde(rename = "metadata")]
|
||||
Metadata {
|
||||
model: String,
|
||||
tokens: TokenUsage,
|
||||
duration_ms: u64,
|
||||
},
|
||||
#[serde(rename = "done")]
|
||||
Done {
|
||||
analysis_id: uuid::Uuid,
|
||||
status: String,
|
||||
},
|
||||
#[serde(rename = "error")]
|
||||
Error { message: String },
|
||||
}
|
||||
@@ -1,4 +1,6 @@
|
||||
pub mod dto;
|
||||
pub mod entity;
|
||||
pub mod error;
|
||||
pub mod provider;
|
||||
|
||||
pub use error::{AiError, AiResult};
|
||||
|
||||
227
crates/erp-ai/src/provider/claude.rs
Normal file
227
crates/erp-ai/src/provider/claude.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use async_stream::stream;
|
||||
use async_trait::async_trait;
|
||||
use futures::{Stream, StreamExt};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::pin::Pin;
|
||||
|
||||
use super::AiProvider;
|
||||
use crate::dto::GenerateRequest;
|
||||
use crate::error::{AiError, AiResult};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClaudeProvider {
|
||||
client: Client,
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
impl ClaudeProvider {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
Self {
|
||||
client: Client::new(),
|
||||
api_key,
|
||||
base_url: "https://api.anthropic.com".into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_base_url(mut self, url: String) -> Self {
|
||||
self.base_url = url;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ClaudeRequest {
|
||||
model: String,
|
||||
max_tokens: u32,
|
||||
temperature: f32,
|
||||
system: String,
|
||||
messages: Vec<ClaudeMessage>,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ClaudeMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClaudeStreamEvent {
|
||||
#[serde(rename = "type")]
|
||||
event_type: String,
|
||||
delta: Option<ClaudeDelta>,
|
||||
message: Option<ClaudeMessageResp>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClaudeDelta {
|
||||
text: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClaudeMessageResp {
|
||||
usage: Option<ClaudeUsage>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct ClaudeUsage {
|
||||
input_tokens: u32,
|
||||
output_tokens: u32,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AiProvider for ClaudeProvider {
|
||||
async fn stream_generate(
|
||||
&self,
|
||||
req: GenerateRequest,
|
||||
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
|
||||
let claude_req = ClaudeRequest {
|
||||
model: req.model,
|
||||
max_tokens: req.max_tokens,
|
||||
temperature: req.temperature,
|
||||
system: req.system_prompt,
|
||||
messages: vec![ClaudeMessage {
|
||||
role: "user".into(),
|
||||
content: req.user_prompt,
|
||||
}],
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(format!("{}/v1/messages", self.base_url))
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("content-type", "application/json")
|
||||
.json(&claude_req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AiError::ProviderError(format!("Claude API 请求失败: {e}")))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
return Err(AiError::ProviderError(format!(
|
||||
"Claude API 错误 {status}: {body}"
|
||||
)));
|
||||
}
|
||||
|
||||
let stream = Box::pin(stream! {
|
||||
let mut stream = response.bytes_stream();
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let bytes = match chunk_result {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
yield Err(AiError::ProviderError(format!("流读取错误: {e}")));
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
for line in text.lines() {
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
if data == "[DONE]" {
|
||||
return;
|
||||
}
|
||||
if let Ok(event) = serde_json::from_str::<ClaudeStreamEvent>(data) {
|
||||
if event.event_type == "content_block_delta" {
|
||||
if let Some(delta) = event.delta {
|
||||
if let Some(text) = delta.text {
|
||||
yield Ok(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
|
||||
async fn generate(&self, req: GenerateRequest) -> AiResult<crate::dto::GenerateResponse> {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
let claude_req = ClaudeRequest {
|
||||
model: req.model.clone(),
|
||||
max_tokens: req.max_tokens,
|
||||
temperature: req.temperature,
|
||||
system: req.system_prompt,
|
||||
messages: vec![ClaudeMessage {
|
||||
role: "user".into(),
|
||||
content: req.user_prompt,
|
||||
}],
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/v1/messages", self.base_url))
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("content-type", "application/json")
|
||||
.json(&claude_req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
||||
|
||||
let status = resp.status();
|
||||
let body = resp
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(AiError::ProviderError(format!(
|
||||
"Claude {status}: {body}"
|
||||
)));
|
||||
}
|
||||
|
||||
let parsed: serde_json::Value = serde_json::from_str(&body)
|
||||
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
|
||||
|
||||
let content = parsed["content"][0]["text"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
let input_tokens = parsed["usage"]["input_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
let output_tokens = parsed["usage"]["output_tokens"].as_u64().unwrap_or(0) as u32;
|
||||
|
||||
Ok(crate::dto::GenerateResponse {
|
||||
content,
|
||||
model: req.model,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
duration_ms: start.elapsed().as_millis() as u64,
|
||||
})
|
||||
}
|
||||
|
||||
fn name(&self) -> &str {
|
||||
"claude"
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> AiResult<bool> {
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/v1/messages", self.base_url))
|
||||
.header("x-api-key", &self.api_key)
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("content-type", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"model": "claude-sonnet-4-6",
|
||||
"max_tokens": 1,
|
||||
"messages": [{"role": "user", "content": "hi"}]
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match resp {
|
||||
Ok(r) => Ok(r.status().is_success() || r.status().as_u16() == 400),
|
||||
Err(_) => Ok(false),
|
||||
}
|
||||
}
|
||||
}
|
||||
27
crates/erp-ai/src/provider/mod.rs
Normal file
27
crates/erp-ai/src/provider/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
pub mod claude;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use std::pin::Pin;
|
||||
|
||||
use crate::dto::GenerateRequest;
|
||||
use crate::error::AiResult;
|
||||
|
||||
/// AI 提供商 trait
|
||||
#[async_trait]
|
||||
pub trait AiProvider: Send + Sync {
|
||||
/// 流式生成
|
||||
async fn stream_generate(
|
||||
&self,
|
||||
req: GenerateRequest,
|
||||
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>>;
|
||||
|
||||
/// 非流式生成
|
||||
async fn generate(&self, req: GenerateRequest) -> AiResult<crate::dto::GenerateResponse>;
|
||||
|
||||
/// 提供商名称
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// 健康检查
|
||||
async fn health_check(&self) -> AiResult<bool>;
|
||||
}
|
||||
Reference in New Issue
Block a user