feat(saas): 接通 embedding 模型管理全栈
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
Some checks failed
CI / Lint & TypeCheck (push) Has been cancelled
CI / Unit Tests (push) Has been cancelled
CI / Build Frontend (push) Has been cancelled
CI / Rust Check (push) Has been cancelled
CI / Security Scan (push) Has been cancelled
CI / E2E Tests (push) Has been cancelled
数据库 migration 已有 is_embedding/model_type 列但全栈未使用。 打通 4 层: ModelRow → ModelInfo/CRUD → CachedModel → Admin 前端。 relay/models 端点也返回 is_embedding 字段,前端可按类型过滤。
This commit is contained in:
@@ -162,13 +162,13 @@ pub async fn list_models(
|
||||
let (count_sql, data_sql) = if provider_id.is_some() {
|
||||
(
|
||||
"SELECT COUNT(*) FROM models WHERE provider_id = $1",
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
|
||||
FROM models WHERE provider_id = $1 ORDER BY alias LIMIT $2 OFFSET $3",
|
||||
)
|
||||
} else {
|
||||
(
|
||||
"SELECT COUNT(*) FROM models",
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
|
||||
FROM models ORDER BY provider_id, alias LIMIT $1 OFFSET $2",
|
||||
)
|
||||
};
|
||||
@@ -186,7 +186,7 @@ pub async fn list_models(
|
||||
let rows = query.bind(ps as i64).bind(offset).fetch_all(db).await?;
|
||||
|
||||
let items = rows.into_iter().map(|r| {
|
||||
ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }
|
||||
ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, is_embedding: r.is_embedding, model_type: r.model_type.clone(), pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at }
|
||||
}).collect();
|
||||
|
||||
Ok(PaginatedResponse { items, total: total.0, page: p, page_size: ps })
|
||||
@@ -225,15 +225,17 @@ pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<M
|
||||
let max_out = req.max_output_tokens.unwrap_or(4096);
|
||||
let streaming = req.supports_streaming.unwrap_or(true);
|
||||
let vision = req.supports_vision.unwrap_or(false);
|
||||
let is_embedding = req.is_embedding.unwrap_or(false);
|
||||
let model_type = req.model_type.as_deref().unwrap_or(if is_embedding { "embedding" } else { "chat" });
|
||||
let pi = req.pricing_input.unwrap_or(0.0);
|
||||
let po = req.pricing_output.unwrap_or(0.0);
|
||||
|
||||
sqlx::query(
|
||||
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $11)"
|
||||
"INSERT INTO models (id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, true, $9, $10, $11, $12, $13, $13)"
|
||||
)
|
||||
.bind(&id).bind(&req.provider_id).bind(&req.model_id).bind(req.alias.as_deref().unwrap_or(&req.model_id))
|
||||
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(pi).bind(po).bind(&now)
|
||||
.bind(ctx).bind(max_out).bind(streaming).bind(vision).bind(is_embedding).bind(model_type).bind(pi).bind(po).bind(&now)
|
||||
.execute(db).await.map_err(|e| SaasError::from_sqlx_unique(e, &format!("模型 '{}' 在 Provider '{}'", req.model_id, req.provider_id)))?;
|
||||
|
||||
get_model(db, &id).await
|
||||
@@ -242,7 +244,7 @@ pub async fn create_model(db: &PgPool, req: &CreateModelRequest) -> SaasResult<M
|
||||
pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
|
||||
let row: Option<ModelRow> =
|
||||
sqlx::query_as(
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
|
||||
"SELECT id, provider_id, model_id, alias, context_window, max_output_tokens, supports_streaming, supports_vision, enabled, is_embedding, model_type, pricing_input, pricing_output, created_at::TEXT, updated_at::TEXT
|
||||
FROM models WHERE id = $1"
|
||||
)
|
||||
.bind(model_id)
|
||||
@@ -251,7 +253,7 @@ pub async fn get_model(db: &PgPool, model_id: &str) -> SaasResult<ModelInfo> {
|
||||
|
||||
let r = row.ok_or_else(|| SaasError::NotFound(format!("模型 {} 不存在", model_id)))?;
|
||||
|
||||
Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at })
|
||||
Ok(ModelInfo { id: r.id, provider_id: r.provider_id, model_id: r.model_id, alias: r.alias, context_window: r.context_window, max_output_tokens: r.max_output_tokens, supports_streaming: r.supports_streaming, supports_vision: r.supports_vision, enabled: r.enabled, is_embedding: r.is_embedding, model_type: r.model_type.clone(), pricing_input: r.pricing_input, pricing_output: r.pricing_output, created_at: r.created_at, updated_at: r.updated_at })
|
||||
}
|
||||
|
||||
pub async fn update_model(
|
||||
@@ -269,10 +271,12 @@ pub async fn update_model(
|
||||
supports_streaming = COALESCE($4, supports_streaming),
|
||||
supports_vision = COALESCE($5, supports_vision),
|
||||
enabled = COALESCE($6, enabled),
|
||||
pricing_input = COALESCE($7, pricing_input),
|
||||
pricing_output = COALESCE($8, pricing_output),
|
||||
updated_at = $9
|
||||
WHERE id = $10"
|
||||
is_embedding = COALESCE($7, is_embedding),
|
||||
model_type = COALESCE($8, model_type),
|
||||
pricing_input = COALESCE($9, pricing_input),
|
||||
pricing_output = COALESCE($10, pricing_output),
|
||||
updated_at = $11
|
||||
WHERE id = $12"
|
||||
)
|
||||
.bind(req.alias.as_deref())
|
||||
.bind(req.context_window)
|
||||
@@ -280,6 +284,8 @@ pub async fn update_model(
|
||||
.bind(req.supports_streaming)
|
||||
.bind(req.supports_vision)
|
||||
.bind(req.enabled)
|
||||
.bind(req.is_embedding)
|
||||
.bind(req.model_type.as_deref())
|
||||
.bind(req.pricing_input)
|
||||
.bind(req.pricing_output)
|
||||
.bind(&now)
|
||||
|
||||
@@ -56,6 +56,8 @@ pub struct ModelInfo {
|
||||
pub supports_streaming: bool,
|
||||
pub supports_vision: bool,
|
||||
pub enabled: bool,
|
||||
pub is_embedding: bool,
|
||||
pub model_type: String,
|
||||
pub pricing_input: f64,
|
||||
pub pricing_output: f64,
|
||||
pub created_at: String,
|
||||
@@ -71,6 +73,8 @@ pub struct CreateModelRequest {
|
||||
pub max_output_tokens: Option<i64>,
|
||||
pub supports_streaming: Option<bool>,
|
||||
pub supports_vision: Option<bool>,
|
||||
pub is_embedding: Option<bool>,
|
||||
pub model_type: Option<String>,
|
||||
pub pricing_input: Option<f64>,
|
||||
pub pricing_output: Option<f64>,
|
||||
}
|
||||
@@ -83,6 +87,8 @@ pub struct UpdateModelRequest {
|
||||
pub supports_streaming: Option<bool>,
|
||||
pub supports_vision: Option<bool>,
|
||||
pub enabled: Option<bool>,
|
||||
pub is_embedding: Option<bool>,
|
||||
pub model_type: Option<String>,
|
||||
pub pricing_input: Option<f64>,
|
||||
pub pricing_output: Option<f64>,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user