feat(ai): 实现 OllamaProvider 本地模型支持

使用 /api/chat 端点,无需 API Key,支持流式/非流式生成
健康检查通过 /api/tags,含 7 个单元测试
This commit is contained in:
iven
2026-05-05 15:10:43 +08:00
parent b728618d61
commit 37acd34154
2 changed files with 345 additions and 0 deletions

View File

@@ -1,4 +1,5 @@
pub mod claude;
pub mod ollama;
pub mod openai;
pub mod registry;

View File

@@ -0,0 +1,344 @@
use async_stream::stream;
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use super::AiProvider;
use crate::dto::GenerateRequest;
use crate::error::{AiError, AiResult};
#[derive(Debug, Clone)]
pub struct OllamaProvider {
client: Client,
base_url: String,
default_model: String,
}
impl OllamaProvider {
pub fn new(base_url: String, default_model: String) -> Self {
Self {
client: Client::new(),
base_url,
default_model,
}
}
}
// Ollama /api/chat 请求格式
#[derive(Serialize)]
struct OllamaChatRequest {
model: String,
messages: Vec<OllamaMessage>,
stream: bool,
options: OllamaOptions,
}
#[derive(Serialize)]
struct OllamaMessage {
role: String,
content: String,
}
#[derive(Serialize)]
struct OllamaOptions {
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
num_predict: Option<u32>,
}
// Ollama /api/chat 非流式响应
#[derive(Deserialize)]
struct OllamaChatResponse {
message: OllamaResponseMessage,
#[allow(dead_code)]
model: String,
#[allow(dead_code)]
done: bool,
eval_count: Option<u64>,
prompt_eval_count: Option<u64>,
total_duration: Option<u64>,
}
#[derive(Deserialize)]
struct OllamaResponseMessage {
content: String,
}
// Ollama /api/chat 流式响应
#[derive(Deserialize)]
struct OllamaStreamChunk {
message: Option<OllamaStreamMessage>,
done: bool,
}
#[derive(Deserialize)]
struct OllamaStreamMessage {
content: Option<String>,
}
#[async_trait]
impl AiProvider for OllamaProvider {
async fn stream_generate(
&self,
req: GenerateRequest,
) -> AiResult<Pin<Box<dyn Stream<Item = AiResult<String>> + Send>>> {
let model = if req.model.is_empty() {
self.default_model.clone()
} else {
req.model
};
let ollama_req = OllamaChatRequest {
model,
messages: vec![
OllamaMessage {
role: "system".into(),
content: req.system_prompt,
},
OllamaMessage {
role: "user".into(),
content: req.user_prompt,
},
],
stream: true,
options: OllamaOptions {
temperature: Some(req.temperature),
num_predict: Some(req.max_tokens),
},
};
let response = self
.client
.post(format!("{}/api/chat", self.base_url))
.header("content-type", "application/json")
.json(&ollama_req)
.send()
.await
.map_err(|e| AiError::ProviderError(format!("Ollama API 请求失败: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(AiError::ProviderError(format!(
"Ollama API 错误 {status}: {body}"
)));
}
let s = Box::pin(stream! {
let mut stream = response.bytes_stream();
while let Some(chunk_result) = stream.next().await {
let bytes = match chunk_result {
Ok(b) => b,
Err(e) => {
yield Err(AiError::ProviderError(format!("流读取错误: {e}")));
break;
}
};
let text = String::from_utf8_lossy(&bytes);
for line in text.lines() {
let line = line.trim();
if line.is_empty() {
continue;
}
if let Ok(chunk) = serde_json::from_str::<OllamaStreamChunk>(line) {
if chunk.done {
return;
}
if let Some(msg) = chunk.message {
if let Some(content) = msg.content {
yield Ok(content);
}
}
}
}
}
});
Ok(s)
}
async fn generate(&self, req: GenerateRequest) -> AiResult<crate::dto::GenerateResponse> {
let start = std::time::Instant::now();
let model = if req.model.is_empty() {
self.default_model.clone()
} else {
req.model.clone()
};
let ollama_req = OllamaChatRequest {
model: model.clone(),
messages: vec![
OllamaMessage {
role: "system".into(),
content: req.system_prompt,
},
OllamaMessage {
role: "user".into(),
content: req.user_prompt,
},
],
stream: false,
options: OllamaOptions {
temperature: Some(req.temperature),
num_predict: Some(req.max_tokens),
},
};
let resp = self
.client
.post(format!("{}/api/chat", self.base_url))
.header("content-type", "application/json")
.json(&ollama_req)
.send()
.await
.map_err(|e| AiError::ProviderError(e.to_string()))?;
let status = resp.status();
let body = resp
.text()
.await
.map_err(|e| AiError::ProviderError(e.to_string()))?;
if !status.is_success() {
return Err(AiError::ProviderError(format!(
"Ollama {status}: {body}"
)));
}
let parsed: OllamaChatResponse = serde_json::from_str(&body)
.map_err(|e| AiError::ProviderError(format!("解析响应失败: {e}")))?;
let duration_ms = parsed
.total_duration
.map(|ns| ns / 1_000_000)
.unwrap_or_else(|| start.elapsed().as_millis() as u64);
let input_tokens = parsed.prompt_eval_count.unwrap_or(0) as u32;
let output_tokens = parsed.eval_count.unwrap_or(0) as u32;
Ok(crate::dto::GenerateResponse {
content: parsed.message.content,
model,
input_tokens,
output_tokens,
duration_ms,
})
}
fn name(&self) -> &str {
"ollama"
}
async fn health_check(&self) -> AiResult<bool> {
let resp = self
.client
.get(format!("{}/api/tags", self.base_url))
.send()
.await;
match resp {
Ok(r) => Ok(r.status().is_success()),
Err(_) => Ok(false),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ollama_provider_construction() {
let provider = OllamaProvider::new(
"http://localhost:11434".into(),
"qwen2.5:7b".into(),
);
assert_eq!(provider.name(), "ollama");
assert_eq!(provider.default_model, "qwen2.5:7b");
}
#[test]
fn ollama_chat_request_serialization() {
let req = OllamaChatRequest {
model: "qwen2.5:7b".into(),
messages: vec![
OllamaMessage {
role: "system".into(),
content: "你是助手".into(),
},
OllamaMessage {
role: "user".into(),
content: "你好".into(),
},
],
stream: false,
options: OllamaOptions {
temperature: Some(0.7),
num_predict: Some(1024),
},
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["model"], "qwen2.5:7b");
let temp = json["options"]["temperature"].as_f64().unwrap();
assert!((temp - 0.7).abs() < 0.01);
}
#[test]
fn ollama_response_deserialization() {
let json = r#"{
"model": "qwen2.5:7b",
"created_at": "2024-01-01T00:00:00Z",
"message": {"role": "assistant", "content": "你好!"},
"done": true,
"eval_count": 5,
"prompt_eval_count": 10,
"total_duration": 1500000000
}"#;
let resp: OllamaChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.message.content, "你好!");
assert_eq!(resp.eval_count, Some(5));
assert_eq!(resp.prompt_eval_count, Some(10));
assert_eq!(resp.total_duration, Some(1_500_000_000));
}
#[test]
fn ollama_stream_chunk_deserialization() {
let json = r#"{
"message": {"role": "assistant", "content": "Hello"},
"done": false
}"#;
let chunk: OllamaStreamChunk = serde_json::from_str(json).unwrap();
assert!(!chunk.done);
assert_eq!(
chunk.message.unwrap().content,
Some("Hello".to_string())
);
}
#[test]
fn ollama_stream_done_chunk() {
let json = r#"{
"message": null,
"done": true,
"total_duration": 2000000000,
"eval_count": 20
}"#;
let chunk: OllamaStreamChunk = serde_json::from_str(json).unwrap();
assert!(chunk.done);
assert!(chunk.message.is_none());
}
#[test]
fn base_url_preserved() {
let provider = OllamaProvider::new(
"http://192.168.1.100:11434".into(),
"llama3.1:8b".into(),
);
assert_eq!(provider.base_url, "http://192.168.1.100:11434");
}
}