Some checks failed
CI / Check / macos-latest (push) Has been cancelled
CI / Check / ubuntu-latest (push) Has been cancelled
CI / Check / windows-latest (push) Has been cancelled
CI / Test / macos-latest (push) Has been cancelled
CI / Test / ubuntu-latest (push) Has been cancelled
CI / Test / windows-latest (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Format (push) Has been cancelled
CI / Security Audit (push) Has been cancelled
CI / Secrets Scan (push) Has been cancelled
CI / Install Script Smoke Test (push) Has been cancelled
359 lines
10 KiB
Rust
359 lines
10 KiB
Rust
//! Embedding driver for vector-based semantic memory.
|
|
//!
|
|
//! Provides an `EmbeddingDriver` trait and an OpenAI-compatible implementation
|
|
//! that works with any provider offering a `/v1/embeddings` endpoint (OpenAI,
|
|
//! Groq, Together, Fireworks, Ollama, etc.).
|
|
|
|
use async_trait::async_trait;
|
|
use openfang_types::model_catalog::{
|
|
FIREWORKS_BASE_URL, GROQ_BASE_URL, LMSTUDIO_BASE_URL, MISTRAL_BASE_URL, OLLAMA_BASE_URL,
|
|
OPENAI_BASE_URL, TOGETHER_BASE_URL, VLLM_BASE_URL,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use tracing::{debug, warn};
|
|
use zeroize::Zeroizing;
|
|
|
|
/// Error type for embedding operations.
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum EmbeddingError {
|
|
#[error("HTTP error: {0}")]
|
|
Http(String),
|
|
#[error("API error (status {status}): {message}")]
|
|
Api { status: u16, message: String },
|
|
#[error("Parse error: {0}")]
|
|
Parse(String),
|
|
#[error("Missing API key: {0}")]
|
|
MissingApiKey(String),
|
|
}
|
|
|
|
/// Configuration for creating an embedding driver.
|
|
#[derive(Debug, Clone)]
|
|
pub struct EmbeddingConfig {
|
|
/// Provider name (openai, groq, together, ollama, etc.).
|
|
pub provider: String,
|
|
/// Model name (e.g., "text-embedding-3-small", "all-MiniLM-L6-v2").
|
|
pub model: String,
|
|
/// API key (resolved from env var).
|
|
pub api_key: String,
|
|
/// Base URL for the API.
|
|
pub base_url: String,
|
|
}
|
|
|
|
/// Trait for computing text embeddings.
|
|
#[async_trait]
|
|
pub trait EmbeddingDriver: Send + Sync {
|
|
/// Compute embedding vectors for a batch of texts.
|
|
async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
|
|
|
|
/// Compute embedding for a single text.
|
|
async fn embed_one(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
|
|
let results = self.embed(&[text]).await?;
|
|
results
|
|
.into_iter()
|
|
.next()
|
|
.ok_or_else(|| EmbeddingError::Parse("Empty embedding response".to_string()))
|
|
}
|
|
|
|
/// Return the dimensionality of embeddings produced by this driver.
|
|
fn dimensions(&self) -> usize;
|
|
}
|
|
|
|
/// OpenAI-compatible embedding driver.
|
|
///
|
|
/// Works with any provider that implements the `/v1/embeddings` endpoint:
|
|
/// OpenAI, Groq, Together, Fireworks, Ollama, vLLM, LM Studio, etc.
|
|
pub struct OpenAIEmbeddingDriver {
|
|
api_key: Zeroizing<String>,
|
|
base_url: String,
|
|
model: String,
|
|
client: reqwest::Client,
|
|
dims: usize,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct EmbedRequest<'a> {
|
|
model: &'a str,
|
|
input: &'a [&'a str],
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct EmbedResponse {
|
|
data: Vec<EmbedData>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct EmbedData {
|
|
embedding: Vec<f32>,
|
|
}
|
|
|
|
impl OpenAIEmbeddingDriver {
|
|
/// Create a new OpenAI-compatible embedding driver.
|
|
pub fn new(config: EmbeddingConfig) -> Result<Self, EmbeddingError> {
|
|
// Infer dimensions from model name (common models)
|
|
let dims = infer_dimensions(&config.model);
|
|
|
|
Ok(Self {
|
|
api_key: Zeroizing::new(config.api_key),
|
|
base_url: config.base_url,
|
|
model: config.model,
|
|
client: reqwest::Client::new(),
|
|
dims,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Infer embedding dimensions from model name.
|
|
fn infer_dimensions(model: &str) -> usize {
|
|
match model {
|
|
// OpenAI
|
|
"text-embedding-3-small" => 1536,
|
|
"text-embedding-3-large" => 3072,
|
|
"text-embedding-ada-002" => 1536,
|
|
// Sentence Transformers / local models
|
|
"all-MiniLM-L6-v2" => 384,
|
|
"all-MiniLM-L12-v2" => 384,
|
|
"all-mpnet-base-v2" => 768,
|
|
"nomic-embed-text" => 768,
|
|
"mxbai-embed-large" => 1024,
|
|
// Default to 1536 (most common)
|
|
_ => 1536,
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl EmbeddingDriver for OpenAIEmbeddingDriver {
|
|
async fn embed(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
|
|
if texts.is_empty() {
|
|
return Ok(vec![]);
|
|
}
|
|
|
|
let url = format!("{}/embeddings", self.base_url);
|
|
let body = EmbedRequest {
|
|
model: &self.model,
|
|
input: texts,
|
|
};
|
|
|
|
let mut req = self.client.post(&url).json(&body);
|
|
if !self.api_key.as_str().is_empty() {
|
|
req = req.header("Authorization", format!("Bearer {}", self.api_key.as_str()));
|
|
}
|
|
|
|
let resp = req
|
|
.send()
|
|
.await
|
|
.map_err(|e| EmbeddingError::Http(e.to_string()))?;
|
|
let status = resp.status().as_u16();
|
|
|
|
if status != 200 {
|
|
let body_text = resp.text().await.unwrap_or_default();
|
|
return Err(EmbeddingError::Api {
|
|
status,
|
|
message: body_text,
|
|
});
|
|
}
|
|
|
|
let data: EmbedResponse = resp
|
|
.json()
|
|
.await
|
|
.map_err(|e| EmbeddingError::Parse(e.to_string()))?;
|
|
|
|
// Update dimensions from actual response if available
|
|
let embeddings: Vec<Vec<f32>> = data.data.into_iter().map(|d| d.embedding).collect();
|
|
|
|
debug!(
|
|
"Embedded {} texts (dims={})",
|
|
embeddings.len(),
|
|
embeddings.first().map(|e| e.len()).unwrap_or(0)
|
|
);
|
|
|
|
Ok(embeddings)
|
|
}
|
|
|
|
fn dimensions(&self) -> usize {
|
|
self.dims
|
|
}
|
|
}
|
|
|
|
/// Create an embedding driver from kernel config.
|
|
pub fn create_embedding_driver(
|
|
provider: &str,
|
|
model: &str,
|
|
api_key_env: &str,
|
|
) -> Result<Box<dyn EmbeddingDriver + Send + Sync>, EmbeddingError> {
|
|
let api_key = if api_key_env.is_empty() {
|
|
String::new()
|
|
} else {
|
|
std::env::var(api_key_env).unwrap_or_default()
|
|
};
|
|
|
|
let base_url = match provider {
|
|
"openai" => OPENAI_BASE_URL.to_string(),
|
|
"groq" => GROQ_BASE_URL.to_string(),
|
|
"together" => TOGETHER_BASE_URL.to_string(),
|
|
"fireworks" => FIREWORKS_BASE_URL.to_string(),
|
|
"mistral" => MISTRAL_BASE_URL.to_string(),
|
|
"ollama" => OLLAMA_BASE_URL.to_string(),
|
|
"vllm" => VLLM_BASE_URL.to_string(),
|
|
"lmstudio" => LMSTUDIO_BASE_URL.to_string(),
|
|
other => {
|
|
warn!("Unknown embedding provider '{other}', using OpenAI-compatible format");
|
|
format!("https://{other}/v1")
|
|
}
|
|
};
|
|
|
|
// SECURITY: Warn when embedding requests will be sent to an external API
|
|
let is_local = base_url.contains("localhost")
|
|
|| base_url.contains("127.0.0.1")
|
|
|| base_url.contains("[::1]");
|
|
if !is_local {
|
|
warn!(
|
|
provider = %provider,
|
|
base_url = %base_url,
|
|
"Embedding driver configured to send data to external API — text content will leave this machine"
|
|
);
|
|
}
|
|
|
|
let config = EmbeddingConfig {
|
|
provider: provider.to_string(),
|
|
model: model.to_string(),
|
|
api_key,
|
|
base_url,
|
|
};
|
|
|
|
let driver = OpenAIEmbeddingDriver::new(config)?;
|
|
Ok(Box::new(driver))
|
|
}
|
|
|
|
/// Compute cosine similarity between two vectors.
|
|
///
|
|
/// Returns a value in [-1.0, 1.0] where 1.0 = identical direction.
|
|
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
|
if a.len() != b.len() || a.is_empty() {
|
|
return 0.0;
|
|
}
|
|
|
|
let mut dot = 0.0f32;
|
|
let mut norm_a = 0.0f32;
|
|
let mut norm_b = 0.0f32;
|
|
|
|
for i in 0..a.len() {
|
|
dot += a[i] * b[i];
|
|
norm_a += a[i] * a[i];
|
|
norm_b += b[i] * b[i];
|
|
}
|
|
|
|
let denom = norm_a.sqrt() * norm_b.sqrt();
|
|
if denom < f32::EPSILON {
|
|
0.0
|
|
} else {
|
|
dot / denom
|
|
}
|
|
}
|
|
|
|
/// Serialize an embedding vector to bytes (for SQLite BLOB storage).
|
|
pub fn embedding_to_bytes(embedding: &[f32]) -> Vec<u8> {
|
|
let mut bytes = Vec::with_capacity(embedding.len() * 4);
|
|
for &val in embedding {
|
|
bytes.extend_from_slice(&val.to_le_bytes());
|
|
}
|
|
bytes
|
|
}
|
|
|
|
/// Deserialize an embedding vector from bytes.
|
|
pub fn embedding_from_bytes(bytes: &[u8]) -> Vec<f32> {
|
|
bytes
|
|
.chunks_exact(4)
|
|
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
|
|
.collect()
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_cosine_similarity_identical() {
|
|
let a = vec![1.0, 0.0, 0.0];
|
|
let b = vec![1.0, 0.0, 0.0];
|
|
let sim = cosine_similarity(&a, &b);
|
|
assert!((sim - 1.0).abs() < 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_similarity_orthogonal() {
|
|
let a = vec![1.0, 0.0];
|
|
let b = vec![0.0, 1.0];
|
|
let sim = cosine_similarity(&a, &b);
|
|
assert!(sim.abs() < 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_similarity_opposite() {
|
|
let a = vec![1.0, 0.0];
|
|
let b = vec![-1.0, 0.0];
|
|
let sim = cosine_similarity(&a, &b);
|
|
assert!((sim + 1.0).abs() < 1e-6);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_similarity_real_vectors() {
|
|
let a = vec![0.1, 0.2, 0.3, 0.4];
|
|
let b = vec![0.1, 0.2, 0.3, 0.4];
|
|
let sim = cosine_similarity(&a, &b);
|
|
assert!((sim - 1.0).abs() < 1e-5);
|
|
|
|
let c = vec![0.4, 0.3, 0.2, 0.1];
|
|
let sim2 = cosine_similarity(&a, &c);
|
|
assert!(sim2 > 0.0 && sim2 < 1.0); // Similar but not identical
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_similarity_empty() {
|
|
let sim = cosine_similarity(&[], &[]);
|
|
assert_eq!(sim, 0.0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_cosine_similarity_length_mismatch() {
|
|
let a = vec![1.0, 2.0];
|
|
let b = vec![1.0, 2.0, 3.0];
|
|
let sim = cosine_similarity(&a, &b);
|
|
assert_eq!(sim, 0.0);
|
|
}
|
|
|
|
#[test]
|
|
fn test_embedding_roundtrip() {
|
|
let embedding = vec![0.1, -0.5, 1.23456, 0.0, -1e10, 1e10];
|
|
let bytes = embedding_to_bytes(&embedding);
|
|
let recovered = embedding_from_bytes(&bytes);
|
|
assert_eq!(embedding.len(), recovered.len());
|
|
for (a, b) in embedding.iter().zip(recovered.iter()) {
|
|
assert!((a - b).abs() < f32::EPSILON);
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn test_embedding_bytes_empty() {
|
|
let bytes = embedding_to_bytes(&[]);
|
|
assert!(bytes.is_empty());
|
|
let recovered = embedding_from_bytes(&bytes);
|
|
assert!(recovered.is_empty());
|
|
}
|
|
|
|
#[test]
|
|
fn test_infer_dimensions() {
|
|
assert_eq!(infer_dimensions("text-embedding-3-small"), 1536);
|
|
assert_eq!(infer_dimensions("all-MiniLM-L6-v2"), 384);
|
|
assert_eq!(infer_dimensions("nomic-embed-text"), 768);
|
|
assert_eq!(infer_dimensions("unknown-model"), 1536); // default
|
|
}
|
|
|
|
#[test]
|
|
fn test_create_embedding_driver_ollama() {
|
|
// Should succeed even without API key (ollama is local)
|
|
let driver = create_embedding_driver("ollama", "all-MiniLM-L6-v2", "");
|
|
assert!(driver.is_ok());
|
|
assert_eq!(driver.unwrap().dimensions(), 384);
|
|
}
|
|
}
|