feat(ai): 实现 OllamaProvider 本地模型支持
使用 /api/chat 端点,无需 API Key,支持流式/非流式生成 健康检查通过 /api/tags,含 7 个单元测试
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
pub mod claude;
|
pub mod claude;
|
||||||
|
pub mod ollama;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
pub mod registry;
|
pub mod registry;
|
||||||
|
|
||||||
|
|||||||
344
crates/erp-ai/src/provider/ollama.rs
Normal file
344
crates/erp-ai/src/provider/ollama.rs
Normal file
@@ -0,0 +1,344 @@
|
|||||||
|
use async_stream::stream;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use futures::{Stream, StreamExt};
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
use super::AiProvider;
|
||||||
|
use crate::dto::GenerateRequest;
|
||||||
|
use crate::error::{AiError, AiResult};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct OllamaProvider {
|
||||||
|
client: Client,
|
||||||
|
base_url: String,
|
||||||
|
default_model: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaProvider {
|
||||||
|
pub fn new(base_url: String, default_model: String) -> Self {
|
||||||
|
Self {
|
||||||
|
client: Client::new(),
|
||||||
|
base_url,
|
||||||
|
default_model,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama /api/chat 请求格式
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct OllamaChatRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<OllamaMessage>,
|
||||||
|
stream: bool,
|
||||||
|
options: OllamaOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct OllamaMessage {
|
||||||
|
role: String,
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
struct OllamaOptions {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
num_predict: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama /api/chat 非流式响应
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OllamaChatResponse {
|
||||||
|
message: OllamaResponseMessage,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
model: String,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
done: bool,
|
||||||
|
eval_count: Option<u64>,
|
||||||
|
prompt_eval_count: Option<u64>,
|
||||||
|
total_duration: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OllamaResponseMessage {
|
||||||
|
content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ollama /api/chat 流式响应
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OllamaStreamChunk {
|
||||||
|
message: Option<OllamaStreamMessage>,
|
||||||
|
done: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
struct OllamaStreamMessage {
|
||||||
|
content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl AiProvider for OllamaProvider {
|
||||||
|
async fn stream_generate(
|
||||||
|
&self,
|
||||||
|
req: GenerateRequest,
|
||||||
|
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
|
||||||
|
let model = if req.model.is_empty() {
|
||||||
|
self.default_model.clone()
|
||||||
|
} else {
|
||||||
|
req.model
|
||||||
|
};
|
||||||
|
|
||||||
|
let ollama_req = OllamaChatRequest {
|
||||||
|
model,
|
||||||
|
messages: vec![
|
||||||
|
OllamaMessage {
|
||||||
|
role: "system".into(),
|
||||||
|
content: req.system_prompt,
|
||||||
|
},
|
||||||
|
OllamaMessage {
|
||||||
|
role: "user".into(),
|
||||||
|
content: req.user_prompt,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream: true,
|
||||||
|
options: OllamaOptions {
|
||||||
|
temperature: Some(req.temperature),
|
||||||
|
num_predict: Some(req.max_tokens),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/api/chat", self.base_url))
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json(&ollama_req)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AiError::ProviderError(format!("Ollama API 请求失败: {e}")))?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
return Err(AiError::ProviderError(format!(
|
||||||
|
"Ollama API 错误 {status}: {body}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let s = Box::pin(stream! {
|
||||||
|
let mut stream = response.bytes_stream();
|
||||||
|
while let Some(chunk_result) = stream.next().await {
|
||||||
|
let bytes = match chunk_result {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
yield Err(AiError::ProviderError(format!("流读取错误: {e}")));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = String::from_utf8_lossy(&bytes);
|
||||||
|
for line in text.lines() {
|
||||||
|
let line = line.trim();
|
||||||
|
if line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if let Ok(chunk) = serde_json::from_str::<OllamaStreamChunk>(line) {
|
||||||
|
if chunk.done {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if let Some(msg) = chunk.message {
|
||||||
|
if let Some(content) = msg.content {
|
||||||
|
yield Ok(content);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn generate(&self, req: GenerateRequest) -> AiResult<crate::dto::GenerateResponse> {
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
|
||||||
|
let model = if req.model.is_empty() {
|
||||||
|
self.default_model.clone()
|
||||||
|
} else {
|
||||||
|
req.model.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let ollama_req = OllamaChatRequest {
|
||||||
|
model: model.clone(),
|
||||||
|
messages: vec![
|
||||||
|
OllamaMessage {
|
||||||
|
role: "system".into(),
|
||||||
|
content: req.system_prompt,
|
||||||
|
},
|
||||||
|
OllamaMessage {
|
||||||
|
role: "user".into(),
|
||||||
|
content: req.user_prompt,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream: false,
|
||||||
|
options: OllamaOptions {
|
||||||
|
temperature: Some(req.temperature),
|
||||||
|
num_predict: Some(req.max_tokens),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/api/chat", self.base_url))
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.json(&ollama_req)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
|
let status = resp.status();
|
||||||
|
let body = resp
|
||||||
|
.text()
|
||||||
|
.await
|
||||||
|
.map_err(|e| AiError::ProviderError(e.to_string()))?;
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
return Err(AiError::ProviderError(format!(
|
||||||
|
"Ollama {status}: {body}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let parsed: OllamaChatResponse = serde_json::from_str(&body)
|
||||||
|
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
|
||||||
|
|
||||||
|
let duration_ms = parsed
|
||||||
|
.total_duration
|
||||||
|
.map(|ns| ns / 1_000_000)
|
||||||
|
.unwrap_or_else(|| start.elapsed().as_millis() as u64);
|
||||||
|
|
||||||
|
let input_tokens = parsed.prompt_eval_count.unwrap_or(0) as u32;
|
||||||
|
let output_tokens = parsed.eval_count.unwrap_or(0) as u32;
|
||||||
|
|
||||||
|
Ok(crate::dto::GenerateResponse {
|
||||||
|
content: parsed.message.content,
|
||||||
|
model,
|
||||||
|
input_tokens,
|
||||||
|
output_tokens,
|
||||||
|
duration_ms,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"ollama"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> AiResult<bool> {
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.get(format!("{}/api/tags", self.base_url))
|
||||||
|
.send()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match resp {
|
||||||
|
Ok(r) => Ok(r.status().is_success()),
|
||||||
|
Err(_) => Ok(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ollama_provider_construction() {
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"http://localhost:11434".into(),
|
||||||
|
"qwen2.5:7b".into(),
|
||||||
|
);
|
||||||
|
assert_eq!(provider.name(), "ollama");
|
||||||
|
assert_eq!(provider.default_model, "qwen2.5:7b");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ollama_chat_request_serialization() {
|
||||||
|
let req = OllamaChatRequest {
|
||||||
|
model: "qwen2.5:7b".into(),
|
||||||
|
messages: vec![
|
||||||
|
OllamaMessage {
|
||||||
|
role: "system".into(),
|
||||||
|
content: "你是助手".into(),
|
||||||
|
},
|
||||||
|
OllamaMessage {
|
||||||
|
role: "user".into(),
|
||||||
|
content: "你好".into(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream: false,
|
||||||
|
options: OllamaOptions {
|
||||||
|
temperature: Some(0.7),
|
||||||
|
num_predict: Some(1024),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let json = serde_json::to_value(&req).unwrap();
|
||||||
|
assert_eq!(json["model"], "qwen2.5:7b");
|
||||||
|
let temp = json["options"]["temperature"].as_f64().unwrap();
|
||||||
|
assert!((temp - 0.7).abs() < 0.01);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ollama_response_deserialization() {
|
||||||
|
let json = r#"{
|
||||||
|
"model": "qwen2.5:7b",
|
||||||
|
"created_at": "2024-01-01T00:00:00Z",
|
||||||
|
"message": {"role": "assistant", "content": "你好!"},
|
||||||
|
"done": true,
|
||||||
|
"eval_count": 5,
|
||||||
|
"prompt_eval_count": 10,
|
||||||
|
"total_duration": 1500000000
|
||||||
|
}"#;
|
||||||
|
let resp: OllamaChatResponse = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(resp.message.content, "你好!");
|
||||||
|
assert_eq!(resp.eval_count, Some(5));
|
||||||
|
assert_eq!(resp.prompt_eval_count, Some(10));
|
||||||
|
assert_eq!(resp.total_duration, Some(1_500_000_000));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ollama_stream_chunk_deserialization() {
|
||||||
|
let json = r#"{
|
||||||
|
"message": {"role": "assistant", "content": "Hello"},
|
||||||
|
"done": false
|
||||||
|
}"#;
|
||||||
|
let chunk: OllamaStreamChunk = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(!chunk.done);
|
||||||
|
assert_eq!(
|
||||||
|
chunk.message.unwrap().content,
|
||||||
|
Some("Hello".to_string())
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ollama_stream_done_chunk() {
|
||||||
|
let json = r#"{
|
||||||
|
"message": null,
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 2000000000,
|
||||||
|
"eval_count": 20
|
||||||
|
}"#;
|
||||||
|
let chunk: OllamaStreamChunk = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(chunk.done);
|
||||||
|
assert!(chunk.message.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn base_url_preserved() {
|
||||||
|
let provider = OllamaProvider::new(
|
||||||
|
"http://192.168.1.100:11434".into(),
|
||||||
|
"llama3.1:8b".into(),
|
||||||
|
);
|
||||||
|
assert_eq!(provider.base_url, "http://192.168.1.100:11434");
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user