From 37acd341546cead441a7f7854b27f06d7dc885dd Mon Sep 17 00:00:00 2001 From: iven Date: Tue, 5 May 2026 15:10:43 +0800 Subject: [PATCH] =?UTF-8?q?feat(ai):=20=E5=AE=9E=E7=8E=B0=20OllamaProvider?= =?UTF-8?q?=20=E6=9C=AC=E5=9C=B0=E6=A8=A1=E5=9E=8B=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 使用 /api/chat 端点,无需 API Key,支持流式/非流式生成 健康检查通过 /api/tags,含 7 个单元测试 --- crates/erp-ai/src/provider/mod.rs | 1 + crates/erp-ai/src/provider/ollama.rs | 344 +++++++++++++++++++++++++++ 2 files changed, 345 insertions(+) create mode 100644 crates/erp-ai/src/provider/ollama.rs diff --git a/crates/erp-ai/src/provider/mod.rs b/crates/erp-ai/src/provider/mod.rs index fd68e14..5388dcc 100644 --- a/crates/erp-ai/src/provider/mod.rs +++ b/crates/erp-ai/src/provider/mod.rs @@ -1,4 +1,5 @@ pub mod claude; +pub mod ollama; pub mod openai; pub mod registry; diff --git a/crates/erp-ai/src/provider/ollama.rs b/crates/erp-ai/src/provider/ollama.rs new file mode 100644 index 0000000..eecd6bf --- /dev/null +++ b/crates/erp-ai/src/provider/ollama.rs @@ -0,0 +1,344 @@ +use async_stream::stream; +use async_trait::async_trait; +use futures::{Stream, StreamExt}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::pin::Pin; + +use super::AiProvider; +use crate::dto::GenerateRequest; +use crate::error::{AiError, AiResult}; + +#[derive(Debug, Clone)] +pub struct OllamaProvider { + client: Client, + base_url: String, + default_model: String, +} + +impl OllamaProvider { + pub fn new(base_url: String, default_model: String) -> Self { + Self { + client: Client::new(), + base_url, + default_model, + } + } +} + +// Ollama /api/chat 请求格式 +#[derive(Serialize)] +struct OllamaChatRequest { + model: String, + messages: Vec, + stream: bool, + options: OllamaOptions, +} + +#[derive(Serialize)] +struct OllamaMessage { + role: String, + content: String, +} + +#[derive(Serialize)] +struct OllamaOptions { + #[serde(skip_serializing_if = "Option::is_none")] + temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + num_predict: Option, +} + +// Ollama /api/chat 非流式响应 +#[derive(Deserialize)] +struct OllamaChatResponse { + message: OllamaResponseMessage, + #[allow(dead_code)] + model: String, + #[allow(dead_code)] + done: bool, + eval_count: Option, + prompt_eval_count: Option, + total_duration: Option, +} + +#[derive(Deserialize)] +struct OllamaResponseMessage { + content: String, +} + +// Ollama /api/chat 流式响应 +#[derive(Deserialize)] +struct OllamaStreamChunk { + message: Option, + done: bool, +} + +#[derive(Deserialize)] +struct OllamaStreamMessage { + content: Option, +} + +#[async_trait] +impl AiProvider for OllamaProvider { + async fn stream_generate( + &self, + req: GenerateRequest, + ) -> AiResult> + Send>>> { + let model = if req.model.is_empty() { + self.default_model.clone() + } else { + req.model + }; + + let ollama_req = OllamaChatRequest { + model, + messages: vec![ + OllamaMessage { + role: "system".into(), + content: req.system_prompt, + }, + OllamaMessage { + role: "user".into(), + content: req.user_prompt, + }, + ], + stream: true, + options: OllamaOptions { + temperature: Some(req.temperature), + num_predict: Some(req.max_tokens), + }, + }; + + let response = self + .client + .post(format!("{}/api/chat", self.base_url)) + .header("content-type", "application/json") + .json(&ollama_req) + .send() + .await + .map_err(|e| AiError::ProviderError(format!("Ollama API 请求失败: {e}")))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(AiError::ProviderError(format!( + "Ollama API 错误 {status}: {body}" + ))); + } + + let s = Box::pin(stream! { + let mut stream = response.bytes_stream(); + while let Some(chunk_result) = stream.next().await { + let bytes = match chunk_result { + Ok(b) => b, + Err(e) => { + yield Err(AiError::ProviderError(format!("流读取错误: {e}"))); + break; + } + }; + + let text = String::from_utf8_lossy(&bytes); + for line in text.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + if let Ok(chunk) = serde_json::from_str::(line) { + if chunk.done { + return; + } + if let Some(msg) = chunk.message { + if let Some(content) = msg.content { + yield Ok(content); + } + } + } + } + } + }); + + Ok(s) + } + + async fn generate(&self, req: GenerateRequest) -> AiResult { + let start = std::time::Instant::now(); + + let model = if req.model.is_empty() { + self.default_model.clone() + } else { + req.model.clone() + }; + + let ollama_req = OllamaChatRequest { + model: model.clone(), + messages: vec![ + OllamaMessage { + role: "system".into(), + content: req.system_prompt, + }, + OllamaMessage { + role: "user".into(), + content: req.user_prompt, + }, + ], + stream: false, + options: OllamaOptions { + temperature: Some(req.temperature), + num_predict: Some(req.max_tokens), + }, + }; + + let resp = self + .client + .post(format!("{}/api/chat", self.base_url)) + .header("content-type", "application/json") + .json(&ollama_req) + .send() + .await + .map_err(|e| AiError::ProviderError(e.to_string()))?; + + let status = resp.status(); + let body = resp + .text() + .await + .map_err(|e| AiError::ProviderError(e.to_string()))?; + + if !status.is_success() { + return Err(AiError::ProviderError(format!( + "Ollama {status}: {body}" + ))); + } + + let parsed: OllamaChatResponse = serde_json::from_str(&body) + .map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?; + + let duration_ms = parsed + .total_duration + .map(|ns| ns / 1_000_000) + .unwrap_or_else(|| start.elapsed().as_millis() as u64); + + let input_tokens = parsed.prompt_eval_count.unwrap_or(0) as u32; + let output_tokens = parsed.eval_count.unwrap_or(0) as u32; + + Ok(crate::dto::GenerateResponse { + content: parsed.message.content, + model, + input_tokens, + output_tokens, + duration_ms, + }) + } + + fn name(&self) -> &str { + "ollama" + } + + async fn health_check(&self) -> AiResult { + let resp = self + .client + .get(format!("{}/api/tags", self.base_url)) + .send() + .await; + + match resp { + Ok(r) => Ok(r.status().is_success()), + Err(_) => Ok(false), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ollama_provider_construction() { + let provider = OllamaProvider::new( + "http://localhost:11434".into(), + "qwen2.5:7b".into(), + ); + assert_eq!(provider.name(), "ollama"); + assert_eq!(provider.default_model, "qwen2.5:7b"); + } + + #[test] + fn ollama_chat_request_serialization() { + let req = OllamaChatRequest { + model: "qwen2.5:7b".into(), + messages: vec![ + OllamaMessage { + role: "system".into(), + content: "你是助手".into(), + }, + OllamaMessage { + role: "user".into(), + content: "你好".into(), + }, + ], + stream: false, + options: OllamaOptions { + temperature: Some(0.7), + num_predict: Some(1024), + }, + }; + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["model"], "qwen2.5:7b"); + let temp = json["options"]["temperature"].as_f64().unwrap(); + assert!((temp - 0.7).abs() < 0.01); + } + + #[test] + fn ollama_response_deserialization() { + let json = r#"{ + "model": "qwen2.5:7b", + "created_at": "2024-01-01T00:00:00Z", + "message": {"role": "assistant", "content": "你好!"}, + "done": true, + "eval_count": 5, + "prompt_eval_count": 10, + "total_duration": 1500000000 + }"#; + let resp: OllamaChatResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.message.content, "你好!"); + assert_eq!(resp.eval_count, Some(5)); + assert_eq!(resp.prompt_eval_count, Some(10)); + assert_eq!(resp.total_duration, Some(1_500_000_000)); + } + + #[test] + fn ollama_stream_chunk_deserialization() { + let json = r#"{ + "message": {"role": "assistant", "content": "Hello"}, + "done": false + }"#; + let chunk: OllamaStreamChunk = serde_json::from_str(json).unwrap(); + assert!(!chunk.done); + assert_eq!( + chunk.message.unwrap().content, + Some("Hello".to_string()) + ); + } + + #[test] + fn ollama_stream_done_chunk() { + let json = r#"{ + "message": null, + "done": true, + "total_duration": 2000000000, + "eval_count": 20 + }"#; + let chunk: OllamaStreamChunk = serde_json::from_str(json).unwrap(); + assert!(chunk.done); + assert!(chunk.message.is_none()); + } + + #[test] + fn base_url_preserved() { + let provider = OllamaProvider::new( + "http://192.168.1.100:11434".into(), + "llama3.1:8b".into(), + ); + assert_eq!(provider.base_url, "http://192.168.1.100:11434"); + } +}