97 lines
2.9 KiB
Rust
97 lines
2.9 KiB
Rust
//! Session management types
|
|
|
|
use serde::{Deserialize, Serialize};
|
|
use chrono::{DateTime, Utc};
|
|
use zclaw_types::{SessionId, AgentId, Message};
|
|
|
|
/// A conversation session
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Session {
|
|
pub id: SessionId,
|
|
pub agent_id: AgentId,
|
|
pub messages: Vec<Message>,
|
|
pub created_at: DateTime<Utc>,
|
|
pub updated_at: DateTime<Utc>,
|
|
/// Token count estimate
|
|
pub token_count: usize,
|
|
}
|
|
|
|
impl Session {
|
|
pub fn new(agent_id: AgentId) -> Self {
|
|
Self {
|
|
id: SessionId::new(),
|
|
agent_id,
|
|
messages: Vec::new(),
|
|
created_at: Utc::now(),
|
|
updated_at: Utc::now(),
|
|
token_count: 0,
|
|
}
|
|
}
|
|
|
|
/// Add a message to the session
|
|
pub fn add_message(&mut self, message: Message) {
|
|
// Simple token estimation: ~4 chars per token
|
|
let tokens = self.estimate_tokens(&message);
|
|
self.messages.push(message);
|
|
self.token_count += tokens;
|
|
self.updated_at = Utc::now();
|
|
}
|
|
|
|
/// Estimate token count for a message
|
|
fn estimate_tokens(&self, message: &Message) -> usize {
|
|
let text = match message {
|
|
Message::User { content } => content,
|
|
Message::Assistant { content, thinking } => {
|
|
thinking.as_ref().map(|t| t.as_str()).unwrap_or("");
|
|
content
|
|
}
|
|
Message::System { content } => content,
|
|
Message::ToolUse { input, .. } => {
|
|
return serde_json::to_string(input).map(|s| s.len() / 4).unwrap_or(0);
|
|
}
|
|
Message::ToolResult { output, .. } => {
|
|
return serde_json::to_string(output).map(|s| s.len() / 4).unwrap_or(0);
|
|
}
|
|
};
|
|
text.len() / 4
|
|
}
|
|
|
|
/// Check if session exceeds context window
|
|
pub fn exceeds_threshold(&self, max_tokens: usize, threshold: f32) -> bool {
|
|
let threshold_tokens = (max_tokens as f32 * threshold) as usize;
|
|
self.token_count > threshold_tokens
|
|
}
|
|
|
|
/// Compact the session by keeping only recent messages
|
|
pub fn compact(&mut self, keep_last: usize) {
|
|
if self.messages.len() <= keep_last {
|
|
return;
|
|
}
|
|
|
|
// Keep system messages and last N messages
|
|
let system_messages: Vec<_> = self.messages.iter()
|
|
.filter(|m| matches!(m, Message::System { .. }))
|
|
.cloned()
|
|
.collect();
|
|
|
|
let recent_messages: Vec<_> = self.messages.iter()
|
|
.rev()
|
|
.take(keep_last)
|
|
.cloned()
|
|
.collect::<Vec<_>>()
|
|
.into_iter()
|
|
.rev()
|
|
.collect();
|
|
|
|
self.messages = [system_messages, recent_messages].concat();
|
|
self.recalculate_token_count();
|
|
self.updated_at = Utc::now();
|
|
}
|
|
|
|
fn recalculate_token_count(&mut self) {
|
|
self.token_count = self.messages.iter()
|
|
.map(|m| self.estimate_tokens(m))
|
|
.sum();
|
|
}
|
|
}
|