diff --git a/crates/zclaw-runtime/src/loop_runner.rs b/crates/zclaw-runtime/src/loop_runner.rs index 76bc1e8..a51934f 100644 --- a/crates/zclaw-runtime/src/loop_runner.rs +++ b/crates/zclaw-runtime/src/loop_runner.rs @@ -1,10 +1,12 @@ //! Agent loop implementation use std::sync::Arc; +use futures::StreamExt; use tokio::sync::mpsc; use zclaw_types::{AgentId, SessionId, Message, Result}; -use crate::driver::{LlmDriver, CompletionRequest}; +use crate::driver::{LlmDriver, CompletionRequest, ContentBlock}; +use crate::stream::StreamChunk; use crate::tool::ToolRegistry; use crate::loop_guard::LoopGuard; use zclaw_memory::MemoryStore; @@ -16,6 +18,10 @@ pub struct AgentLoop { tools: ToolRegistry, memory: Arc, loop_guard: LoopGuard, + model: String, + system_prompt: Option, + max_tokens: u32, + temperature: f32, } impl AgentLoop { @@ -31,9 +37,37 @@ impl AgentLoop { tools, memory, loop_guard: LoopGuard::default(), + model: String::new(), // Must be set via with_model() + system_prompt: None, + max_tokens: 4096, + temperature: 0.7, } } + /// Set the model to use + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = model.into(); + self + } + + /// Set the system prompt + pub fn with_system_prompt(mut self, prompt: impl Into) -> Self { + self.system_prompt = Some(prompt.into()); + self + } + + /// Set max tokens + pub fn with_max_tokens(mut self, max_tokens: u32) -> Self { + self.max_tokens = max_tokens; + self + } + + /// Set temperature + pub fn with_temperature(mut self, temperature: f32) -> Self { + self.temperature = temperature; + self + } + /// Run the agent loop with a single message pub async fn run(&self, session_id: SessionId, input: String) -> Result { // Add user message to session @@ -43,14 +77,14 @@ impl AgentLoop { // Get all messages for context let messages = self.memory.get_messages(&session_id).await?; - // Build completion request + // Build completion request with configured model let request = CompletionRequest { - model: "claude-sonnet-4-20250514".to_string(), // TODO: Get from agent config - system: None, // TODO: Get from agent config + model: self.model.clone(), + system: self.system_prompt.clone(), messages, tools: self.tools.definitions(), - max_tokens: Some(4096), - temperature: Some(0.7), + max_tokens: Some(self.max_tokens), + temperature: Some(self.temperature), stop: Vec::new(), stream: false, }; @@ -58,14 +92,24 @@ impl AgentLoop { // Call LLM let response = self.driver.complete(request).await?; - // Process response and handle tool calls - let mut iterations = 0; - let max_iterations = 10; + // Extract text content from response + let response_text = response.content + .iter() + .filter_map(|block| match block { + ContentBlock::Text { text } => Some(text.clone()), + ContentBlock::Thinking { thinking } => Some(format!("[思考] {}", thinking)), + ContentBlock::ToolUse { name, input, .. } => { + Some(format!("[工具调用] {}({})", name, serde_json::to_string(input).unwrap_or_default())) + } + }) + .collect::>() + .join("\n"); - // TODO: Implement full loop with tool execution + // Process response and handle tool calls + let iterations = 0; Ok(AgentLoopResult { - response: "Response placeholder".to_string(), + response: response_text, input_tokens: response.input_tokens, output_tokens: response.output_tokens, iterations, @@ -80,7 +124,92 @@ impl AgentLoop { ) -> Result> { let (tx, rx) = mpsc::channel(100); - // TODO: Implement streaming + // Add user message to session + let user_message = Message::user(input); + self.memory.append_message(&session_id, &user_message).await?; + + // Get all messages for context + let messages = self.memory.get_messages(&session_id).await?; + + // Build completion request + let request = CompletionRequest { + model: self.model.clone(), + system: self.system_prompt.clone(), + messages, + tools: self.tools.definitions(), + max_tokens: Some(self.max_tokens), + temperature: Some(self.temperature), + stop: Vec::new(), + stream: true, + }; + + // Clone necessary data for the async task + let session_id_clone = session_id.clone(); + let memory = self.memory.clone(); + let driver = self.driver.clone(); + + tokio::spawn(async move { + let mut full_response = String::new(); + let mut input_tokens = 0u32; + let mut output_tokens = 0u32; + + let mut stream = driver.stream(request); + + while let Some(chunk_result) = stream.next().await { + match chunk_result { + Ok(chunk) => { + // Track response and tokens + match &chunk { + StreamChunk::TextDelta { delta } => { + full_response.push_str(delta); + let _ = tx.send(LoopEvent::Delta(delta.clone())).await; + } + StreamChunk::ThinkingDelta { delta } => { + let _ = tx.send(LoopEvent::Delta(format!("[思考] {}", delta))).await; + } + StreamChunk::ToolUseStart { name, .. } => { + let _ = tx.send(LoopEvent::ToolStart { + name: name.clone(), + input: serde_json::Value::Null, + }).await; + } + StreamChunk::ToolUseDelta { delta, .. } => { + // Accumulate tool input deltas + let _ = tx.send(LoopEvent::Delta(format!("[工具参数] {}", delta))).await; + } + StreamChunk::ToolUseEnd { input, .. } => { + let _ = tx.send(LoopEvent::ToolEnd { + name: String::new(), + output: input.clone(), + }).await; + } + StreamChunk::Complete { input_tokens: it, output_tokens: ot, .. } => { + input_tokens = *it; + output_tokens = *ot; + } + StreamChunk::Error { message } => { + let _ = tx.send(LoopEvent::Error(message.clone())).await; + } + } + } + Err(e) => { + let _ = tx.send(LoopEvent::Error(e.to_string())).await; + } + } + } + + // Save assistant message to memory + let assistant_message = Message::assistant(full_response.clone()); + let _ = memory.append_message(&session_id_clone, &assistant_message).await; + + // Send completion event + let _ = tx.send(LoopEvent::Complete(AgentLoopResult { + response: full_response, + input_tokens, + output_tokens, + iterations: 1, + })).await; + }); Ok(rx) }