Files
zclaw_openfang/crates/zclaw-growth/src/profile_updater.rs

110 lines
3.5 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! 用户画像增量更新器
//! 从 CombinedExtraction 的 profile_signals 更新 UserProfileStore
//! 不额外调用 LLM纯规则驱动
use crate::types::CombinedExtraction;
/// 用户画像更新器
/// 接收 CombinedExtraction 中的 profile_signals通过回调函数更新画像
pub struct UserProfileUpdater;
impl UserProfileUpdater {
pub fn new() -> Self {
Self
}
/// 从提取结果更新用户画像
/// profile_store 通过闭包注入,避免 zclaw-growth 依赖 zclaw-memory
pub async fn update<F>(
&self,
user_id: &str,
extraction: &CombinedExtraction,
update_fn: F,
) -> zclaw_types::Result<()>
where
F: Fn(&str, &str, &str) -> zclaw_types::Result<()> + Send + Sync,
{
let signals = &extraction.profile_signals;
if let Some(ref industry) = signals.industry {
update_fn(user_id, "industry", industry)?;
}
if let Some(ref style) = signals.communication_style {
update_fn(user_id, "communication_style", style)?;
}
// pain_point 和 preferred_tool 使用单独的方法(有去重和容量限制)
// 这些通过 GrowthIntegration 中的具体调用处理
Ok(())
}
}
impl Default for UserProfileUpdater {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[tokio::test]
async fn test_update_industry() {
let calls = Arc::new(Mutex::new(Vec::new()));
let calls_clone = calls.clone();
let update_fn = move |uid: &str, field: &str, val: &str| -> zclaw_types::Result<()> {
calls_clone
.lock()
.unwrap()
.push((uid.to_string(), field.to_string(), val.to_string()));
Ok(())
};
let mut extraction = CombinedExtraction::default();
extraction.profile_signals.industry = Some("healthcare".to_string());
let updater = UserProfileUpdater::new();
updater.update("user1", &extraction, update_fn).await.unwrap();
let locked = calls.lock().unwrap();
assert_eq!(locked.len(), 1);
assert_eq!(locked[0].1, "industry");
assert_eq!(locked[0].2, "healthcare");
}
#[tokio::test]
async fn test_update_no_signals() {
let update_fn =
|_: &str, _: &str, _: &str| -> zclaw_types::Result<()> { Ok(()) };
let extraction = CombinedExtraction::default();
let updater = UserProfileUpdater::new();
updater.update("user1", &extraction, update_fn).await.unwrap();
// No panic = pass
}
#[tokio::test]
async fn test_update_multiple_signals() {
let calls = Arc::new(Mutex::new(Vec::new()));
let calls_clone = calls.clone();
let update_fn = move |uid: &str, field: &str, val: &str| -> zclaw_types::Result<()> {
calls_clone
.lock()
.unwrap()
.push((uid.to_string(), field.to_string(), val.to_string()));
Ok(())
};
let mut extraction = CombinedExtraction::default();
extraction.profile_signals.industry = Some("ecommerce".to_string());
extraction.profile_signals.communication_style = Some("concise".to_string());
let updater = UserProfileUpdater::new();
updater.update("user1", &extraction, update_fn).await.unwrap();
let locked = calls.lock().unwrap();
assert_eq!(locked.len(), 2);
}
}