Files
zclaw_openfang/crates/zclaw-memory/src/session.rs

97 lines
2.9 KiB
Rust

//! Session management types
use serde::{Deserialize, Serialize};
use chrono::{DateTime, Utc};
use zclaw_types::{SessionId, AgentId, Message};
/// A conversation session
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: SessionId,
pub agent_id: AgentId,
pub messages: Vec<Message>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
/// Token count estimate
pub token_count: usize,
}
impl Session {
pub fn new(agent_id: AgentId) -> Self {
Self {
id: SessionId::new(),
agent_id,
messages: Vec::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
token_count: 0,
}
}
/// Add a message to the session
pub fn add_message(&mut self, message: Message) {
// Simple token estimation: ~4 chars per token
let tokens = self.estimate_tokens(&message);
self.messages.push(message);
self.token_count += tokens;
self.updated_at = Utc::now();
}
/// Estimate token count for a message
fn estimate_tokens(&self, message: &Message) -> usize {
let text = match message {
Message::User { content } => content,
Message::Assistant { content, thinking } => {
thinking.as_ref().map(|t| t.as_str()).unwrap_or("");
content
}
Message::System { content } => content,
Message::ToolUse { input, .. } => {
return serde_json::to_string(input).map(|s| s.len() / 4).unwrap_or(0);
}
Message::ToolResult { output, .. } => {
return serde_json::to_string(output).map(|s| s.len() / 4).unwrap_or(0);
}
};
text.len() / 4
}
/// Check if session exceeds context window
pub fn exceeds_threshold(&self, max_tokens: usize, threshold: f32) -> bool {
let threshold_tokens = (max_tokens as f32 * threshold) as usize;
self.token_count > threshold_tokens
}
/// Compact the session by keeping only recent messages
pub fn compact(&mut self, keep_last: usize) {
if self.messages.len() <= keep_last {
return;
}
// Keep system messages and last N messages
let system_messages: Vec<_> = self.messages.iter()
.filter(|m| matches!(m, Message::System { .. }))
.cloned()
.collect();
let recent_messages: Vec<_> = self.messages.iter()
.rev()
.take(keep_last)
.cloned()
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
self.messages = [system_messages, recent_messages].concat();
self.recalculate_token_count();
self.updated_at = Utc::now();
}
fn recalculate_token_count(&mut self) {
self.token_count = self.messages.iter()
.map(|m| self.estimate_tokens(m))
.sum();
}
}