//! LLM Client Module //! //! Provides LLM API integration for memory extraction. //! Supports multiple providers with a unified interface. //! //! Note: Some fields are reserved for future streaming and provider selection features #![allow(dead_code)] use serde::{Deserialize, Serialize}; use std::collections::HashMap; // === Types === #[derive(Debug, Clone)] pub struct LlmConfig { pub provider: String, pub api_key: String, pub endpoint: Option, pub model: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LlmMessage { pub role: String, pub content: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LlmRequest { pub messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, #[serde(skip_serializing_if = "Option::is_none")] pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LlmResponse { pub content: String, pub model: Option, pub usage: Option, pub finish_reason: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LlmUsage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } // === Provider Configuration === #[derive(Debug, Clone)] pub struct ProviderConfig { pub name: String, pub endpoint: String, pub default_model: String, pub supports_streaming: bool, } pub fn get_provider_configs() -> HashMap { let mut configs = HashMap::new(); configs.insert( "doubao".to_string(), ProviderConfig { name: "Doubao (火山引擎)".to_string(), endpoint: "https://ark.cn-beijing.volces.com/api/v3".to_string(), default_model: "doubao-pro-32k".to_string(), supports_streaming: true, }, ); configs.insert( "openai".to_string(), ProviderConfig { name: "OpenAI".to_string(), endpoint: "https://api.openai.com/v1".to_string(), default_model: "gpt-4o".to_string(), supports_streaming: true, }, ); configs.insert( "anthropic".to_string(), ProviderConfig { name: "Anthropic".to_string(), endpoint: "https://api.anthropic.com/v1".to_string(), default_model: "claude-sonnet-4-20250514".to_string(), supports_streaming: false, }, ); configs } // === LLM Client === pub struct LlmClient { config: LlmConfig, provider_config: Option, } impl LlmClient { pub fn new(config: LlmConfig) -> Self { let provider_config = get_provider_configs() .get(&config.provider) .cloned(); Self { config, provider_config, } } /// Complete a chat completion request pub async fn complete(&self, messages: Vec) -> Result { let endpoint = self.config.endpoint.clone() .or_else(|| { self.provider_config .as_ref() .map(|c| c.endpoint.clone()) }) .unwrap_or_else(|| "https://ark.cn-beijing.volces.com/api/v3".to_string()); let model = self.config.model.clone() .or_else(|| { self.provider_config .as_ref() .map(|c| c.default_model.clone()) }) .unwrap_or_else(|| "doubao-pro-32k".to_string()); let request = LlmRequest { messages, model: Some(model), temperature: Some(0.3), max_tokens: Some(2000), }; self.call_api(&endpoint, &request).await } /// Call LLM API async fn call_api(&self, endpoint: &str, request: &LlmRequest) -> Result { let client = reqwest::Client::new(); let response = client .post(format!("{}/chat/completions", endpoint)) .header("Authorization", format!("Bearer {}", self.config.api_key)) .header("Content-Type", "application/json") .json(&request) .send() .await .map_err(|e| format!("LLM API request failed: {}", e))?; if !response.status().is_success() { let status = response.status(); let body = response.text().await.unwrap_or_default(); return Err(format!("LLM API error {}: {}", status, body)); } let json: serde_json::Value = response .json() .await .map_err(|e| format!("Failed to parse LLM response: {}", e))?; // Parse response (OpenAI-compatible format) let content = json .get("choices") .and_then(|c| c.get(0)) .and_then(|c| c.get("message")) .and_then(|m| m.get("content")) .and_then(|c| c.as_str()) .ok_or("Invalid LLM response format")? .to_string(); let usage = json .get("usage") .map(|u| LlmUsage { prompt_tokens: u.get("prompt_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32, completion_tokens: u.get("completion_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32, total_tokens: u.get("total_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as u32, }); Ok(LlmResponse { content, model: self.config.model.clone(), usage, finish_reason: json .get("choices") .and_then(|c| c.get(0)) .and_then(|c| c.get("finish_reason")) .and_then(|v| v.as_str()) .map(String::from), }) } } // === Tauri Commands === #[tauri::command] pub async fn llm_complete( provider: String, api_key: String, messages: Vec, model: Option, ) -> Result { let config = LlmConfig { provider, api_key, endpoint: None, model, }; let client = LlmClient::new(config); client.complete(messages).await } #[cfg(test)] mod tests { use super::*; #[test] fn test_provider_configs() { let configs = get_provider_configs(); assert!(configs.contains_key("doubao")); assert!(configs.contains_key("openai")); assert!(configs.contains_key("anthropic")); } #[test] fn test_llm_client_creation() { let config = LlmConfig { provider: "doubao".to_string(), api_key: "test_key".to_string(), endpoint: None, model: None, }; let client = LlmClient::new(config); assert!(client.provider_config.is_some()); } }