110 lines
3.5 KiB
Rust
110 lines
3.5 KiB
Rust
//! 用户画像增量更新器
|
||
//! 从 CombinedExtraction 的 profile_signals 更新 UserProfileStore
|
||
//! 不额外调用 LLM,纯规则驱动
|
||
|
||
use crate::types::CombinedExtraction;
|
||
|
||
/// 用户画像更新器
|
||
/// 接收 CombinedExtraction 中的 profile_signals,通过回调函数更新画像
|
||
pub struct UserProfileUpdater;
|
||
|
||
impl UserProfileUpdater {
|
||
pub fn new() -> Self {
|
||
Self
|
||
}
|
||
|
||
/// 从提取结果更新用户画像
|
||
/// profile_store 通过闭包注入,避免 zclaw-growth 依赖 zclaw-memory
|
||
pub async fn update<F>(
|
||
&self,
|
||
user_id: &str,
|
||
extraction: &CombinedExtraction,
|
||
update_fn: F,
|
||
) -> zclaw_types::Result<()>
|
||
where
|
||
F: Fn(&str, &str, &str) -> zclaw_types::Result<()> + Send + Sync,
|
||
{
|
||
let signals = &extraction.profile_signals;
|
||
|
||
if let Some(ref industry) = signals.industry {
|
||
update_fn(user_id, "industry", industry)?;
|
||
}
|
||
|
||
if let Some(ref style) = signals.communication_style {
|
||
update_fn(user_id, "communication_style", style)?;
|
||
}
|
||
|
||
// pain_point 和 preferred_tool 使用单独的方法(有去重和容量限制)
|
||
// 这些通过 GrowthIntegration 中的具体调用处理
|
||
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
impl Default for UserProfileUpdater {
|
||
fn default() -> Self {
|
||
Self::new()
|
||
}
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use std::sync::{Arc, Mutex};
|
||
|
||
#[tokio::test]
|
||
async fn test_update_industry() {
|
||
let calls = Arc::new(Mutex::new(Vec::new()));
|
||
let calls_clone = calls.clone();
|
||
let update_fn = move |uid: &str, field: &str, val: &str| -> zclaw_types::Result<()> {
|
||
calls_clone
|
||
.lock()
|
||
.unwrap()
|
||
.push((uid.to_string(), field.to_string(), val.to_string()));
|
||
Ok(())
|
||
};
|
||
let mut extraction = CombinedExtraction::default();
|
||
extraction.profile_signals.industry = Some("healthcare".to_string());
|
||
|
||
let updater = UserProfileUpdater::new();
|
||
updater.update("user1", &extraction, update_fn).await.unwrap();
|
||
|
||
let locked = calls.lock().unwrap();
|
||
assert_eq!(locked.len(), 1);
|
||
assert_eq!(locked[0].1, "industry");
|
||
assert_eq!(locked[0].2, "healthcare");
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_update_no_signals() {
|
||
let update_fn =
|
||
|_: &str, _: &str, _: &str| -> zclaw_types::Result<()> { Ok(()) };
|
||
let extraction = CombinedExtraction::default();
|
||
let updater = UserProfileUpdater::new();
|
||
updater.update("user1", &extraction, update_fn).await.unwrap();
|
||
// No panic = pass
|
||
}
|
||
|
||
#[tokio::test]
|
||
async fn test_update_multiple_signals() {
|
||
let calls = Arc::new(Mutex::new(Vec::new()));
|
||
let calls_clone = calls.clone();
|
||
let update_fn = move |uid: &str, field: &str, val: &str| -> zclaw_types::Result<()> {
|
||
calls_clone
|
||
.lock()
|
||
.unwrap()
|
||
.push((uid.to_string(), field.to_string(), val.to_string()));
|
||
Ok(())
|
||
};
|
||
let mut extraction = CombinedExtraction::default();
|
||
extraction.profile_signals.industry = Some("ecommerce".to_string());
|
||
extraction.profile_signals.communication_style = Some("concise".to_string());
|
||
|
||
let updater = UserProfileUpdater::new();
|
||
updater.update("user1", &extraction, update_fn).await.unwrap();
|
||
|
||
let locked = calls.lock().unwrap();
|
||
assert_eq!(locked.len(), 2);
|
||
}
|
||
}
|