feat(saas): 接通 embedding 模型管理全栈
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled

数据库 migration 已有 is_embedding/model_type 列但全栈未使用。
打通 4 层: ModelRow → ModelInfo/CRUD → CachedModel → Admin 前端。
relay/models 端点也返回 is_embedding 字段,前端可按类型过滤。
This commit is contained in:
iven
2026-04-12 08:10:50 +08:00
parent b0a304ca82
commit 5599cefc41
7 changed files with 45 additions and 18 deletions

View File

@@ -21,6 +21,8 @@ pub struct CachedModel {
pub supports_streaming: bool,
pub supports_vision: bool,
pub enabled: bool,
pub is_embedding: bool,
pub model_type: String,
pub pricing_input: f64,
pub pricing_output: f64,
}
@@ -111,15 +113,15 @@ impl AppCache {
self.providers.retain(|k, _| provider_keys.contains(k));
// Load models (key = model_id for relay lookup) — insert-then-retain
let model_rows: Vec<(String, String, String, String, i64, i64, bool, bool, bool, f64, f64)> = sqlx::query_as(
let model_rows: Vec<(String, String, String, String, i64, i64, bool, bool, bool, bool, String, f64, f64)> = sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens,
supports_streaming, supports_vision, enabled, pricing_input, pricing_output
supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output
FROM models"
).fetch_all(db).await?;
let model_keys: HashSet<String> = model_rows.iter().map(|(_, _, mid, ..)| mid.clone()).collect();
for (id, provider_id, model_id, alias, context_window, max_output_tokens,
supports_streaming, supports_vision, enabled, pricing_input, pricing_output) in &model_rows
supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output) in &model_rows
{
self.models.insert(model_id.clone(), CachedModel {
id: id.clone(),
@@ -131,6 +133,8 @@ impl AppCache {
supports_streaming: *supports_streaming,
supports_vision: *supports_vision,
enabled: *enabled,
is_embedding: *is_embedding,
model_type: model_type.clone(),
pricing_input: *pricing_input,
pricing_output: *pricing_output,
});

View File

@@ -162,13 +162,13 @@ pub async fn list_models(
let (count_sql, data_sql) = if provider_id.is_some() {
(
"SELECT COUNT(*) FROM models WHERE provider_id = $1",
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
FROM models WHERE provider_id = $1 ORDER BY alias LIMIT $2 OFFSET $3",
)
} else {
(
"SELECT COUNT(*) FROM models",
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
FROM models ORDER BY provider_id, alias LIMIT $1 OFFSET $2",
)
};
@@ -186,7 +186,7 @@ pub async fn list_models(
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
let items = rows.into_iter().map(|r| {
ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }
ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, is_embedding: r.is_embedding, model_type: r.model_type.clone(), pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }
}).collect();
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
@@ -225,15 +225,17 @@ pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<M
let max_out = req.max_output_tokens.unwrap_or(4096);
let streaming = req.supports_streaming.unwrap_or(true);
let vision = req.supports_vision.unwrap_or(false);
let is_embedding = req.is_embedding.unwrap_or(false);
let model_type = req.model_type.as_deref().unwrap_or(if is_embedding { "embedding" } else { "chat" });
let pi = req.pricing_input.unwrap_or(0.0);
let po = req.pricing_output.unwrap_or(0.0);
sqlx::query(
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)"
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $12, $13, $13)"
)
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(req.alias.as_deref().unwrap_or(&req.model_id))
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(is_embedding).bind(model_type).bind(pi).bind(po).bind(&now)
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型 '{}' 在 Provider '{}'", req.model_id, req.provider_id)))?;
get_model(db, &id).await
@@ -242,7 +244,7 @@ pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<M
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
let row: Option<ModelRow> =
sqlx::query_as(
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
FROM models WHERE id = $1"
)
.bind(model_id)
@@ -251,7 +253,7 @@ pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
let r = row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at })
Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, is_embedding: r.is_embedding, model_type: r.model_type.clone(), pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at })
}
pub async fn update_model(
@@ -269,10 +271,12 @@ pub async fn update_model(
supports_streaming = COALESCE($4, supports_streaming),
supports_vision = COALESCE($5, supports_vision),
enabled = COALESCE($6, enabled),
pricing_input = COALESCE($7, pricing_input),
pricing_output = COALESCE($8, pricing_output),
updated_at = $9
WHERE id = $10"
is_embedding = COALESCE($7, is_embedding),
model_type = COALESCE($8, model_type),
pricing_input = COALESCE($9, pricing_input),
pricing_output = COALESCE($10, pricing_output),
updated_at = $11
WHERE id = $12"
)
.bind(req.alias.as_deref())
.bind(req.context_window)
@@ -280,6 +284,8 @@ pub async fn update_model(
.bind(req.supports_streaming)
.bind(req.supports_vision)
.bind(req.enabled)
.bind(req.is_embedding)
.bind(req.model_type.as_deref())
.bind(req.pricing_input)
.bind(req.pricing_output)
.bind(&now)

View File

@@ -56,6 +56,8 @@ pub struct ModelInfo {
pub supports_streaming: bool,
pub supports_vision: bool,
pub enabled: bool,
pub is_embedding: bool,
pub model_type: String,
pub pricing_input: f64,
pub pricing_output: f64,
pub created_at: String,
@@ -71,6 +73,8 @@ pub struct CreateModelRequest {
pub max_output_tokens: Option<i64>,
pub supports_streaming: Option<bool>,
pub supports_vision: Option<bool>,
pub is_embedding: Option<bool>,
pub model_type: Option<String>,
pub pricing_input: Option<f64>,
pub pricing_output: Option<f64>,
}
@@ -83,6 +87,8 @@ pub struct UpdateModelRequest {
pub supports_streaming: Option<bool>,
pub supports_vision: Option<bool>,
pub enabled: Option<bool>,
pub is_embedding: Option<bool>,
pub model_type: Option<String>,
pub pricing_input: Option<f64>,
pub pricing_output: Option<f64>,
}

View File

@@ -14,6 +14,8 @@ pub struct ModelRow {
pub supports_streaming: bool,
pub supports_vision: bool,
pub enabled: bool,
pub is_embedding: bool,
pub model_type: String,
pub pricing_input: f64,
pub pricing_output: f64,
pub created_at: String,

View File

@@ -373,9 +373,10 @@ pub async fn list_available_models(
_ctx: Extension<AuthContext>,
) -> SaasResult<Json<Vec<serde_json::Value>>> {
// 单次 JOIN 查询替代 2 次全量加载
let rows: Vec<(String, String, String, i64, i64, bool, bool)> = sqlx::query_as(
let rows: Vec<(String, String, String, i64, i64, bool, bool, bool, String)> = sqlx::query_as(
"SELECT m.model_id, m.provider_id, m.alias, m.context_window,
m.max_output_tokens, m.supports_streaming, m.supports_vision
m.max_output_tokens, m.supports_streaming, m.supports_vision,
m.is_embedding, m.model_type
FROM models m
INNER JOIN providers p ON m.provider_id = p.id
WHERE m.enabled = true AND p.enabled = true
@@ -385,7 +386,7 @@ pub async fn list_available_models(
.await?;
let mut available: Vec<serde_json::Value> = rows.into_iter()
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision)| {
.map(|(model_id, provider_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, is_embedding, model_type)| {
serde_json::json!({
"id": model_id,
"provider_id": provider_id,
@@ -394,6 +395,8 @@ pub async fn list_available_models(
"max_output_tokens": max_output_tokens,
"supports_streaming": supports_streaming,
"supports_vision": supports_vision,
"is_embedding": is_embedding,
"model_type": model_type,
})
})
.collect();