初始化提交
Some checks failed
CI / Check / macos-latest (push) Has been cancelled
CI / Check / ubuntu-latest (push) Has been cancelled
CI / Check / windows-latest (push) Has been cancelled
CI / Test / macos-latest (push) Has been cancelled
CI / Test / ubuntu-latest (push) Has been cancelled
CI / Test / windows-latest (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Format (push) Has been cancelled
CI / Security Audit (push) Has been cancelled
CI / Secrets Scan (push) Has been cancelled
CI / Install Script Smoke Test (push) Has been cancelled
Some checks failed
CI / Check / macos-latest (push) Has been cancelled
CI / Check / ubuntu-latest (push) Has been cancelled
CI / Check / windows-latest (push) Has been cancelled
CI / Test / macos-latest (push) Has been cancelled
CI / Test / ubuntu-latest (push) Has been cancelled
CI / Test / windows-latest (push) Has been cancelled
CI / Clippy (push) Has been cancelled
CI / Format (push) Has been cancelled
CI / Security Audit (push) Has been cancelled
CI / Secrets Scan (push) Has been cancelled
CI / Install Script Smoke Test (push) Has been cancelled
This commit is contained in:
41
crates/openfang-api/Cargo.toml
Normal file
41
crates/openfang-api/Cargo.toml
Normal file
@@ -0,0 +1,41 @@
|
||||
[package]
|
||||
name = "openfang-api"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "HTTP/WebSocket API server for the OpenFang Agent OS daemon"
|
||||
|
||||
[dependencies]
|
||||
openfang-types = { path = "../openfang-types" }
|
||||
openfang-kernel = { path = "../openfang-kernel" }
|
||||
openfang-runtime = { path = "../openfang-runtime" }
|
||||
openfang-memory = { path = "../openfang-memory" }
|
||||
openfang-channels = { path = "../openfang-channels" }
|
||||
openfang-wire = { path = "../openfang-wire" }
|
||||
openfang-skills = { path = "../openfang-skills" }
|
||||
openfang-hands = { path = "../openfang-hands" }
|
||||
openfang-extensions = { path = "../openfang-extensions" }
|
||||
openfang-migrate = { path = "../openfang-migrate" }
|
||||
dashmap = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
governor = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
subtle = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
1691
crates/openfang-api/src/channel_bridge.rs
Normal file
1691
crates/openfang-api/src/channel_bridge.rs
Normal file
File diff suppressed because it is too large
Load Diff
16
crates/openfang-api/src/lib.rs
Normal file
16
crates/openfang-api/src/lib.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
//! HTTP/WebSocket API server for the OpenFang Agent OS daemon.
|
||||
//!
|
||||
//! Exposes agent management, status, and chat via JSON REST endpoints.
|
||||
//! The kernel runs in-process; the CLI connects over HTTP.
|
||||
|
||||
pub mod channel_bridge;
|
||||
pub mod middleware;
|
||||
pub mod openai_compat;
|
||||
pub mod rate_limiter;
|
||||
pub mod routes;
|
||||
pub mod server;
|
||||
pub mod stream_chunker;
|
||||
pub mod stream_dedup;
|
||||
pub mod types;
|
||||
pub mod webchat;
|
||||
pub mod ws;
|
||||
206
crates/openfang-api/src/middleware.rs
Normal file
206
crates/openfang-api/src/middleware.rs
Normal file
@@ -0,0 +1,206 @@
|
||||
//! Production middleware for the OpenFang API server.
|
||||
//!
|
||||
//! Provides:
|
||||
//! - Request ID generation and propagation
|
||||
//! - Per-endpoint structured request logging
|
||||
//! - In-memory rate limiting (per IP)
|
||||
|
||||
use axum::body::Body;
|
||||
use axum::http::{Request, Response, StatusCode};
|
||||
use axum::middleware::Next;
|
||||
use std::time::Instant;
|
||||
use tracing::info;
|
||||
|
||||
/// Request ID header name (standard).
|
||||
pub const REQUEST_ID_HEADER: &str = "x-request-id";
|
||||
|
||||
/// Middleware: inject a unique request ID and log the request/response.
|
||||
pub async fn request_logging(request: Request<Body>, next: Next) -> Response<Body> {
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
let method = request.method().clone();
|
||||
let uri = request.uri().path().to_string();
|
||||
let start = Instant::now();
|
||||
|
||||
let mut response = next.run(request).await;
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
let status = response.status().as_u16();
|
||||
|
||||
info!(
|
||||
request_id = %request_id,
|
||||
method = %method,
|
||||
path = %uri,
|
||||
status = status,
|
||||
latency_ms = elapsed.as_millis() as u64,
|
||||
"API request"
|
||||
);
|
||||
|
||||
// Inject the request ID into the response
|
||||
if let Ok(header_val) = request_id.parse() {
|
||||
response.headers_mut().insert(REQUEST_ID_HEADER, header_val);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
/// Bearer token authentication middleware.
|
||||
///
|
||||
/// When `api_key` is non-empty, all requests must include
|
||||
/// `Authorization: Bearer <api_key>`. If the key is empty, auth is bypassed.
|
||||
pub async fn auth(
|
||||
axum::extract::State(api_key): axum::extract::State<String>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response<Body> {
|
||||
// If no API key configured, restrict to loopback addresses only.
|
||||
if api_key.is_empty() {
|
||||
let is_loopback = request
|
||||
.extensions()
|
||||
.get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
|
||||
.map(|ci| ci.0.ip().is_loopback())
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_loopback {
|
||||
tracing::warn!(
|
||||
"Rejected non-localhost request: no API key configured. \
|
||||
Set api_key in config.toml for remote access."
|
||||
);
|
||||
return Response::builder()
|
||||
.status(StatusCode::FORBIDDEN)
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::json!({
|
||||
"error": "No API key configured. Remote access denied. Configure api_key in ~/.openfang/config.toml"
|
||||
})
|
||||
.to_string(),
|
||||
))
|
||||
.unwrap_or_default();
|
||||
}
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
// Public endpoints that don't require auth (dashboard needs these)
|
||||
let path = request.uri().path();
|
||||
if path == "/"
|
||||
|| path == "/api/health"
|
||||
|| path == "/api/health/detail"
|
||||
|| path == "/api/status"
|
||||
|| path == "/api/version"
|
||||
|| path == "/api/agents"
|
||||
|| path == "/api/profiles"
|
||||
|| path == "/api/config"
|
||||
|| path.starts_with("/api/uploads/")
|
||||
// Dashboard read endpoints — allow unauthenticated so the SPA can
|
||||
// render before the user enters their API key.
|
||||
|| path == "/api/models"
|
||||
|| path == "/api/models/aliases"
|
||||
|| path == "/api/providers"
|
||||
|| path == "/api/budget"
|
||||
|| path == "/api/budget/agents"
|
||||
|| path.starts_with("/api/budget/agents/")
|
||||
|| path == "/api/network/status"
|
||||
|| path == "/api/a2a/agents"
|
||||
|| path == "/api/approvals"
|
||||
|| path.starts_with("/api/approvals/")
|
||||
|| path == "/api/channels"
|
||||
|| path == "/api/skills"
|
||||
|| path == "/api/sessions"
|
||||
|| path == "/api/integrations"
|
||||
|| path == "/api/integrations/available"
|
||||
|| path == "/api/integrations/health"
|
||||
|| path == "/api/workflows"
|
||||
|| path == "/api/logs/stream"
|
||||
|| path.starts_with("/api/cron/")
|
||||
|| path.starts_with("/api/providers/github-copilot/oauth/")
|
||||
{
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
// Check Authorization: Bearer <token> header
|
||||
let bearer_token = request
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
// SECURITY: Use constant-time comparison to prevent timing attacks.
|
||||
let header_auth = bearer_token.map(|token| {
|
||||
use subtle::ConstantTimeEq;
|
||||
if token.len() != api_key.len() {
|
||||
return false;
|
||||
}
|
||||
token.as_bytes().ct_eq(api_key.as_bytes()).into()
|
||||
});
|
||||
|
||||
// Also check ?token= query parameter (for EventSource/SSE clients that
|
||||
// cannot set custom headers, same approach as WebSocket auth).
|
||||
let query_token = request
|
||||
.uri()
|
||||
.query()
|
||||
.and_then(|q| q.split('&').find_map(|pair| pair.strip_prefix("token=")));
|
||||
|
||||
// SECURITY: Use constant-time comparison to prevent timing attacks.
|
||||
let query_auth = query_token.map(|token| {
|
||||
use subtle::ConstantTimeEq;
|
||||
if token.len() != api_key.len() {
|
||||
return false;
|
||||
}
|
||||
token.as_bytes().ct_eq(api_key.as_bytes()).into()
|
||||
});
|
||||
|
||||
// Accept if either auth method matches
|
||||
if header_auth == Some(true) || query_auth == Some(true) {
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
// Determine error message: was a credential provided but wrong, or missing entirely?
|
||||
let credential_provided = header_auth.is_some() || query_auth.is_some();
|
||||
let error_msg = if credential_provided {
|
||||
"Invalid API key"
|
||||
} else {
|
||||
"Missing Authorization: Bearer <api_key> header"
|
||||
};
|
||||
|
||||
Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.header("www-authenticate", "Bearer")
|
||||
.body(Body::from(
|
||||
serde_json::json!({"error": error_msg}).to_string(),
|
||||
))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Security headers middleware — applied to ALL API responses.
|
||||
pub async fn security_headers(request: Request<Body>, next: Next) -> Response<Body> {
|
||||
let mut response = next.run(request).await;
|
||||
let headers = response.headers_mut();
|
||||
headers.insert("x-content-type-options", "nosniff".parse().unwrap());
|
||||
headers.insert("x-frame-options", "DENY".parse().unwrap());
|
||||
headers.insert("x-xss-protection", "1; mode=block".parse().unwrap());
|
||||
// All JS/CSS is bundled inline — only external resource is Google Fonts.
|
||||
headers.insert(
|
||||
"content-security-policy",
|
||||
"default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval'; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com https://fonts.gstatic.com; img-src 'self' data: blob:; connect-src 'self' ws://localhost:* ws://127.0.0.1:* wss://localhost:* wss://127.0.0.1:*; font-src 'self' https://fonts.gstatic.com; media-src 'self' blob:; frame-src 'self' blob:; object-src 'none'; base-uri 'self'; form-action 'self'"
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"referrer-policy",
|
||||
"strict-origin-when-cross-origin".parse().unwrap(),
|
||||
);
|
||||
headers.insert(
|
||||
"cache-control",
|
||||
"no-store, no-cache, must-revalidate".parse().unwrap(),
|
||||
);
|
||||
response
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_request_id_header_constant() {
|
||||
assert_eq!(REQUEST_ID_HEADER, "x-request-id");
|
||||
}
|
||||
}
|
||||
773
crates/openfang-api/src/openai_compat.rs
Normal file
773
crates/openfang-api/src/openai_compat.rs
Normal file
@@ -0,0 +1,773 @@
|
||||
//! OpenAI-compatible `/v1/chat/completions` API endpoint.
|
||||
//!
|
||||
//! Allows any OpenAI-compatible client library to talk to OpenFang agents.
|
||||
//! The `model` field resolves to an agent (by name, UUID, or `openfang:<name>`),
|
||||
//! and the messages are forwarded to the agent's LLM loop.
|
||||
//!
|
||||
//! Supports both streaming (SSE) and non-streaming responses.
|
||||
|
||||
use crate::routes::AppState;
|
||||
use axum::extract::State;
|
||||
use axum::http::StatusCode;
|
||||
use axum::response::sse::{Event as SseEvent, KeepAlive, Sse};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::Json;
|
||||
use openfang_runtime::kernel_handle::KernelHandle;
|
||||
use openfang_runtime::llm_driver::StreamEvent;
|
||||
use openfang_types::agent::AgentId;
|
||||
use openfang_types::message::{ContentBlock, Message, MessageContent, Role, StopReason};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use tracing::warn;
|
||||
|
||||
// ── Request types ──────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ChatCompletionRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<OaiMessage>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub temperature: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OaiMessage {
|
||||
pub role: String,
|
||||
#[serde(default)]
|
||||
pub content: OaiContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Default)]
|
||||
#[serde(untagged)]
|
||||
pub enum OaiContent {
|
||||
Text(String),
|
||||
Parts(Vec<OaiContentPart>),
|
||||
#[default]
|
||||
Null,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum OaiContentPart {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl { image_url: OaiImageUrlRef },
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OaiImageUrlRef {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
// ── Response types ──────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionResponse {
|
||||
id: String,
|
||||
object: &'static str,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<Choice>,
|
||||
usage: UsageInfo,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Choice {
|
||||
index: u32,
|
||||
message: ChoiceMessage,
|
||||
finish_reason: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChoiceMessage {
|
||||
role: &'static str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<OaiToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct UsageInfo {
|
||||
prompt_tokens: u64,
|
||||
completion_tokens: u64,
|
||||
total_tokens: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChatCompletionChunk {
|
||||
id: String,
|
||||
object: &'static str,
|
||||
created: u64,
|
||||
model: String,
|
||||
choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChunkChoice {
|
||||
index: u32,
|
||||
delta: ChunkDelta,
|
||||
finish_reason: Option<&'static str>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ChunkDelta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
role: Option<&'static str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<OaiToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
struct OaiToolCall {
|
||||
index: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "type")]
|
||||
call_type: Option<&'static str>,
|
||||
function: OaiToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
struct OaiToolCallFunction {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ModelObject {
|
||||
id: String,
|
||||
object: &'static str,
|
||||
created: u64,
|
||||
owned_by: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ModelListResponse {
|
||||
object: &'static str,
|
||||
data: Vec<ModelObject>,
|
||||
}
|
||||
|
||||
// ── Agent resolution ────────────────────────────────────────────────────────
|
||||
|
||||
fn resolve_agent(state: &AppState, model: &str) -> Option<(AgentId, String)> {
|
||||
// 1. "openfang:<name>" → find agent by name
|
||||
if let Some(name) = model.strip_prefix("openfang:") {
|
||||
if let Some(entry) = state.kernel.registry.find_by_name(name) {
|
||||
return Some((entry.id, entry.name.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Valid UUID → find agent by ID
|
||||
if let Ok(id) = model.parse::<AgentId>() {
|
||||
if let Some(entry) = state.kernel.registry.get(id) {
|
||||
return Some((entry.id, entry.name.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Plain string → try as agent name
|
||||
if let Some(entry) = state.kernel.registry.find_by_name(model) {
|
||||
return Some((entry.id, entry.name.clone()));
|
||||
}
|
||||
|
||||
// 4. Fallback → first registered agent
|
||||
let agents = state.kernel.registry.list();
|
||||
agents.first().map(|e| (e.id, e.name.clone()))
|
||||
}
|
||||
|
||||
// ── Message conversion ──────────────────────────────────────────────────────
|
||||
|
||||
fn convert_messages(oai_messages: &[OaiMessage]) -> Vec<Message> {
|
||||
oai_messages
|
||||
.iter()
|
||||
.filter_map(|m| {
|
||||
let role = match m.role.as_str() {
|
||||
"user" => Role::User,
|
||||
"assistant" => Role::Assistant,
|
||||
"system" => Role::System,
|
||||
_ => Role::User,
|
||||
};
|
||||
|
||||
let content = match &m.content {
|
||||
OaiContent::Text(text) => MessageContent::Text(text.clone()),
|
||||
OaiContent::Parts(parts) => {
|
||||
let blocks: Vec<ContentBlock> = parts
|
||||
.iter()
|
||||
.filter_map(|part| match part {
|
||||
OaiContentPart::Text { text } => {
|
||||
Some(ContentBlock::Text { text: text.clone() })
|
||||
}
|
||||
OaiContentPart::ImageUrl { image_url } => {
|
||||
// Parse data URI: data:{media_type};base64,{data}
|
||||
if let Some(rest) = image_url.url.strip_prefix("data:") {
|
||||
let parts: Vec<&str> = rest.splitn(2, ',').collect();
|
||||
if parts.len() == 2 {
|
||||
let media_type = parts[0]
|
||||
.strip_suffix(";base64")
|
||||
.unwrap_or(parts[0])
|
||||
.to_string();
|
||||
let data = parts[1].to_string();
|
||||
Some(ContentBlock::Image { media_type, data })
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
// URL-based images not supported (would require fetching)
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
if blocks.is_empty() {
|
||||
return None;
|
||||
}
|
||||
MessageContent::Blocks(blocks)
|
||||
}
|
||||
OaiContent::Null => return None,
|
||||
};
|
||||
|
||||
Some(Message { role, content })
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// ── Handlers ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// POST /v1/chat/completions
|
||||
pub async fn chat_completions(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(req): Json<ChatCompletionRequest>,
|
||||
) -> impl IntoResponse {
|
||||
let (agent_id, agent_name) = match resolve_agent(&state, &req.model) {
|
||||
Some(pair) => pair,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("No agent found for model '{}'", req.model),
|
||||
"type": "invalid_request_error",
|
||||
"code": "model_not_found"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
// Extract the last user message as the input
|
||||
let messages = convert_messages(&req.messages);
|
||||
let last_user_msg = messages
|
||||
.iter()
|
||||
.rev()
|
||||
.find(|m| m.role == Role::User)
|
||||
.map(|m| m.content.text_content())
|
||||
.unwrap_or_default();
|
||||
|
||||
if last_user_msg.is_empty() {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": "No user message found in request",
|
||||
"type": "invalid_request_error",
|
||||
"code": "missing_message"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let request_id = format!("chatcmpl-{}", uuid::Uuid::new_v4());
|
||||
let created = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
if req.stream {
|
||||
// Streaming response
|
||||
return match stream_response(
|
||||
state,
|
||||
agent_id,
|
||||
agent_name,
|
||||
&last_user_msg,
|
||||
request_id,
|
||||
created,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(sse) => sse.into_response(),
|
||||
Err(e) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": format!("{e}"),
|
||||
"type": "server_error"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
};
|
||||
}
|
||||
|
||||
// Non-streaming response
|
||||
let kernel_handle: Arc<dyn KernelHandle> = state.kernel.clone() as Arc<dyn KernelHandle>;
|
||||
match state
|
||||
.kernel
|
||||
.send_message_with_handle(agent_id, &last_user_msg, Some(kernel_handle))
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
let response = ChatCompletionResponse {
|
||||
id: request_id,
|
||||
object: "chat.completion",
|
||||
created,
|
||||
model: agent_name,
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: ChoiceMessage {
|
||||
role: "assistant",
|
||||
content: Some(result.response),
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: "stop",
|
||||
}],
|
||||
usage: UsageInfo {
|
||||
prompt_tokens: result.total_usage.input_tokens,
|
||||
completion_tokens: result.total_usage.output_tokens,
|
||||
total_tokens: result.total_usage.input_tokens
|
||||
+ result.total_usage.output_tokens,
|
||||
},
|
||||
};
|
||||
Json(serde_json::to_value(&response).unwrap_or_default()).into_response()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("OpenAI compat: agent error: {e}");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(serde_json::json!({
|
||||
"error": {
|
||||
"message": "Agent processing failed",
|
||||
"type": "server_error"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an SSE stream response for streaming completions.
|
||||
async fn stream_response(
|
||||
state: Arc<AppState>,
|
||||
agent_id: AgentId,
|
||||
agent_name: String,
|
||||
message: &str,
|
||||
request_id: String,
|
||||
created: u64,
|
||||
) -> Result<axum::response::Response, String> {
|
||||
let kernel_handle: Arc<dyn KernelHandle> = state.kernel.clone() as Arc<dyn KernelHandle>;
|
||||
|
||||
let (mut rx, _handle) = state
|
||||
.kernel
|
||||
.send_message_streaming(agent_id, message, Some(kernel_handle))
|
||||
.map_err(|e| format!("Streaming setup failed: {e}"))?;
|
||||
|
||||
let (tx, stream_rx) = tokio::sync::mpsc::channel::<Result<SseEvent, Infallible>>(64);
|
||||
|
||||
// Send initial role delta
|
||||
let first_chunk = ChatCompletionChunk {
|
||||
id: request_id.clone(),
|
||||
object: "chat.completion.chunk",
|
||||
created,
|
||||
model: agent_name.clone(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta: ChunkDelta {
|
||||
role: Some("assistant"),
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
let _ = tx
|
||||
.send(Ok(SseEvent::default().data(
|
||||
serde_json::to_string(&first_chunk).unwrap_or_default(),
|
||||
)))
|
||||
.await;
|
||||
|
||||
// Helper to build a chunk with a delta and optional finish_reason.
|
||||
fn make_chunk(
|
||||
id: &str,
|
||||
created: u64,
|
||||
model: &str,
|
||||
delta: ChunkDelta,
|
||||
finish_reason: Option<&'static str>,
|
||||
) -> String {
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: id.to_string(),
|
||||
object: "chat.completion.chunk",
|
||||
created,
|
||||
model: model.to_string(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta,
|
||||
finish_reason,
|
||||
}],
|
||||
};
|
||||
serde_json::to_string(&chunk).unwrap_or_default()
|
||||
}
|
||||
|
||||
// Spawn forwarder task — streams ALL iterations until the agent loop channel closes.
|
||||
let req_id = request_id.clone();
|
||||
tokio::spawn(async move {
|
||||
// Tracks current tool_call index within each LLM iteration.
|
||||
let mut tool_index: u32 = 0;
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
let json = match event {
|
||||
StreamEvent::TextDelta { text } => make_chunk(
|
||||
&req_id,
|
||||
created,
|
||||
&agent_name,
|
||||
ChunkDelta {
|
||||
role: None,
|
||||
content: Some(text),
|
||||
tool_calls: None,
|
||||
},
|
||||
None,
|
||||
),
|
||||
StreamEvent::ToolUseStart { id, name } => {
|
||||
let idx = tool_index;
|
||||
tool_index += 1;
|
||||
make_chunk(
|
||||
&req_id,
|
||||
created,
|
||||
&agent_name,
|
||||
ChunkDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
tool_calls: Some(vec![OaiToolCall {
|
||||
index: idx,
|
||||
id: Some(id),
|
||||
call_type: Some("function"),
|
||||
function: OaiToolCallFunction {
|
||||
name: Some(name),
|
||||
arguments: Some(String::new()),
|
||||
},
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
)
|
||||
}
|
||||
StreamEvent::ToolInputDelta { text } => {
|
||||
// tool_index already incremented past current tool, so current = index - 1
|
||||
let idx = tool_index.saturating_sub(1);
|
||||
make_chunk(
|
||||
&req_id,
|
||||
created,
|
||||
&agent_name,
|
||||
ChunkDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
tool_calls: Some(vec![OaiToolCall {
|
||||
index: idx,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: OaiToolCallFunction {
|
||||
name: None,
|
||||
arguments: Some(text),
|
||||
},
|
||||
}]),
|
||||
},
|
||||
None,
|
||||
)
|
||||
}
|
||||
StreamEvent::ContentComplete { stop_reason, .. } => {
|
||||
// ToolUse → reset tool index for next iteration, do NOT finish.
|
||||
// EndTurn/MaxTokens/StopSequence → continue, wait for channel close.
|
||||
if matches!(stop_reason, StopReason::ToolUse) {
|
||||
tool_index = 0;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// ToolUseEnd, ToolExecutionResult, ThinkingDelta, PhaseChange — skip
|
||||
_ => continue,
|
||||
};
|
||||
if tx.send(Ok(SseEvent::default().data(json))).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Channel closed — agent loop is fully done. Send finish + [DONE].
|
||||
let final_json = make_chunk(
|
||||
&req_id,
|
||||
created,
|
||||
&agent_name,
|
||||
ChunkDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
},
|
||||
Some("stop"),
|
||||
);
|
||||
let _ = tx.send(Ok(SseEvent::default().data(final_json))).await;
|
||||
let _ = tx.send(Ok(SseEvent::default().data("[DONE]"))).await;
|
||||
});
|
||||
|
||||
let stream = tokio_stream::wrappers::ReceiverStream::new(stream_rx);
|
||||
Ok(Sse::new(stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response())
|
||||
}
|
||||
|
||||
/// GET /v1/models — List available agents as OpenAI model objects.
|
||||
pub async fn list_models(State(state): State<Arc<AppState>>) -> impl IntoResponse {
|
||||
let agents = state.kernel.registry.list();
|
||||
let created = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
|
||||
let models: Vec<ModelObject> = agents
|
||||
.iter()
|
||||
.map(|e| ModelObject {
|
||||
id: format!("openfang:{}", e.name),
|
||||
object: "model",
|
||||
created,
|
||||
owned_by: "openfang".to_string(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
Json(
|
||||
serde_json::to_value(&ModelListResponse {
|
||||
object: "list",
|
||||
data: models,
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_oai_content_deserialize_string() {
|
||||
let json = r#"{"role":"user","content":"hello"}"#;
|
||||
let msg: OaiMessage = serde_json::from_str(json).unwrap();
|
||||
assert!(matches!(msg.content, OaiContent::Text(ref t) if t == "hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_oai_content_deserialize_parts() {
|
||||
let json = r#"{"role":"user","content":[{"type":"text","text":"what is this?"},{"type":"image_url","image_url":{"url":"data:image/png;base64,abc123"}}]}"#;
|
||||
let msg: OaiMessage = serde_json::from_str(json).unwrap();
|
||||
assert!(matches!(msg.content, OaiContent::Parts(ref p) if p.len() == 2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_messages_text() {
|
||||
let oai = vec![
|
||||
OaiMessage {
|
||||
role: "system".to_string(),
|
||||
content: OaiContent::Text("You are helpful.".to_string()),
|
||||
},
|
||||
OaiMessage {
|
||||
role: "user".to_string(),
|
||||
content: OaiContent::Text("Hello!".to_string()),
|
||||
},
|
||||
];
|
||||
let msgs = convert_messages(&oai);
|
||||
assert_eq!(msgs.len(), 2);
|
||||
assert_eq!(msgs[0].role, Role::System);
|
||||
assert_eq!(msgs[1].role, Role::User);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_messages_with_image() {
|
||||
let oai = vec![OaiMessage {
|
||||
role: "user".to_string(),
|
||||
content: OaiContent::Parts(vec![
|
||||
OaiContentPart::Text {
|
||||
text: "What is this?".to_string(),
|
||||
},
|
||||
OaiContentPart::ImageUrl {
|
||||
image_url: OaiImageUrlRef {
|
||||
url: "data:image/png;base64,iVBORw0KGgo=".to_string(),
|
||||
},
|
||||
},
|
||||
]),
|
||||
}];
|
||||
let msgs = convert_messages(&oai);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
match &msgs[0].content {
|
||||
MessageContent::Blocks(blocks) => {
|
||||
assert_eq!(blocks.len(), 2);
|
||||
assert!(matches!(&blocks[0], ContentBlock::Text { .. }));
|
||||
assert!(matches!(&blocks[1], ContentBlock::Image { .. }));
|
||||
}
|
||||
_ => panic!("Expected Blocks"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_response_serialization() {
|
||||
let resp = ChatCompletionResponse {
|
||||
id: "chatcmpl-test".to_string(),
|
||||
object: "chat.completion",
|
||||
created: 1234567890,
|
||||
model: "test-agent".to_string(),
|
||||
choices: vec![Choice {
|
||||
index: 0,
|
||||
message: ChoiceMessage {
|
||||
role: "assistant",
|
||||
content: Some("Hello!".to_string()),
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: "stop",
|
||||
}],
|
||||
usage: UsageInfo {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
};
|
||||
let json = serde_json::to_value(&resp).unwrap();
|
||||
assert_eq!(json["object"], "chat.completion");
|
||||
assert_eq!(json["choices"][0]["message"]["content"], "Hello!");
|
||||
assert_eq!(json["usage"]["total_tokens"], 15);
|
||||
// tool_calls should be omitted when None
|
||||
assert!(json["choices"][0]["message"].get("tool_calls").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_serialization() {
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: "chatcmpl-test".to_string(),
|
||||
object: "chat.completion.chunk",
|
||||
created: 1234567890,
|
||||
model: "test-agent".to_string(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta: ChunkDelta {
|
||||
role: None,
|
||||
content: Some("Hello".to_string()),
|
||||
tool_calls: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
let json = serde_json::to_value(&chunk).unwrap();
|
||||
assert_eq!(json["object"], "chat.completion.chunk");
|
||||
assert_eq!(json["choices"][0]["delta"]["content"], "Hello");
|
||||
assert!(json["choices"][0]["delta"]["role"].is_null());
|
||||
// tool_calls should be omitted when None
|
||||
assert!(json["choices"][0]["delta"].get("tool_calls").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_call_serialization() {
|
||||
let tc = OaiToolCall {
|
||||
index: 0,
|
||||
id: Some("call_abc123".to_string()),
|
||||
call_type: Some("function"),
|
||||
function: OaiToolCallFunction {
|
||||
name: Some("get_weather".to_string()),
|
||||
arguments: Some(r#"{"location":"NYC"}"#.to_string()),
|
||||
},
|
||||
};
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert_eq!(json["index"], 0);
|
||||
assert_eq!(json["id"], "call_abc123");
|
||||
assert_eq!(json["type"], "function");
|
||||
assert_eq!(json["function"]["name"], "get_weather");
|
||||
assert_eq!(json["function"]["arguments"], r#"{"location":"NYC"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunk_delta_with_tool_calls() {
|
||||
let chunk = ChatCompletionChunk {
|
||||
id: "chatcmpl-test".to_string(),
|
||||
object: "chat.completion.chunk",
|
||||
created: 1234567890,
|
||||
model: "test-agent".to_string(),
|
||||
choices: vec![ChunkChoice {
|
||||
index: 0,
|
||||
delta: ChunkDelta {
|
||||
role: None,
|
||||
content: None,
|
||||
tool_calls: Some(vec![OaiToolCall {
|
||||
index: 0,
|
||||
id: Some("call_1".to_string()),
|
||||
call_type: Some("function"),
|
||||
function: OaiToolCallFunction {
|
||||
name: Some("search".to_string()),
|
||||
arguments: Some(String::new()),
|
||||
},
|
||||
}]),
|
||||
},
|
||||
finish_reason: None,
|
||||
}],
|
||||
};
|
||||
let json = serde_json::to_value(&chunk).unwrap();
|
||||
let tc = &json["choices"][0]["delta"]["tool_calls"][0];
|
||||
assert_eq!(tc["index"], 0);
|
||||
assert_eq!(tc["id"], "call_1");
|
||||
assert_eq!(tc["type"], "function");
|
||||
assert_eq!(tc["function"]["name"], "search");
|
||||
// content should be omitted
|
||||
assert!(json["choices"][0]["delta"].get("content").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_input_delta_chunk() {
|
||||
// Incremental arguments chunk — no id, no type, no name
|
||||
let tc = OaiToolCall {
|
||||
index: 2,
|
||||
id: None,
|
||||
call_type: None,
|
||||
function: OaiToolCallFunction {
|
||||
name: None,
|
||||
arguments: Some(r#"{"q":"rust"}"#.to_string()),
|
||||
},
|
||||
};
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert_eq!(json["index"], 2);
|
||||
// id and type should be omitted
|
||||
assert!(json.get("id").is_none());
|
||||
assert!(json.get("type").is_none());
|
||||
assert!(json["function"].get("name").is_none());
|
||||
assert_eq!(json["function"]["arguments"], r#"{"q":"rust"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backward_compat_no_tool_calls() {
|
||||
// When tool_calls is None, it should not appear in JSON at all (backward compat)
|
||||
let msg = ChoiceMessage {
|
||||
role: "assistant",
|
||||
content: Some("Hello".to_string()),
|
||||
tool_calls: None,
|
||||
};
|
||||
let json_str = serde_json::to_string(&msg).unwrap();
|
||||
assert!(!json_str.contains("tool_calls"));
|
||||
|
||||
let delta = ChunkDelta {
|
||||
role: Some("assistant"),
|
||||
content: Some("Hi".to_string()),
|
||||
tool_calls: None,
|
||||
};
|
||||
let json_str = serde_json::to_string(&delta).unwrap();
|
||||
assert!(!json_str.contains("tool_calls"));
|
||||
}
|
||||
}
|
||||
99
crates/openfang-api/src/rate_limiter.rs
Normal file
99
crates/openfang-api/src/rate_limiter.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
//! Cost-aware rate limiting using GCRA (Generic Cell Rate Algorithm).
|
||||
//!
|
||||
//! Each API operation has a token cost (e.g., health=1, spawn=50, message=30).
|
||||
//! The GCRA algorithm allows 500 tokens per minute per IP address.
|
||||
|
||||
use axum::body::Body;
|
||||
use axum::http::{Request, Response, StatusCode};
|
||||
use axum::middleware::Next;
|
||||
use governor::{clock::DefaultClock, state::keyed::DashMapStateStore, Quota, RateLimiter};
|
||||
use std::net::{IpAddr, SocketAddr};
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub fn operation_cost(method: &str, path: &str) -> NonZeroU32 {
|
||||
match (method, path) {
|
||||
(_, "/api/health") => NonZeroU32::new(1).unwrap(),
|
||||
("GET", "/api/status") => NonZeroU32::new(1).unwrap(),
|
||||
("GET", "/api/version") => NonZeroU32::new(1).unwrap(),
|
||||
("GET", "/api/tools") => NonZeroU32::new(1).unwrap(),
|
||||
("GET", "/api/agents") => NonZeroU32::new(2).unwrap(),
|
||||
("GET", "/api/skills") => NonZeroU32::new(2).unwrap(),
|
||||
("GET", "/api/peers") => NonZeroU32::new(2).unwrap(),
|
||||
("GET", "/api/config") => NonZeroU32::new(2).unwrap(),
|
||||
("GET", "/api/usage") => NonZeroU32::new(3).unwrap(),
|
||||
("GET", p) if p.starts_with("/api/audit") => NonZeroU32::new(5).unwrap(),
|
||||
("GET", p) if p.starts_with("/api/marketplace") => NonZeroU32::new(10).unwrap(),
|
||||
("POST", "/api/agents") => NonZeroU32::new(50).unwrap(),
|
||||
("POST", p) if p.contains("/message") => NonZeroU32::new(30).unwrap(),
|
||||
("POST", p) if p.contains("/run") => NonZeroU32::new(100).unwrap(),
|
||||
("POST", "/api/skills/install") => NonZeroU32::new(50).unwrap(),
|
||||
("POST", "/api/skills/uninstall") => NonZeroU32::new(10).unwrap(),
|
||||
("POST", "/api/migrate") => NonZeroU32::new(100).unwrap(),
|
||||
("PUT", p) if p.contains("/update") => NonZeroU32::new(10).unwrap(),
|
||||
_ => NonZeroU32::new(5).unwrap(),
|
||||
}
|
||||
}
|
||||
|
||||
pub type KeyedRateLimiter = RateLimiter<IpAddr, DashMapStateStore<IpAddr>, DefaultClock>;
|
||||
|
||||
/// 500 tokens per minute per IP.
|
||||
pub fn create_rate_limiter() -> Arc<KeyedRateLimiter> {
|
||||
Arc::new(RateLimiter::keyed(Quota::per_minute(
|
||||
NonZeroU32::new(500).unwrap(),
|
||||
)))
|
||||
}
|
||||
|
||||
/// GCRA rate limiting middleware.
|
||||
///
|
||||
/// Extracts the client IP from `ConnectInfo`, computes the cost for the
|
||||
/// requested operation, and checks the GCRA limiter. Returns 429 if the
|
||||
/// client has exhausted its token budget.
|
||||
pub async fn gcra_rate_limit(
|
||||
axum::extract::State(limiter): axum::extract::State<Arc<KeyedRateLimiter>>,
|
||||
request: Request<Body>,
|
||||
next: Next,
|
||||
) -> Response<Body> {
|
||||
let ip = request
|
||||
.extensions()
|
||||
.get::<axum::extract::ConnectInfo<SocketAddr>>()
|
||||
.map(|ci| ci.0.ip())
|
||||
.unwrap_or(IpAddr::from([127, 0, 0, 1]));
|
||||
|
||||
let method = request.method().as_str().to_string();
|
||||
let path = request.uri().path().to_string();
|
||||
let cost = operation_cost(&method, &path);
|
||||
|
||||
if limiter.check_key_n(&ip, cost).is_err() {
|
||||
tracing::warn!(ip = %ip, cost = cost.get(), path = %path, "GCRA rate limit exceeded");
|
||||
return Response::builder()
|
||||
.status(StatusCode::TOO_MANY_REQUESTS)
|
||||
.header("content-type", "application/json")
|
||||
.header("retry-after", "60")
|
||||
.body(Body::from(
|
||||
serde_json::json!({"error": "Rate limit exceeded"}).to_string(),
|
||||
))
|
||||
.unwrap_or_default();
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_costs() {
|
||||
assert_eq!(operation_cost("GET", "/api/health").get(), 1);
|
||||
assert_eq!(operation_cost("GET", "/api/tools").get(), 1);
|
||||
assert_eq!(operation_cost("POST", "/api/agents/1/message").get(), 30);
|
||||
assert_eq!(operation_cost("POST", "/api/agents").get(), 50);
|
||||
assert_eq!(operation_cost("POST", "/api/workflows/1/run").get(), 100);
|
||||
assert_eq!(operation_cost("GET", "/api/agents/1/session").get(), 5);
|
||||
assert_eq!(operation_cost("GET", "/api/skills").get(), 2);
|
||||
assert_eq!(operation_cost("GET", "/api/peers").get(), 2);
|
||||
assert_eq!(operation_cost("GET", "/api/audit/recent").get(), 5);
|
||||
assert_eq!(operation_cost("POST", "/api/skills/install").get(), 50);
|
||||
assert_eq!(operation_cost("POST", "/api/migrate").get(), 100);
|
||||
}
|
||||
}
|
||||
8979
crates/openfang-api/src/routes.rs
Normal file
8979
crates/openfang-api/src/routes.rs
Normal file
File diff suppressed because it is too large
Load Diff
849
crates/openfang-api/src/server.rs
Normal file
849
crates/openfang-api/src/server.rs
Normal file
@@ -0,0 +1,849 @@
|
||||
//! OpenFang daemon server — boots the kernel and serves the HTTP API.
|
||||
|
||||
use crate::channel_bridge;
|
||||
use crate::middleware;
|
||||
use crate::rate_limiter;
|
||||
use crate::routes::{self, AppState};
|
||||
use crate::webchat;
|
||||
use crate::ws;
|
||||
use axum::Router;
|
||||
use openfang_kernel::OpenFangKernel;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tower_http::compression::CompressionLayer;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::info;
|
||||
|
||||
/// Daemon info written to `~/.openfang/daemon.json` so the CLI can find us.
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
pub struct DaemonInfo {
|
||||
pub pid: u32,
|
||||
pub listen_addr: String,
|
||||
pub started_at: String,
|
||||
pub version: String,
|
||||
pub platform: String,
|
||||
}
|
||||
|
||||
/// Build the full API router with all routes, middleware, and state.
|
||||
///
|
||||
/// This is extracted from `run_daemon()` so that embedders (e.g. openfang-desktop)
|
||||
/// can create the router without starting the full daemon lifecycle.
|
||||
///
|
||||
/// Returns `(router, shared_state)`. The caller can use `state.bridge_manager`
|
||||
/// to shut down the bridge on exit.
|
||||
pub async fn build_router(
|
||||
kernel: Arc<OpenFangKernel>,
|
||||
listen_addr: SocketAddr,
|
||||
) -> (Router<()>, Arc<AppState>) {
|
||||
// Start channel bridges (Telegram, etc.)
|
||||
let bridge = channel_bridge::start_channel_bridge(kernel.clone()).await;
|
||||
|
||||
let channels_config = kernel.config.channels.clone();
|
||||
let state = Arc::new(AppState {
|
||||
kernel: kernel.clone(),
|
||||
started_at: Instant::now(),
|
||||
peer_registry: kernel.peer_registry.as_ref().map(|r| Arc::new(r.clone())),
|
||||
bridge_manager: tokio::sync::Mutex::new(bridge),
|
||||
channels_config: tokio::sync::RwLock::new(channels_config),
|
||||
shutdown_notify: Arc::new(tokio::sync::Notify::new()),
|
||||
});
|
||||
|
||||
// CORS: allow localhost origins by default. If API key is set, the API
|
||||
// is protected anyway. For development, permissive CORS is convenient.
|
||||
let cors = if state.kernel.config.api_key.is_empty() {
|
||||
// No auth → restrict CORS to localhost origins (include both 127.0.0.1 and localhost)
|
||||
let port = listen_addr.port();
|
||||
let mut origins: Vec<axum::http::HeaderValue> = vec![
|
||||
format!("http://{listen_addr}").parse().unwrap(),
|
||||
format!("http://localhost:{port}").parse().unwrap(),
|
||||
];
|
||||
// Also allow common dev ports
|
||||
for p in [3000u16, 8080] {
|
||||
if p != port {
|
||||
if let Ok(v) = format!("http://127.0.0.1:{p}").parse() {
|
||||
origins.push(v);
|
||||
}
|
||||
if let Ok(v) = format!("http://localhost:{p}").parse() {
|
||||
origins.push(v);
|
||||
}
|
||||
}
|
||||
}
|
||||
CorsLayer::new()
|
||||
.allow_origin(origins)
|
||||
.allow_methods(tower_http::cors::Any)
|
||||
.allow_headers(tower_http::cors::Any)
|
||||
} else {
|
||||
// Auth enabled → restrict CORS to localhost + configured origins.
|
||||
// SECURITY: CorsLayer::permissive() is dangerous — any website could
|
||||
// make cross-origin requests. Restrict to known origins instead.
|
||||
let mut origins: Vec<axum::http::HeaderValue> = vec![
|
||||
format!("http://{listen_addr}").parse().unwrap(),
|
||||
"http://localhost:4200".parse().unwrap(),
|
||||
"http://127.0.0.1:4200".parse().unwrap(),
|
||||
"http://localhost:8080".parse().unwrap(),
|
||||
"http://127.0.0.1:8080".parse().unwrap(),
|
||||
];
|
||||
// Add the actual listen address variants
|
||||
if listen_addr.port() != 4200 && listen_addr.port() != 8080 {
|
||||
if let Ok(v) = format!("http://localhost:{}", listen_addr.port()).parse() {
|
||||
origins.push(v);
|
||||
}
|
||||
if let Ok(v) = format!("http://127.0.0.1:{}", listen_addr.port()).parse() {
|
||||
origins.push(v);
|
||||
}
|
||||
}
|
||||
CorsLayer::new()
|
||||
.allow_origin(origins)
|
||||
.allow_methods(tower_http::cors::Any)
|
||||
.allow_headers(tower_http::cors::Any)
|
||||
};
|
||||
|
||||
let api_key = state.kernel.config.api_key.clone();
|
||||
let gcra_limiter = rate_limiter::create_rate_limiter();
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", axum::routing::get(webchat::webchat_page))
|
||||
.route("/logo.png", axum::routing::get(webchat::logo_png))
|
||||
.route("/favicon.ico", axum::routing::get(webchat::favicon_ico))
|
||||
.route(
|
||||
"/api/metrics",
|
||||
axum::routing::get(routes::prometheus_metrics),
|
||||
)
|
||||
.route("/api/health", axum::routing::get(routes::health))
|
||||
.route(
|
||||
"/api/health/detail",
|
||||
axum::routing::get(routes::health_detail),
|
||||
)
|
||||
.route("/api/status", axum::routing::get(routes::status))
|
||||
.route("/api/version", axum::routing::get(routes::version))
|
||||
.route(
|
||||
"/api/agents",
|
||||
axum::routing::get(routes::list_agents).post(routes::spawn_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}",
|
||||
axum::routing::get(routes::get_agent).delete(routes::kill_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/mode",
|
||||
axum::routing::put(routes::set_agent_mode),
|
||||
)
|
||||
.route("/api/profiles", axum::routing::get(routes::list_profiles))
|
||||
.route(
|
||||
"/api/agents/{id}/message",
|
||||
axum::routing::post(routes::send_message),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/message/stream",
|
||||
axum::routing::post(routes::send_message_stream),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/session",
|
||||
axum::routing::get(routes::get_agent_session),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/sessions",
|
||||
axum::routing::get(routes::list_agent_sessions).post(routes::create_agent_session),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/sessions/{session_id}/switch",
|
||||
axum::routing::post(routes::switch_agent_session),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/session/reset",
|
||||
axum::routing::post(routes::reset_session),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/session/compact",
|
||||
axum::routing::post(routes::compact_session),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/stop",
|
||||
axum::routing::post(routes::stop_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/model",
|
||||
axum::routing::put(routes::set_model),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/skills",
|
||||
axum::routing::get(routes::get_agent_skills).put(routes::set_agent_skills),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/mcp_servers",
|
||||
axum::routing::get(routes::get_agent_mcp_servers).put(routes::set_agent_mcp_servers),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/identity",
|
||||
axum::routing::patch(routes::update_agent_identity),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/config",
|
||||
axum::routing::patch(routes::patch_agent_config),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/clone",
|
||||
axum::routing::post(routes::clone_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/files",
|
||||
axum::routing::get(routes::list_agent_files),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/files/{filename}",
|
||||
axum::routing::get(routes::get_agent_file).put(routes::set_agent_file),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/deliveries",
|
||||
axum::routing::get(routes::get_agent_deliveries),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/upload",
|
||||
axum::routing::post(routes::upload_file),
|
||||
)
|
||||
.route("/api/agents/{id}/ws", axum::routing::get(ws::agent_ws))
|
||||
// Upload serving
|
||||
.route(
|
||||
"/api/uploads/{file_id}",
|
||||
axum::routing::get(routes::serve_upload),
|
||||
)
|
||||
// Channel endpoints
|
||||
.route("/api/channels", axum::routing::get(routes::list_channels))
|
||||
.route(
|
||||
"/api/channels/{name}/configure",
|
||||
axum::routing::post(routes::configure_channel).delete(routes::remove_channel),
|
||||
)
|
||||
.route(
|
||||
"/api/channels/{name}/test",
|
||||
axum::routing::post(routes::test_channel),
|
||||
)
|
||||
.route(
|
||||
"/api/channels/reload",
|
||||
axum::routing::post(routes::reload_channels),
|
||||
)
|
||||
// WhatsApp QR login flow
|
||||
.route(
|
||||
"/api/channels/whatsapp/qr/start",
|
||||
axum::routing::post(routes::whatsapp_qr_start),
|
||||
)
|
||||
.route(
|
||||
"/api/channels/whatsapp/qr/status",
|
||||
axum::routing::get(routes::whatsapp_qr_status),
|
||||
)
|
||||
// Template endpoints
|
||||
.route("/api/templates", axum::routing::get(routes::list_templates))
|
||||
.route(
|
||||
"/api/templates/{name}",
|
||||
axum::routing::get(routes::get_template),
|
||||
)
|
||||
// Memory endpoints
|
||||
.route(
|
||||
"/api/memory/agents/{id}/kv",
|
||||
axum::routing::get(routes::get_agent_kv),
|
||||
)
|
||||
.route(
|
||||
"/api/memory/agents/{id}/kv/{key}",
|
||||
axum::routing::get(routes::get_agent_kv_key)
|
||||
.put(routes::set_agent_kv_key)
|
||||
.delete(routes::delete_agent_kv_key),
|
||||
)
|
||||
// Trigger endpoints
|
||||
.route(
|
||||
"/api/triggers",
|
||||
axum::routing::get(routes::list_triggers).post(routes::create_trigger),
|
||||
)
|
||||
.route(
|
||||
"/api/triggers/{id}",
|
||||
axum::routing::delete(routes::delete_trigger).put(routes::update_trigger),
|
||||
)
|
||||
// Schedule (cron job) endpoints
|
||||
.route(
|
||||
"/api/schedules",
|
||||
axum::routing::get(routes::list_schedules).post(routes::create_schedule),
|
||||
)
|
||||
.route(
|
||||
"/api/schedules/{id}",
|
||||
axum::routing::delete(routes::delete_schedule).put(routes::update_schedule),
|
||||
)
|
||||
.route(
|
||||
"/api/schedules/{id}/run",
|
||||
axum::routing::post(routes::run_schedule),
|
||||
)
|
||||
// Workflow endpoints
|
||||
.route(
|
||||
"/api/workflows",
|
||||
axum::routing::get(routes::list_workflows).post(routes::create_workflow),
|
||||
)
|
||||
.route(
|
||||
"/api/workflows/{id}/run",
|
||||
axum::routing::post(routes::run_workflow),
|
||||
)
|
||||
.route(
|
||||
"/api/workflows/{id}/runs",
|
||||
axum::routing::get(routes::list_workflow_runs),
|
||||
)
|
||||
// Skills endpoints
|
||||
.route("/api/skills", axum::routing::get(routes::list_skills))
|
||||
.route(
|
||||
"/api/skills/install",
|
||||
axum::routing::post(routes::install_skill),
|
||||
)
|
||||
.route(
|
||||
"/api/skills/uninstall",
|
||||
axum::routing::post(routes::uninstall_skill),
|
||||
)
|
||||
.route(
|
||||
"/api/marketplace/search",
|
||||
axum::routing::get(routes::marketplace_search),
|
||||
)
|
||||
// ClawHub (OpenClaw ecosystem) endpoints
|
||||
.route(
|
||||
"/api/clawhub/search",
|
||||
axum::routing::get(routes::clawhub_search),
|
||||
)
|
||||
.route(
|
||||
"/api/clawhub/browse",
|
||||
axum::routing::get(routes::clawhub_browse),
|
||||
)
|
||||
.route(
|
||||
"/api/clawhub/skill/{slug}",
|
||||
axum::routing::get(routes::clawhub_skill_detail),
|
||||
)
|
||||
.route(
|
||||
"/api/clawhub/install",
|
||||
axum::routing::post(routes::clawhub_install),
|
||||
)
|
||||
// Hands endpoints
|
||||
.route("/api/hands", axum::routing::get(routes::list_hands))
|
||||
.route(
|
||||
"/api/hands/active",
|
||||
axum::routing::get(routes::list_active_hands),
|
||||
)
|
||||
.route("/api/hands/{hand_id}", axum::routing::get(routes::get_hand))
|
||||
.route(
|
||||
"/api/hands/{hand_id}/activate",
|
||||
axum::routing::post(routes::activate_hand),
|
||||
)
|
||||
.route(
|
||||
"/api/hands/{hand_id}/check-deps",
|
||||
axum::routing::post(routes::check_hand_deps),
|
||||
)
|
||||
.route(
|
||||
"/api/hands/{hand_id}/install-deps",
|
||||
axum::routing::post(routes::install_hand_deps),
|
||||
)
|
||||
.route(
|
||||
"/api/hands/instances/{id}/pause",
|
||||
axum::routing::post(routes::pause_hand),
|
||||
)
|
||||
.route(
|
||||
"/api/hands/instances/{id}/resume",
|
||||
axum::routing::post(routes::resume_hand),
|
||||
)
|
||||
.route(
|
||||
"/api/hands/instances/{id}",
|
||||
axum::routing::delete(routes::deactivate_hand),
|
||||
)
|
||||
.route(
|
||||
"/api/hands/instances/{id}/stats",
|
||||
axum::routing::get(routes::hand_stats),
|
||||
)
|
||||
.route(
|
||||
"/api/hands/instances/{id}/browser",
|
||||
axum::routing::get(routes::hand_instance_browser),
|
||||
)
|
||||
// MCP server endpoints
|
||||
.route(
|
||||
"/api/mcp/servers",
|
||||
axum::routing::get(routes::list_mcp_servers),
|
||||
)
|
||||
// Audit endpoints
|
||||
.route(
|
||||
"/api/audit/recent",
|
||||
axum::routing::get(routes::audit_recent),
|
||||
)
|
||||
.route(
|
||||
"/api/audit/verify",
|
||||
axum::routing::get(routes::audit_verify),
|
||||
)
|
||||
// Live log streaming (SSE)
|
||||
.route("/api/logs/stream", axum::routing::get(routes::logs_stream))
|
||||
// Peer/Network endpoints
|
||||
.route("/api/peers", axum::routing::get(routes::list_peers))
|
||||
.route(
|
||||
"/api/network/status",
|
||||
axum::routing::get(routes::network_status),
|
||||
)
|
||||
// Tools endpoint
|
||||
.route("/api/tools", axum::routing::get(routes::list_tools))
|
||||
// Config endpoints
|
||||
.route("/api/config", axum::routing::get(routes::get_config))
|
||||
.route(
|
||||
"/api/config/schema",
|
||||
axum::routing::get(routes::config_schema),
|
||||
)
|
||||
.route("/api/config/set", axum::routing::post(routes::config_set))
|
||||
// Approval endpoints
|
||||
.route(
|
||||
"/api/approvals",
|
||||
axum::routing::get(routes::list_approvals).post(routes::create_approval),
|
||||
)
|
||||
.route(
|
||||
"/api/approvals/{id}/approve",
|
||||
axum::routing::post(routes::approve_request),
|
||||
)
|
||||
.route(
|
||||
"/api/approvals/{id}/reject",
|
||||
axum::routing::post(routes::reject_request),
|
||||
)
|
||||
// Usage endpoints
|
||||
.route("/api/usage", axum::routing::get(routes::usage_stats))
|
||||
.route(
|
||||
"/api/usage/summary",
|
||||
axum::routing::get(routes::usage_summary),
|
||||
)
|
||||
.route(
|
||||
"/api/usage/by-model",
|
||||
axum::routing::get(routes::usage_by_model),
|
||||
)
|
||||
.route("/api/usage/daily", axum::routing::get(routes::usage_daily))
|
||||
// Budget endpoints
|
||||
.route(
|
||||
"/api/budget",
|
||||
axum::routing::get(routes::budget_status).put(routes::update_budget),
|
||||
)
|
||||
.route(
|
||||
"/api/budget/agents",
|
||||
axum::routing::get(routes::agent_budget_ranking),
|
||||
)
|
||||
.route(
|
||||
"/api/budget/agents/{id}",
|
||||
axum::routing::get(routes::agent_budget_status),
|
||||
)
|
||||
// Session endpoints
|
||||
.route("/api/sessions", axum::routing::get(routes::list_sessions))
|
||||
.route(
|
||||
"/api/sessions/{id}",
|
||||
axum::routing::delete(routes::delete_session),
|
||||
)
|
||||
.route(
|
||||
"/api/sessions/{id}/label",
|
||||
axum::routing::put(routes::set_session_label),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/sessions/by-label/{label}",
|
||||
axum::routing::get(routes::find_session_by_label),
|
||||
)
|
||||
// Agent update
|
||||
.route(
|
||||
"/api/agents/{id}/update",
|
||||
axum::routing::put(routes::update_agent),
|
||||
)
|
||||
// Security dashboard endpoint
|
||||
.route("/api/security", axum::routing::get(routes::security_status))
|
||||
// Model catalog endpoints
|
||||
.route("/api/models", axum::routing::get(routes::list_models))
|
||||
.route(
|
||||
"/api/models/aliases",
|
||||
axum::routing::get(routes::list_aliases),
|
||||
)
|
||||
.route(
|
||||
"/api/models/custom",
|
||||
axum::routing::post(routes::add_custom_model),
|
||||
)
|
||||
.route(
|
||||
"/api/models/custom/{*id}",
|
||||
axum::routing::delete(routes::remove_custom_model),
|
||||
)
|
||||
.route("/api/models/{*id}", axum::routing::get(routes::get_model))
|
||||
.route("/api/providers", axum::routing::get(routes::list_providers))
|
||||
// Copilot OAuth (must be before parametric {name} routes)
|
||||
.route(
|
||||
"/api/providers/github-copilot/oauth/start",
|
||||
axum::routing::post(routes::copilot_oauth_start),
|
||||
)
|
||||
.route(
|
||||
"/api/providers/github-copilot/oauth/poll/{poll_id}",
|
||||
axum::routing::get(routes::copilot_oauth_poll),
|
||||
)
|
||||
.route(
|
||||
"/api/providers/{name}/key",
|
||||
axum::routing::post(routes::set_provider_key).delete(routes::delete_provider_key),
|
||||
)
|
||||
.route(
|
||||
"/api/providers/{name}/test",
|
||||
axum::routing::post(routes::test_provider),
|
||||
)
|
||||
.route(
|
||||
"/api/providers/{name}/url",
|
||||
axum::routing::put(routes::set_provider_url),
|
||||
)
|
||||
.route(
|
||||
"/api/skills/create",
|
||||
axum::routing::post(routes::create_skill),
|
||||
)
|
||||
// Migration endpoints
|
||||
.route(
|
||||
"/api/migrate/detect",
|
||||
axum::routing::get(routes::migrate_detect),
|
||||
)
|
||||
.route(
|
||||
"/api/migrate/scan",
|
||||
axum::routing::post(routes::migrate_scan),
|
||||
)
|
||||
.route("/api/migrate", axum::routing::post(routes::run_migrate))
|
||||
// Cron job management endpoints
|
||||
.route(
|
||||
"/api/cron/jobs",
|
||||
axum::routing::get(routes::list_cron_jobs).post(routes::create_cron_job),
|
||||
)
|
||||
.route(
|
||||
"/api/cron/jobs/{id}",
|
||||
axum::routing::delete(routes::delete_cron_job),
|
||||
)
|
||||
.route(
|
||||
"/api/cron/jobs/{id}/enable",
|
||||
axum::routing::put(routes::toggle_cron_job),
|
||||
)
|
||||
.route(
|
||||
"/api/cron/jobs/{id}/status",
|
||||
axum::routing::get(routes::cron_job_status),
|
||||
)
|
||||
// Webhook trigger endpoints (external event injection)
|
||||
.route("/hooks/wake", axum::routing::post(routes::webhook_wake))
|
||||
.route("/hooks/agent", axum::routing::post(routes::webhook_agent))
|
||||
.route("/api/shutdown", axum::routing::post(routes::shutdown))
|
||||
// Chat commands endpoint (dynamic slash menu)
|
||||
.route("/api/commands", axum::routing::get(routes::list_commands))
|
||||
// Config reload endpoint
|
||||
.route(
|
||||
"/api/config/reload",
|
||||
axum::routing::post(routes::config_reload),
|
||||
)
|
||||
// Agent binding routes
|
||||
.route(
|
||||
"/api/bindings",
|
||||
axum::routing::get(routes::list_bindings).post(routes::add_binding),
|
||||
)
|
||||
.route(
|
||||
"/api/bindings/{index}",
|
||||
axum::routing::delete(routes::remove_binding),
|
||||
)
|
||||
// A2A (Agent-to-Agent) Protocol endpoints
|
||||
.route(
|
||||
"/.well-known/agent.json",
|
||||
axum::routing::get(routes::a2a_agent_card),
|
||||
)
|
||||
.route("/a2a/agents", axum::routing::get(routes::a2a_list_agents))
|
||||
.route(
|
||||
"/a2a/tasks/send",
|
||||
axum::routing::post(routes::a2a_send_task),
|
||||
)
|
||||
.route("/a2a/tasks/{id}", axum::routing::get(routes::a2a_get_task))
|
||||
.route(
|
||||
"/a2a/tasks/{id}/cancel",
|
||||
axum::routing::post(routes::a2a_cancel_task),
|
||||
)
|
||||
// A2A management (outbound) endpoints
|
||||
.route(
|
||||
"/api/a2a/agents",
|
||||
axum::routing::get(routes::a2a_list_external_agents),
|
||||
)
|
||||
.route(
|
||||
"/api/a2a/discover",
|
||||
axum::routing::post(routes::a2a_discover_external),
|
||||
)
|
||||
.route(
|
||||
"/api/a2a/send",
|
||||
axum::routing::post(routes::a2a_send_external),
|
||||
)
|
||||
.route(
|
||||
"/api/a2a/tasks/{id}/status",
|
||||
axum::routing::get(routes::a2a_external_task_status),
|
||||
)
|
||||
// Integration management endpoints
|
||||
.route(
|
||||
"/api/integrations",
|
||||
axum::routing::get(routes::list_integrations),
|
||||
)
|
||||
.route(
|
||||
"/api/integrations/available",
|
||||
axum::routing::get(routes::list_available_integrations),
|
||||
)
|
||||
.route(
|
||||
"/api/integrations/add",
|
||||
axum::routing::post(routes::add_integration),
|
||||
)
|
||||
.route(
|
||||
"/api/integrations/{id}",
|
||||
axum::routing::delete(routes::remove_integration),
|
||||
)
|
||||
.route(
|
||||
"/api/integrations/{id}/reconnect",
|
||||
axum::routing::post(routes::reconnect_integration),
|
||||
)
|
||||
.route(
|
||||
"/api/integrations/health",
|
||||
axum::routing::get(routes::integrations_health),
|
||||
)
|
||||
.route(
|
||||
"/api/integrations/reload",
|
||||
axum::routing::post(routes::reload_integrations),
|
||||
)
|
||||
// Device pairing endpoints
|
||||
.route(
|
||||
"/api/pairing/request",
|
||||
axum::routing::post(routes::pairing_request),
|
||||
)
|
||||
.route(
|
||||
"/api/pairing/complete",
|
||||
axum::routing::post(routes::pairing_complete),
|
||||
)
|
||||
.route(
|
||||
"/api/pairing/devices",
|
||||
axum::routing::get(routes::pairing_devices),
|
||||
)
|
||||
.route(
|
||||
"/api/pairing/devices/{id}",
|
||||
axum::routing::delete(routes::pairing_remove_device),
|
||||
)
|
||||
.route(
|
||||
"/api/pairing/notify",
|
||||
axum::routing::post(routes::pairing_notify),
|
||||
)
|
||||
// MCP HTTP endpoint (exposes MCP protocol over HTTP)
|
||||
.route("/mcp", axum::routing::post(routes::mcp_http))
|
||||
// OpenAI-compatible API
|
||||
.route(
|
||||
"/v1/chat/completions",
|
||||
axum::routing::post(crate::openai_compat::chat_completions),
|
||||
)
|
||||
.route(
|
||||
"/v1/models",
|
||||
axum::routing::get(crate::openai_compat::list_models),
|
||||
)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
api_key,
|
||||
middleware::auth,
|
||||
))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
gcra_limiter,
|
||||
rate_limiter::gcra_rate_limit,
|
||||
))
|
||||
.layer(axum::middleware::from_fn(middleware::security_headers))
|
||||
.layer(axum::middleware::from_fn(middleware::request_logging))
|
||||
.layer(CompressionLayer::new())
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(cors)
|
||||
.with_state(state.clone());
|
||||
|
||||
(app, state)
|
||||
}
|
||||
|
||||
/// Start the OpenFang daemon: boot kernel + HTTP API server.
|
||||
///
|
||||
/// This function blocks until Ctrl+C or a shutdown request.
|
||||
pub async fn run_daemon(
|
||||
kernel: OpenFangKernel,
|
||||
listen_addr: &str,
|
||||
daemon_info_path: Option<&Path>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let addr: SocketAddr = listen_addr.parse()?;
|
||||
|
||||
let kernel = Arc::new(kernel);
|
||||
kernel.set_self_handle();
|
||||
kernel.start_background_agents();
|
||||
|
||||
// Config file hot-reload watcher (polls every 30 seconds)
|
||||
{
|
||||
let k = kernel.clone();
|
||||
let config_path = kernel.config.home_dir.join("config.toml");
|
||||
tokio::spawn(async move {
|
||||
let mut last_modified = std::fs::metadata(&config_path)
|
||||
.and_then(|m| m.modified())
|
||||
.ok();
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
||||
let current = std::fs::metadata(&config_path)
|
||||
.and_then(|m| m.modified())
|
||||
.ok();
|
||||
if current != last_modified && current.is_some() {
|
||||
last_modified = current;
|
||||
tracing::info!("Config file changed, reloading...");
|
||||
match k.reload_config() {
|
||||
Ok(plan) => {
|
||||
if plan.has_changes() {
|
||||
tracing::info!("Config hot-reload applied: {:?}", plan.hot_actions);
|
||||
} else {
|
||||
tracing::debug!("Config hot-reload: no actionable changes");
|
||||
}
|
||||
}
|
||||
Err(e) => tracing::warn!("Config hot-reload failed: {e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let (app, state) = build_router(kernel.clone(), addr).await;
|
||||
|
||||
// Write daemon info file
|
||||
if let Some(info_path) = daemon_info_path {
|
||||
// Check if another daemon is already running with this PID file
|
||||
if info_path.exists() {
|
||||
if let Ok(existing) = std::fs::read_to_string(info_path) {
|
||||
if let Ok(info) = serde_json::from_str::<DaemonInfo>(&existing) {
|
||||
if is_process_alive(info.pid) {
|
||||
return Err(format!(
|
||||
"Another daemon (PID {}) is already running at {}",
|
||||
info.pid, info.listen_addr
|
||||
)
|
||||
.into());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Stale PID file, remove it
|
||||
let _ = std::fs::remove_file(info_path);
|
||||
}
|
||||
|
||||
let daemon_info = DaemonInfo {
|
||||
pid: std::process::id(),
|
||||
listen_addr: addr.to_string(),
|
||||
started_at: chrono::Utc::now().to_rfc3339(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
platform: std::env::consts::OS.to_string(),
|
||||
};
|
||||
if let Ok(json) = serde_json::to_string_pretty(&daemon_info) {
|
||||
let _ = std::fs::write(info_path, json);
|
||||
// SECURITY: Restrict daemon info file permissions (contains PID and port).
|
||||
restrict_permissions(info_path);
|
||||
}
|
||||
}
|
||||
|
||||
info!("OpenFang API server listening on http://{addr}");
|
||||
info!("WebChat UI available at http://{addr}/",);
|
||||
info!("WebSocket endpoint: ws://{addr}/api/agents/{{id}}/ws",);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
|
||||
// Run server with graceful shutdown.
|
||||
// SECURITY: `into_make_service_with_connect_info` injects the peer
|
||||
// SocketAddr so the auth middleware can check for loopback connections.
|
||||
let api_shutdown = state.shutdown_notify.clone();
|
||||
axum::serve(
|
||||
listener,
|
||||
app.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.with_graceful_shutdown(shutdown_signal(api_shutdown))
|
||||
.await?;
|
||||
|
||||
// Clean up daemon info file
|
||||
if let Some(info_path) = daemon_info_path {
|
||||
let _ = std::fs::remove_file(info_path);
|
||||
}
|
||||
|
||||
// Stop channel bridges
|
||||
if let Some(ref mut b) = *state.bridge_manager.lock().await {
|
||||
b.stop().await;
|
||||
}
|
||||
|
||||
// Shutdown kernel
|
||||
kernel.shutdown();
|
||||
|
||||
info!("OpenFang daemon stopped");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// SECURITY: Restrict file permissions to owner-only (0600) on Unix.
|
||||
/// On non-Unix platforms this is a no-op.
|
||||
#[cfg(unix)]
|
||||
fn restrict_permissions(path: &Path) {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
fn restrict_permissions(_path: &Path) {}
|
||||
|
||||
/// Read daemon info from the standard location.
|
||||
pub fn read_daemon_info(home_dir: &Path) -> Option<DaemonInfo> {
|
||||
let info_path = home_dir.join("daemon.json");
|
||||
let contents = std::fs::read_to_string(info_path).ok()?;
|
||||
serde_json::from_str(&contents).ok()
|
||||
}
|
||||
|
||||
/// Wait for an OS termination signal OR an API shutdown request.
|
||||
///
|
||||
/// On Unix: listens for SIGINT, SIGTERM, and API notify.
|
||||
/// On Windows: listens for Ctrl+C and API notify.
|
||||
async fn shutdown_signal(api_shutdown: Arc<tokio::sync::Notify>) {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use tokio::signal::unix::{signal, SignalKind};
|
||||
let mut sigint = signal(SignalKind::interrupt()).expect("Failed to listen for SIGINT");
|
||||
let mut sigterm = signal(SignalKind::terminate()).expect("Failed to listen for SIGTERM");
|
||||
|
||||
tokio::select! {
|
||||
_ = sigint.recv() => {
|
||||
info!("Received SIGINT (Ctrl+C), shutting down...");
|
||||
}
|
||||
_ = sigterm.recv() => {
|
||||
info!("Received SIGTERM, shutting down...");
|
||||
}
|
||||
_ = api_shutdown.notified() => {
|
||||
info!("Shutdown requested via API, shutting down...");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
tokio::select! {
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
info!("Ctrl+C received, shutting down...");
|
||||
}
|
||||
_ = api_shutdown.notified() => {
|
||||
info!("Shutdown requested via API, shutting down...");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a process with the given PID is still alive.
|
||||
fn is_process_alive(pid: u32) -> bool {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
// Use kill -0 to check if process exists without sending a signal
|
||||
std::process::Command::new("kill")
|
||||
.args(["-0", &pid.to_string()])
|
||||
.output()
|
||||
.map(|o| o.status.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
{
|
||||
// tasklist /FI "PID eq N" returns "INFO: No tasks..." when no match,
|
||||
// or a table row with the PID when found. Check exit code and that
|
||||
// "INFO:" is NOT in the output to confirm the process exists.
|
||||
std::process::Command::new("tasklist")
|
||||
.args(["/FI", &format!("PID eq {pid}"), "/NH"])
|
||||
.output()
|
||||
.map(|o| {
|
||||
o.status.success() && {
|
||||
let out = String::from_utf8_lossy(&o.stdout);
|
||||
!out.contains("INFO:") && out.contains(&pid.to_string())
|
||||
}
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
#[cfg(not(any(unix, windows)))]
|
||||
{
|
||||
let _ = pid;
|
||||
false
|
||||
}
|
||||
}
|
||||
224
crates/openfang-api/src/stream_chunker.rs
Normal file
224
crates/openfang-api/src/stream_chunker.rs
Normal file
@@ -0,0 +1,224 @@
|
||||
//! Markdown-aware stream chunking.
|
||||
//!
|
||||
//! Replaces naive 200-char text buffer flushing with smart chunking that
|
||||
//! never splits inside fenced code blocks and respects Markdown structure.
|
||||
|
||||
/// Markdown-aware stream chunker.
|
||||
///
|
||||
/// Buffers incoming text and flushes at natural break points:
|
||||
/// paragraph boundaries > newlines > sentence endings.
|
||||
/// Never splits inside fenced code blocks.
|
||||
pub struct StreamChunker {
|
||||
buffer: String,
|
||||
in_code_fence: bool,
|
||||
fence_marker: String,
|
||||
min_chunk_chars: usize,
|
||||
max_chunk_chars: usize,
|
||||
}
|
||||
|
||||
impl StreamChunker {
|
||||
/// Create a new chunker with custom min/max thresholds.
|
||||
pub fn new(min_chunk_chars: usize, max_chunk_chars: usize) -> Self {
|
||||
Self {
|
||||
buffer: String::new(),
|
||||
in_code_fence: false,
|
||||
fence_marker: String::new(),
|
||||
min_chunk_chars,
|
||||
max_chunk_chars,
|
||||
}
|
||||
}
|
||||
|
||||
/// Push new text into the buffer. Updates code fence tracking.
|
||||
pub fn push(&mut self, text: &str) {
|
||||
for line in text.split_inclusive('\n') {
|
||||
self.buffer.push_str(line);
|
||||
// Track code fence state
|
||||
let trimmed = line.trim();
|
||||
if trimmed.starts_with("```") {
|
||||
if self.in_code_fence {
|
||||
// Check if this closes the current fence
|
||||
if trimmed == "```" || trimmed.starts_with(&self.fence_marker) {
|
||||
self.in_code_fence = false;
|
||||
self.fence_marker.clear();
|
||||
}
|
||||
} else {
|
||||
self.in_code_fence = true;
|
||||
self.fence_marker = "```".to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to flush a chunk from the buffer.
|
||||
///
|
||||
/// Returns `Some(chunk)` if enough content has accumulated,
|
||||
/// `None` if we should wait for more input.
|
||||
pub fn try_flush(&mut self) -> Option<String> {
|
||||
if self.buffer.len() < self.min_chunk_chars {
|
||||
return None;
|
||||
}
|
||||
|
||||
// If inside a code fence and under max, wait for fence to close
|
||||
if self.in_code_fence && self.buffer.len() < self.max_chunk_chars {
|
||||
return None;
|
||||
}
|
||||
|
||||
// If at max inside a fence, force-close and flush
|
||||
if self.in_code_fence && self.buffer.len() >= self.max_chunk_chars {
|
||||
// Close the fence, flush everything, reopen on next push
|
||||
let mut chunk = std::mem::take(&mut self.buffer);
|
||||
chunk.push_str("\n```\n");
|
||||
// Mark that we need to reopen the fence
|
||||
self.buffer = format!("```{}\n", self.fence_marker.trim_start_matches('`'));
|
||||
return Some(chunk);
|
||||
}
|
||||
|
||||
// Find best break point
|
||||
let search_range = self.min_chunk_chars..self.buffer.len().min(self.max_chunk_chars);
|
||||
|
||||
// Priority 1: Paragraph break (double newline)
|
||||
if let Some(pos) = find_last_in_range(&self.buffer, "\n\n", &search_range) {
|
||||
let break_at = pos + 2;
|
||||
let chunk = self.buffer[..break_at].to_string();
|
||||
self.buffer = self.buffer[break_at..].to_string();
|
||||
return Some(chunk);
|
||||
}
|
||||
|
||||
// Priority 2: Single newline
|
||||
if let Some(pos) = find_last_in_range(&self.buffer, "\n", &search_range) {
|
||||
let break_at = pos + 1;
|
||||
let chunk = self.buffer[..break_at].to_string();
|
||||
self.buffer = self.buffer[break_at..].to_string();
|
||||
return Some(chunk);
|
||||
}
|
||||
|
||||
// Priority 3: Sentence ending (". ", "! ", "? ")
|
||||
for ending in &[". ", "! ", "? "] {
|
||||
if let Some(pos) = find_last_in_range(&self.buffer, ending, &search_range) {
|
||||
let break_at = pos + ending.len();
|
||||
let chunk = self.buffer[..break_at].to_string();
|
||||
self.buffer = self.buffer[break_at..].to_string();
|
||||
return Some(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 4: Forced break at max_chunk_chars
|
||||
if self.buffer.len() >= self.max_chunk_chars {
|
||||
let break_at = self.max_chunk_chars;
|
||||
let chunk = self.buffer[..break_at].to_string();
|
||||
self.buffer = self.buffer[break_at..].to_string();
|
||||
return Some(chunk);
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Force-flush all remaining text.
|
||||
pub fn flush_remaining(&mut self) -> Option<String> {
|
||||
if self.buffer.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(std::mem::take(&mut self.buffer))
|
||||
}
|
||||
}
|
||||
|
||||
/// Current buffer length.
|
||||
pub fn buffered_len(&self) -> usize {
|
||||
self.buffer.len()
|
||||
}
|
||||
|
||||
/// Whether currently inside a code fence.
|
||||
pub fn is_in_code_fence(&self) -> bool {
|
||||
self.in_code_fence
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the last occurrence of a pattern within a byte range.
|
||||
fn find_last_in_range(text: &str, pattern: &str, range: &std::ops::Range<usize>) -> Option<usize> {
|
||||
let search_text = &text[range.start..range.end.min(text.len())];
|
||||
search_text.rfind(pattern).map(|pos| range.start + pos)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_basic_chunking() {
|
||||
let mut chunker = StreamChunker::new(10, 50);
|
||||
chunker.push("Hello world.\nThis is a test.\nAnother line.\n");
|
||||
|
||||
let chunk = chunker.try_flush();
|
||||
assert!(chunk.is_some());
|
||||
let text = chunk.unwrap();
|
||||
// Should break at a newline
|
||||
assert!(text.ends_with('\n'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_code_fence_not_split() {
|
||||
let mut chunker = StreamChunker::new(5, 200);
|
||||
chunker.push("Before\n```python\ndef foo():\n pass\n```\nAfter\n");
|
||||
|
||||
// Should not flush mid-fence
|
||||
// Since buffer is >5 chars and fence is now closed, should flush
|
||||
let chunk = chunker.try_flush();
|
||||
assert!(chunk.is_some());
|
||||
let text = chunk.unwrap();
|
||||
// If it includes the code block, the fence should be complete
|
||||
if text.contains("```python") {
|
||||
assert!(text.contains("```\n") || text.ends_with("```"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_code_fence_force_close_at_max() {
|
||||
let mut chunker = StreamChunker::new(5, 30);
|
||||
chunker.push("```python\nline1\nline2\nline3\nline4\nline5\nline6\n");
|
||||
|
||||
// Buffer exceeds max while in fence — should force close
|
||||
let chunk = chunker.try_flush();
|
||||
assert!(chunk.is_some());
|
||||
let text = chunk.unwrap();
|
||||
assert!(text.contains("```\n")); // force-closed fence
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_paragraph_break_priority() {
|
||||
let mut chunker = StreamChunker::new(10, 200);
|
||||
chunker.push("First paragraph text.\n\nSecond paragraph text.\n");
|
||||
|
||||
let chunk = chunker.try_flush();
|
||||
assert!(chunk.is_some());
|
||||
let text = chunk.unwrap();
|
||||
assert!(text.ends_with("\n\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flush_remaining() {
|
||||
let mut chunker = StreamChunker::new(100, 200);
|
||||
chunker.push("short");
|
||||
|
||||
// try_flush should return None (under min)
|
||||
assert!(chunker.try_flush().is_none());
|
||||
|
||||
// flush_remaining should return everything
|
||||
let remaining = chunker.flush_remaining();
|
||||
assert_eq!(remaining, Some("short".to_string()));
|
||||
|
||||
// Second flush should be None
|
||||
assert!(chunker.flush_remaining().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sentence_break() {
|
||||
let mut chunker = StreamChunker::new(10, 200);
|
||||
chunker.push("This is the first sentence. This is the second sentence. More text here.");
|
||||
|
||||
let chunk = chunker.try_flush();
|
||||
assert!(chunk.is_some());
|
||||
let text = chunk.unwrap();
|
||||
// Should break at a sentence ending
|
||||
assert!(text.ends_with(". ") || text.ends_with(".\n"));
|
||||
}
|
||||
}
|
||||
160
crates/openfang-api/src/stream_dedup.rs
Normal file
160
crates/openfang-api/src/stream_dedup.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
//! Streaming duplicate detection.
|
||||
//!
|
||||
//! Detects when the LLM repeats text that was already sent (e.g., repeating
|
||||
//! tool output verbatim). Uses exact + normalized matching with a sliding window.
|
||||
|
||||
/// Minimum text length to consider for deduplication.
|
||||
const MIN_DEDUP_LENGTH: usize = 10;
|
||||
|
||||
/// Number of recent chunks to keep in the dedup window.
|
||||
const DEDUP_WINDOW: usize = 50;
|
||||
|
||||
/// Streaming duplicate detector.
|
||||
pub struct StreamDedup {
|
||||
/// Recent chunks (exact text).
|
||||
recent_chunks: Vec<String>,
|
||||
/// Recent chunks (normalized: lowercased, whitespace-collapsed).
|
||||
recent_normalized: Vec<String>,
|
||||
}
|
||||
|
||||
impl StreamDedup {
|
||||
/// Create a new dedup detector.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
recent_chunks: Vec::with_capacity(DEDUP_WINDOW),
|
||||
recent_normalized: Vec::with_capacity(DEDUP_WINDOW),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if text is a duplicate of recently sent content.
|
||||
///
|
||||
/// Returns `true` if the text matches (exact or normalized) any
|
||||
/// recent chunk. Skips very short texts.
|
||||
pub fn is_duplicate(&self, text: &str) -> bool {
|
||||
if text.len() < MIN_DEDUP_LENGTH {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Exact match
|
||||
if self.recent_chunks.iter().any(|c| c == text) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Normalized match
|
||||
let normalized = normalize(text);
|
||||
self.recent_normalized.iter().any(|c| c == &normalized)
|
||||
}
|
||||
|
||||
/// Record text that was successfully sent to the client.
|
||||
pub fn record_sent(&mut self, text: &str) {
|
||||
if text.len() < MIN_DEDUP_LENGTH {
|
||||
return;
|
||||
}
|
||||
|
||||
// Evict oldest if at capacity
|
||||
if self.recent_chunks.len() >= DEDUP_WINDOW {
|
||||
self.recent_chunks.remove(0);
|
||||
self.recent_normalized.remove(0);
|
||||
}
|
||||
|
||||
self.recent_chunks.push(text.to_string());
|
||||
self.recent_normalized.push(normalize(text));
|
||||
}
|
||||
|
||||
/// Clear the dedup window.
|
||||
pub fn clear(&mut self) {
|
||||
self.recent_chunks.clear();
|
||||
self.recent_normalized.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for StreamDedup {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Normalize text for fuzzy matching: lowercase + collapse whitespace.
|
||||
fn normalize(text: &str) -> String {
|
||||
let mut result = String::with_capacity(text.len());
|
||||
let mut last_was_space = false;
|
||||
|
||||
for ch in text.chars() {
|
||||
if ch.is_whitespace() {
|
||||
if !last_was_space {
|
||||
result.push(' ');
|
||||
last_was_space = true;
|
||||
}
|
||||
} else {
|
||||
result.push(ch.to_lowercase().next().unwrap_or(ch));
|
||||
last_was_space = false;
|
||||
}
|
||||
}
|
||||
|
||||
result.trim().to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_exact_match_detected() {
|
||||
let mut dedup = StreamDedup::new();
|
||||
dedup.record_sent("This is a test chunk of text that was sent.");
|
||||
assert!(dedup.is_duplicate("This is a test chunk of text that was sent."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalized_match_detected() {
|
||||
let mut dedup = StreamDedup::new();
|
||||
dedup.record_sent("This is a test chunk");
|
||||
// Same text but different whitespace/case
|
||||
assert!(dedup.is_duplicate("this is a test chunk"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_short_text_skipped() {
|
||||
let mut dedup = StreamDedup::new();
|
||||
dedup.record_sent("short");
|
||||
assert!(!dedup.is_duplicate("short"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_window_rollover() {
|
||||
let mut dedup = StreamDedup::new();
|
||||
// Fill the window
|
||||
for i in 0..DEDUP_WINDOW {
|
||||
dedup.record_sent(&format!("chunk number {} is here", i));
|
||||
}
|
||||
// Add one more — should evict the oldest
|
||||
dedup.record_sent("new chunk that is quite long");
|
||||
// Oldest should no longer be detected
|
||||
assert!(!dedup.is_duplicate("chunk number 0 is here"));
|
||||
// Newest should be detected
|
||||
assert!(dedup.is_duplicate("new chunk that is quite long"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_false_positives() {
|
||||
let mut dedup = StreamDedup::new();
|
||||
dedup.record_sent("The quick brown fox jumps over the lazy dog");
|
||||
assert!(!dedup.is_duplicate("A completely different sentence here"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let mut dedup = StreamDedup::new();
|
||||
dedup.record_sent("This is test content here");
|
||||
assert!(dedup.is_duplicate("This is test content here"));
|
||||
dedup.clear();
|
||||
assert!(!dedup.is_duplicate("This is test content here"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize() {
|
||||
assert_eq!(normalize("Hello World"), "hello world");
|
||||
assert_eq!(normalize(" spaced out "), "spaced out");
|
||||
assert_eq!(normalize("UPPER case"), "upper case");
|
||||
}
|
||||
}
|
||||
98
crates/openfang-api/src/types.rs
Normal file
98
crates/openfang-api/src/types.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
//! Request/response types for the OpenFang API.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Request to spawn an agent from a TOML manifest string.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SpawnRequest {
|
||||
/// Agent manifest as TOML string.
|
||||
pub manifest_toml: String,
|
||||
/// Optional Ed25519 signed manifest envelope (JSON).
|
||||
/// When present, the signature is verified before spawning.
|
||||
#[serde(default)]
|
||||
pub signed_manifest: Option<String>,
|
||||
}
|
||||
|
||||
/// Response after spawning an agent.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct SpawnResponse {
|
||||
pub agent_id: String,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
/// A file attachment reference (from a prior upload).
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AttachmentRef {
|
||||
pub file_id: String,
|
||||
#[serde(default)]
|
||||
pub filename: String,
|
||||
#[serde(default)]
|
||||
pub content_type: String,
|
||||
}
|
||||
|
||||
/// Request to send a message to an agent.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct MessageRequest {
|
||||
pub message: String,
|
||||
/// Optional file attachments (uploaded via /upload endpoint).
|
||||
#[serde(default)]
|
||||
pub attachments: Vec<AttachmentRef>,
|
||||
}
|
||||
|
||||
/// Response from sending a message.
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct MessageResponse {
|
||||
pub response: String,
|
||||
pub input_tokens: u64,
|
||||
pub output_tokens: u64,
|
||||
pub iterations: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cost_usd: Option<f64>,
|
||||
}
|
||||
|
||||
/// Request to install a skill from the marketplace.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SkillInstallRequest {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
/// Request to uninstall a skill.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SkillUninstallRequest {
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
/// Request to update an agent's manifest.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct AgentUpdateRequest {
|
||||
pub manifest_toml: String,
|
||||
}
|
||||
|
||||
/// Request to change an agent's operational mode.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SetModeRequest {
|
||||
pub mode: openfang_types::agent::AgentMode,
|
||||
}
|
||||
|
||||
/// Request to run a migration.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct MigrateRequest {
|
||||
pub source: String,
|
||||
pub source_dir: String,
|
||||
pub target_dir: String,
|
||||
#[serde(default)]
|
||||
pub dry_run: bool,
|
||||
}
|
||||
|
||||
/// Request to scan a directory for migration.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct MigrateScanRequest {
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
/// Request to install a skill from ClawHub.
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ClawHubInstallRequest {
|
||||
/// ClawHub skill slug (e.g., "github-helper").
|
||||
pub slug: String,
|
||||
}
|
||||
132
crates/openfang-api/src/webchat.rs
Normal file
132
crates/openfang-api/src/webchat.rs
Normal file
@@ -0,0 +1,132 @@
|
||||
//! Embedded WebChat UI served as static HTML.
|
||||
//!
|
||||
//! The production dashboard is assembled at compile time from separate
|
||||
//! HTML/CSS/JS files under `static/` using `include_str!()`. This keeps
|
||||
//! single-binary deployment while allowing organized source files.
|
||||
//!
|
||||
//! Features:
|
||||
//! - Alpine.js SPA with hash-based routing (10 panels)
|
||||
//! - Dark/light theme toggle with system preference detection
|
||||
//! - Responsive layout with collapsible sidebar
|
||||
//! - Markdown rendering + syntax highlighting (bundled locally)
|
||||
//! - WebSocket real-time chat with HTTP fallback
|
||||
//! - Agent management, workflows, memory browser, audit log, and more
|
||||
|
||||
use axum::http::header;
|
||||
use axum::response::IntoResponse;
|
||||
|
||||
/// Compile-time ETag based on the crate version.
|
||||
const ETAG: &str = concat!("\"openfang-", env!("CARGO_PKG_VERSION"), "\"");
|
||||
|
||||
/// Embedded logo PNG for single-binary deployment.
|
||||
const LOGO_PNG: &[u8] = include_bytes!("../static/logo.png");
|
||||
|
||||
/// Embedded favicon ICO for browser tabs.
|
||||
const FAVICON_ICO: &[u8] = include_bytes!("../static/favicon.ico");
|
||||
|
||||
/// GET /logo.png — Serve the OpenFang logo.
|
||||
pub async fn logo_png() -> impl IntoResponse {
|
||||
(
|
||||
[
|
||||
(header::CONTENT_TYPE, "image/png"),
|
||||
(header::CACHE_CONTROL, "public, max-age=86400, immutable"),
|
||||
],
|
||||
LOGO_PNG,
|
||||
)
|
||||
}
|
||||
|
||||
/// GET /favicon.ico — Serve the OpenFang favicon.
|
||||
pub async fn favicon_ico() -> impl IntoResponse {
|
||||
(
|
||||
[
|
||||
(header::CONTENT_TYPE, "image/x-icon"),
|
||||
(header::CACHE_CONTROL, "public, max-age=86400, immutable"),
|
||||
],
|
||||
FAVICON_ICO,
|
||||
)
|
||||
}
|
||||
|
||||
/// GET / — Serve the OpenFang Dashboard single-page application.
|
||||
///
|
||||
/// Returns the full SPA with ETag header based on package version for caching.
|
||||
pub async fn webchat_page() -> impl IntoResponse {
|
||||
(
|
||||
[
|
||||
(header::CONTENT_TYPE, "text/html; charset=utf-8"),
|
||||
(header::ETAG, ETAG),
|
||||
(
|
||||
header::CACHE_CONTROL,
|
||||
"public, max-age=3600, must-revalidate",
|
||||
),
|
||||
],
|
||||
WEBCHAT_HTML,
|
||||
)
|
||||
}
|
||||
|
||||
/// The embedded HTML/CSS/JS for the OpenFang Dashboard.
|
||||
///
|
||||
/// Assembled at compile time from organized static files.
|
||||
/// All vendor libraries (Alpine.js, marked.js, highlight.js) are bundled
|
||||
/// locally — no CDN dependency. Alpine.js is included LAST because it
|
||||
/// immediately processes x-data directives and fires alpine:init on load.
|
||||
const WEBCHAT_HTML: &str = concat!(
|
||||
include_str!("../static/index_head.html"),
|
||||
"<style>\n",
|
||||
include_str!("../static/css/theme.css"),
|
||||
"\n",
|
||||
include_str!("../static/css/layout.css"),
|
||||
"\n",
|
||||
include_str!("../static/css/components.css"),
|
||||
"\n",
|
||||
include_str!("../static/vendor/github-dark.min.css"),
|
||||
"\n</style>\n",
|
||||
include_str!("../static/index_body.html"),
|
||||
// Vendor libs: marked + highlight first (used by app.js)
|
||||
"<script>\n",
|
||||
include_str!("../static/vendor/marked.min.js"),
|
||||
"\n</script>\n",
|
||||
"<script>\n",
|
||||
include_str!("../static/vendor/highlight.min.js"),
|
||||
"\n</script>\n",
|
||||
// App code
|
||||
"<script>\n",
|
||||
include_str!("../static/js/api.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/app.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/overview.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/chat.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/agents.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/workflows.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/workflow-builder.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/channels.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/skills.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/hands.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/scheduler.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/settings.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/usage.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/sessions.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/logs.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/wizard.js"),
|
||||
"\n",
|
||||
include_str!("../static/js/pages/approvals.js"),
|
||||
"\n</script>\n",
|
||||
// Alpine.js MUST be last — it processes x-data and fires alpine:init
|
||||
"<script>\n",
|
||||
include_str!("../static/vendor/alpine.min.js"),
|
||||
"\n</script>\n",
|
||||
"</body></html>"
|
||||
);
|
||||
1226
crates/openfang-api/src/ws.rs
Normal file
1226
crates/openfang-api/src/ws.rs
Normal file
File diff suppressed because it is too large
Load Diff
3075
crates/openfang-api/static/css/components.css
Normal file
3075
crates/openfang-api/static/css/components.css
Normal file
File diff suppressed because it is too large
Load Diff
309
crates/openfang-api/static/css/layout.css
Normal file
309
crates/openfang-api/static/css/layout.css
Normal file
@@ -0,0 +1,309 @@
|
||||
/* OpenFang Layout — Grid + Sidebar + Responsive */
|
||||
|
||||
.app-layout {
|
||||
display: flex;
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* Sidebar */
|
||||
.sidebar {
|
||||
width: var(--sidebar-width);
|
||||
background: var(--bg-primary);
|
||||
border-right: 1px solid var(--border);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex-shrink: 0;
|
||||
transition: width var(--transition-normal);
|
||||
z-index: 100;
|
||||
}
|
||||
|
||||
.sidebar.collapsed {
|
||||
width: var(--sidebar-collapsed);
|
||||
}
|
||||
|
||||
.sidebar.collapsed .sidebar-label,
|
||||
.sidebar.collapsed .sidebar-header-text,
|
||||
.sidebar.collapsed .nav-label { display: none; }
|
||||
|
||||
.sidebar.collapsed .nav-item { justify-content: center; padding: 12px 0; }
|
||||
|
||||
.sidebar-header {
|
||||
padding: 16px;
|
||||
border-bottom: 1px solid var(--border);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
min-height: 60px;
|
||||
}
|
||||
|
||||
.sidebar-logo {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.sidebar-logo img {
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
opacity: 0.8;
|
||||
transition: opacity 0.2s, transform 0.2s;
|
||||
}
|
||||
|
||||
.sidebar-logo img:hover {
|
||||
opacity: 1;
|
||||
transform: scale(1.05);
|
||||
}
|
||||
|
||||
.sidebar-header h1 {
|
||||
font-size: 14px;
|
||||
font-weight: 700;
|
||||
color: var(--accent);
|
||||
letter-spacing: 3px;
|
||||
font-family: var(--font-mono);
|
||||
}
|
||||
|
||||
.sidebar-header .version {
|
||||
font-size: 9px;
|
||||
color: var(--text-muted);
|
||||
margin-top: 1px;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.sidebar-status {
|
||||
font-size: 11px;
|
||||
color: var(--success);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
padding: 8px 16px;
|
||||
border-bottom: 1px solid var(--border);
|
||||
}
|
||||
|
||||
.sidebar-status.offline { color: var(--error); }
|
||||
|
||||
.status-dot {
|
||||
width: 6px; height: 6px;
|
||||
border-radius: 50%;
|
||||
background: currentColor;
|
||||
flex-shrink: 0;
|
||||
box-shadow: 0 0 6px currentColor;
|
||||
}
|
||||
|
||||
.conn-badge {
|
||||
font-size: 9px;
|
||||
padding: 1px 5px;
|
||||
border-radius: 3px;
|
||||
font-weight: 600;
|
||||
letter-spacing: 0.5px;
|
||||
margin-left: auto;
|
||||
}
|
||||
.conn-badge.ws { background: var(--success); color: #000; }
|
||||
.conn-badge.http { background: var(--warning); color: #000; }
|
||||
|
||||
/* Navigation */
|
||||
.sidebar-nav {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: 8px;
|
||||
scrollbar-width: none;
|
||||
}
|
||||
.sidebar-nav::-webkit-scrollbar { width: 0; }
|
||||
|
||||
.nav-section {
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.nav-section-title {
|
||||
font-size: 9px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 1.5px;
|
||||
color: var(--text-muted);
|
||||
padding: 12px 12px 4px;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.sidebar.collapsed .nav-section-title { display: none; }
|
||||
|
||||
.nav-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
padding: 9px 12px;
|
||||
border-radius: var(--radius-md);
|
||||
cursor: pointer;
|
||||
font-size: 13px;
|
||||
color: var(--text-dim);
|
||||
transition: all var(--transition-fast);
|
||||
text-decoration: none;
|
||||
border: 1px solid transparent;
|
||||
white-space: nowrap;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.nav-item:hover {
|
||||
background: var(--surface2);
|
||||
color: var(--text);
|
||||
transform: translateX(2px);
|
||||
}
|
||||
|
||||
.nav-item.active {
|
||||
background: var(--accent);
|
||||
color: var(--bg-primary);
|
||||
font-weight: 600;
|
||||
box-shadow: var(--shadow-sm), 0 2px 8px rgba(255, 92, 0, 0.2);
|
||||
}
|
||||
|
||||
.nav-icon {
|
||||
width: 18px;
|
||||
height: 18px;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.nav-icon svg {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
fill: none;
|
||||
stroke: currentColor;
|
||||
stroke-width: 2;
|
||||
stroke-linecap: round;
|
||||
stroke-linejoin: round;
|
||||
}
|
||||
|
||||
/* Sidebar toggle button */
|
||||
.sidebar-toggle {
|
||||
padding: 10px 16px;
|
||||
border-top: 1px solid var(--border);
|
||||
cursor: pointer;
|
||||
text-align: center;
|
||||
font-size: 14px;
|
||||
color: var(--text-muted);
|
||||
transition: color var(--transition-fast);
|
||||
}
|
||||
|
||||
.sidebar-toggle:hover { color: var(--text); }
|
||||
|
||||
/* Main content area */
|
||||
.main-content {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
min-width: 0;
|
||||
overflow: hidden;
|
||||
background: var(--bg);
|
||||
}
|
||||
|
||||
/* Page wrapper divs (rendered by x-if) must fill the column
|
||||
and be flex containers so .page-body can scroll. */
|
||||
.main-content > div {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.page-header {
|
||||
padding: 14px 24px;
|
||||
border-bottom: 1px solid var(--border);
|
||||
background: var(--bg-primary);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
min-height: var(--header-height);
|
||||
}
|
||||
|
||||
.page-header h2 {
|
||||
font-size: 15px;
|
||||
font-weight: 600;
|
||||
letter-spacing: -0.01em;
|
||||
}
|
||||
|
||||
.page-body {
|
||||
flex: 1;
|
||||
min-height: 0;
|
||||
overflow-y: auto;
|
||||
padding: 24px;
|
||||
}
|
||||
|
||||
/* Mobile overlay */
|
||||
.sidebar-overlay {
|
||||
display: none;
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
background: rgba(0,0,0,0.6);
|
||||
z-index: 99;
|
||||
}
|
||||
|
||||
/* Wide desktop — larger card grids */
|
||||
@media (min-width: 1400px) {
|
||||
.card-grid { grid-template-columns: repeat(auto-fill, minmax(320px, 1fr)); }
|
||||
}
|
||||
|
||||
/* Responsive — tablet breakpoint */
|
||||
@media (max-width: 1024px) {
|
||||
.card-grid { grid-template-columns: repeat(auto-fill, minmax(240px, 1fr)); }
|
||||
.security-grid { grid-template-columns: 1fr; }
|
||||
.cost-charts-row { grid-template-columns: 1fr; }
|
||||
.overview-grid { grid-template-columns: repeat(auto-fill, minmax(240px, 1fr)); }
|
||||
.page-body { padding: 16px; }
|
||||
}
|
||||
|
||||
/* Responsive — mobile breakpoint */
|
||||
@media (max-width: 768px) {
|
||||
.sidebar {
|
||||
position: fixed;
|
||||
left: -300px;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
transition: left var(--transition-normal);
|
||||
}
|
||||
.sidebar.mobile-open {
|
||||
left: 0;
|
||||
}
|
||||
.sidebar.mobile-open + .sidebar-overlay {
|
||||
display: block;
|
||||
}
|
||||
.sidebar.collapsed {
|
||||
width: var(--sidebar-width);
|
||||
left: -300px;
|
||||
}
|
||||
.mobile-menu-btn { display: flex !important; }
|
||||
}
|
||||
|
||||
@media (min-width: 769px) {
|
||||
.mobile-menu-btn { display: none !important; }
|
||||
}
|
||||
|
||||
/* Mobile small screen */
|
||||
@media (max-width: 480px) {
|
||||
.page-header { flex-direction: column; gap: 8px; align-items: flex-start; padding: 12px 16px; }
|
||||
.page-body { padding: 12px; }
|
||||
.stats-row { flex-wrap: wrap; }
|
||||
.stat-card { min-width: 80px; flex: 1 1 40%; }
|
||||
.stat-card-lg { min-width: 80px; flex: 1 1 40%; padding: 12px; }
|
||||
.stat-card-lg .stat-value { font-size: 22px; }
|
||||
.card-grid { grid-template-columns: 1fr; }
|
||||
.overview-grid { grid-template-columns: 1fr; }
|
||||
.input-area { padding: 8px 12px; }
|
||||
.main-content { padding: 0; }
|
||||
.table-wrap { font-size: 10px; }
|
||||
.modal { margin: 8px; max-height: calc(100vh - 16px); }
|
||||
}
|
||||
|
||||
/* Touch-friendly tap targets */
|
||||
@media (pointer: coarse) {
|
||||
.btn { min-height: 44px; min-width: 44px; }
|
||||
.nav-item { min-height: 44px; }
|
||||
.form-input, .form-select, .form-textarea { min-height: 44px; }
|
||||
.toggle { min-width: 44px; min-height: 28px; }
|
||||
}
|
||||
|
||||
/* Focus mode — hide sidebar for distraction-free chat */
|
||||
.app-layout.focus-mode .sidebar { display: none; }
|
||||
.app-layout.focus-mode .sidebar-overlay { display: none; }
|
||||
.app-layout.focus-mode .main-content { max-width: 100%; margin-left: 0; }
|
||||
.app-layout.focus-mode .mobile-menu-btn { display: none !important; }
|
||||
276
crates/openfang-api/static/css/theme.css
Normal file
276
crates/openfang-api/static/css/theme.css
Normal file
@@ -0,0 +1,276 @@
|
||||
/* OpenFang Theme — Premium design system */
|
||||
|
||||
/* Font imports in index_head.html: Inter (body) + Geist Mono (code) */
|
||||
|
||||
[data-theme="light"], :root {
|
||||
/* Backgrounds — layered depth */
|
||||
--bg: #F5F4F2;
|
||||
--bg-primary: #EDECEB;
|
||||
--bg-elevated: #F8F7F6;
|
||||
--surface: #FFFFFF;
|
||||
--surface2: #F0EEEC;
|
||||
--surface3: #E8E6E3;
|
||||
--border: #D5D2CF;
|
||||
--border-light: #C8C4C0;
|
||||
--border-subtle: #E0DEDA;
|
||||
|
||||
/* Text hierarchy */
|
||||
--text: #1A1817;
|
||||
--text-secondary: #3D3935;
|
||||
--text-dim: #6B6560;
|
||||
--text-muted: #9A958F;
|
||||
|
||||
/* Brand — Orange accent */
|
||||
--accent: #FF5C00;
|
||||
--accent-light: #FF7A2E;
|
||||
--accent-dim: #E05200;
|
||||
--accent-glow: rgba(255, 92, 0, 0.1);
|
||||
--accent-subtle: rgba(255, 92, 0, 0.05);
|
||||
|
||||
/* Status colors */
|
||||
--success: #22C55E;
|
||||
--success-dim: #16A34A;
|
||||
--success-subtle: rgba(34, 197, 94, 0.08);
|
||||
--error: #EF4444;
|
||||
--error-dim: #DC2626;
|
||||
--error-subtle: rgba(239, 68, 68, 0.06);
|
||||
--warning: #F59E0B;
|
||||
--warning-dim: #D97706;
|
||||
--warning-subtle: rgba(245, 158, 11, 0.08);
|
||||
--info: #3B82F6;
|
||||
--info-dim: #2563EB;
|
||||
--info-subtle: rgba(59, 130, 246, 0.06);
|
||||
--success-muted: rgba(34, 197, 94, 0.15);
|
||||
--error-muted: rgba(239, 68, 68, 0.15);
|
||||
--warning-muted: rgba(245, 158, 11, 0.15);
|
||||
--info-muted: rgba(59, 130, 246, 0.15);
|
||||
--border-strong: #B0ACA8;
|
||||
--card-highlight: rgba(0, 0, 0, 0.02);
|
||||
|
||||
/* Chat-specific */
|
||||
--agent-bg: #F5F4F2;
|
||||
--user-bg: #FFF3E6;
|
||||
|
||||
/* Layout */
|
||||
--sidebar-width: 240px;
|
||||
--sidebar-collapsed: 56px;
|
||||
--header-height: 48px;
|
||||
|
||||
/* Radius — slightly larger for premium feel */
|
||||
--radius-xs: 4px;
|
||||
--radius-sm: 6px;
|
||||
--radius-md: 8px;
|
||||
--radius-lg: 12px;
|
||||
--radius-xl: 16px;
|
||||
|
||||
/* Shadows — 6-level depth system */
|
||||
--shadow-xs: 0 1px 2px rgba(0,0,0,0.04);
|
||||
--shadow-sm: 0 1px 3px rgba(0,0,0,0.06), 0 1px 2px rgba(0,0,0,0.04);
|
||||
--shadow-md: 0 4px 12px rgba(0,0,0,0.07), 0 2px 4px rgba(0,0,0,0.04);
|
||||
--shadow-lg: 0 12px 28px rgba(0,0,0,0.08), 0 4px 10px rgba(0,0,0,0.05);
|
||||
--shadow-xl: 0 20px 40px rgba(0,0,0,0.1), 0 8px 16px rgba(0,0,0,0.06);
|
||||
--shadow-glow: 0 0 40px rgba(0,0,0,0.05);
|
||||
--shadow-accent: 0 4px 16px rgba(255, 92, 0, 0.12);
|
||||
--shadow-inset: inset 0 1px 0 rgba(255,255,255,0.5);
|
||||
|
||||
/* Typography — dual font system */
|
||||
--font-sans: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', system-ui, sans-serif;
|
||||
--font-mono: 'Geist Mono', 'SF Mono', 'Fira Code', 'Cascadia Code', 'JetBrains Mono', monospace;
|
||||
|
||||
/* Motion — spring curves for premium feel */
|
||||
--ease-spring: cubic-bezier(0.34, 1.56, 0.64, 1);
|
||||
--ease-smooth: cubic-bezier(0.4, 0, 0.2, 1);
|
||||
--ease-out: cubic-bezier(0, 0, 0.2, 1);
|
||||
--ease-in: cubic-bezier(0.4, 0, 1, 1);
|
||||
--transition-fast: 0.15s var(--ease-smooth);
|
||||
--transition-normal: 0.25s var(--ease-smooth);
|
||||
--transition-spring: 0.4s var(--ease-spring);
|
||||
}
|
||||
|
||||
[data-theme="dark"] {
|
||||
--bg: #080706;
|
||||
--bg-primary: #0F0E0E;
|
||||
--bg-elevated: #161413;
|
||||
--surface: #1F1D1C;
|
||||
--surface2: #2A2725;
|
||||
--surface3: #1A1817;
|
||||
--border: #2D2A28;
|
||||
--border-light: #3D3A38;
|
||||
--border-subtle: #232120;
|
||||
--text: #F0EFEE;
|
||||
--text-secondary: #C4C0BC;
|
||||
--text-dim: #8A8380;
|
||||
--text-muted: #5C5754;
|
||||
--accent: #FF5C00;
|
||||
--accent-light: #FF7A2E;
|
||||
--accent-dim: #E05200;
|
||||
--accent-glow: rgba(255, 92, 0, 0.15);
|
||||
--accent-subtle: rgba(255, 92, 0, 0.08);
|
||||
--success: #4ADE80;
|
||||
--success-dim: #22C55E;
|
||||
--success-subtle: rgba(74, 222, 128, 0.1);
|
||||
--error: #EF4444;
|
||||
--error-dim: #B91C1C;
|
||||
--error-subtle: rgba(239, 68, 68, 0.1);
|
||||
--warning: #F59E0B;
|
||||
--warning-dim: #D97706;
|
||||
--warning-subtle: rgba(245, 158, 11, 0.1);
|
||||
--info: #3B82F6;
|
||||
--info-dim: #2563EB;
|
||||
--info-subtle: rgba(59, 130, 246, 0.1);
|
||||
--success-muted: rgba(74, 222, 128, 0.25);
|
||||
--error-muted: rgba(239, 68, 68, 0.25);
|
||||
--warning-muted: rgba(245, 158, 11, 0.25);
|
||||
--info-muted: rgba(59, 130, 246, 0.25);
|
||||
--border-strong: #4A4644;
|
||||
--card-highlight: rgba(255, 255, 255, 0.04);
|
||||
--agent-bg: #1A1817;
|
||||
--user-bg: #2A1A08;
|
||||
--shadow-xs: 0 1px 2px rgba(0,0,0,0.3);
|
||||
--shadow-sm: 0 1px 3px rgba(0,0,0,0.4), 0 1px 2px rgba(0,0,0,0.3);
|
||||
--shadow-md: 0 4px 12px rgba(0,0,0,0.4), 0 2px 4px rgba(0,0,0,0.3);
|
||||
--shadow-lg: 0 12px 28px rgba(0,0,0,0.35), 0 4px 10px rgba(0,0,0,0.3);
|
||||
--shadow-xl: 0 20px 40px rgba(0,0,0,0.4), 0 8px 16px rgba(0,0,0,0.3);
|
||||
--shadow-glow: 0 0 80px rgba(0,0,0,0.6);
|
||||
--shadow-accent: 0 4px 16px rgba(255, 92, 0, 0.2);
|
||||
--shadow-inset: inset 0 1px 0 rgba(255,255,255,0.03);
|
||||
}
|
||||
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
|
||||
html { scroll-behavior: smooth; }
|
||||
|
||||
body {
|
||||
font-family: var(--font-sans);
|
||||
background: var(--bg);
|
||||
color: var(--text);
|
||||
height: 100vh;
|
||||
overflow: hidden;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-moz-osx-font-smoothing: grayscale;
|
||||
font-size: 14px;
|
||||
line-height: 1.5;
|
||||
letter-spacing: -0.01em;
|
||||
}
|
||||
|
||||
/* Mono text utility — only for code/data */
|
||||
.font-mono, code, pre, .tool-pre, .tool-card-name, .detail-value,
|
||||
.stat-value, .conn-badge, .version { font-family: var(--font-mono); }
|
||||
|
||||
/* Scrollbar — Webkit (Chrome, Edge, Safari) */
|
||||
::-webkit-scrollbar { width: 6px; height: 6px; }
|
||||
::-webkit-scrollbar-track { background: transparent; }
|
||||
::-webkit-scrollbar-thumb { background: var(--border); border-radius: 3px; }
|
||||
::-webkit-scrollbar-thumb:hover { background: var(--border-light); }
|
||||
|
||||
/* Scrollbar — Firefox */
|
||||
* {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: var(--border) transparent;
|
||||
}
|
||||
|
||||
::selection {
|
||||
background: var(--accent);
|
||||
color: var(--bg-primary);
|
||||
}
|
||||
|
||||
/* Theme transition — smooth switch between light/dark */
|
||||
body {
|
||||
transition: background-color 0.3s ease, color 0.3s ease;
|
||||
}
|
||||
.sidebar, .main-content, .card, .modal, .tool-card, .toast, .page-header {
|
||||
transition: background-color 0.3s ease, border-color 0.3s ease, color 0.3s ease, box-shadow 0.3s ease;
|
||||
}
|
||||
|
||||
/* Tighter letter spacing for headings */
|
||||
h1, h2, h3, .card-header, .stat-value, .page-header h2 { letter-spacing: -0.02em; }
|
||||
.nav-section-title, .badge, th { letter-spacing: 0.04em; }
|
||||
|
||||
/* Focus styles — accessible double-ring with glow */
|
||||
:focus-visible {
|
||||
outline: 2px solid var(--accent);
|
||||
outline-offset: 2px;
|
||||
box-shadow: 0 0 0 4px var(--accent-glow);
|
||||
}
|
||||
button:focus-visible, a:focus-visible, input:focus-visible, select:focus-visible, textarea:focus-visible {
|
||||
outline: 2px solid var(--accent);
|
||||
outline-offset: 2px;
|
||||
box-shadow: 0 0 0 4px var(--accent-glow);
|
||||
}
|
||||
|
||||
/* Entrance animations */
|
||||
@keyframes fadeIn {
|
||||
from { opacity: 0; }
|
||||
to { opacity: 1; }
|
||||
}
|
||||
|
||||
@keyframes slideUp {
|
||||
from { opacity: 0; transform: translateY(8px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
@keyframes slideDown {
|
||||
from { opacity: 0; transform: translateY(-8px); }
|
||||
to { opacity: 1; transform: translateY(0); }
|
||||
}
|
||||
|
||||
@keyframes scaleIn {
|
||||
from { opacity: 0; transform: scale(0.95); }
|
||||
to { opacity: 1; transform: scale(1); }
|
||||
}
|
||||
|
||||
@keyframes shimmer {
|
||||
0% { background-position: -200% 0; }
|
||||
100% { background-position: 200% 0; }
|
||||
}
|
||||
|
||||
@keyframes pulse-ring {
|
||||
0% { box-shadow: 0 0 0 0 currentColor; }
|
||||
70% { box-shadow: 0 0 0 4px transparent; }
|
||||
100% { box-shadow: 0 0 0 0 transparent; }
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
to { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
/* Staggered card entry animation */
|
||||
@keyframes cardEntry {
|
||||
from { opacity: 0; transform: translateY(12px) scale(0.98); }
|
||||
to { opacity: 1; transform: translateY(0) scale(1); }
|
||||
}
|
||||
.animate-entry { animation: cardEntry 0.35s var(--ease-spring) both; }
|
||||
.stagger-1 { animation-delay: 0.05s; }
|
||||
.stagger-2 { animation-delay: 0.10s; }
|
||||
.stagger-3 { animation-delay: 0.15s; }
|
||||
.stagger-4 { animation-delay: 0.20s; }
|
||||
.stagger-5 { animation-delay: 0.25s; }
|
||||
.stagger-6 { animation-delay: 0.30s; }
|
||||
|
||||
/* Skeleton loading animation */
|
||||
.skeleton {
|
||||
background: linear-gradient(90deg, var(--surface) 25%, var(--surface2) 50%, var(--surface) 75%);
|
||||
background-size: 200% 100%;
|
||||
animation: shimmer 1.5s ease-in-out infinite;
|
||||
border-radius: var(--radius-sm);
|
||||
}
|
||||
|
||||
.skeleton-text { height: 14px; margin-bottom: 8px; }
|
||||
.skeleton-text:last-child { width: 60%; }
|
||||
.skeleton-heading { height: 20px; width: 40%; margin-bottom: 12px; }
|
||||
.skeleton-card { height: 100px; border-radius: var(--radius-lg); }
|
||||
.skeleton-avatar { width: 32px; height: 32px; border-radius: 50%; }
|
||||
|
||||
/* Print styles */
|
||||
@media print {
|
||||
.sidebar, .sidebar-overlay, .mobile-menu-btn, .toast-container, .btn { display: none !important; }
|
||||
.main-content { margin: 0; max-width: 100%; }
|
||||
body { background: #fff; color: #000; }
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
*, *::before, *::after {
|
||||
animation-duration: 0.01ms !important;
|
||||
transition-duration: 0.01ms !important;
|
||||
}
|
||||
}
|
||||
BIN
crates/openfang-api/static/favicon.ico
Normal file
BIN
crates/openfang-api/static/favicon.ico
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.0 KiB |
4391
crates/openfang-api/static/index_body.html
Normal file
4391
crates/openfang-api/static/index_body.html
Normal file
File diff suppressed because it is too large
Load Diff
12
crates/openfang-api/static/index_head.html
Normal file
12
crates/openfang-api/static/index_head.html
Normal file
@@ -0,0 +1,12 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1">
|
||||
<title>OpenFang Dashboard</title>
|
||||
<link rel="icon" type="image/x-icon" href="/favicon.ico">
|
||||
<link rel="icon" type="image/png" href="/logo.png">
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Geist+Mono:wght@400;500;600;700&display=swap" rel="stylesheet">
|
||||
</head>
|
||||
321
crates/openfang-api/static/js/api.js
Normal file
321
crates/openfang-api/static/js/api.js
Normal file
@@ -0,0 +1,321 @@
|
||||
// OpenFang API Client — Fetch wrapper, WebSocket manager, auth injection, toast notifications
|
||||
'use strict';
|
||||
|
||||
// ── Toast Notification System ──
|
||||
var OpenFangToast = (function() {
|
||||
var _container = null;
|
||||
var _toastId = 0;
|
||||
|
||||
function getContainer() {
|
||||
if (!_container) {
|
||||
_container = document.getElementById('toast-container');
|
||||
if (!_container) {
|
||||
_container = document.createElement('div');
|
||||
_container.id = 'toast-container';
|
||||
_container.className = 'toast-container';
|
||||
document.body.appendChild(_container);
|
||||
}
|
||||
}
|
||||
return _container;
|
||||
}
|
||||
|
||||
function toast(message, type, duration) {
|
||||
type = type || 'info';
|
||||
duration = duration || 4000;
|
||||
var id = ++_toastId;
|
||||
var el = document.createElement('div');
|
||||
el.className = 'toast toast-' + type;
|
||||
el.setAttribute('data-toast-id', id);
|
||||
|
||||
var msgSpan = document.createElement('span');
|
||||
msgSpan.className = 'toast-msg';
|
||||
msgSpan.textContent = message;
|
||||
el.appendChild(msgSpan);
|
||||
|
||||
var closeBtn = document.createElement('button');
|
||||
closeBtn.className = 'toast-close';
|
||||
closeBtn.textContent = '\u00D7';
|
||||
closeBtn.onclick = function() { dismissToast(el); };
|
||||
el.appendChild(closeBtn);
|
||||
|
||||
el.onclick = function(e) { if (e.target === el) dismissToast(el); };
|
||||
getContainer().appendChild(el);
|
||||
|
||||
// Auto-dismiss
|
||||
if (duration > 0) {
|
||||
setTimeout(function() { dismissToast(el); }, duration);
|
||||
}
|
||||
return id;
|
||||
}
|
||||
|
||||
function dismissToast(el) {
|
||||
if (!el || el.classList.contains('toast-dismiss')) return;
|
||||
el.classList.add('toast-dismiss');
|
||||
setTimeout(function() { if (el.parentNode) el.parentNode.removeChild(el); }, 300);
|
||||
}
|
||||
|
||||
function success(msg, duration) { return toast(msg, 'success', duration); }
|
||||
function error(msg, duration) { return toast(msg, 'error', duration || 6000); }
|
||||
function warn(msg, duration) { return toast(msg, 'warn', duration || 5000); }
|
||||
function info(msg, duration) { return toast(msg, 'info', duration); }
|
||||
|
||||
// Styled confirmation modal — replaces native confirm()
|
||||
function confirm(title, message, onConfirm) {
|
||||
var overlay = document.createElement('div');
|
||||
overlay.className = 'confirm-overlay';
|
||||
|
||||
var modal = document.createElement('div');
|
||||
modal.className = 'confirm-modal';
|
||||
|
||||
var titleEl = document.createElement('div');
|
||||
titleEl.className = 'confirm-title';
|
||||
titleEl.textContent = title;
|
||||
modal.appendChild(titleEl);
|
||||
|
||||
var msgEl = document.createElement('div');
|
||||
msgEl.className = 'confirm-message';
|
||||
msgEl.textContent = message;
|
||||
modal.appendChild(msgEl);
|
||||
|
||||
var actions = document.createElement('div');
|
||||
actions.className = 'confirm-actions';
|
||||
|
||||
var cancelBtn = document.createElement('button');
|
||||
cancelBtn.className = 'btn btn-ghost confirm-cancel';
|
||||
cancelBtn.textContent = 'Cancel';
|
||||
actions.appendChild(cancelBtn);
|
||||
|
||||
var okBtn = document.createElement('button');
|
||||
okBtn.className = 'btn btn-danger confirm-ok';
|
||||
okBtn.textContent = 'Confirm';
|
||||
actions.appendChild(okBtn);
|
||||
|
||||
modal.appendChild(actions);
|
||||
overlay.appendChild(modal);
|
||||
|
||||
function close() { if (overlay.parentNode) overlay.parentNode.removeChild(overlay); document.removeEventListener('keydown', onKey); }
|
||||
cancelBtn.onclick = close;
|
||||
okBtn.onclick = function() { close(); if (onConfirm) onConfirm(); };
|
||||
overlay.addEventListener('click', function(e) { if (e.target === overlay) close(); });
|
||||
|
||||
function onKey(e) { if (e.key === 'Escape') close(); }
|
||||
document.addEventListener('keydown', onKey);
|
||||
|
||||
document.body.appendChild(overlay);
|
||||
okBtn.focus();
|
||||
}
|
||||
|
||||
return {
|
||||
toast: toast,
|
||||
success: success,
|
||||
error: error,
|
||||
warn: warn,
|
||||
info: info,
|
||||
confirm: confirm
|
||||
};
|
||||
})();
|
||||
|
||||
// ── Friendly Error Messages ──
|
||||
function friendlyError(status, serverMsg) {
|
||||
if (status === 0 || !status) return 'Cannot reach daemon — is openfang running?';
|
||||
if (status === 401) return 'Not authorized — check your API key';
|
||||
if (status === 403) return 'Permission denied';
|
||||
if (status === 404) return serverMsg || 'Resource not found';
|
||||
if (status === 429) return 'Rate limited — slow down and try again';
|
||||
if (status === 413) return 'Request too large';
|
||||
if (status === 500) return 'Server error — check daemon logs';
|
||||
if (status === 502 || status === 503) return 'Daemon unavailable — is it running?';
|
||||
return serverMsg || 'Unexpected error (' + status + ')';
|
||||
}
|
||||
|
||||
// ── API Client ──
|
||||
var OpenFangAPI = (function() {
|
||||
var BASE = window.location.origin;
|
||||
var WS_BASE = BASE.replace(/^http/, 'ws');
|
||||
var _authToken = '';
|
||||
|
||||
// Connection state tracking
|
||||
var _connectionState = 'connected';
|
||||
var _reconnectAttempt = 0;
|
||||
var _connectionListeners = [];
|
||||
|
||||
function setAuthToken(token) { _authToken = token; }
|
||||
|
||||
function headers() {
|
||||
var h = { 'Content-Type': 'application/json' };
|
||||
if (_authToken) h['Authorization'] = 'Bearer ' + _authToken;
|
||||
return h;
|
||||
}
|
||||
|
||||
function setConnectionState(state) {
|
||||
if (_connectionState === state) return;
|
||||
_connectionState = state;
|
||||
_connectionListeners.forEach(function(fn) { fn(state); });
|
||||
}
|
||||
|
||||
function onConnectionChange(fn) { _connectionListeners.push(fn); }
|
||||
|
||||
function request(method, path, body) {
|
||||
var opts = { method: method, headers: headers() };
|
||||
if (body !== undefined) opts.body = JSON.stringify(body);
|
||||
return fetch(BASE + path, opts).then(function(r) {
|
||||
if (_connectionState !== 'connected') setConnectionState('connected');
|
||||
if (!r.ok) {
|
||||
return r.text().then(function(text) {
|
||||
var msg = '';
|
||||
try {
|
||||
var json = JSON.parse(text);
|
||||
msg = json.error || r.statusText;
|
||||
} catch(e) {
|
||||
msg = r.statusText;
|
||||
}
|
||||
throw new Error(friendlyError(r.status, msg));
|
||||
});
|
||||
}
|
||||
var ct = r.headers.get('content-type') || '';
|
||||
if (ct.indexOf('application/json') >= 0) return r.json();
|
||||
return r.text().then(function(t) {
|
||||
try { return JSON.parse(t); } catch(e) { return { text: t }; }
|
||||
});
|
||||
}).catch(function(e) {
|
||||
if (e.name === 'TypeError' && e.message.includes('Failed to fetch')) {
|
||||
setConnectionState('disconnected');
|
||||
throw new Error('Cannot connect to daemon — is openfang running?');
|
||||
}
|
||||
throw e;
|
||||
});
|
||||
}
|
||||
|
||||
function get(path) { return request('GET', path); }
|
||||
function post(path, body) { return request('POST', path, body); }
|
||||
function put(path, body) { return request('PUT', path, body); }
|
||||
function patch(path, body) { return request('PATCH', path, body); }
|
||||
function del(path) { return request('DELETE', path); }
|
||||
|
||||
// WebSocket manager with auto-reconnect
|
||||
var _ws = null;
|
||||
var _wsCallbacks = {};
|
||||
var _wsConnected = false;
|
||||
var _wsAgentId = null;
|
||||
var _reconnectTimer = null;
|
||||
var _reconnectAttempts = 0;
|
||||
var MAX_RECONNECT = 5;
|
||||
|
||||
function wsConnect(agentId, callbacks) {
|
||||
wsDisconnect();
|
||||
_wsCallbacks = callbacks || {};
|
||||
_wsAgentId = agentId;
|
||||
_reconnectAttempts = 0;
|
||||
_doConnect(agentId);
|
||||
}
|
||||
|
||||
function _doConnect(agentId) {
|
||||
try {
|
||||
var url = WS_BASE + '/api/agents/' + agentId + '/ws';
|
||||
if (_authToken) url += '?token=' + encodeURIComponent(_authToken);
|
||||
_ws = new WebSocket(url);
|
||||
|
||||
_ws.onopen = function() {
|
||||
_wsConnected = true;
|
||||
_reconnectAttempts = 0;
|
||||
setConnectionState('connected');
|
||||
if (_reconnectAttempt > 0) {
|
||||
OpenFangToast.success('Reconnected');
|
||||
_reconnectAttempt = 0;
|
||||
}
|
||||
if (_wsCallbacks.onOpen) _wsCallbacks.onOpen();
|
||||
};
|
||||
|
||||
_ws.onmessage = function(e) {
|
||||
try {
|
||||
var data = JSON.parse(e.data);
|
||||
if (_wsCallbacks.onMessage) _wsCallbacks.onMessage(data);
|
||||
} catch(err) { /* ignore parse errors */ }
|
||||
};
|
||||
|
||||
_ws.onclose = function(e) {
|
||||
_wsConnected = false;
|
||||
_ws = null;
|
||||
if (_wsAgentId && _reconnectAttempts < MAX_RECONNECT && e.code !== 1000) {
|
||||
_reconnectAttempts++;
|
||||
_reconnectAttempt = _reconnectAttempts;
|
||||
setConnectionState('reconnecting');
|
||||
if (_reconnectAttempts === 1) {
|
||||
OpenFangToast.warn('Connection lost, reconnecting...');
|
||||
}
|
||||
var delay = Math.min(1000 * Math.pow(2, _reconnectAttempts - 1), 10000);
|
||||
_reconnectTimer = setTimeout(function() { _doConnect(_wsAgentId); }, delay);
|
||||
return;
|
||||
}
|
||||
if (_wsAgentId && _reconnectAttempts >= MAX_RECONNECT) {
|
||||
setConnectionState('disconnected');
|
||||
OpenFangToast.error('Connection lost — switched to HTTP mode', 0);
|
||||
}
|
||||
if (_wsCallbacks.onClose) _wsCallbacks.onClose();
|
||||
};
|
||||
|
||||
_ws.onerror = function() {
|
||||
_wsConnected = false;
|
||||
if (_wsCallbacks.onError) _wsCallbacks.onError();
|
||||
};
|
||||
} catch(e) {
|
||||
_wsConnected = false;
|
||||
}
|
||||
}
|
||||
|
||||
function wsDisconnect() {
|
||||
_wsAgentId = null;
|
||||
_reconnectAttempts = MAX_RECONNECT;
|
||||
if (_reconnectTimer) { clearTimeout(_reconnectTimer); _reconnectTimer = null; }
|
||||
if (_ws) { _ws.close(1000); _ws = null; }
|
||||
_wsConnected = false;
|
||||
}
|
||||
|
||||
function wsSend(data) {
|
||||
if (_ws && _ws.readyState === WebSocket.OPEN) {
|
||||
_ws.send(JSON.stringify(data));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function isWsConnected() { return _wsConnected; }
|
||||
|
||||
function getConnectionState() { return _connectionState; }
|
||||
|
||||
function getToken() { return _authToken; }
|
||||
|
||||
function upload(agentId, file) {
|
||||
var hdrs = {
|
||||
'Content-Type': file.type || 'application/octet-stream',
|
||||
'X-Filename': file.name
|
||||
};
|
||||
if (_authToken) hdrs['Authorization'] = 'Bearer ' + _authToken;
|
||||
return fetch(BASE + '/api/agents/' + agentId + '/upload', {
|
||||
method: 'POST',
|
||||
headers: hdrs,
|
||||
body: file
|
||||
}).then(function(r) {
|
||||
if (!r.ok) throw new Error('Upload failed');
|
||||
return r.json();
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
setAuthToken: setAuthToken,
|
||||
getToken: getToken,
|
||||
get: get,
|
||||
post: post,
|
||||
put: put,
|
||||
patch: patch,
|
||||
del: del,
|
||||
delete: del,
|
||||
upload: upload,
|
||||
wsConnect: wsConnect,
|
||||
wsDisconnect: wsDisconnect,
|
||||
wsSend: wsSend,
|
||||
isWsConnected: isWsConnected,
|
||||
getConnectionState: getConnectionState,
|
||||
onConnectionChange: onConnectionChange
|
||||
};
|
||||
})();
|
||||
319
crates/openfang-api/static/js/app.js
Normal file
319
crates/openfang-api/static/js/app.js
Normal file
@@ -0,0 +1,319 @@
|
||||
// OpenFang App — Alpine.js init, hash router, global store
|
||||
'use strict';
|
||||
|
||||
// Marked.js configuration
|
||||
if (typeof marked !== 'undefined') {
|
||||
marked.setOptions({
|
||||
breaks: true,
|
||||
gfm: true,
|
||||
highlight: function(code, lang) {
|
||||
if (typeof hljs !== 'undefined' && lang && hljs.getLanguage(lang)) {
|
||||
try { return hljs.highlight(code, { language: lang }).value; } catch(e) {}
|
||||
}
|
||||
return code;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function escapeHtml(text) {
|
||||
var div = document.createElement('div');
|
||||
div.textContent = text || '';
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
function renderMarkdown(text) {
|
||||
if (!text) return '';
|
||||
if (typeof marked !== 'undefined') {
|
||||
var html = marked.parse(text);
|
||||
// Add copy buttons to code blocks
|
||||
html = html.replace(/<pre><code/g, '<pre><button class="copy-btn" onclick="copyCode(this)">Copy</button><code');
|
||||
return html;
|
||||
}
|
||||
return escapeHtml(text);
|
||||
}
|
||||
|
||||
function copyCode(btn) {
|
||||
var code = btn.nextElementSibling;
|
||||
if (code) {
|
||||
navigator.clipboard.writeText(code.textContent).then(function() {
|
||||
btn.textContent = 'Copied!';
|
||||
btn.classList.add('copied');
|
||||
setTimeout(function() { btn.textContent = 'Copy'; btn.classList.remove('copied'); }, 1500);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Tool category icon SVGs — returns inline SVG for each tool category
|
||||
function toolIcon(toolName) {
|
||||
if (!toolName) return '';
|
||||
var n = toolName.toLowerCase();
|
||||
var s = 'width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"';
|
||||
// File/directory operations
|
||||
if (n.indexOf('file_') === 0 || n.indexOf('directory_') === 0)
|
||||
return '<svg ' + s + '><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><path d="M14 2v6h6"/><path d="M16 13H8"/><path d="M16 17H8"/></svg>';
|
||||
// Web/fetch
|
||||
if (n.indexOf('web_') === 0 || n.indexOf('link_') === 0)
|
||||
return '<svg ' + s + '><circle cx="12" cy="12" r="10"/><path d="M2 12h20"/><path d="M12 2a15 15 0 0 1 4 10 15 15 0 0 1-4 10 15 15 0 0 1-4-10 15 15 0 0 1 4-10z"/></svg>';
|
||||
// Shell/exec
|
||||
if (n.indexOf('shell') === 0 || n.indexOf('exec_') === 0)
|
||||
return '<svg ' + s + '><polyline points="4 17 10 11 4 5"/><line x1="12" y1="19" x2="20" y2="19"/></svg>';
|
||||
// Agent operations
|
||||
if (n.indexOf('agent_') === 0)
|
||||
return '<svg ' + s + '><path d="M17 21v-2a4 4 0 0 0-4-4H5a4 4 0 0 0-4 4v2"/><circle cx="9" cy="7" r="4"/><path d="M23 21v-2a4 4 0 0 0-3-3.87"/><path d="M16 3.13a4 4 0 0 1 0 7.75"/></svg>';
|
||||
// Memory/knowledge
|
||||
if (n.indexOf('memory_') === 0 || n.indexOf('knowledge_') === 0)
|
||||
return '<svg ' + s + '><path d="M2 3h6a4 4 0 0 1 4 4v14a3 3 0 0 0-3-3H2z"/><path d="M22 3h-6a4 4 0 0 0-4 4v14a3 3 0 0 1 3-3h7z"/></svg>';
|
||||
// Cron/schedule
|
||||
if (n.indexOf('cron_') === 0 || n.indexOf('schedule_') === 0)
|
||||
return '<svg ' + s + '><circle cx="12" cy="12" r="10"/><polyline points="12 6 12 12 16 14"/></svg>';
|
||||
// Browser/playwright
|
||||
if (n.indexOf('browser_') === 0 || n.indexOf('playwright_') === 0)
|
||||
return '<svg ' + s + '><rect x="2" y="3" width="20" height="14" rx="2"/><path d="M8 21h8"/><path d="M12 17v4"/></svg>';
|
||||
// Container/docker
|
||||
if (n.indexOf('container_') === 0 || n.indexOf('docker_') === 0)
|
||||
return '<svg ' + s + '><path d="M22 12H2"/><path d="M5.45 5.11L2 12v6a2 2 0 0 0 2 2h16a2 2 0 0 0 2-2v-6l-3.45-6.89A2 2 0 0 0 16.76 4H7.24a2 2 0 0 0-1.79 1.11z"/></svg>';
|
||||
// Image/media
|
||||
if (n.indexOf('image_') === 0 || n.indexOf('tts_') === 0)
|
||||
return '<svg ' + s + '><rect x="3" y="3" width="18" height="18" rx="2"/><circle cx="8.5" cy="8.5" r="1.5"/><polyline points="21 15 16 10 5 21"/></svg>';
|
||||
// Hand tools
|
||||
if (n.indexOf('hand_') === 0)
|
||||
return '<svg ' + s + '><path d="M18 11V6a2 2 0 0 0-2-2 2 2 0 0 0-2 2"/><path d="M14 10V4a2 2 0 0 0-2-2 2 2 0 0 0-2 2v6"/><path d="M10 10.5V6a2 2 0 0 0-2-2 2 2 0 0 0-2 2v8"/><path d="M18 8a2 2 0 1 1 4 0v6a8 8 0 0 1-8 8h-2c-2.8 0-4.5-.9-5.7-2.4L3.4 16a2 2 0 0 1 3.2-2.4L8 15"/></svg>';
|
||||
// Task/collab
|
||||
if (n.indexOf('task_') === 0)
|
||||
return '<svg ' + s + '><path d="M9 11l3 3L22 4"/><path d="M21 12v7a2 2 0 01-2 2H5a2 2 0 01-2-2V5a2 2 0 012-2h11"/></svg>';
|
||||
// Default — wrench
|
||||
return '<svg ' + s + '><path d="M14.7 6.3a1 1 0 0 0 0 1.4l1.6 1.6a1 1 0 0 0 1.4 0l3.77-3.77a6 6 0 0 1-7.94 7.94l-6.91 6.91a2.12 2.12 0 0 1-3-3l6.91-6.91a6 6 0 0 1 7.94-7.94l-3.76 3.76z"/></svg>';
|
||||
}
|
||||
|
||||
// Alpine.js global store
|
||||
document.addEventListener('alpine:init', function() {
|
||||
// Restore saved API key on load
|
||||
var savedKey = localStorage.getItem('openfang-api-key');
|
||||
if (savedKey) OpenFangAPI.setAuthToken(savedKey);
|
||||
|
||||
Alpine.store('app', {
|
||||
agents: [],
|
||||
connected: false,
|
||||
booting: true,
|
||||
wsConnected: false,
|
||||
connectionState: 'connected',
|
||||
lastError: '',
|
||||
version: '0.1.0',
|
||||
agentCount: 0,
|
||||
pendingAgent: null,
|
||||
focusMode: localStorage.getItem('openfang-focus') === 'true',
|
||||
showOnboarding: false,
|
||||
showAuthPrompt: false,
|
||||
|
||||
toggleFocusMode() {
|
||||
this.focusMode = !this.focusMode;
|
||||
localStorage.setItem('openfang-focus', this.focusMode);
|
||||
},
|
||||
|
||||
async refreshAgents() {
|
||||
try {
|
||||
var agents = await OpenFangAPI.get('/api/agents');
|
||||
this.agents = Array.isArray(agents) ? agents : [];
|
||||
this.agentCount = this.agents.length;
|
||||
} catch(e) { /* silent */ }
|
||||
},
|
||||
|
||||
async checkStatus() {
|
||||
try {
|
||||
var s = await OpenFangAPI.get('/api/status');
|
||||
this.connected = true;
|
||||
this.booting = false;
|
||||
this.lastError = '';
|
||||
this.version = s.version || '0.1.0';
|
||||
this.agentCount = s.agent_count || 0;
|
||||
} catch(e) {
|
||||
this.connected = false;
|
||||
this.lastError = e.message || 'Unknown error';
|
||||
console.warn('[OpenFang] Status check failed:', e.message);
|
||||
}
|
||||
},
|
||||
|
||||
async checkOnboarding() {
|
||||
if (localStorage.getItem('openfang-onboarded')) return;
|
||||
try {
|
||||
var config = await OpenFangAPI.get('/api/config');
|
||||
var apiKey = config && config.api_key;
|
||||
var noKey = !apiKey || apiKey === 'not set' || apiKey === '';
|
||||
if (noKey && this.agentCount === 0) {
|
||||
this.showOnboarding = true;
|
||||
}
|
||||
} catch(e) {
|
||||
// If config endpoint fails, still show onboarding if no agents
|
||||
if (this.agentCount === 0) this.showOnboarding = true;
|
||||
}
|
||||
},
|
||||
|
||||
dismissOnboarding() {
|
||||
this.showOnboarding = false;
|
||||
localStorage.setItem('openfang-onboarded', 'true');
|
||||
},
|
||||
|
||||
async checkAuth() {
|
||||
try {
|
||||
// Use a protected endpoint (not in the public allowlist) to detect
|
||||
// whether the server requires an API key.
|
||||
await OpenFangAPI.get('/api/tools');
|
||||
this.showAuthPrompt = false;
|
||||
} catch(e) {
|
||||
if (e.message && (e.message.indexOf('Not authorized') >= 0 || e.message.indexOf('401') >= 0 || e.message.indexOf('Missing Authorization') >= 0 || e.message.indexOf('Unauthorized') >= 0)) {
|
||||
// Only show prompt if we don't already have a saved key
|
||||
var saved = localStorage.getItem('openfang-api-key');
|
||||
if (saved) {
|
||||
// Saved key might be stale — clear it and show prompt
|
||||
OpenFangAPI.setAuthToken('');
|
||||
localStorage.removeItem('openfang-api-key');
|
||||
}
|
||||
this.showAuthPrompt = true;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
submitApiKey(key) {
|
||||
if (!key || !key.trim()) return;
|
||||
OpenFangAPI.setAuthToken(key.trim());
|
||||
localStorage.setItem('openfang-api-key', key.trim());
|
||||
this.showAuthPrompt = false;
|
||||
this.refreshAgents();
|
||||
},
|
||||
|
||||
clearApiKey() {
|
||||
OpenFangAPI.setAuthToken('');
|
||||
localStorage.removeItem('openfang-api-key');
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Main app component
|
||||
function app() {
|
||||
return {
|
||||
page: 'agents',
|
||||
themeMode: localStorage.getItem('openfang-theme-mode') || 'system',
|
||||
theme: (() => {
|
||||
var mode = localStorage.getItem('openfang-theme-mode') || 'system';
|
||||
if (mode === 'system') return window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
|
||||
return mode;
|
||||
})(),
|
||||
sidebarCollapsed: localStorage.getItem('openfang-sidebar') === 'collapsed',
|
||||
mobileMenuOpen: false,
|
||||
connected: false,
|
||||
wsConnected: false,
|
||||
version: '0.1.0',
|
||||
agentCount: 0,
|
||||
|
||||
get agents() { return Alpine.store('app').agents; },
|
||||
|
||||
init() {
|
||||
var self = this;
|
||||
|
||||
// Listen for OS theme changes (only matters when mode is 'system')
|
||||
window.matchMedia('(prefers-color-scheme: dark)').addEventListener('change', function(e) {
|
||||
if (self.themeMode === 'system') {
|
||||
self.theme = e.matches ? 'dark' : 'light';
|
||||
}
|
||||
});
|
||||
|
||||
// Hash routing
|
||||
var validPages = ['overview','agents','sessions','approvals','workflows','scheduler','channels','skills','hands','analytics','logs','settings','wizard'];
|
||||
var pageRedirects = {
|
||||
'chat': 'agents',
|
||||
'templates': 'agents',
|
||||
'triggers': 'workflows',
|
||||
'cron': 'scheduler',
|
||||
'schedules': 'scheduler',
|
||||
'memory': 'sessions',
|
||||
'audit': 'logs',
|
||||
'security': 'settings',
|
||||
'peers': 'settings',
|
||||
'migration': 'settings',
|
||||
'usage': 'analytics',
|
||||
'approval': 'approvals'
|
||||
};
|
||||
function handleHash() {
|
||||
var hash = window.location.hash.replace('#', '') || 'agents';
|
||||
if (pageRedirects[hash]) {
|
||||
hash = pageRedirects[hash];
|
||||
window.location.hash = hash;
|
||||
}
|
||||
if (validPages.indexOf(hash) >= 0) self.page = hash;
|
||||
}
|
||||
window.addEventListener('hashchange', handleHash);
|
||||
handleHash();
|
||||
|
||||
// Keyboard shortcuts
|
||||
document.addEventListener('keydown', function(e) {
|
||||
// Ctrl+K — focus agent switch / go to agents
|
||||
if ((e.ctrlKey || e.metaKey) && e.key === 'k') {
|
||||
e.preventDefault();
|
||||
self.navigate('agents');
|
||||
}
|
||||
// Ctrl+N — new agent
|
||||
if ((e.ctrlKey || e.metaKey) && e.key === 'n' && !e.shiftKey) {
|
||||
e.preventDefault();
|
||||
self.navigate('agents');
|
||||
}
|
||||
// Ctrl+Shift+F — toggle focus mode
|
||||
if ((e.ctrlKey || e.metaKey) && e.shiftKey && e.key === 'F') {
|
||||
e.preventDefault();
|
||||
Alpine.store('app').toggleFocusMode();
|
||||
}
|
||||
// Escape — close mobile menu
|
||||
if (e.key === 'Escape') {
|
||||
self.mobileMenuOpen = false;
|
||||
}
|
||||
});
|
||||
|
||||
// Connection state listener
|
||||
OpenFangAPI.onConnectionChange(function(state) {
|
||||
Alpine.store('app').connectionState = state;
|
||||
});
|
||||
|
||||
// Initial data load
|
||||
this.pollStatus();
|
||||
Alpine.store('app').checkOnboarding();
|
||||
Alpine.store('app').checkAuth();
|
||||
setInterval(function() { self.pollStatus(); }, 5000);
|
||||
},
|
||||
|
||||
navigate(p) {
|
||||
this.page = p;
|
||||
window.location.hash = p;
|
||||
this.mobileMenuOpen = false;
|
||||
},
|
||||
|
||||
setTheme(mode) {
|
||||
this.themeMode = mode;
|
||||
localStorage.setItem('openfang-theme-mode', mode);
|
||||
if (mode === 'system') {
|
||||
this.theme = window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
|
||||
} else {
|
||||
this.theme = mode;
|
||||
}
|
||||
},
|
||||
|
||||
toggleTheme() {
|
||||
var modes = ['light', 'system', 'dark'];
|
||||
var next = modes[(modes.indexOf(this.themeMode) + 1) % modes.length];
|
||||
this.setTheme(next);
|
||||
},
|
||||
|
||||
toggleSidebar() {
|
||||
this.sidebarCollapsed = !this.sidebarCollapsed;
|
||||
localStorage.setItem('openfang-sidebar', this.sidebarCollapsed ? 'collapsed' : 'expanded');
|
||||
},
|
||||
|
||||
async pollStatus() {
|
||||
var store = Alpine.store('app');
|
||||
await store.checkStatus();
|
||||
await store.refreshAgents();
|
||||
this.connected = store.connected;
|
||||
this.version = store.version;
|
||||
this.agentCount = store.agentCount;
|
||||
this.wsConnected = OpenFangAPI.isWsConnected();
|
||||
}
|
||||
};
|
||||
}
|
||||
582
crates/openfang-api/static/js/pages/agents.js
Normal file
582
crates/openfang-api/static/js/pages/agents.js
Normal file
@@ -0,0 +1,582 @@
|
||||
// OpenFang Agents Page — Multi-step spawn wizard, detail view with tabs, file editor, personality presets
|
||||
'use strict';
|
||||
|
||||
function agentsPage() {
|
||||
return {
|
||||
tab: 'agents',
|
||||
activeChatAgent: null,
|
||||
// -- Agents state --
|
||||
showSpawnModal: false,
|
||||
showDetailModal: false,
|
||||
detailAgent: null,
|
||||
spawnMode: 'wizard',
|
||||
spawning: false,
|
||||
spawnToml: '',
|
||||
filterState: 'all',
|
||||
loading: true,
|
||||
loadError: '',
|
||||
spawnForm: {
|
||||
name: '',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
systemPrompt: 'You are a helpful assistant.',
|
||||
profile: 'full',
|
||||
caps: { memory_read: true, memory_write: true, network: false, shell: false, agent_spawn: false }
|
||||
},
|
||||
|
||||
// -- Multi-step wizard state --
|
||||
spawnStep: 1,
|
||||
spawnIdentity: { emoji: '', color: '#FF5C00', archetype: '' },
|
||||
selectedPreset: '',
|
||||
soulContent: '',
|
||||
emojiOptions: [
|
||||
'\u{1F916}', '\u{1F4BB}', '\u{1F50D}', '\u{270D}\uFE0F', '\u{1F4CA}', '\u{1F6E0}\uFE0F',
|
||||
'\u{1F4AC}', '\u{1F393}', '\u{1F310}', '\u{1F512}', '\u{26A1}', '\u{1F680}',
|
||||
'\u{1F9EA}', '\u{1F3AF}', '\u{1F4D6}', '\u{1F9D1}\u200D\u{1F4BB}', '\u{1F4E7}', '\u{1F3E2}',
|
||||
'\u{2764}\uFE0F', '\u{1F31F}', '\u{1F527}', '\u{1F4DD}', '\u{1F4A1}', '\u{1F3A8}'
|
||||
],
|
||||
archetypeOptions: ['Assistant', 'Researcher', 'Coder', 'Writer', 'DevOps', 'Support', 'Analyst', 'Custom'],
|
||||
personalityPresets: [
|
||||
{ id: 'professional', label: 'Professional', soul: 'Communicate in a clear, professional tone. Be direct and structured. Use formal language and data-driven reasoning. Prioritize accuracy over personality.' },
|
||||
{ id: 'friendly', label: 'Friendly', soul: 'Be warm, approachable, and conversational. Use casual language and show genuine interest in the user. Add personality to your responses while staying helpful.' },
|
||||
{ id: 'technical', label: 'Technical', soul: 'Focus on technical accuracy and depth. Use precise terminology. Show your work and reasoning. Prefer code examples and structured explanations.' },
|
||||
{ id: 'creative', label: 'Creative', soul: 'Be imaginative and expressive. Use vivid language, analogies, and unexpected connections. Encourage creative thinking and explore multiple perspectives.' },
|
||||
{ id: 'concise', label: 'Concise', soul: 'Be extremely brief and to the point. No filler, no pleasantries. Answer in the fewest words possible while remaining accurate and complete.' },
|
||||
{ id: 'mentor', label: 'Mentor', soul: 'Be patient and encouraging like a great teacher. Break down complex topics step by step. Ask guiding questions. Celebrate progress and build confidence.' }
|
||||
],
|
||||
|
||||
// -- Detail modal tabs --
|
||||
detailTab: 'info',
|
||||
agentFiles: [],
|
||||
editingFile: null,
|
||||
fileContent: '',
|
||||
fileSaving: false,
|
||||
filesLoading: false,
|
||||
configForm: {},
|
||||
configSaving: false,
|
||||
|
||||
// -- Templates state --
|
||||
tplTemplates: [],
|
||||
tplProviders: [],
|
||||
tplLoading: false,
|
||||
tplLoadError: '',
|
||||
selectedCategory: 'All',
|
||||
searchQuery: '',
|
||||
|
||||
builtinTemplates: [
|
||||
{
|
||||
name: 'General Assistant',
|
||||
description: 'A versatile conversational agent that can help with everyday tasks, answer questions, and provide recommendations.',
|
||||
category: 'General',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'full',
|
||||
system_prompt: 'You are a helpful, friendly assistant. Provide clear, accurate, and concise responses. Ask clarifying questions when needed.'
|
||||
},
|
||||
{
|
||||
name: 'Code Helper',
|
||||
description: 'A programming-focused agent that writes, reviews, and debugs code across multiple languages.',
|
||||
category: 'Development',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'coding',
|
||||
system_prompt: 'You are an expert programmer. Help users write clean, efficient code. Explain your reasoning. Follow best practices and conventions for the language being used.'
|
||||
},
|
||||
{
|
||||
name: 'Researcher',
|
||||
description: 'An analytical agent that breaks down complex topics, synthesizes information, and provides cited summaries.',
|
||||
category: 'Research',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'research',
|
||||
system_prompt: 'You are a research analyst. Break down complex topics into clear explanations. Provide structured analysis with key findings. Cite sources when available.'
|
||||
},
|
||||
{
|
||||
name: 'Writer',
|
||||
description: 'A creative writing agent that helps with drafting, editing, and improving written content of all kinds.',
|
||||
category: 'Writing',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'full',
|
||||
system_prompt: 'You are a skilled writer and editor. Help users create polished content. Adapt your tone and style to match the intended audience. Offer constructive suggestions for improvement.'
|
||||
},
|
||||
{
|
||||
name: 'Data Analyst',
|
||||
description: 'A data-focused agent that helps analyze datasets, create queries, and interpret statistical results.',
|
||||
category: 'Development',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'coding',
|
||||
system_prompt: 'You are a data analysis expert. Help users understand their data, write SQL/Python queries, and interpret results. Present findings clearly with actionable insights.'
|
||||
},
|
||||
{
|
||||
name: 'DevOps Engineer',
|
||||
description: 'A systems-focused agent for CI/CD, infrastructure, Docker, and deployment troubleshooting.',
|
||||
category: 'Development',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'automation',
|
||||
system_prompt: 'You are a DevOps engineer. Help with CI/CD pipelines, Docker, Kubernetes, infrastructure as code, and deployment. Prioritize reliability and security.'
|
||||
},
|
||||
{
|
||||
name: 'Customer Support',
|
||||
description: 'A professional, empathetic agent for handling customer inquiries and resolving issues.',
|
||||
category: 'Business',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'messaging',
|
||||
system_prompt: 'You are a professional customer support representative. Be empathetic, patient, and solution-oriented. Acknowledge concerns before offering solutions. Escalate complex issues appropriately.'
|
||||
},
|
||||
{
|
||||
name: 'Tutor',
|
||||
description: 'A patient educational agent that explains concepts step-by-step and adapts to the learner\'s level.',
|
||||
category: 'General',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'full',
|
||||
system_prompt: 'You are a patient and encouraging tutor. Explain concepts step by step, starting from fundamentals. Use analogies and examples. Check understanding before moving on. Adapt to the learner\'s pace.'
|
||||
},
|
||||
{
|
||||
name: 'API Designer',
|
||||
description: 'An agent specialized in RESTful API design, OpenAPI specs, and integration architecture.',
|
||||
category: 'Development',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'coding',
|
||||
system_prompt: 'You are an API design expert. Help users design clean, consistent RESTful APIs following best practices. Cover endpoint naming, request/response schemas, error handling, and versioning.'
|
||||
},
|
||||
{
|
||||
name: 'Meeting Notes',
|
||||
description: 'Summarizes meeting transcripts into structured notes with action items and key decisions.',
|
||||
category: 'Business',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'minimal',
|
||||
system_prompt: 'You are a meeting summarizer. When given a meeting transcript or notes, produce a structured summary with: key decisions, action items (with owners), discussion highlights, and follow-up questions.'
|
||||
}
|
||||
],
|
||||
|
||||
// ── Profile Descriptions ──
|
||||
profileDescriptions: {
|
||||
minimal: { label: 'Minimal', desc: 'Read-only file access' },
|
||||
coding: { label: 'Coding', desc: 'Files + shell + web fetch' },
|
||||
research: { label: 'Research', desc: 'Web search + file read/write' },
|
||||
messaging: { label: 'Messaging', desc: 'Agents + memory access' },
|
||||
automation: { label: 'Automation', desc: 'All tools except custom' },
|
||||
balanced: { label: 'Balanced', desc: 'General-purpose tool set' },
|
||||
precise: { label: 'Precise', desc: 'Focused tool set for accuracy' },
|
||||
creative: { label: 'Creative', desc: 'Full tools with creative emphasis' },
|
||||
full: { label: 'Full', desc: 'All 35+ tools' }
|
||||
},
|
||||
profileInfo: function(name) {
|
||||
return this.profileDescriptions[name] || { label: name, desc: '' };
|
||||
},
|
||||
|
||||
// ── Tool Preview in Spawn Modal ──
|
||||
spawnProfiles: [],
|
||||
spawnProfilesLoaded: false,
|
||||
async loadSpawnProfiles() {
|
||||
if (this.spawnProfilesLoaded) return;
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/profiles');
|
||||
this.spawnProfiles = data.profiles || [];
|
||||
this.spawnProfilesLoaded = true;
|
||||
} catch(e) { this.spawnProfiles = []; }
|
||||
},
|
||||
get selectedProfileTools() {
|
||||
var pname = this.spawnForm.profile;
|
||||
var match = this.spawnProfiles.find(function(p) { return p.name === pname; });
|
||||
if (match && match.tools) return match.tools.slice(0, 15);
|
||||
return [];
|
||||
},
|
||||
|
||||
get agents() { return Alpine.store('app').agents; },
|
||||
|
||||
get filteredAgents() {
|
||||
var f = this.filterState;
|
||||
if (f === 'all') return this.agents;
|
||||
return this.agents.filter(function(a) { return a.state.toLowerCase() === f; });
|
||||
},
|
||||
|
||||
get runningCount() {
|
||||
return this.agents.filter(function(a) { return a.state === 'Running'; }).length;
|
||||
},
|
||||
|
||||
get stoppedCount() {
|
||||
return this.agents.filter(function(a) { return a.state !== 'Running'; }).length;
|
||||
},
|
||||
|
||||
// -- Templates computed --
|
||||
get categories() {
|
||||
var cats = { 'All': true };
|
||||
this.builtinTemplates.forEach(function(t) { cats[t.category] = true; });
|
||||
this.tplTemplates.forEach(function(t) { if (t.category) cats[t.category] = true; });
|
||||
return Object.keys(cats);
|
||||
},
|
||||
|
||||
get filteredBuiltins() {
|
||||
var self = this;
|
||||
return this.builtinTemplates.filter(function(t) {
|
||||
if (self.selectedCategory !== 'All' && t.category !== self.selectedCategory) return false;
|
||||
if (self.searchQuery) {
|
||||
var q = self.searchQuery.toLowerCase();
|
||||
if (t.name.toLowerCase().indexOf(q) === -1 &&
|
||||
t.description.toLowerCase().indexOf(q) === -1) return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
},
|
||||
|
||||
get filteredCustom() {
|
||||
var self = this;
|
||||
return this.tplTemplates.filter(function(t) {
|
||||
if (self.searchQuery) {
|
||||
var q = self.searchQuery.toLowerCase();
|
||||
if ((t.name || '').toLowerCase().indexOf(q) === -1 &&
|
||||
(t.description || '').toLowerCase().indexOf(q) === -1) return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
},
|
||||
|
||||
isProviderConfigured(providerName) {
|
||||
if (!providerName) return false;
|
||||
var p = this.tplProviders.find(function(pr) { return pr.id === providerName; });
|
||||
return p ? p.auth_status === 'configured' : false;
|
||||
},
|
||||
|
||||
async init() {
|
||||
var self = this;
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
await Alpine.store('app').refreshAgents();
|
||||
} catch(e) {
|
||||
this.loadError = e.message || 'Could not load agents. Is the daemon running?';
|
||||
}
|
||||
this.loading = false;
|
||||
|
||||
// If a pending agent was set (e.g. from wizard or redirect), open chat inline
|
||||
var store = Alpine.store('app');
|
||||
if (store.pendingAgent) {
|
||||
this.activeChatAgent = store.pendingAgent;
|
||||
}
|
||||
// Watch for future pendingAgent changes
|
||||
this.$watch('$store.app.pendingAgent', function(agent) {
|
||||
if (agent) {
|
||||
self.activeChatAgent = agent;
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
async loadData() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
await Alpine.store('app').refreshAgents();
|
||||
} catch(e) {
|
||||
this.loadError = e.message || 'Could not load agents.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadTemplates() {
|
||||
this.tplLoading = true;
|
||||
this.tplLoadError = '';
|
||||
try {
|
||||
var results = await Promise.all([
|
||||
OpenFangAPI.get('/api/templates'),
|
||||
OpenFangAPI.get('/api/providers').catch(function() { return { providers: [] }; })
|
||||
]);
|
||||
this.tplTemplates = results[0].templates || [];
|
||||
this.tplProviders = results[1].providers || [];
|
||||
} catch(e) {
|
||||
this.tplTemplates = [];
|
||||
this.tplLoadError = e.message || 'Could not load templates.';
|
||||
}
|
||||
this.tplLoading = false;
|
||||
},
|
||||
|
||||
chatWithAgent(agent) {
|
||||
Alpine.store('app').pendingAgent = agent;
|
||||
this.activeChatAgent = agent;
|
||||
},
|
||||
|
||||
closeChat() {
|
||||
this.activeChatAgent = null;
|
||||
OpenFangAPI.wsDisconnect();
|
||||
},
|
||||
|
||||
showDetail(agent) {
|
||||
this.detailAgent = agent;
|
||||
this.detailTab = 'info';
|
||||
this.agentFiles = [];
|
||||
this.editingFile = null;
|
||||
this.fileContent = '';
|
||||
this.configForm = {
|
||||
name: agent.name || '',
|
||||
system_prompt: agent.system_prompt || '',
|
||||
emoji: (agent.identity && agent.identity.emoji) || '',
|
||||
color: (agent.identity && agent.identity.color) || '#FF5C00',
|
||||
archetype: (agent.identity && agent.identity.archetype) || '',
|
||||
vibe: (agent.identity && agent.identity.vibe) || ''
|
||||
};
|
||||
this.showDetailModal = true;
|
||||
},
|
||||
|
||||
killAgent(agent) {
|
||||
var self = this;
|
||||
OpenFangToast.confirm('Stop Agent', 'Stop agent "' + agent.name + '"? The agent will be shut down.', async function() {
|
||||
try {
|
||||
await OpenFangAPI.del('/api/agents/' + agent.id);
|
||||
OpenFangToast.success('Agent "' + agent.name + '" stopped');
|
||||
self.showDetailModal = false;
|
||||
await Alpine.store('app').refreshAgents();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to stop agent: ' + e.message);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
killAllAgents() {
|
||||
var list = this.filteredAgents;
|
||||
if (!list.length) return;
|
||||
OpenFangToast.confirm('Stop All Agents', 'Stop ' + list.length + ' agent(s)? All agents will be shut down.', async function() {
|
||||
var errors = [];
|
||||
for (var i = 0; i < list.length; i++) {
|
||||
try {
|
||||
await OpenFangAPI.del('/api/agents/' + list[i].id);
|
||||
} catch(e) { errors.push(list[i].name + ': ' + e.message); }
|
||||
}
|
||||
await Alpine.store('app').refreshAgents();
|
||||
if (errors.length) {
|
||||
OpenFangToast.error('Some agents failed to stop: ' + errors.join(', '));
|
||||
} else {
|
||||
OpenFangToast.success(list.length + ' agent(s) stopped');
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
// ── Multi-step wizard navigation ──
|
||||
openSpawnWizard() {
|
||||
this.showSpawnModal = true;
|
||||
this.spawnStep = 1;
|
||||
this.spawnMode = 'wizard';
|
||||
this.spawnIdentity = { emoji: '', color: '#FF5C00', archetype: '' };
|
||||
this.selectedPreset = '';
|
||||
this.soulContent = '';
|
||||
this.spawnForm.name = '';
|
||||
this.spawnForm.systemPrompt = 'You are a helpful assistant.';
|
||||
this.spawnForm.profile = 'full';
|
||||
},
|
||||
|
||||
nextStep() {
|
||||
if (this.spawnStep === 1 && !this.spawnForm.name.trim()) {
|
||||
OpenFangToast.warn('Please enter an agent name');
|
||||
return;
|
||||
}
|
||||
if (this.spawnStep < 5) this.spawnStep++;
|
||||
},
|
||||
|
||||
prevStep() {
|
||||
if (this.spawnStep > 1) this.spawnStep--;
|
||||
},
|
||||
|
||||
selectPreset(preset) {
|
||||
this.selectedPreset = preset.id;
|
||||
this.soulContent = preset.soul;
|
||||
},
|
||||
|
||||
generateToml() {
|
||||
var f = this.spawnForm;
|
||||
var si = this.spawnIdentity;
|
||||
var lines = [
|
||||
'name = "' + f.name + '"',
|
||||
'module = "builtin:chat"'
|
||||
];
|
||||
if (f.profile && f.profile !== 'custom') {
|
||||
lines.push('profile = "' + f.profile + '"');
|
||||
}
|
||||
lines.push('', '[model]');
|
||||
lines.push('provider = "' + f.provider + '"');
|
||||
lines.push('model = "' + f.model + '"');
|
||||
lines.push('system_prompt = "' + f.systemPrompt.replace(/"/g, '\\"') + '"');
|
||||
if (f.profile === 'custom') {
|
||||
lines.push('', '[capabilities]');
|
||||
if (f.caps.memory_read) lines.push('memory_read = ["*"]');
|
||||
if (f.caps.memory_write) lines.push('memory_write = ["self.*"]');
|
||||
if (f.caps.network) lines.push('network = ["*"]');
|
||||
if (f.caps.shell) lines.push('shell = ["*"]');
|
||||
if (f.caps.agent_spawn) lines.push('agent_spawn = true');
|
||||
}
|
||||
return lines.join('\n');
|
||||
},
|
||||
|
||||
async setMode(agent, mode) {
|
||||
try {
|
||||
await OpenFangAPI.put('/api/agents/' + agent.id + '/mode', { mode: mode });
|
||||
agent.mode = mode;
|
||||
OpenFangToast.success('Mode set to ' + mode);
|
||||
await Alpine.store('app').refreshAgents();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to set mode: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
async spawnAgent() {
|
||||
this.spawning = true;
|
||||
var toml = this.spawnMode === 'wizard' ? this.generateToml() : this.spawnToml;
|
||||
if (!toml.trim()) {
|
||||
this.spawning = false;
|
||||
OpenFangToast.warn('Manifest is empty \u2014 enter agent config first');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
var res = await OpenFangAPI.post('/api/agents', { manifest_toml: toml });
|
||||
if (res.agent_id) {
|
||||
// Post-spawn: update identity + write SOUL.md if personality preset selected
|
||||
var patchBody = {};
|
||||
if (this.spawnIdentity.emoji) patchBody.emoji = this.spawnIdentity.emoji;
|
||||
if (this.spawnIdentity.color) patchBody.color = this.spawnIdentity.color;
|
||||
if (this.spawnIdentity.archetype) patchBody.archetype = this.spawnIdentity.archetype;
|
||||
if (this.selectedPreset) patchBody.vibe = this.selectedPreset;
|
||||
|
||||
if (Object.keys(patchBody).length) {
|
||||
OpenFangAPI.patch('/api/agents/' + res.agent_id + '/config', patchBody).catch(function(e) { console.warn('Post-spawn config patch failed:', e.message); });
|
||||
}
|
||||
if (this.soulContent.trim()) {
|
||||
OpenFangAPI.put('/api/agents/' + res.agent_id + '/files/SOUL.md', { content: '# Soul\n' + this.soulContent }).catch(function(e) { console.warn('SOUL.md write failed:', e.message); });
|
||||
}
|
||||
|
||||
this.showSpawnModal = false;
|
||||
this.spawnForm.name = '';
|
||||
this.spawnToml = '';
|
||||
this.spawnStep = 1;
|
||||
OpenFangToast.success('Agent "' + (res.name || 'new') + '" spawned');
|
||||
await Alpine.store('app').refreshAgents();
|
||||
this.chatWithAgent({ id: res.agent_id, name: res.name, model_provider: '?', model_name: '?' });
|
||||
} else {
|
||||
OpenFangToast.error('Spawn failed: ' + (res.error || 'Unknown error'));
|
||||
}
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to spawn agent: ' + e.message);
|
||||
}
|
||||
this.spawning = false;
|
||||
},
|
||||
|
||||
// ── Detail modal: Files tab ──
|
||||
async loadAgentFiles() {
|
||||
if (!this.detailAgent) return;
|
||||
this.filesLoading = true;
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/agents/' + this.detailAgent.id + '/files');
|
||||
this.agentFiles = data.files || [];
|
||||
} catch(e) {
|
||||
this.agentFiles = [];
|
||||
OpenFangToast.error('Failed to load files: ' + e.message);
|
||||
}
|
||||
this.filesLoading = false;
|
||||
},
|
||||
|
||||
async openFile(file) {
|
||||
if (!file.exists) {
|
||||
// Create with empty content
|
||||
this.editingFile = file.name;
|
||||
this.fileContent = '';
|
||||
return;
|
||||
}
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/agents/' + this.detailAgent.id + '/files/' + encodeURIComponent(file.name));
|
||||
this.editingFile = file.name;
|
||||
this.fileContent = data.content || '';
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to read file: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
async saveFile() {
|
||||
if (!this.editingFile || !this.detailAgent) return;
|
||||
this.fileSaving = true;
|
||||
try {
|
||||
await OpenFangAPI.put('/api/agents/' + this.detailAgent.id + '/files/' + encodeURIComponent(this.editingFile), { content: this.fileContent });
|
||||
OpenFangToast.success(this.editingFile + ' saved');
|
||||
await this.loadAgentFiles();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to save file: ' + e.message);
|
||||
}
|
||||
this.fileSaving = false;
|
||||
},
|
||||
|
||||
closeFileEditor() {
|
||||
this.editingFile = null;
|
||||
this.fileContent = '';
|
||||
},
|
||||
|
||||
// ── Detail modal: Config tab ──
|
||||
async saveConfig() {
|
||||
if (!this.detailAgent) return;
|
||||
this.configSaving = true;
|
||||
try {
|
||||
await OpenFangAPI.patch('/api/agents/' + this.detailAgent.id + '/config', this.configForm);
|
||||
OpenFangToast.success('Config updated');
|
||||
await Alpine.store('app').refreshAgents();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to save config: ' + e.message);
|
||||
}
|
||||
this.configSaving = false;
|
||||
},
|
||||
|
||||
// ── Clone agent ──
|
||||
async cloneAgent(agent) {
|
||||
var newName = (agent.name || 'agent') + '-copy';
|
||||
try {
|
||||
var res = await OpenFangAPI.post('/api/agents/' + agent.id + '/clone', { new_name: newName });
|
||||
if (res.agent_id) {
|
||||
OpenFangToast.success('Cloned as "' + res.name + '"');
|
||||
await Alpine.store('app').refreshAgents();
|
||||
this.showDetailModal = false;
|
||||
}
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Clone failed: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
// -- Template methods --
|
||||
async spawnFromTemplate(name) {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/templates/' + encodeURIComponent(name));
|
||||
if (data.manifest_toml) {
|
||||
var res = await OpenFangAPI.post('/api/agents', { manifest_toml: data.manifest_toml });
|
||||
if (res.agent_id) {
|
||||
OpenFangToast.success('Agent "' + (res.name || name) + '" spawned from template');
|
||||
await Alpine.store('app').refreshAgents();
|
||||
this.chatWithAgent({ id: res.agent_id, name: res.name || name, model_provider: '?', model_name: '?' });
|
||||
}
|
||||
}
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to spawn from template: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
async spawnBuiltin(t) {
|
||||
var toml = 'name = "' + t.name + '"\n';
|
||||
toml += 'description = "' + t.description.replace(/"/g, '\\"') + '"\n';
|
||||
toml += 'module = "builtin:chat"\n';
|
||||
toml += 'profile = "' + t.profile + '"\n\n';
|
||||
toml += '[model]\nprovider = "' + t.provider + '"\nmodel = "' + t.model + '"\n';
|
||||
toml += 'system_prompt = """\n' + t.system_prompt + '\n"""\n';
|
||||
|
||||
try {
|
||||
var res = await OpenFangAPI.post('/api/agents', { manifest_toml: toml });
|
||||
if (res.agent_id) {
|
||||
OpenFangToast.success('Agent "' + t.name + '" spawned');
|
||||
await Alpine.store('app').refreshAgents();
|
||||
this.chatWithAgent({ id: res.agent_id, name: t.name, model_provider: t.provider, model_name: t.model });
|
||||
}
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to spawn agent: ' + e.message);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
66
crates/openfang-api/static/js/pages/approvals.js
Normal file
66
crates/openfang-api/static/js/pages/approvals.js
Normal file
@@ -0,0 +1,66 @@
|
||||
// OpenFang Approvals Page — Execution approval queue for sensitive agent actions
|
||||
'use strict';
|
||||
|
||||
function approvalsPage() {
|
||||
return {
|
||||
approvals: [],
|
||||
filterStatus: 'all',
|
||||
loading: true,
|
||||
loadError: '',
|
||||
|
||||
get filtered() {
|
||||
var f = this.filterStatus;
|
||||
if (f === 'all') return this.approvals;
|
||||
return this.approvals.filter(function(a) { return a.status === f; });
|
||||
},
|
||||
|
||||
get pendingCount() {
|
||||
return this.approvals.filter(function(a) { return a.status === 'pending'; }).length;
|
||||
},
|
||||
|
||||
async loadData() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/approvals');
|
||||
this.approvals = data.approvals || [];
|
||||
} catch(e) {
|
||||
this.loadError = e.message || 'Could not load approvals.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async approve(id) {
|
||||
try {
|
||||
await OpenFangAPI.post('/api/approvals/' + id + '/approve', {});
|
||||
OpenFangToast.success('Approved');
|
||||
await this.loadData();
|
||||
} catch(e) {
|
||||
OpenFangToast.error(e.message);
|
||||
}
|
||||
},
|
||||
|
||||
async reject(id) {
|
||||
var self = this;
|
||||
OpenFangToast.confirm('Reject Action', 'Are you sure you want to reject this action?', async function() {
|
||||
try {
|
||||
await OpenFangAPI.post('/api/approvals/' + id + '/reject', {});
|
||||
OpenFangToast.success('Rejected');
|
||||
await self.loadData();
|
||||
} catch(e) {
|
||||
OpenFangToast.error(e.message);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
timeAgo(dateStr) {
|
||||
if (!dateStr) return '';
|
||||
var d = new Date(dateStr);
|
||||
var secs = Math.floor((Date.now() - d.getTime()) / 1000);
|
||||
if (secs < 60) return secs + 's ago';
|
||||
if (secs < 3600) return Math.floor(secs / 60) + 'm ago';
|
||||
if (secs < 86400) return Math.floor(secs / 3600) + 'h ago';
|
||||
return Math.floor(secs / 86400) + 'd ago';
|
||||
}
|
||||
};
|
||||
}
|
||||
300
crates/openfang-api/static/js/pages/channels.js
Normal file
300
crates/openfang-api/static/js/pages/channels.js
Normal file
@@ -0,0 +1,300 @@
|
||||
// OpenFang Channels Page — OpenClaw-style setup UX with QR code support
|
||||
'use strict';
|
||||
|
||||
function channelsPage() {
|
||||
return {
|
||||
allChannels: [],
|
||||
categoryFilter: 'all',
|
||||
searchQuery: '',
|
||||
setupModal: null,
|
||||
configuring: false,
|
||||
testing: {},
|
||||
formValues: {},
|
||||
showAdvanced: false,
|
||||
showBusinessApi: false,
|
||||
loading: true,
|
||||
loadError: '',
|
||||
pollTimer: null,
|
||||
|
||||
// Setup flow step tracking
|
||||
setupStep: 1, // 1=Configure, 2=Verify, 3=Ready
|
||||
testPassed: false,
|
||||
|
||||
// WhatsApp QR state
|
||||
qr: {
|
||||
loading: false,
|
||||
available: false,
|
||||
dataUrl: '',
|
||||
sessionId: '',
|
||||
message: '',
|
||||
help: '',
|
||||
connected: false,
|
||||
expired: false,
|
||||
error: ''
|
||||
},
|
||||
qrPollTimer: null,
|
||||
|
||||
categories: [
|
||||
{ key: 'all', label: 'All' },
|
||||
{ key: 'messaging', label: 'Messaging' },
|
||||
{ key: 'social', label: 'Social' },
|
||||
{ key: 'enterprise', label: 'Enterprise' },
|
||||
{ key: 'developer', label: 'Developer' },
|
||||
{ key: 'notifications', label: 'Notifications' }
|
||||
],
|
||||
|
||||
get filteredChannels() {
|
||||
var self = this;
|
||||
return this.allChannels.filter(function(ch) {
|
||||
if (self.categoryFilter !== 'all' && ch.category !== self.categoryFilter) return false;
|
||||
if (self.searchQuery) {
|
||||
var q = self.searchQuery.toLowerCase();
|
||||
return ch.name.toLowerCase().indexOf(q) !== -1 ||
|
||||
ch.display_name.toLowerCase().indexOf(q) !== -1 ||
|
||||
ch.description.toLowerCase().indexOf(q) !== -1;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
},
|
||||
|
||||
get configuredCount() {
|
||||
return this.allChannels.filter(function(ch) { return ch.configured; }).length;
|
||||
},
|
||||
|
||||
categoryCount(cat) {
|
||||
var all = this.allChannels.filter(function(ch) { return cat === 'all' || ch.category === cat; });
|
||||
var configured = all.filter(function(ch) { return ch.configured; });
|
||||
return configured.length + '/' + all.length;
|
||||
},
|
||||
|
||||
basicFields() {
|
||||
if (!this.setupModal || !this.setupModal.fields) return [];
|
||||
return this.setupModal.fields.filter(function(f) { return !f.advanced; });
|
||||
},
|
||||
|
||||
advancedFields() {
|
||||
if (!this.setupModal || !this.setupModal.fields) return [];
|
||||
return this.setupModal.fields.filter(function(f) { return f.advanced; });
|
||||
},
|
||||
|
||||
hasAdvanced() {
|
||||
return this.advancedFields().length > 0;
|
||||
},
|
||||
|
||||
isQrChannel() {
|
||||
return this.setupModal && this.setupModal.setup_type === 'qr';
|
||||
},
|
||||
|
||||
async loadChannels() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/channels');
|
||||
this.allChannels = (data.channels || []).map(function(ch) {
|
||||
ch.connected = ch.configured && ch.has_token;
|
||||
return ch;
|
||||
});
|
||||
} catch(e) {
|
||||
this.loadError = e.message || 'Could not load channels.';
|
||||
}
|
||||
this.loading = false;
|
||||
this.startPolling();
|
||||
},
|
||||
|
||||
async loadData() { return this.loadChannels(); },
|
||||
|
||||
startPolling() {
|
||||
var self = this;
|
||||
if (this.pollTimer) clearInterval(this.pollTimer);
|
||||
this.pollTimer = setInterval(function() { self.refreshStatus(); }, 15000);
|
||||
},
|
||||
|
||||
async refreshStatus() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/channels');
|
||||
var byName = {};
|
||||
(data.channels || []).forEach(function(ch) { byName[ch.name] = ch; });
|
||||
this.allChannels.forEach(function(c) {
|
||||
var fresh = byName[c.name];
|
||||
if (fresh) {
|
||||
c.configured = fresh.configured;
|
||||
c.has_token = fresh.has_token;
|
||||
c.connected = fresh.configured && fresh.has_token;
|
||||
c.fields = fresh.fields;
|
||||
}
|
||||
});
|
||||
} catch(e) { console.warn('Channel refresh failed:', e.message); }
|
||||
},
|
||||
|
||||
statusBadge(ch) {
|
||||
if (!ch.configured) return { text: 'Not Configured', cls: 'badge-muted' };
|
||||
if (!ch.has_token) return { text: 'Missing Token', cls: 'badge-warn' };
|
||||
if (ch.connected) return { text: 'Ready', cls: 'badge-success' };
|
||||
return { text: 'Configured', cls: 'badge-info' };
|
||||
},
|
||||
|
||||
difficultyClass(d) {
|
||||
if (d === 'Easy') return 'difficulty-easy';
|
||||
if (d === 'Hard') return 'difficulty-hard';
|
||||
return 'difficulty-medium';
|
||||
},
|
||||
|
||||
openSetup(ch) {
|
||||
this.setupModal = ch;
|
||||
this.formValues = {};
|
||||
this.showAdvanced = false;
|
||||
this.showBusinessApi = false;
|
||||
this.setupStep = ch.configured ? 3 : 1;
|
||||
this.testPassed = !!ch.configured;
|
||||
this.resetQR();
|
||||
// Auto-start QR flow for QR-type channels
|
||||
if (ch.setup_type === 'qr') {
|
||||
this.startQR();
|
||||
}
|
||||
},
|
||||
|
||||
// ── QR Code Flow (WhatsApp Web style) ──────────────────────────
|
||||
|
||||
resetQR() {
|
||||
this.qr = {
|
||||
loading: false, available: false, dataUrl: '', sessionId: '',
|
||||
message: '', help: '', connected: false, expired: false, error: ''
|
||||
};
|
||||
if (this.qrPollTimer) { clearInterval(this.qrPollTimer); this.qrPollTimer = null; }
|
||||
},
|
||||
|
||||
async startQR() {
|
||||
this.qr.loading = true;
|
||||
this.qr.error = '';
|
||||
this.qr.connected = false;
|
||||
this.qr.expired = false;
|
||||
try {
|
||||
var result = await OpenFangAPI.post('/api/channels/whatsapp/qr/start', {});
|
||||
this.qr.available = result.available || false;
|
||||
this.qr.dataUrl = result.qr_data_url || '';
|
||||
this.qr.sessionId = result.session_id || '';
|
||||
this.qr.message = result.message || '';
|
||||
this.qr.help = result.help || '';
|
||||
this.qr.connected = result.connected || false;
|
||||
if (this.qr.available && this.qr.dataUrl && !this.qr.connected) {
|
||||
this.pollQR();
|
||||
}
|
||||
if (this.qr.connected) {
|
||||
OpenFangToast.success('WhatsApp connected!');
|
||||
await this.refreshStatus();
|
||||
}
|
||||
} catch(e) {
|
||||
this.qr.error = e.message || 'Could not start QR login';
|
||||
}
|
||||
this.qr.loading = false;
|
||||
},
|
||||
|
||||
pollQR() {
|
||||
var self = this;
|
||||
if (this.qrPollTimer) clearInterval(this.qrPollTimer);
|
||||
this.qrPollTimer = setInterval(async function() {
|
||||
try {
|
||||
var result = await OpenFangAPI.get('/api/channels/whatsapp/qr/status?session_id=' + encodeURIComponent(self.qr.sessionId));
|
||||
if (result.connected) {
|
||||
clearInterval(self.qrPollTimer);
|
||||
self.qrPollTimer = null;
|
||||
self.qr.connected = true;
|
||||
self.qr.message = result.message || 'Connected!';
|
||||
OpenFangToast.success('WhatsApp linked successfully!');
|
||||
await self.refreshStatus();
|
||||
} else if (result.expired) {
|
||||
clearInterval(self.qrPollTimer);
|
||||
self.qrPollTimer = null;
|
||||
self.qr.expired = true;
|
||||
self.qr.message = 'QR code expired. Click to generate a new one.';
|
||||
} else {
|
||||
self.qr.message = result.message || 'Waiting for scan...';
|
||||
}
|
||||
} catch(e) { /* silent retry */ }
|
||||
}, 3000);
|
||||
},
|
||||
|
||||
// ── Standard Form Flow ─────────────────────────────────────────
|
||||
|
||||
async saveChannel() {
|
||||
if (!this.setupModal) return;
|
||||
var name = this.setupModal.name;
|
||||
this.configuring = true;
|
||||
try {
|
||||
await OpenFangAPI.post('/api/channels/' + name + '/configure', {
|
||||
fields: this.formValues
|
||||
});
|
||||
this.setupStep = 2;
|
||||
// Auto-test after save
|
||||
try {
|
||||
var testResult = await OpenFangAPI.post('/api/channels/' + name + '/test', {});
|
||||
if (testResult.status === 'ok') {
|
||||
this.testPassed = true;
|
||||
this.setupStep = 3;
|
||||
OpenFangToast.success(this.setupModal.display_name + ' activated!');
|
||||
} else {
|
||||
OpenFangToast.success(this.setupModal.display_name + ' saved. ' + (testResult.message || ''));
|
||||
}
|
||||
} catch(te) {
|
||||
OpenFangToast.success(this.setupModal.display_name + ' saved. Test to verify connection.');
|
||||
}
|
||||
await this.refreshStatus();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed: ' + (e.message || 'Unknown error'));
|
||||
}
|
||||
this.configuring = false;
|
||||
},
|
||||
|
||||
async removeChannel() {
|
||||
if (!this.setupModal) return;
|
||||
var name = this.setupModal.name;
|
||||
var displayName = this.setupModal.display_name;
|
||||
var self = this;
|
||||
OpenFangToast.confirm('Remove Channel', 'Remove ' + displayName + ' configuration? This will deactivate the channel.', async function() {
|
||||
try {
|
||||
await OpenFangAPI.delete('/api/channels/' + name + '/configure');
|
||||
OpenFangToast.success(displayName + ' removed and deactivated.');
|
||||
await self.refreshStatus();
|
||||
self.setupModal = null;
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed: ' + (e.message || 'Unknown error'));
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
async testChannel() {
|
||||
if (!this.setupModal) return;
|
||||
var name = this.setupModal.name;
|
||||
this.testing[name] = true;
|
||||
try {
|
||||
var result = await OpenFangAPI.post('/api/channels/' + name + '/test', {});
|
||||
if (result.status === 'ok') {
|
||||
this.testPassed = true;
|
||||
this.setupStep = 3;
|
||||
OpenFangToast.success(result.message);
|
||||
} else {
|
||||
OpenFangToast.error(result.message);
|
||||
}
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Test failed: ' + (e.message || 'Unknown error'));
|
||||
}
|
||||
this.testing[name] = false;
|
||||
},
|
||||
|
||||
async copyConfig(ch) {
|
||||
var tpl = ch ? ch.config_template : (this.setupModal ? this.setupModal.config_template : '');
|
||||
if (!tpl) return;
|
||||
try {
|
||||
await navigator.clipboard.writeText(tpl);
|
||||
OpenFangToast.success('Copied to clipboard');
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Copy failed');
|
||||
}
|
||||
},
|
||||
|
||||
destroy() {
|
||||
if (this.pollTimer) { clearInterval(this.pollTimer); this.pollTimer = null; }
|
||||
if (this.qrPollTimer) { clearInterval(this.qrPollTimer); this.qrPollTimer = null; }
|
||||
}
|
||||
};
|
||||
}
|
||||
1072
crates/openfang-api/static/js/pages/chat.js
Normal file
1072
crates/openfang-api/static/js/pages/chat.js
Normal file
File diff suppressed because it is too large
Load Diff
504
crates/openfang-api/static/js/pages/hands.js
Normal file
504
crates/openfang-api/static/js/pages/hands.js
Normal file
@@ -0,0 +1,504 @@
|
||||
// OpenFang Hands Page — curated autonomous capability packages
|
||||
'use strict';
|
||||
|
||||
function handsPage() {
|
||||
return {
|
||||
tab: 'available',
|
||||
hands: [],
|
||||
instances: [],
|
||||
loading: true,
|
||||
activeLoading: false,
|
||||
loadError: '',
|
||||
activatingId: null,
|
||||
activateResult: null,
|
||||
detailHand: null,
|
||||
settingsValues: {},
|
||||
_toastTimer: null,
|
||||
browserViewer: null,
|
||||
browserViewerOpen: false,
|
||||
_browserPollTimer: null,
|
||||
|
||||
// ── Setup Wizard State ──────────────────────────────────────────────
|
||||
setupWizard: null,
|
||||
setupStep: 1,
|
||||
setupLoading: false,
|
||||
setupChecking: false,
|
||||
clipboardMsg: null,
|
||||
_clipboardTimer: null,
|
||||
detectedPlatform: 'linux',
|
||||
installPlatforms: {},
|
||||
|
||||
async loadData() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/hands');
|
||||
this.hands = data.hands || [];
|
||||
} catch(e) {
|
||||
this.hands = [];
|
||||
this.loadError = e.message || 'Could not load hands.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadActive() {
|
||||
this.activeLoading = true;
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/hands/active');
|
||||
this.instances = (data.instances || []).map(function(i) {
|
||||
i._stats = null;
|
||||
return i;
|
||||
});
|
||||
} catch(e) {
|
||||
this.instances = [];
|
||||
}
|
||||
this.activeLoading = false;
|
||||
},
|
||||
|
||||
getHandIcon(handId) {
|
||||
for (var i = 0; i < this.hands.length; i++) {
|
||||
if (this.hands[i].id === handId) return this.hands[i].icon;
|
||||
}
|
||||
return '\u{1F91A}';
|
||||
},
|
||||
|
||||
async showDetail(handId) {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/hands/' + handId);
|
||||
this.detailHand = data;
|
||||
} catch(e) {
|
||||
for (var i = 0; i < this.hands.length; i++) {
|
||||
if (this.hands[i].id === handId) {
|
||||
this.detailHand = this.hands[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
// ── Setup Wizard ────────────────────────────────────────────────────
|
||||
|
||||
async activate(handId) {
|
||||
this.openSetupWizard(handId);
|
||||
},
|
||||
|
||||
async openSetupWizard(handId) {
|
||||
this.setupLoading = true;
|
||||
this.setupWizard = null;
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/hands/' + handId);
|
||||
// Pre-populate settings defaults
|
||||
this.settingsValues = {};
|
||||
if (data.settings && data.settings.length > 0) {
|
||||
for (var i = 0; i < data.settings.length; i++) {
|
||||
var s = data.settings[i];
|
||||
this.settingsValues[s.key] = s.default || '';
|
||||
}
|
||||
}
|
||||
// Detect platform from server response, fallback to client-side
|
||||
if (data.server_platform) {
|
||||
this.detectedPlatform = data.server_platform;
|
||||
} else {
|
||||
this._detectClientPlatform();
|
||||
}
|
||||
// Initialize per-requirement platform selections
|
||||
this.installPlatforms = {};
|
||||
if (data.requirements) {
|
||||
for (var j = 0; j < data.requirements.length; j++) {
|
||||
this.installPlatforms[data.requirements[j].key] = this.detectedPlatform;
|
||||
}
|
||||
}
|
||||
this.setupWizard = data;
|
||||
// Skip deps step if no requirements
|
||||
var hasReqs = data.requirements && data.requirements.length > 0;
|
||||
this.setupStep = hasReqs ? 1 : 2;
|
||||
} catch(e) {
|
||||
this.showToast('Could not load hand details: ' + (e.message || 'unknown error'));
|
||||
}
|
||||
this.setupLoading = false;
|
||||
},
|
||||
|
||||
_detectClientPlatform() {
|
||||
var ua = (navigator.userAgent || '').toLowerCase();
|
||||
if (ua.indexOf('mac') !== -1) {
|
||||
this.detectedPlatform = 'macos';
|
||||
} else if (ua.indexOf('win') !== -1) {
|
||||
this.detectedPlatform = 'windows';
|
||||
} else {
|
||||
this.detectedPlatform = 'linux';
|
||||
}
|
||||
},
|
||||
|
||||
// ── Auto-Install Dependencies ───────────────────────────────────
|
||||
installProgress: null, // null = idle, object = { status, current, total, results, error }
|
||||
|
||||
async installDeps() {
|
||||
if (!this.setupWizard) return;
|
||||
var handId = this.setupWizard.id;
|
||||
var missing = (this.setupWizard.requirements || []).filter(function(r) { return !r.satisfied; });
|
||||
if (missing.length === 0) {
|
||||
this.showToast('All dependencies already installed!');
|
||||
return;
|
||||
}
|
||||
|
||||
this.installProgress = {
|
||||
status: 'installing',
|
||||
current: 0,
|
||||
total: missing.length,
|
||||
currentLabel: missing[0] ? missing[0].label : '',
|
||||
results: [],
|
||||
error: null
|
||||
};
|
||||
|
||||
try {
|
||||
var data = await OpenFangAPI.post('/api/hands/' + handId + '/install-deps', {});
|
||||
var results = data.results || [];
|
||||
this.installProgress.results = results;
|
||||
this.installProgress.current = results.length;
|
||||
this.installProgress.status = 'done';
|
||||
|
||||
// Update requirements from server response
|
||||
if (data.requirements && this.setupWizard.requirements) {
|
||||
for (var i = 0; i < this.setupWizard.requirements.length; i++) {
|
||||
var existing = this.setupWizard.requirements[i];
|
||||
for (var j = 0; j < data.requirements.length; j++) {
|
||||
if (data.requirements[j].key === existing.key) {
|
||||
existing.satisfied = data.requirements[j].satisfied;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
this.setupWizard.requirements_met = data.requirements_met;
|
||||
}
|
||||
|
||||
var installed = results.filter(function(r) { return r.status === 'installed' || r.status === 'already_installed'; }).length;
|
||||
var failed = results.filter(function(r) { return r.status === 'error' || r.status === 'timeout'; }).length;
|
||||
|
||||
if (data.requirements_met) {
|
||||
this.showToast('All dependencies installed successfully!');
|
||||
// Auto-advance to step 2 after a short delay
|
||||
var self = this;
|
||||
setTimeout(function() {
|
||||
self.installProgress = null;
|
||||
self.setupNextStep();
|
||||
}, 1500);
|
||||
} else if (failed > 0) {
|
||||
this.installProgress.error = failed + ' dependency(ies) failed to install. Check the details below.';
|
||||
}
|
||||
} catch(e) {
|
||||
this.installProgress = {
|
||||
status: 'error',
|
||||
current: 0,
|
||||
total: missing.length,
|
||||
currentLabel: '',
|
||||
results: [],
|
||||
error: e.message || 'Installation request failed'
|
||||
};
|
||||
}
|
||||
},
|
||||
|
||||
getInstallResultIcon(status) {
|
||||
if (status === 'installed' || status === 'already_installed') return '\u2713';
|
||||
if (status === 'error' || status === 'timeout') return '\u2717';
|
||||
return '\u2022';
|
||||
},
|
||||
|
||||
getInstallResultClass(status) {
|
||||
if (status === 'installed' || status === 'already_installed') return 'dep-met';
|
||||
if (status === 'error' || status === 'timeout') return 'dep-missing';
|
||||
return '';
|
||||
},
|
||||
|
||||
async recheckDeps() {
|
||||
if (!this.setupWizard) return;
|
||||
this.setupChecking = true;
|
||||
try {
|
||||
var data = await OpenFangAPI.post('/api/hands/' + this.setupWizard.id + '/check-deps', {});
|
||||
if (data.requirements && this.setupWizard.requirements) {
|
||||
for (var i = 0; i < this.setupWizard.requirements.length; i++) {
|
||||
var existing = this.setupWizard.requirements[i];
|
||||
for (var j = 0; j < data.requirements.length; j++) {
|
||||
if (data.requirements[j].key === existing.key) {
|
||||
existing.satisfied = data.requirements[j].satisfied;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
this.setupWizard.requirements_met = data.requirements_met;
|
||||
}
|
||||
if (data.requirements_met) {
|
||||
this.showToast('All dependencies satisfied!');
|
||||
}
|
||||
} catch(e) {
|
||||
this.showToast('Check failed: ' + (e.message || 'unknown'));
|
||||
}
|
||||
this.setupChecking = false;
|
||||
},
|
||||
|
||||
getInstallCmd(req) {
|
||||
if (!req || !req.install) return null;
|
||||
var inst = req.install;
|
||||
var plat = this.installPlatforms[req.key] || this.detectedPlatform;
|
||||
if (plat === 'macos' && inst.macos) return inst.macos;
|
||||
if (plat === 'windows' && inst.windows) return inst.windows;
|
||||
if (plat === 'linux') {
|
||||
return inst.linux_apt || inst.linux_dnf || inst.linux_pacman || inst.pip || null;
|
||||
}
|
||||
return inst.pip || inst.macos || inst.windows || inst.linux_apt || null;
|
||||
},
|
||||
|
||||
getLinuxVariant(req) {
|
||||
if (!req || !req.install) return null;
|
||||
var inst = req.install;
|
||||
var plat = this.installPlatforms[req.key] || this.detectedPlatform;
|
||||
if (plat !== 'linux') return null;
|
||||
// Return all available Linux variants
|
||||
var variants = [];
|
||||
if (inst.linux_apt) variants.push({ label: 'apt', cmd: inst.linux_apt });
|
||||
if (inst.linux_dnf) variants.push({ label: 'dnf', cmd: inst.linux_dnf });
|
||||
if (inst.linux_pacman) variants.push({ label: 'pacman', cmd: inst.linux_pacman });
|
||||
if (inst.pip) variants.push({ label: 'pip', cmd: inst.pip });
|
||||
return variants.length > 1 ? variants : null;
|
||||
},
|
||||
|
||||
copyToClipboard(text) {
|
||||
var self = this;
|
||||
navigator.clipboard.writeText(text).then(function() {
|
||||
self.clipboardMsg = text;
|
||||
if (self._clipboardTimer) clearTimeout(self._clipboardTimer);
|
||||
self._clipboardTimer = setTimeout(function() { self.clipboardMsg = null; }, 2000);
|
||||
});
|
||||
},
|
||||
|
||||
get setupReqsMet() {
|
||||
if (!this.setupWizard || !this.setupWizard.requirements) return 0;
|
||||
var count = 0;
|
||||
for (var i = 0; i < this.setupWizard.requirements.length; i++) {
|
||||
if (this.setupWizard.requirements[i].satisfied) count++;
|
||||
}
|
||||
return count;
|
||||
},
|
||||
|
||||
get setupReqsTotal() {
|
||||
if (!this.setupWizard || !this.setupWizard.requirements) return 0;
|
||||
return this.setupWizard.requirements.length;
|
||||
},
|
||||
|
||||
get setupAllReqsMet() {
|
||||
return this.setupReqsTotal > 0 && this.setupReqsMet === this.setupReqsTotal;
|
||||
},
|
||||
|
||||
get setupHasReqs() {
|
||||
return this.setupReqsTotal > 0;
|
||||
},
|
||||
|
||||
get setupHasSettings() {
|
||||
return this.setupWizard && this.setupWizard.settings && this.setupWizard.settings.length > 0;
|
||||
},
|
||||
|
||||
setupNextStep() {
|
||||
if (this.setupStep === 1 && this.setupHasSettings) {
|
||||
this.setupStep = 2;
|
||||
} else if (this.setupStep === 1) {
|
||||
this.setupStep = 3;
|
||||
} else if (this.setupStep === 2) {
|
||||
this.setupStep = 3;
|
||||
}
|
||||
},
|
||||
|
||||
setupPrevStep() {
|
||||
if (this.setupStep === 3 && this.setupHasSettings) {
|
||||
this.setupStep = 2;
|
||||
} else if (this.setupStep === 3) {
|
||||
this.setupStep = this.setupHasReqs ? 1 : 2;
|
||||
} else if (this.setupStep === 2 && this.setupHasReqs) {
|
||||
this.setupStep = 1;
|
||||
}
|
||||
},
|
||||
|
||||
closeSetupWizard() {
|
||||
this.setupWizard = null;
|
||||
this.setupStep = 1;
|
||||
this.setupLoading = false;
|
||||
this.setupChecking = false;
|
||||
this.clipboardMsg = null;
|
||||
this.installPlatforms = {};
|
||||
},
|
||||
|
||||
async launchHand() {
|
||||
if (!this.setupWizard) return;
|
||||
var handId = this.setupWizard.id;
|
||||
var config = {};
|
||||
for (var key in this.settingsValues) {
|
||||
config[key] = this.settingsValues[key];
|
||||
}
|
||||
this.activatingId = handId;
|
||||
try {
|
||||
var data = await OpenFangAPI.post('/api/hands/' + handId + '/activate', { config: config });
|
||||
this.showToast('Hand "' + handId + '" activated as ' + (data.agent_name || data.instance_id));
|
||||
this.closeSetupWizard();
|
||||
await this.loadActive();
|
||||
this.tab = 'active';
|
||||
} catch(e) {
|
||||
this.showToast('Activation failed: ' + (e.message || 'unknown error'));
|
||||
}
|
||||
this.activatingId = null;
|
||||
},
|
||||
|
||||
selectOption(settingKey, value) {
|
||||
this.settingsValues[settingKey] = value;
|
||||
},
|
||||
|
||||
getSettingDisplayValue(setting) {
|
||||
var val = this.settingsValues[setting.key] || setting.default || '';
|
||||
if (setting.setting_type === 'toggle') {
|
||||
return val === 'true' ? 'Enabled' : 'Disabled';
|
||||
}
|
||||
if (setting.setting_type === 'select' && setting.options) {
|
||||
for (var i = 0; i < setting.options.length; i++) {
|
||||
if (setting.options[i].value === val) return setting.options[i].label;
|
||||
}
|
||||
}
|
||||
return val || '-';
|
||||
},
|
||||
|
||||
// ── Existing methods ────────────────────────────────────────────────
|
||||
|
||||
async pauseHand(inst) {
|
||||
try {
|
||||
await OpenFangAPI.post('/api/hands/instances/' + inst.instance_id + '/pause', {});
|
||||
inst.status = 'Paused';
|
||||
} catch(e) {
|
||||
this.showToast('Pause failed: ' + (e.message || 'unknown error'));
|
||||
}
|
||||
},
|
||||
|
||||
async resumeHand(inst) {
|
||||
try {
|
||||
await OpenFangAPI.post('/api/hands/instances/' + inst.instance_id + '/resume', {});
|
||||
inst.status = 'Active';
|
||||
} catch(e) {
|
||||
this.showToast('Resume failed: ' + (e.message || 'unknown error'));
|
||||
}
|
||||
},
|
||||
|
||||
async deactivate(inst) {
|
||||
var self = this;
|
||||
var handName = inst.agent_name || inst.hand_id;
|
||||
OpenFangToast.confirm('Deactivate Hand', 'Deactivate hand "' + handName + '"? This will kill its agent.', async function() {
|
||||
try {
|
||||
await OpenFangAPI.delete('/api/hands/instances/' + inst.instance_id);
|
||||
self.instances = self.instances.filter(function(i) { return i.instance_id !== inst.instance_id; });
|
||||
OpenFangToast.success('Hand deactivated.');
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Deactivation failed: ' + (e.message || 'unknown error'));
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
async loadStats(inst) {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/hands/instances/' + inst.instance_id + '/stats');
|
||||
inst._stats = data.metrics || {};
|
||||
} catch(e) {
|
||||
inst._stats = { 'Error': { value: e.message || 'Could not load stats', format: 'text' } };
|
||||
}
|
||||
},
|
||||
|
||||
formatMetric(m) {
|
||||
if (!m || m.value === null || m.value === undefined) return '-';
|
||||
if (m.format === 'duration') {
|
||||
var secs = parseInt(m.value, 10);
|
||||
if (isNaN(secs)) return String(m.value);
|
||||
var h = Math.floor(secs / 3600);
|
||||
var min = Math.floor((secs % 3600) / 60);
|
||||
var s = secs % 60;
|
||||
if (h > 0) return h + 'h ' + min + 'm';
|
||||
if (min > 0) return min + 'm ' + s + 's';
|
||||
return s + 's';
|
||||
}
|
||||
if (m.format === 'number') {
|
||||
var n = parseFloat(m.value);
|
||||
if (isNaN(n)) return String(m.value);
|
||||
return n.toLocaleString();
|
||||
}
|
||||
return String(m.value);
|
||||
},
|
||||
|
||||
showToast(msg) {
|
||||
var self = this;
|
||||
this.activateResult = msg;
|
||||
if (this._toastTimer) clearTimeout(this._toastTimer);
|
||||
this._toastTimer = setTimeout(function() { self.activateResult = null; }, 4000);
|
||||
},
|
||||
|
||||
// ── Browser Viewer ───────────────────────────────────────────────────
|
||||
|
||||
isBrowserHand(inst) {
|
||||
return inst.hand_id === 'browser';
|
||||
},
|
||||
|
||||
async openBrowserViewer(inst) {
|
||||
this.browserViewer = {
|
||||
instance_id: inst.instance_id,
|
||||
hand_id: inst.hand_id,
|
||||
agent_name: inst.agent_name,
|
||||
url: '',
|
||||
title: '',
|
||||
screenshot: '',
|
||||
content: '',
|
||||
loading: true,
|
||||
error: ''
|
||||
};
|
||||
this.browserViewerOpen = true;
|
||||
await this.refreshBrowserView();
|
||||
this.startBrowserPolling();
|
||||
},
|
||||
|
||||
async refreshBrowserView() {
|
||||
if (!this.browserViewer) return;
|
||||
var id = this.browserViewer.instance_id;
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/hands/instances/' + id + '/browser');
|
||||
if (data.active) {
|
||||
this.browserViewer.url = data.url || '';
|
||||
this.browserViewer.title = data.title || '';
|
||||
this.browserViewer.screenshot = data.screenshot_base64 || '';
|
||||
this.browserViewer.content = data.content || '';
|
||||
this.browserViewer.error = '';
|
||||
} else {
|
||||
this.browserViewer.error = 'No active browser session';
|
||||
this.browserViewer.screenshot = '';
|
||||
}
|
||||
} catch(e) {
|
||||
this.browserViewer.error = e.message || 'Could not load browser state';
|
||||
}
|
||||
this.browserViewer.loading = false;
|
||||
},
|
||||
|
||||
startBrowserPolling() {
|
||||
var self = this;
|
||||
this.stopBrowserPolling();
|
||||
this._browserPollTimer = setInterval(function() {
|
||||
if (self.browserViewerOpen) {
|
||||
self.refreshBrowserView();
|
||||
} else {
|
||||
self.stopBrowserPolling();
|
||||
}
|
||||
}, 3000);
|
||||
},
|
||||
|
||||
stopBrowserPolling() {
|
||||
if (this._browserPollTimer) {
|
||||
clearInterval(this._browserPollTimer);
|
||||
this._browserPollTimer = null;
|
||||
}
|
||||
},
|
||||
|
||||
closeBrowserViewer() {
|
||||
this.stopBrowserPolling();
|
||||
this.browserViewerOpen = false;
|
||||
this.browserViewer = null;
|
||||
}
|
||||
};
|
||||
}
|
||||
255
crates/openfang-api/static/js/pages/logs.js
Normal file
255
crates/openfang-api/static/js/pages/logs.js
Normal file
@@ -0,0 +1,255 @@
|
||||
// OpenFang Logs Page — Real-time log viewer (SSE streaming + polling fallback) + Audit Trail tab
|
||||
'use strict';
|
||||
|
||||
function logsPage() {
|
||||
return {
|
||||
tab: 'live',
|
||||
// -- Live logs state --
|
||||
entries: [],
|
||||
levelFilter: '',
|
||||
textFilter: '',
|
||||
autoRefresh: true,
|
||||
hovering: false,
|
||||
loading: true,
|
||||
loadError: '',
|
||||
_pollTimer: null,
|
||||
|
||||
// -- SSE streaming state --
|
||||
_eventSource: null,
|
||||
streamConnected: false,
|
||||
streamPaused: false,
|
||||
|
||||
// -- Audit state --
|
||||
auditEntries: [],
|
||||
tipHash: '',
|
||||
chainValid: null,
|
||||
filterAction: '',
|
||||
auditLoading: false,
|
||||
auditLoadError: '',
|
||||
|
||||
startStreaming: function() {
|
||||
var self = this;
|
||||
if (this._eventSource) { this._eventSource.close(); this._eventSource = null; }
|
||||
|
||||
var url = '/api/logs/stream';
|
||||
var sep = '?';
|
||||
var token = OpenFangAPI.getToken();
|
||||
if (token) { url += sep + 'token=' + encodeURIComponent(token); sep = '&'; }
|
||||
|
||||
try {
|
||||
this._eventSource = new EventSource(url);
|
||||
} catch(e) {
|
||||
// EventSource not supported or blocked; fall back to polling
|
||||
this.streamConnected = false;
|
||||
this.startPolling();
|
||||
return;
|
||||
}
|
||||
|
||||
this._eventSource.onopen = function() {
|
||||
self.streamConnected = true;
|
||||
self.loading = false;
|
||||
self.loadError = '';
|
||||
};
|
||||
|
||||
this._eventSource.onmessage = function(event) {
|
||||
if (self.streamPaused) return;
|
||||
try {
|
||||
var entry = JSON.parse(event.data);
|
||||
// Avoid duplicate entries by checking seq
|
||||
var dominated = false;
|
||||
for (var i = 0; i < self.entries.length; i++) {
|
||||
if (self.entries[i].seq === entry.seq) { dominated = true; break; }
|
||||
}
|
||||
if (!dominated) {
|
||||
self.entries.push(entry);
|
||||
// Cap at 500 entries (remove oldest)
|
||||
if (self.entries.length > 500) {
|
||||
self.entries.splice(0, self.entries.length - 500);
|
||||
}
|
||||
// Auto-scroll to bottom
|
||||
if (self.autoRefresh && !self.hovering) {
|
||||
self.$nextTick(function() {
|
||||
var el = document.getElementById('log-container');
|
||||
if (el) el.scrollTop = el.scrollHeight;
|
||||
});
|
||||
}
|
||||
}
|
||||
} catch(e) {
|
||||
// Ignore parse errors (heartbeat comments are not delivered to onmessage)
|
||||
}
|
||||
};
|
||||
|
||||
this._eventSource.onerror = function() {
|
||||
self.streamConnected = false;
|
||||
if (self._eventSource) {
|
||||
self._eventSource.close();
|
||||
self._eventSource = null;
|
||||
}
|
||||
// Fall back to polling
|
||||
self.startPolling();
|
||||
};
|
||||
},
|
||||
|
||||
startPolling: function() {
|
||||
var self = this;
|
||||
this.streamConnected = false;
|
||||
this.fetchLogs();
|
||||
if (this._pollTimer) clearInterval(this._pollTimer);
|
||||
this._pollTimer = setInterval(function() {
|
||||
if (self.autoRefresh && !self.hovering && self.tab === 'live' && !self.streamPaused) {
|
||||
self.fetchLogs();
|
||||
}
|
||||
}, 2000);
|
||||
},
|
||||
|
||||
async fetchLogs() {
|
||||
if (this.loading) this.loadError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/audit/recent?n=200');
|
||||
this.entries = data.entries || [];
|
||||
if (this.autoRefresh && !this.hovering) {
|
||||
this.$nextTick(function() {
|
||||
var el = document.getElementById('log-container');
|
||||
if (el) el.scrollTop = el.scrollHeight;
|
||||
});
|
||||
}
|
||||
if (this.loading) this.loading = false;
|
||||
} catch(e) {
|
||||
if (this.loading) {
|
||||
this.loadError = e.message || 'Could not load logs.';
|
||||
this.loading = false;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
async loadData() {
|
||||
this.loading = true;
|
||||
return this.fetchLogs();
|
||||
},
|
||||
|
||||
togglePause: function() {
|
||||
this.streamPaused = !this.streamPaused;
|
||||
if (!this.streamPaused && this.streamConnected) {
|
||||
// Resume: scroll to bottom
|
||||
var self = this;
|
||||
this.$nextTick(function() {
|
||||
var el = document.getElementById('log-container');
|
||||
if (el) el.scrollTop = el.scrollHeight;
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
clearLogs: function() {
|
||||
this.entries = [];
|
||||
},
|
||||
|
||||
classifyLevel: function(action) {
|
||||
if (!action) return 'info';
|
||||
var a = action.toLowerCase();
|
||||
if (a.indexOf('error') !== -1 || a.indexOf('fail') !== -1 || a.indexOf('crash') !== -1) return 'error';
|
||||
if (a.indexOf('warn') !== -1 || a.indexOf('deny') !== -1 || a.indexOf('block') !== -1) return 'warn';
|
||||
return 'info';
|
||||
},
|
||||
|
||||
get filteredEntries() {
|
||||
var self = this;
|
||||
var levelF = this.levelFilter;
|
||||
var textF = this.textFilter.toLowerCase();
|
||||
return this.entries.filter(function(e) {
|
||||
if (levelF && self.classifyLevel(e.action) !== levelF) return false;
|
||||
if (textF) {
|
||||
var haystack = ((e.action || '') + ' ' + (e.detail || '') + ' ' + (e.agent_id || '')).toLowerCase();
|
||||
if (haystack.indexOf(textF) === -1) return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
},
|
||||
|
||||
get connectionLabel() {
|
||||
if (this.streamPaused) return 'Paused';
|
||||
if (this.streamConnected) return 'Live';
|
||||
if (this._pollTimer) return 'Polling';
|
||||
return 'Disconnected';
|
||||
},
|
||||
|
||||
get connectionClass() {
|
||||
if (this.streamPaused) return 'paused';
|
||||
if (this.streamConnected) return 'live';
|
||||
if (this._pollTimer) return 'polling';
|
||||
return 'disconnected';
|
||||
},
|
||||
|
||||
exportLogs: function() {
|
||||
var lines = this.filteredEntries.map(function(e) {
|
||||
return new Date(e.timestamp).toISOString() + ' [' + e.action + '] ' + (e.detail || '');
|
||||
});
|
||||
var blob = new Blob([lines.join('\n')], { type: 'text/plain' });
|
||||
var url = URL.createObjectURL(blob);
|
||||
var a = document.createElement('a');
|
||||
a.href = url;
|
||||
a.download = 'openfang-logs-' + new Date().toISOString().slice(0, 10) + '.txt';
|
||||
a.click();
|
||||
URL.revokeObjectURL(url);
|
||||
},
|
||||
|
||||
// -- Audit methods --
|
||||
get filteredAuditEntries() {
|
||||
var self = this;
|
||||
if (!self.filterAction) return self.auditEntries;
|
||||
return self.auditEntries.filter(function(e) { return e.action === self.filterAction; });
|
||||
},
|
||||
|
||||
async loadAudit() {
|
||||
this.auditLoading = true;
|
||||
this.auditLoadError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/audit/recent?n=200');
|
||||
this.auditEntries = data.entries || [];
|
||||
this.tipHash = data.tip_hash || '';
|
||||
} catch(e) {
|
||||
this.auditEntries = [];
|
||||
this.auditLoadError = e.message || 'Could not load audit log.';
|
||||
}
|
||||
this.auditLoading = false;
|
||||
},
|
||||
|
||||
auditAgentName: function(agentId) {
|
||||
if (!agentId) return '-';
|
||||
var agents = Alpine.store('app').agents || [];
|
||||
var agent = agents.find(function(a) { return a.id === agentId; });
|
||||
return agent ? agent.name : agentId.substring(0, 8) + '...';
|
||||
},
|
||||
|
||||
friendlyAction: function(action) {
|
||||
if (!action) return 'Unknown';
|
||||
var map = {
|
||||
'AgentSpawn': 'Agent Created', 'AgentKill': 'Agent Stopped', 'AgentTerminated': 'Agent Stopped',
|
||||
'ToolInvoke': 'Tool Used', 'ToolResult': 'Tool Completed', 'AgentMessage': 'Message',
|
||||
'NetworkAccess': 'Network Access', 'ShellExec': 'Shell Command', 'FileAccess': 'File Access',
|
||||
'MemoryAccess': 'Memory Access', 'AuthAttempt': 'Login Attempt', 'AuthSuccess': 'Login Success',
|
||||
'AuthFailure': 'Login Failed', 'CapabilityDenied': 'Permission Denied', 'RateLimited': 'Rate Limited'
|
||||
};
|
||||
return map[action] || action.replace(/([A-Z])/g, ' $1').trim();
|
||||
},
|
||||
|
||||
async verifyChain() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/audit/verify');
|
||||
this.chainValid = data.valid === true;
|
||||
if (this.chainValid) {
|
||||
OpenFangToast.success('Audit chain verified — ' + (data.entries || 0) + ' entries valid');
|
||||
} else {
|
||||
OpenFangToast.error('Audit chain broken!');
|
||||
}
|
||||
} catch(e) {
|
||||
this.chainValid = false;
|
||||
OpenFangToast.error('Chain verification failed: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
destroy: function() {
|
||||
if (this._eventSource) { this._eventSource.close(); this._eventSource = null; }
|
||||
if (this._pollTimer) { clearInterval(this._pollTimer); this._pollTimer = null; }
|
||||
}
|
||||
};
|
||||
}
|
||||
292
crates/openfang-api/static/js/pages/overview.js
Normal file
292
crates/openfang-api/static/js/pages/overview.js
Normal file
@@ -0,0 +1,292 @@
|
||||
// OpenFang Overview Dashboard — Landing page with system stats + provider status
|
||||
'use strict';
|
||||
|
||||
function overviewPage() {
|
||||
return {
|
||||
health: {},
|
||||
status: {},
|
||||
usageSummary: {},
|
||||
recentAudit: [],
|
||||
channels: [],
|
||||
providers: [],
|
||||
mcpServers: [],
|
||||
skillCount: 0,
|
||||
loading: true,
|
||||
loadError: '',
|
||||
refreshTimer: null,
|
||||
lastRefresh: null,
|
||||
|
||||
async loadOverview() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
await Promise.all([
|
||||
this.loadHealth(),
|
||||
this.loadStatus(),
|
||||
this.loadUsage(),
|
||||
this.loadAudit(),
|
||||
this.loadChannels(),
|
||||
this.loadProviders(),
|
||||
this.loadMcpServers(),
|
||||
this.loadSkills()
|
||||
]);
|
||||
this.lastRefresh = Date.now();
|
||||
} catch(e) {
|
||||
this.loadError = e.message || 'Could not load overview data.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadData() { return this.loadOverview(); },
|
||||
|
||||
// Silent background refresh (no loading spinner)
|
||||
async silentRefresh() {
|
||||
try {
|
||||
await Promise.all([
|
||||
this.loadHealth(),
|
||||
this.loadStatus(),
|
||||
this.loadUsage(),
|
||||
this.loadAudit(),
|
||||
this.loadChannels(),
|
||||
this.loadProviders(),
|
||||
this.loadMcpServers(),
|
||||
this.loadSkills()
|
||||
]);
|
||||
this.lastRefresh = Date.now();
|
||||
} catch(e) { /* silent */ }
|
||||
},
|
||||
|
||||
startAutoRefresh() {
|
||||
this.stopAutoRefresh();
|
||||
this.refreshTimer = setInterval(() => this.silentRefresh(), 30000);
|
||||
},
|
||||
|
||||
stopAutoRefresh() {
|
||||
if (this.refreshTimer) {
|
||||
clearInterval(this.refreshTimer);
|
||||
this.refreshTimer = null;
|
||||
}
|
||||
},
|
||||
|
||||
async loadHealth() {
|
||||
try {
|
||||
this.health = await OpenFangAPI.get('/api/health');
|
||||
} catch(e) { this.health = { status: 'unreachable' }; }
|
||||
},
|
||||
|
||||
async loadStatus() {
|
||||
try {
|
||||
this.status = await OpenFangAPI.get('/api/status');
|
||||
} catch(e) { this.status = {}; throw e; }
|
||||
},
|
||||
|
||||
async loadUsage() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/usage');
|
||||
var agents = data.agents || [];
|
||||
var totalTokens = 0;
|
||||
var totalTools = 0;
|
||||
var totalCost = 0;
|
||||
agents.forEach(function(a) {
|
||||
totalTokens += (a.total_tokens || 0);
|
||||
totalTools += (a.tool_calls || 0);
|
||||
totalCost += (a.cost_usd || 0);
|
||||
});
|
||||
this.usageSummary = {
|
||||
total_tokens: totalTokens,
|
||||
total_tools: totalTools,
|
||||
total_cost: totalCost,
|
||||
agent_count: agents.length
|
||||
};
|
||||
} catch(e) {
|
||||
this.usageSummary = { total_tokens: 0, total_tools: 0, total_cost: 0, agent_count: 0 };
|
||||
}
|
||||
},
|
||||
|
||||
async loadAudit() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/audit/recent?n=8');
|
||||
this.recentAudit = data.entries || [];
|
||||
} catch(e) { this.recentAudit = []; }
|
||||
},
|
||||
|
||||
async loadChannels() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/channels');
|
||||
this.channels = (data.channels || []).filter(function(ch) { return ch.has_token; });
|
||||
} catch(e) { this.channels = []; }
|
||||
},
|
||||
|
||||
async loadProviders() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/providers');
|
||||
this.providers = data.providers || [];
|
||||
} catch(e) { this.providers = []; }
|
||||
},
|
||||
|
||||
async loadMcpServers() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/mcp/servers');
|
||||
this.mcpServers = data.servers || [];
|
||||
} catch(e) { this.mcpServers = []; }
|
||||
},
|
||||
|
||||
async loadSkills() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/skills');
|
||||
this.skillCount = (data.skills || []).length;
|
||||
} catch(e) { this.skillCount = 0; }
|
||||
},
|
||||
|
||||
get configuredProviders() {
|
||||
return this.providers.filter(function(p) { return p.auth_status === 'configured'; });
|
||||
},
|
||||
|
||||
get unconfiguredProviders() {
|
||||
return this.providers.filter(function(p) { return p.auth_status === 'not_set' || p.auth_status === 'missing'; });
|
||||
},
|
||||
|
||||
get connectedMcp() {
|
||||
return this.mcpServers.filter(function(s) { return s.status === 'connected'; });
|
||||
},
|
||||
|
||||
// Provider health badge color
|
||||
providerBadgeClass(p) {
|
||||
if (p.auth_status === 'configured') {
|
||||
if (p.health === 'cooldown' || p.health === 'open') return 'badge-warn';
|
||||
return 'badge-success';
|
||||
}
|
||||
if (p.auth_status === 'not_set' || p.auth_status === 'missing') return 'badge-muted';
|
||||
return 'badge-dim';
|
||||
},
|
||||
|
||||
// Provider health tooltip
|
||||
providerTooltip(p) {
|
||||
if (p.health === 'cooldown') return p.display_name + ' \u2014 cooling down (rate limited)';
|
||||
if (p.health === 'open') return p.display_name + ' \u2014 circuit breaker open';
|
||||
if (p.auth_status === 'configured') return p.display_name + ' \u2014 ready';
|
||||
return p.display_name + ' \u2014 not configured';
|
||||
},
|
||||
|
||||
// Audit action badge color
|
||||
actionBadgeClass(action) {
|
||||
if (!action) return 'badge-dim';
|
||||
if (action === 'AgentSpawn' || action === 'AuthSuccess') return 'badge-success';
|
||||
if (action === 'AgentKill' || action === 'AgentTerminated' || action === 'AuthFailure' || action === 'CapabilityDenied') return 'badge-error';
|
||||
if (action === 'RateLimited' || action === 'ToolInvoke') return 'badge-warn';
|
||||
return 'badge-created';
|
||||
},
|
||||
|
||||
// ── Setup Checklist ──
|
||||
checklistDismissed: localStorage.getItem('of-checklist-dismissed') === 'true',
|
||||
|
||||
get setupChecklist() {
|
||||
return [
|
||||
{ key: 'provider', label: 'Configure an LLM provider', done: this.configuredProviders.length > 0, action: '#settings' },
|
||||
{ key: 'agent', label: 'Create your first agent', done: (Alpine.store('app').agents || []).length > 0, action: '#agents' },
|
||||
{ key: 'chat', label: 'Send your first message', done: localStorage.getItem('of-first-msg') === 'true', action: '#chat' },
|
||||
{ key: 'channel', label: 'Connect a messaging channel', done: this.channels.length > 0, action: '#channels' },
|
||||
{ key: 'skill', label: 'Browse or install a skill', done: localStorage.getItem('of-skill-browsed') === 'true', action: '#skills' }
|
||||
];
|
||||
},
|
||||
|
||||
get setupProgress() {
|
||||
var done = this.setupChecklist.filter(function(item) { return item.done; }).length;
|
||||
return (done / 5) * 100;
|
||||
},
|
||||
|
||||
get setupDoneCount() {
|
||||
return this.setupChecklist.filter(function(item) { return item.done; }).length;
|
||||
},
|
||||
|
||||
dismissChecklist() {
|
||||
this.checklistDismissed = true;
|
||||
localStorage.setItem('of-checklist-dismissed', 'true');
|
||||
},
|
||||
|
||||
formatUptime(secs) {
|
||||
if (!secs) return '-';
|
||||
var d = Math.floor(secs / 86400);
|
||||
var h = Math.floor((secs % 86400) / 3600);
|
||||
var m = Math.floor((secs % 3600) / 60);
|
||||
if (d > 0) return d + 'd ' + h + 'h';
|
||||
if (h > 0) return h + 'h ' + m + 'm';
|
||||
return m + 'm';
|
||||
},
|
||||
|
||||
formatNumber(n) {
|
||||
if (!n) return '0';
|
||||
if (n >= 1000000) return (n / 1000000).toFixed(1) + 'M';
|
||||
if (n >= 1000) return (n / 1000).toFixed(1) + 'K';
|
||||
return String(n);
|
||||
},
|
||||
|
||||
formatCost(n) {
|
||||
if (!n || n === 0) return '$0.00';
|
||||
if (n < 0.01) return '<$0.01';
|
||||
return '$' + n.toFixed(2);
|
||||
},
|
||||
|
||||
// Relative time formatting ("2m ago", "1h ago", "just now")
|
||||
timeAgo(timestamp) {
|
||||
if (!timestamp) return '';
|
||||
var now = Date.now();
|
||||
var ts = new Date(timestamp).getTime();
|
||||
var diff = Math.floor((now - ts) / 1000);
|
||||
if (diff < 10) return 'just now';
|
||||
if (diff < 60) return diff + 's ago';
|
||||
if (diff < 3600) return Math.floor(diff / 60) + 'm ago';
|
||||
if (diff < 86400) return Math.floor(diff / 3600) + 'h ago';
|
||||
return Math.floor(diff / 86400) + 'd ago';
|
||||
},
|
||||
|
||||
// Map raw audit action names to user-friendly labels
|
||||
friendlyAction(action) {
|
||||
if (!action) return 'Unknown';
|
||||
var map = {
|
||||
'AgentSpawn': 'Agent Created',
|
||||
'AgentKill': 'Agent Stopped',
|
||||
'AgentTerminated': 'Agent Stopped',
|
||||
'ToolInvoke': 'Tool Used',
|
||||
'ToolResult': 'Tool Completed',
|
||||
'MessageReceived': 'Message In',
|
||||
'MessageSent': 'Response Sent',
|
||||
'SessionReset': 'Session Reset',
|
||||
'SessionCompact': 'Compacted',
|
||||
'ModelSwitch': 'Model Changed',
|
||||
'AuthAttempt': 'Login Attempt',
|
||||
'AuthSuccess': 'Login OK',
|
||||
'AuthFailure': 'Login Failed',
|
||||
'CapabilityDenied': 'Denied',
|
||||
'RateLimited': 'Rate Limited',
|
||||
'WorkflowRun': 'Workflow Run',
|
||||
'TriggerFired': 'Trigger Fired',
|
||||
'SkillInstalled': 'Skill Installed',
|
||||
'McpConnected': 'MCP Connected'
|
||||
};
|
||||
return map[action] || action.replace(/([A-Z])/g, ' $1').trim();
|
||||
},
|
||||
|
||||
// Audit action icon (small inline SVG)
|
||||
actionIcon(action) {
|
||||
if (!action) return '';
|
||||
var icons = {
|
||||
'AgentSpawn': '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><circle cx="12" cy="12" r="10"/><path d="M12 8v8M8 12h8"/></svg>',
|
||||
'AgentKill': '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><circle cx="12" cy="12" r="10"/><path d="M15 9l-6 6M9 9l6 6"/></svg>',
|
||||
'AgentTerminated': '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><circle cx="12" cy="12" r="10"/><path d="M15 9l-6 6M9 9l6 6"/></svg>',
|
||||
'ToolInvoke': '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M14.7 6.3a1 1 0 0 0 0 1.4l1.6 1.6a1 1 0 0 0 1.4 0l3.77-3.77a6 6 0 0 1-7.94 7.94l-6.91 6.91a2.12 2.12 0 0 1-3-3l6.91-6.91a6 6 0 0 1 7.94-7.94l-3.76 3.76z"/></svg>',
|
||||
'MessageReceived': '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M21 15a2 2 0 0 1-2 2H7l-4 4V5a2 2 0 0 1 2-2h14a2 2 0 0 1 2 2z"/></svg>',
|
||||
'MessageSent': '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><path d="M22 2L11 13M22 2l-7 20-4-9-9-4 20-7z"/></svg>'
|
||||
};
|
||||
return icons[action] || '<svg width="12" height="12" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2"><circle cx="12" cy="12" r="10"/></svg>';
|
||||
},
|
||||
|
||||
// Resolve agent UUID to name if possible
|
||||
agentName(agentId) {
|
||||
if (!agentId) return '-';
|
||||
var agents = Alpine.store('app').agents || [];
|
||||
var agent = agents.find(function(a) { return a.id === agentId; });
|
||||
return agent ? agent.name : agentId.substring(0, 8) + '\u2026';
|
||||
}
|
||||
};
|
||||
}
|
||||
393
crates/openfang-api/static/js/pages/scheduler.js
Normal file
393
crates/openfang-api/static/js/pages/scheduler.js
Normal file
@@ -0,0 +1,393 @@
|
||||
// OpenFang Scheduler Page — Cron job management + event triggers unified view
|
||||
'use strict';
|
||||
|
||||
function schedulerPage() {
|
||||
return {
|
||||
tab: 'jobs',
|
||||
|
||||
// -- Scheduled Jobs state --
|
||||
jobs: [],
|
||||
loading: true,
|
||||
loadError: '',
|
||||
|
||||
// -- Event Triggers state --
|
||||
triggers: [],
|
||||
trigLoading: false,
|
||||
trigLoadError: '',
|
||||
|
||||
// -- Run History state --
|
||||
history: [],
|
||||
historyLoading: false,
|
||||
|
||||
// -- Create Job form --
|
||||
showCreateForm: false,
|
||||
newJob: {
|
||||
name: '',
|
||||
cron: '',
|
||||
agent_id: '',
|
||||
message: '',
|
||||
enabled: true
|
||||
},
|
||||
creating: false,
|
||||
|
||||
// -- Run Now state --
|
||||
runningJobId: '',
|
||||
|
||||
// Cron presets
|
||||
cronPresets: [
|
||||
{ label: 'Every minute', cron: '* * * * *' },
|
||||
{ label: 'Every 5 minutes', cron: '*/5 * * * *' },
|
||||
{ label: 'Every 15 minutes', cron: '*/15 * * * *' },
|
||||
{ label: 'Every 30 minutes', cron: '*/30 * * * *' },
|
||||
{ label: 'Every hour', cron: '0 * * * *' },
|
||||
{ label: 'Every 6 hours', cron: '0 */6 * * *' },
|
||||
{ label: 'Daily at midnight', cron: '0 0 * * *' },
|
||||
{ label: 'Daily at 9am', cron: '0 9 * * *' },
|
||||
{ label: 'Weekdays at 9am', cron: '0 9 * * 1-5' },
|
||||
{ label: 'Every Monday 9am', cron: '0 9 * * 1' },
|
||||
{ label: 'First of month', cron: '0 0 1 * *' }
|
||||
],
|
||||
|
||||
// ── Lifecycle ──
|
||||
|
||||
async loadData() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
await this.loadJobs();
|
||||
} catch(e) {
|
||||
this.loadError = e.message || 'Could not load scheduler data.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadJobs() {
|
||||
var data = await OpenFangAPI.get('/api/cron/jobs');
|
||||
var raw = data.jobs || [];
|
||||
// Normalize cron API response to flat fields the UI expects
|
||||
this.jobs = raw.map(function(j) {
|
||||
var cron = '';
|
||||
if (j.schedule) {
|
||||
if (j.schedule.kind === 'cron') cron = j.schedule.expr || '';
|
||||
else if (j.schedule.kind === 'every') cron = 'every ' + j.schedule.every_secs + 's';
|
||||
else if (j.schedule.kind === 'at') cron = 'at ' + (j.schedule.at || '');
|
||||
}
|
||||
return {
|
||||
id: j.id,
|
||||
name: j.name,
|
||||
cron: cron,
|
||||
agent_id: j.agent_id,
|
||||
message: j.action ? j.action.message || '' : '',
|
||||
enabled: j.enabled,
|
||||
last_run: j.last_run,
|
||||
next_run: j.next_run,
|
||||
delivery: j.delivery ? j.delivery.kind || '' : '',
|
||||
created_at: j.created_at
|
||||
};
|
||||
});
|
||||
},
|
||||
|
||||
async loadTriggers() {
|
||||
this.trigLoading = true;
|
||||
this.trigLoadError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/triggers');
|
||||
this.triggers = Array.isArray(data) ? data : [];
|
||||
} catch(e) {
|
||||
this.triggers = [];
|
||||
this.trigLoadError = e.message || 'Could not load triggers.';
|
||||
}
|
||||
this.trigLoading = false;
|
||||
},
|
||||
|
||||
async loadHistory() {
|
||||
this.historyLoading = true;
|
||||
try {
|
||||
var historyItems = [];
|
||||
var jobs = this.jobs || [];
|
||||
for (var i = 0; i < jobs.length; i++) {
|
||||
var job = jobs[i];
|
||||
if (job.last_run) {
|
||||
historyItems.push({
|
||||
timestamp: job.last_run,
|
||||
name: job.name || '(unnamed)',
|
||||
type: 'schedule',
|
||||
status: 'completed',
|
||||
run_count: 0
|
||||
});
|
||||
}
|
||||
}
|
||||
var triggers = this.triggers || [];
|
||||
for (var j = 0; j < triggers.length; j++) {
|
||||
var t = triggers[j];
|
||||
if (t.fire_count > 0) {
|
||||
historyItems.push({
|
||||
timestamp: t.created_at,
|
||||
name: 'Trigger: ' + this.triggerType(t.pattern),
|
||||
type: 'trigger',
|
||||
status: 'fired',
|
||||
run_count: t.fire_count
|
||||
});
|
||||
}
|
||||
}
|
||||
historyItems.sort(function(a, b) {
|
||||
return new Date(b.timestamp).getTime() - new Date(a.timestamp).getTime();
|
||||
});
|
||||
this.history = historyItems;
|
||||
} catch(e) {
|
||||
this.history = [];
|
||||
}
|
||||
this.historyLoading = false;
|
||||
},
|
||||
|
||||
// ── Job CRUD ──
|
||||
|
||||
async createJob() {
|
||||
if (!this.newJob.name.trim()) {
|
||||
OpenFangToast.warn('Please enter a job name');
|
||||
return;
|
||||
}
|
||||
if (!this.newJob.cron.trim()) {
|
||||
OpenFangToast.warn('Please enter a cron expression');
|
||||
return;
|
||||
}
|
||||
this.creating = true;
|
||||
try {
|
||||
var jobName = this.newJob.name;
|
||||
var body = {
|
||||
agent_id: this.newJob.agent_id,
|
||||
name: this.newJob.name,
|
||||
schedule: { kind: 'cron', expr: this.newJob.cron },
|
||||
action: { kind: 'agent_turn', message: this.newJob.message || 'Scheduled task: ' + this.newJob.name },
|
||||
delivery: { kind: 'last_channel' },
|
||||
enabled: this.newJob.enabled
|
||||
};
|
||||
await OpenFangAPI.post('/api/cron/jobs', body);
|
||||
this.showCreateForm = false;
|
||||
this.newJob = { name: '', cron: '', agent_id: '', message: '', enabled: true };
|
||||
OpenFangToast.success('Schedule "' + jobName + '" created');
|
||||
await this.loadJobs();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to create schedule: ' + (e.message || e));
|
||||
}
|
||||
this.creating = false;
|
||||
},
|
||||
|
||||
async toggleJob(job) {
|
||||
try {
|
||||
var newState = !job.enabled;
|
||||
await OpenFangAPI.put('/api/cron/jobs/' + job.id + '/enable', { enabled: newState });
|
||||
job.enabled = newState;
|
||||
OpenFangToast.success('Schedule ' + (newState ? 'enabled' : 'paused'));
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to toggle schedule: ' + (e.message || e));
|
||||
}
|
||||
},
|
||||
|
||||
deleteJob(job) {
|
||||
var self = this;
|
||||
var jobName = job.name || job.id;
|
||||
OpenFangToast.confirm('Delete Schedule', 'Delete "' + jobName + '"? This cannot be undone.', async function() {
|
||||
try {
|
||||
await OpenFangAPI.del('/api/cron/jobs/' + job.id);
|
||||
self.jobs = self.jobs.filter(function(j) { return j.id !== job.id; });
|
||||
OpenFangToast.success('Schedule "' + jobName + '" deleted');
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to delete schedule: ' + (e.message || e));
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
async runNow(job) {
|
||||
this.runningJobId = job.id;
|
||||
try {
|
||||
var result = await OpenFangAPI.post('/api/schedules/' + job.id + '/run', {});
|
||||
if (result.status === 'completed') {
|
||||
OpenFangToast.success('Schedule "' + (job.name || 'job') + '" executed successfully');
|
||||
job.last_run = new Date().toISOString();
|
||||
} else {
|
||||
OpenFangToast.error('Schedule run failed: ' + (result.error || 'Unknown error'));
|
||||
}
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Run Now is not yet available for cron jobs');
|
||||
}
|
||||
this.runningJobId = '';
|
||||
},
|
||||
|
||||
// ── Trigger helpers ──
|
||||
|
||||
triggerType(pattern) {
|
||||
if (!pattern) return 'unknown';
|
||||
if (typeof pattern === 'string') return pattern;
|
||||
var keys = Object.keys(pattern);
|
||||
if (keys.length === 0) return 'unknown';
|
||||
var key = keys[0];
|
||||
var names = {
|
||||
lifecycle: 'Lifecycle',
|
||||
agent_spawned: 'Agent Spawned',
|
||||
agent_terminated: 'Agent Terminated',
|
||||
system: 'System',
|
||||
system_keyword: 'System Keyword',
|
||||
memory_update: 'Memory Update',
|
||||
memory_key_pattern: 'Memory Key',
|
||||
all: 'All Events',
|
||||
content_match: 'Content Match'
|
||||
};
|
||||
return names[key] || key.replace(/_/g, ' ');
|
||||
},
|
||||
|
||||
async toggleTrigger(trigger) {
|
||||
try {
|
||||
var newState = !trigger.enabled;
|
||||
await OpenFangAPI.put('/api/triggers/' + trigger.id, { enabled: newState });
|
||||
trigger.enabled = newState;
|
||||
OpenFangToast.success('Trigger ' + (newState ? 'enabled' : 'disabled'));
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to toggle trigger: ' + (e.message || e));
|
||||
}
|
||||
},
|
||||
|
||||
deleteTrigger(trigger) {
|
||||
var self = this;
|
||||
OpenFangToast.confirm('Delete Trigger', 'Delete this trigger? This cannot be undone.', async function() {
|
||||
try {
|
||||
await OpenFangAPI.del('/api/triggers/' + trigger.id);
|
||||
self.triggers = self.triggers.filter(function(t) { return t.id !== trigger.id; });
|
||||
OpenFangToast.success('Trigger deleted');
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to delete trigger: ' + (e.message || e));
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
// ── Utility ──
|
||||
|
||||
get availableAgents() {
|
||||
return Alpine.store('app').agents || [];
|
||||
},
|
||||
|
||||
agentName(agentId) {
|
||||
if (!agentId) return '(any)';
|
||||
var agents = this.availableAgents;
|
||||
for (var i = 0; i < agents.length; i++) {
|
||||
if (agents[i].id === agentId) return agents[i].name;
|
||||
}
|
||||
if (agentId.length > 12) return agentId.substring(0, 8) + '...';
|
||||
return agentId;
|
||||
},
|
||||
|
||||
describeCron(expr) {
|
||||
if (!expr) return '';
|
||||
// Handle non-cron schedule descriptions
|
||||
if (expr.indexOf('every ') === 0) return expr;
|
||||
if (expr.indexOf('at ') === 0) return 'One-time: ' + expr.substring(3);
|
||||
|
||||
var map = {
|
||||
'* * * * *': 'Every minute',
|
||||
'*/2 * * * *': 'Every 2 minutes',
|
||||
'*/5 * * * *': 'Every 5 minutes',
|
||||
'*/10 * * * *': 'Every 10 minutes',
|
||||
'*/15 * * * *': 'Every 15 minutes',
|
||||
'*/30 * * * *': 'Every 30 minutes',
|
||||
'0 * * * *': 'Every hour',
|
||||
'0 */2 * * *': 'Every 2 hours',
|
||||
'0 */4 * * *': 'Every 4 hours',
|
||||
'0 */6 * * *': 'Every 6 hours',
|
||||
'0 */12 * * *': 'Every 12 hours',
|
||||
'0 0 * * *': 'Daily at midnight',
|
||||
'0 6 * * *': 'Daily at 6:00 AM',
|
||||
'0 9 * * *': 'Daily at 9:00 AM',
|
||||
'0 12 * * *': 'Daily at noon',
|
||||
'0 18 * * *': 'Daily at 6:00 PM',
|
||||
'0 9 * * 1-5': 'Weekdays at 9:00 AM',
|
||||
'0 9 * * 1': 'Mondays at 9:00 AM',
|
||||
'0 0 * * 0': 'Sundays at midnight',
|
||||
'0 0 1 * *': '1st of every month',
|
||||
'0 0 * * 1': 'Mondays at midnight'
|
||||
};
|
||||
if (map[expr]) return map[expr];
|
||||
|
||||
var parts = expr.split(' ');
|
||||
if (parts.length !== 5) return expr;
|
||||
|
||||
var min = parts[0];
|
||||
var hour = parts[1];
|
||||
var dom = parts[2];
|
||||
var mon = parts[3];
|
||||
var dow = parts[4];
|
||||
|
||||
if (min.indexOf('*/') === 0 && hour === '*' && dom === '*' && mon === '*' && dow === '*') {
|
||||
return 'Every ' + min.substring(2) + ' minutes';
|
||||
}
|
||||
if (min === '0' && hour.indexOf('*/') === 0 && dom === '*' && mon === '*' && dow === '*') {
|
||||
return 'Every ' + hour.substring(2) + ' hours';
|
||||
}
|
||||
|
||||
var dowNames = { '0': 'Sun', '1': 'Mon', '2': 'Tue', '3': 'Wed', '4': 'Thu', '5': 'Fri', '6': 'Sat', '7': 'Sun',
|
||||
'1-5': 'Weekdays', '0,6': 'Weekends', '6,0': 'Weekends' };
|
||||
|
||||
if (dom === '*' && mon === '*' && min.match(/^\d+$/) && hour.match(/^\d+$/)) {
|
||||
var h = parseInt(hour, 10);
|
||||
var m = parseInt(min, 10);
|
||||
var ampm = h >= 12 ? 'PM' : 'AM';
|
||||
var h12 = h === 0 ? 12 : (h > 12 ? h - 12 : h);
|
||||
var mStr = m < 10 ? '0' + m : '' + m;
|
||||
var timeStr = h12 + ':' + mStr + ' ' + ampm;
|
||||
if (dow === '*') return 'Daily at ' + timeStr;
|
||||
var dowLabel = dowNames[dow] || ('DoW ' + dow);
|
||||
return dowLabel + ' at ' + timeStr;
|
||||
}
|
||||
|
||||
return expr;
|
||||
},
|
||||
|
||||
applyCronPreset(preset) {
|
||||
this.newJob.cron = preset.cron;
|
||||
},
|
||||
|
||||
formatTime(ts) {
|
||||
if (!ts) return '-';
|
||||
try {
|
||||
var d = new Date(ts);
|
||||
if (isNaN(d.getTime())) return '-';
|
||||
return d.toLocaleString();
|
||||
} catch(e) { return '-'; }
|
||||
},
|
||||
|
||||
relativeTime(ts) {
|
||||
if (!ts) return 'never';
|
||||
try {
|
||||
var diff = Date.now() - new Date(ts).getTime();
|
||||
if (isNaN(diff)) return 'never';
|
||||
if (diff < 0) {
|
||||
// Future time
|
||||
var absDiff = Math.abs(diff);
|
||||
if (absDiff < 60000) return 'in <1m';
|
||||
if (absDiff < 3600000) return 'in ' + Math.floor(absDiff / 60000) + 'm';
|
||||
if (absDiff < 86400000) return 'in ' + Math.floor(absDiff / 3600000) + 'h';
|
||||
return 'in ' + Math.floor(absDiff / 86400000) + 'd';
|
||||
}
|
||||
if (diff < 60000) return 'just now';
|
||||
if (diff < 3600000) return Math.floor(diff / 60000) + 'm ago';
|
||||
if (diff < 86400000) return Math.floor(diff / 3600000) + 'h ago';
|
||||
return Math.floor(diff / 86400000) + 'd ago';
|
||||
} catch(e) { return 'never'; }
|
||||
},
|
||||
|
||||
jobCount() {
|
||||
var enabled = 0;
|
||||
for (var i = 0; i < this.jobs.length; i++) {
|
||||
if (this.jobs[i].enabled) enabled++;
|
||||
}
|
||||
return enabled;
|
||||
},
|
||||
|
||||
triggerCount() {
|
||||
var enabled = 0;
|
||||
for (var i = 0; i < this.triggers.length; i++) {
|
||||
if (this.triggers[i].enabled) enabled++;
|
||||
}
|
||||
return enabled;
|
||||
}
|
||||
};
|
||||
}
|
||||
147
crates/openfang-api/static/js/pages/sessions.js
Normal file
147
crates/openfang-api/static/js/pages/sessions.js
Normal file
@@ -0,0 +1,147 @@
|
||||
// OpenFang Sessions Page — Session listing + Memory tab
|
||||
'use strict';
|
||||
|
||||
function sessionsPage() {
|
||||
return {
|
||||
tab: 'sessions',
|
||||
// -- Sessions state --
|
||||
sessions: [],
|
||||
searchFilter: '',
|
||||
loading: true,
|
||||
loadError: '',
|
||||
|
||||
// -- Memory state --
|
||||
memAgentId: '',
|
||||
kvPairs: [],
|
||||
showAdd: false,
|
||||
newKey: '',
|
||||
newValue: '""',
|
||||
editingKey: null,
|
||||
editingValue: '',
|
||||
memLoading: false,
|
||||
memLoadError: '',
|
||||
|
||||
// -- Sessions methods --
|
||||
async loadSessions() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/sessions');
|
||||
var sessions = data.sessions || [];
|
||||
var agents = Alpine.store('app').agents;
|
||||
var agentMap = {};
|
||||
agents.forEach(function(a) { agentMap[a.id] = a.name; });
|
||||
sessions.forEach(function(s) {
|
||||
s.agent_name = agentMap[s.agent_id] || '';
|
||||
});
|
||||
this.sessions = sessions;
|
||||
} catch(e) {
|
||||
this.sessions = [];
|
||||
this.loadError = e.message || 'Could not load sessions.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadData() { return this.loadSessions(); },
|
||||
|
||||
get filteredSessions() {
|
||||
var f = this.searchFilter.toLowerCase();
|
||||
if (!f) return this.sessions;
|
||||
return this.sessions.filter(function(s) {
|
||||
return (s.agent_name || '').toLowerCase().indexOf(f) !== -1 ||
|
||||
(s.agent_id || '').toLowerCase().indexOf(f) !== -1;
|
||||
});
|
||||
},
|
||||
|
||||
openInChat(session) {
|
||||
var agents = Alpine.store('app').agents;
|
||||
var agent = agents.find(function(a) { return a.id === session.agent_id; });
|
||||
if (agent) {
|
||||
Alpine.store('app').pendingAgent = agent;
|
||||
}
|
||||
location.hash = 'agents';
|
||||
},
|
||||
|
||||
deleteSession(sessionId) {
|
||||
var self = this;
|
||||
OpenFangToast.confirm('Delete Session', 'This will permanently remove the session and its messages.', async function() {
|
||||
try {
|
||||
await OpenFangAPI.del('/api/sessions/' + sessionId);
|
||||
self.sessions = self.sessions.filter(function(s) { return s.session_id !== sessionId; });
|
||||
OpenFangToast.success('Session deleted');
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to delete session: ' + e.message);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
// -- Memory methods --
|
||||
async loadKv() {
|
||||
if (!this.memAgentId) { this.kvPairs = []; return; }
|
||||
this.memLoading = true;
|
||||
this.memLoadError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/memory/agents/' + this.memAgentId + '/kv');
|
||||
this.kvPairs = data.kv_pairs || [];
|
||||
} catch(e) {
|
||||
this.kvPairs = [];
|
||||
this.memLoadError = e.message || 'Could not load memory data.';
|
||||
}
|
||||
this.memLoading = false;
|
||||
},
|
||||
|
||||
async addKey() {
|
||||
if (!this.memAgentId || !this.newKey.trim()) return;
|
||||
var value;
|
||||
try { value = JSON.parse(this.newValue); } catch(e) { value = this.newValue; }
|
||||
try {
|
||||
await OpenFangAPI.put('/api/memory/agents/' + this.memAgentId + '/kv/' + encodeURIComponent(this.newKey), { value: value });
|
||||
this.showAdd = false;
|
||||
OpenFangToast.success('Key "' + this.newKey + '" saved');
|
||||
this.newKey = '';
|
||||
this.newValue = '""';
|
||||
await this.loadKv();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to save key: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
deleteKey(key) {
|
||||
var self = this;
|
||||
OpenFangToast.confirm('Delete Key', 'Delete key "' + key + '"? This cannot be undone.', async function() {
|
||||
try {
|
||||
await OpenFangAPI.del('/api/memory/agents/' + self.memAgentId + '/kv/' + encodeURIComponent(key));
|
||||
OpenFangToast.success('Key "' + key + '" deleted');
|
||||
await self.loadKv();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to delete key: ' + e.message);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
startEdit(kv) {
|
||||
this.editingKey = kv.key;
|
||||
this.editingValue = typeof kv.value === 'object' ? JSON.stringify(kv.value, null, 2) : String(kv.value);
|
||||
},
|
||||
|
||||
cancelEdit() {
|
||||
this.editingKey = null;
|
||||
this.editingValue = '';
|
||||
},
|
||||
|
||||
async saveEdit() {
|
||||
if (!this.editingKey || !this.memAgentId) return;
|
||||
var value;
|
||||
try { value = JSON.parse(this.editingValue); } catch(e) { value = this.editingValue; }
|
||||
try {
|
||||
await OpenFangAPI.put('/api/memory/agents/' + this.memAgentId + '/kv/' + encodeURIComponent(this.editingKey), { value: value });
|
||||
OpenFangToast.success('Key "' + this.editingKey + '" updated');
|
||||
this.editingKey = null;
|
||||
this.editingValue = '';
|
||||
await this.loadKv();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to save: ' + e.message);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
669
crates/openfang-api/static/js/pages/settings.js
Normal file
669
crates/openfang-api/static/js/pages/settings.js
Normal file
@@ -0,0 +1,669 @@
|
||||
// OpenFang Settings Page — Provider Hub, Model Catalog, Config, Tools + Security, Network, Migration tabs
|
||||
'use strict';
|
||||
|
||||
function settingsPage() {
|
||||
return {
|
||||
tab: 'providers',
|
||||
sysInfo: {},
|
||||
usageData: [],
|
||||
tools: [],
|
||||
config: {},
|
||||
providers: [],
|
||||
models: [],
|
||||
toolSearch: '',
|
||||
modelSearch: '',
|
||||
modelProviderFilter: '',
|
||||
modelTierFilter: '',
|
||||
showCustomModelForm: false,
|
||||
customModelId: '',
|
||||
customModelProvider: 'openrouter',
|
||||
customModelContext: 128000,
|
||||
customModelMaxOutput: 8192,
|
||||
customModelStatus: '',
|
||||
providerKeyInputs: {},
|
||||
providerUrlInputs: {},
|
||||
providerUrlSaving: {},
|
||||
providerTesting: {},
|
||||
providerTestResults: {},
|
||||
copilotOAuth: { polling: false, userCode: '', verificationUri: '', pollId: '', interval: 5 },
|
||||
loading: true,
|
||||
loadError: '',
|
||||
|
||||
// -- Dynamic config state --
|
||||
configSchema: null,
|
||||
configValues: {},
|
||||
configDirty: {},
|
||||
configSaving: {},
|
||||
|
||||
// -- Security state --
|
||||
securityData: null,
|
||||
secLoading: false,
|
||||
verifyingChain: false,
|
||||
chainResult: null,
|
||||
|
||||
coreFeatures: [
|
||||
{
|
||||
name: 'Path Traversal Prevention', key: 'path_traversal',
|
||||
description: 'Blocks directory escape attacks (../) in all file operations. Two-phase validation: syntactic rejection of path components, then canonicalization to normalize symlinks.',
|
||||
threat: 'Directory escape, privilege escalation via symlinks',
|
||||
impl: 'host_functions.rs — safe_resolve_path() + safe_resolve_parent()'
|
||||
},
|
||||
{
|
||||
name: 'SSRF Protection', key: 'ssrf_protection',
|
||||
description: 'Blocks outbound requests to private IPs, localhost, and cloud metadata endpoints (AWS/GCP/Azure). Validates DNS resolution results to defeat rebinding attacks.',
|
||||
threat: 'Internal network reconnaissance, cloud credential theft',
|
||||
impl: 'host_functions.rs — is_ssrf_target() + is_private_ip()'
|
||||
},
|
||||
{
|
||||
name: 'Capability-Based Access Control', key: 'capability_system',
|
||||
description: 'Deny-by-default permission system. Every agent operation (file I/O, network, shell, memory, spawn) requires an explicit capability grant in the manifest.',
|
||||
threat: 'Unauthorized resource access, sandbox escape',
|
||||
impl: 'host_functions.rs — check_capability() on every host function'
|
||||
},
|
||||
{
|
||||
name: 'Privilege Escalation Prevention', key: 'privilege_escalation_prevention',
|
||||
description: 'When a parent agent spawns a child, the kernel enforces child capabilities are a subset of parent capabilities. No agent can grant rights it does not have.',
|
||||
threat: 'Capability escalation through agent spawning chains',
|
||||
impl: 'kernel_handle.rs — spawn_agent_checked()'
|
||||
},
|
||||
{
|
||||
name: 'Subprocess Environment Isolation', key: 'subprocess_isolation',
|
||||
description: 'Child processes (shell tools) inherit only a safe allow-list of environment variables. API keys, database passwords, and secrets are never leaked to subprocesses.',
|
||||
threat: 'Secret exfiltration via child process environment',
|
||||
impl: 'subprocess_sandbox.rs — env_clear() + SAFE_ENV_VARS'
|
||||
},
|
||||
{
|
||||
name: 'Security Headers', key: 'security_headers',
|
||||
description: 'Every HTTP response includes CSP, X-Frame-Options: DENY, X-Content-Type-Options: nosniff, Referrer-Policy, and X-XSS-Protection headers.',
|
||||
threat: 'XSS, clickjacking, MIME sniffing, content injection',
|
||||
impl: 'middleware.rs — security_headers()'
|
||||
},
|
||||
{
|
||||
name: 'Wire Protocol Authentication', key: 'wire_hmac_auth',
|
||||
description: 'Agent-to-agent OFP connections use HMAC-SHA256 mutual authentication with nonce-based handshake and constant-time signature comparison (subtle crate).',
|
||||
threat: 'Man-in-the-middle attacks on mesh network',
|
||||
impl: 'peer.rs — hmac_sign() + hmac_verify()'
|
||||
},
|
||||
{
|
||||
name: 'Request ID Tracking', key: 'request_id_tracking',
|
||||
description: 'Every API request receives a unique UUID (x-request-id header) and is logged with method, path, status code, and latency for full traceability.',
|
||||
threat: 'Untraceable actions, forensic blind spots',
|
||||
impl: 'middleware.rs — request_logging()'
|
||||
}
|
||||
],
|
||||
|
||||
configurableFeatures: [
|
||||
{
|
||||
name: 'API Rate Limiting', key: 'rate_limiter',
|
||||
description: 'GCRA (Generic Cell Rate Algorithm) with cost-aware tokens. Different endpoints cost different amounts — spawning an agent costs 50 tokens, health check costs 1.',
|
||||
configHint: 'Hard-coded: 500 tokens/minute per IP. Edit rate_limiter.rs to tune.',
|
||||
valueKey: 'rate_limiter'
|
||||
},
|
||||
{
|
||||
name: 'WebSocket Connection Limits', key: 'websocket_limits',
|
||||
description: 'Per-IP connection cap prevents connection exhaustion. Idle timeout closes abandoned connections. Message rate limiting prevents flooding.',
|
||||
configHint: 'Hard-coded: 5 connections/IP, 30min idle timeout, 64KB max message. Edit ws.rs to tune.',
|
||||
valueKey: 'websocket_limits'
|
||||
},
|
||||
{
|
||||
name: 'WASM Dual Metering', key: 'wasm_sandbox',
|
||||
description: 'WASM modules run with two independent resource limits: fuel metering (CPU instruction count) and epoch interruption (wall-clock timeout with watchdog thread).',
|
||||
configHint: 'Default: 1M fuel units, 30s timeout. Configurable per-agent via SandboxConfig.',
|
||||
valueKey: 'wasm_sandbox'
|
||||
},
|
||||
{
|
||||
name: 'Bearer Token Authentication', key: 'auth',
|
||||
description: 'All non-health endpoints require Authorization: Bearer header. When no API key is configured, all requests are restricted to localhost only.',
|
||||
configHint: 'Set api_key in ~/.openfang/config.toml for remote access. Empty = localhost only.',
|
||||
valueKey: 'auth'
|
||||
}
|
||||
],
|
||||
|
||||
monitoringFeatures: [
|
||||
{
|
||||
name: 'Merkle Audit Trail', key: 'audit_trail',
|
||||
description: 'Every security-critical action is appended to an immutable, tamper-evident log. Each entry is cryptographically linked to the previous via SHA-256 hash chain.',
|
||||
configHint: 'Always active. Verify chain integrity from the Audit Log page.',
|
||||
valueKey: 'audit_trail'
|
||||
},
|
||||
{
|
||||
name: 'Information Flow Taint Tracking', key: 'taint_tracking',
|
||||
description: 'Labels data by provenance (ExternalNetwork, UserInput, PII, Secret, UntrustedAgent) and blocks unsafe flows: external data cannot reach shell_exec, secrets cannot reach network.',
|
||||
configHint: 'Always active. Prevents data flow attacks automatically.',
|
||||
valueKey: 'taint_tracking'
|
||||
},
|
||||
{
|
||||
name: 'Ed25519 Manifest Signing', key: 'manifest_signing',
|
||||
description: 'Agent manifests can be cryptographically signed with Ed25519. Verify manifest integrity before loading to prevent supply chain tampering.',
|
||||
configHint: 'Available for use. Sign manifests with ed25519-dalek for verification.',
|
||||
valueKey: 'manifest_signing'
|
||||
}
|
||||
],
|
||||
|
||||
// -- Peers state --
|
||||
peers: [],
|
||||
peersLoading: false,
|
||||
peersLoadError: '',
|
||||
_peerPollTimer: null,
|
||||
|
||||
// -- Migration state --
|
||||
migStep: 'intro',
|
||||
detecting: false,
|
||||
scanning: false,
|
||||
migrating: false,
|
||||
sourcePath: '',
|
||||
targetPath: '',
|
||||
scanResult: null,
|
||||
migResult: null,
|
||||
|
||||
// -- Settings load --
|
||||
async loadSettings() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
await Promise.all([
|
||||
this.loadSysInfo(),
|
||||
this.loadUsage(),
|
||||
this.loadTools(),
|
||||
this.loadConfig(),
|
||||
this.loadProviders(),
|
||||
this.loadModels()
|
||||
]);
|
||||
} catch(e) {
|
||||
this.loadError = e.message || 'Could not load settings.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadData() { return this.loadSettings(); },
|
||||
|
||||
async loadSysInfo() {
|
||||
try {
|
||||
var ver = await OpenFangAPI.get('/api/version');
|
||||
var status = await OpenFangAPI.get('/api/status');
|
||||
this.sysInfo = {
|
||||
version: ver.version || '-',
|
||||
platform: ver.platform || '-',
|
||||
arch: ver.arch || '-',
|
||||
uptime_seconds: status.uptime_seconds || 0,
|
||||
agent_count: status.agent_count || 0,
|
||||
default_provider: status.default_provider || '-',
|
||||
default_model: status.default_model || '-'
|
||||
};
|
||||
} catch(e) { throw e; }
|
||||
},
|
||||
|
||||
async loadUsage() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/usage');
|
||||
this.usageData = data.agents || [];
|
||||
} catch(e) { this.usageData = []; }
|
||||
},
|
||||
|
||||
async loadTools() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/tools');
|
||||
this.tools = data.tools || [];
|
||||
} catch(e) { this.tools = []; }
|
||||
},
|
||||
|
||||
async loadConfig() {
|
||||
try {
|
||||
this.config = await OpenFangAPI.get('/api/config');
|
||||
} catch(e) { this.config = {}; }
|
||||
},
|
||||
|
||||
async loadProviders() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/providers');
|
||||
this.providers = data.providers || [];
|
||||
for (var i = 0; i < this.providers.length; i++) {
|
||||
var p = this.providers[i];
|
||||
if (p.is_local && p.base_url && !this.providerUrlInputs[p.id]) {
|
||||
this.providerUrlInputs[p.id] = p.base_url;
|
||||
}
|
||||
}
|
||||
} catch(e) { this.providers = []; }
|
||||
},
|
||||
|
||||
async loadModels() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/models');
|
||||
this.models = data.models || [];
|
||||
} catch(e) { this.models = []; }
|
||||
},
|
||||
|
||||
async addCustomModel() {
|
||||
var id = this.customModelId.trim();
|
||||
if (!id) return;
|
||||
this.customModelStatus = 'Adding...';
|
||||
try {
|
||||
await OpenFangAPI.post('/api/models/custom', {
|
||||
id: id,
|
||||
provider: this.customModelProvider || 'openrouter',
|
||||
context_window: this.customModelContext || 128000,
|
||||
max_output_tokens: this.customModelMaxOutput || 8192,
|
||||
});
|
||||
this.customModelStatus = 'Added!';
|
||||
this.customModelId = '';
|
||||
this.showCustomModelForm = false;
|
||||
await this.loadModels();
|
||||
} catch(e) {
|
||||
this.customModelStatus = 'Error: ' + (e.message || 'Failed');
|
||||
}
|
||||
},
|
||||
|
||||
async loadConfigSchema() {
|
||||
try {
|
||||
var results = await Promise.all([
|
||||
OpenFangAPI.get('/api/config/schema').catch(function() { return {}; }),
|
||||
OpenFangAPI.get('/api/config')
|
||||
]);
|
||||
this.configSchema = results[0].sections || null;
|
||||
this.configValues = results[1] || {};
|
||||
} catch(e) { /* silent */ }
|
||||
},
|
||||
|
||||
isConfigDirty(section, field) {
|
||||
return this.configDirty[section + '.' + field] === true;
|
||||
},
|
||||
|
||||
markConfigDirty(section, field) {
|
||||
this.configDirty[section + '.' + field] = true;
|
||||
},
|
||||
|
||||
async saveConfigField(section, field, value) {
|
||||
var key = section + '.' + field;
|
||||
this.configSaving[key] = true;
|
||||
try {
|
||||
await OpenFangAPI.post('/api/config/set', { path: key, value: value });
|
||||
this.configDirty[key] = false;
|
||||
OpenFangToast.success('Saved ' + key);
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to save: ' + e.message);
|
||||
}
|
||||
this.configSaving[key] = false;
|
||||
},
|
||||
|
||||
get filteredTools() {
|
||||
var q = this.toolSearch.toLowerCase().trim();
|
||||
if (!q) return this.tools;
|
||||
return this.tools.filter(function(t) {
|
||||
return t.name.toLowerCase().indexOf(q) !== -1 ||
|
||||
(t.description || '').toLowerCase().indexOf(q) !== -1;
|
||||
});
|
||||
},
|
||||
|
||||
get filteredModels() {
|
||||
var self = this;
|
||||
return this.models.filter(function(m) {
|
||||
if (self.modelProviderFilter && m.provider !== self.modelProviderFilter) return false;
|
||||
if (self.modelTierFilter && m.tier !== self.modelTierFilter) return false;
|
||||
if (self.modelSearch) {
|
||||
var q = self.modelSearch.toLowerCase();
|
||||
if (m.id.toLowerCase().indexOf(q) === -1 &&
|
||||
(m.display_name || '').toLowerCase().indexOf(q) === -1) return false;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
},
|
||||
|
||||
get uniqueProviderNames() {
|
||||
var seen = {};
|
||||
this.models.forEach(function(m) { seen[m.provider] = true; });
|
||||
return Object.keys(seen).sort();
|
||||
},
|
||||
|
||||
get uniqueTiers() {
|
||||
var seen = {};
|
||||
this.models.forEach(function(m) { if (m.tier) seen[m.tier] = true; });
|
||||
return Object.keys(seen).sort();
|
||||
},
|
||||
|
||||
providerAuthClass(p) {
|
||||
if (p.auth_status === 'configured') return 'auth-configured';
|
||||
if (p.auth_status === 'not_set' || p.auth_status === 'missing') return 'auth-not-set';
|
||||
return 'auth-no-key';
|
||||
},
|
||||
|
||||
providerAuthText(p) {
|
||||
if (p.auth_status === 'configured') return 'Configured';
|
||||
if (p.auth_status === 'not_set' || p.auth_status === 'missing') return 'Not Set';
|
||||
return 'No Key Needed';
|
||||
},
|
||||
|
||||
providerCardClass(p) {
|
||||
if (p.auth_status === 'configured') return 'configured';
|
||||
if (p.auth_status === 'not_set' || p.auth_status === 'missing') return 'not-configured';
|
||||
return 'no-key';
|
||||
},
|
||||
|
||||
tierBadgeClass(tier) {
|
||||
if (!tier) return '';
|
||||
var t = tier.toLowerCase();
|
||||
if (t === 'frontier') return 'tier-frontier';
|
||||
if (t === 'smart') return 'tier-smart';
|
||||
if (t === 'balanced') return 'tier-balanced';
|
||||
if (t === 'fast') return 'tier-fast';
|
||||
return '';
|
||||
},
|
||||
|
||||
formatCost(cost) {
|
||||
if (!cost && cost !== 0) return '-';
|
||||
return '$' + cost.toFixed(4);
|
||||
},
|
||||
|
||||
formatContext(ctx) {
|
||||
if (!ctx) return '-';
|
||||
if (ctx >= 1000000) return (ctx / 1000000).toFixed(1) + 'M';
|
||||
if (ctx >= 1000) return Math.round(ctx / 1000) + 'K';
|
||||
return String(ctx);
|
||||
},
|
||||
|
||||
formatUptime(secs) {
|
||||
if (!secs) return '-';
|
||||
var h = Math.floor(secs / 3600);
|
||||
var m = Math.floor((secs % 3600) / 60);
|
||||
var s = secs % 60;
|
||||
if (h > 0) return h + 'h ' + m + 'm';
|
||||
if (m > 0) return m + 'm ' + s + 's';
|
||||
return s + 's';
|
||||
},
|
||||
|
||||
async saveProviderKey(provider) {
|
||||
var key = this.providerKeyInputs[provider.id];
|
||||
if (!key || !key.trim()) { OpenFangToast.error('Please enter an API key'); return; }
|
||||
try {
|
||||
await OpenFangAPI.post('/api/providers/' + encodeURIComponent(provider.id) + '/key', { key: key.trim() });
|
||||
OpenFangToast.success('API key saved for ' + provider.display_name);
|
||||
this.providerKeyInputs[provider.id] = '';
|
||||
await this.loadProviders();
|
||||
await this.loadModels();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to save key: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
async removeProviderKey(provider) {
|
||||
try {
|
||||
await OpenFangAPI.del('/api/providers/' + encodeURIComponent(provider.id) + '/key');
|
||||
OpenFangToast.success('API key removed for ' + provider.display_name);
|
||||
await this.loadProviders();
|
||||
await this.loadModels();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to remove key: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
async startCopilotOAuth() {
|
||||
this.copilotOAuth.polling = true;
|
||||
this.copilotOAuth.userCode = '';
|
||||
try {
|
||||
var resp = await OpenFangAPI.post('/api/providers/github-copilot/oauth/start', {});
|
||||
this.copilotOAuth.userCode = resp.user_code;
|
||||
this.copilotOAuth.verificationUri = resp.verification_uri;
|
||||
this.copilotOAuth.pollId = resp.poll_id;
|
||||
this.copilotOAuth.interval = resp.interval || 5;
|
||||
window.open(resp.verification_uri, '_blank');
|
||||
this.pollCopilotOAuth();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to start Copilot login: ' + e.message);
|
||||
this.copilotOAuth.polling = false;
|
||||
}
|
||||
},
|
||||
|
||||
pollCopilotOAuth() {
|
||||
var self = this;
|
||||
setTimeout(async function() {
|
||||
if (!self.copilotOAuth.pollId) return;
|
||||
try {
|
||||
var resp = await OpenFangAPI.get('/api/providers/github-copilot/oauth/poll/' + self.copilotOAuth.pollId);
|
||||
if (resp.status === 'complete') {
|
||||
OpenFangToast.success('GitHub Copilot authenticated successfully!');
|
||||
self.copilotOAuth = { polling: false, userCode: '', verificationUri: '', pollId: '', interval: 5 };
|
||||
await self.loadProviders();
|
||||
await self.loadModels();
|
||||
} else if (resp.status === 'pending') {
|
||||
if (resp.interval) self.copilotOAuth.interval = resp.interval;
|
||||
self.pollCopilotOAuth();
|
||||
} else if (resp.status === 'expired') {
|
||||
OpenFangToast.error('Device code expired. Please try again.');
|
||||
self.copilotOAuth = { polling: false, userCode: '', verificationUri: '', pollId: '', interval: 5 };
|
||||
} else if (resp.status === 'denied') {
|
||||
OpenFangToast.error('Access denied by user.');
|
||||
self.copilotOAuth = { polling: false, userCode: '', verificationUri: '', pollId: '', interval: 5 };
|
||||
} else {
|
||||
OpenFangToast.error('OAuth error: ' + (resp.error || resp.status));
|
||||
self.copilotOAuth = { polling: false, userCode: '', verificationUri: '', pollId: '', interval: 5 };
|
||||
}
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Poll error: ' + e.message);
|
||||
self.copilotOAuth = { polling: false, userCode: '', verificationUri: '', pollId: '', interval: 5 };
|
||||
}
|
||||
}, self.copilotOAuth.interval * 1000);
|
||||
},
|
||||
|
||||
async testProvider(provider) {
|
||||
this.providerTesting[provider.id] = true;
|
||||
this.providerTestResults[provider.id] = null;
|
||||
try {
|
||||
var result = await OpenFangAPI.post('/api/providers/' + encodeURIComponent(provider.id) + '/test', {});
|
||||
this.providerTestResults[provider.id] = result;
|
||||
if (result.status === 'ok') {
|
||||
OpenFangToast.success(provider.display_name + ' connected (' + (result.latency_ms || '?') + 'ms)');
|
||||
} else {
|
||||
OpenFangToast.error(provider.display_name + ': ' + (result.error || 'Connection failed'));
|
||||
}
|
||||
} catch(e) {
|
||||
this.providerTestResults[provider.id] = { status: 'error', error: e.message };
|
||||
OpenFangToast.error('Test failed: ' + e.message);
|
||||
}
|
||||
this.providerTesting[provider.id] = false;
|
||||
},
|
||||
|
||||
async saveProviderUrl(provider) {
|
||||
var url = this.providerUrlInputs[provider.id];
|
||||
if (!url || !url.trim()) { OpenFangToast.error('Please enter a base URL'); return; }
|
||||
url = url.trim();
|
||||
if (url.indexOf('http://') !== 0 && url.indexOf('https://') !== 0) {
|
||||
OpenFangToast.error('URL must start with http:// or https://'); return;
|
||||
}
|
||||
this.providerUrlSaving[provider.id] = true;
|
||||
try {
|
||||
var result = await OpenFangAPI.put('/api/providers/' + encodeURIComponent(provider.id) + '/url', { base_url: url });
|
||||
if (result.reachable) {
|
||||
OpenFangToast.success(provider.display_name + ' URL saved — reachable (' + (result.latency_ms || '?') + 'ms)');
|
||||
} else {
|
||||
OpenFangToast.warning(provider.display_name + ' URL saved but not reachable');
|
||||
}
|
||||
await this.loadProviders();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to save URL: ' + e.message);
|
||||
}
|
||||
this.providerUrlSaving[provider.id] = false;
|
||||
},
|
||||
|
||||
// -- Security methods --
|
||||
async loadSecurity() {
|
||||
this.secLoading = true;
|
||||
try {
|
||||
this.securityData = await OpenFangAPI.get('/api/security');
|
||||
} catch(e) {
|
||||
this.securityData = null;
|
||||
}
|
||||
this.secLoading = false;
|
||||
},
|
||||
|
||||
isActive(key) {
|
||||
if (!this.securityData) return true;
|
||||
var core = this.securityData.core_protections || {};
|
||||
if (core[key] !== undefined) return core[key];
|
||||
return true;
|
||||
},
|
||||
|
||||
getConfigValue(key) {
|
||||
if (!this.securityData) return null;
|
||||
var cfg = this.securityData.configurable || {};
|
||||
return cfg[key] || null;
|
||||
},
|
||||
|
||||
getMonitoringValue(key) {
|
||||
if (!this.securityData) return null;
|
||||
var mon = this.securityData.monitoring || {};
|
||||
return mon[key] || null;
|
||||
},
|
||||
|
||||
formatConfigValue(feature) {
|
||||
var val = this.getConfigValue(feature.valueKey);
|
||||
if (!val) return feature.configHint;
|
||||
switch (feature.valueKey) {
|
||||
case 'rate_limiter':
|
||||
return 'Algorithm: ' + (val.algorithm || 'GCRA') + ' | ' + (val.tokens_per_minute || 500) + ' tokens/min per IP';
|
||||
case 'websocket_limits':
|
||||
return 'Max ' + (val.max_per_ip || 5) + ' conn/IP | ' + Math.round((val.idle_timeout_secs || 1800) / 60) + 'min idle timeout | ' + Math.round((val.max_message_size || 65536) / 1024) + 'KB max msg';
|
||||
case 'wasm_sandbox':
|
||||
return 'Fuel: ' + (val.fuel_metering ? 'ON' : 'OFF') + ' | Epoch: ' + (val.epoch_interruption ? 'ON' : 'OFF') + ' | Timeout: ' + (val.default_timeout_secs || 30) + 's';
|
||||
case 'auth':
|
||||
return 'Mode: ' + (val.mode || 'unknown') + (val.api_key_set ? ' (key configured)' : ' (no key set)');
|
||||
default:
|
||||
return feature.configHint;
|
||||
}
|
||||
},
|
||||
|
||||
formatMonitoringValue(feature) {
|
||||
var val = this.getMonitoringValue(feature.valueKey);
|
||||
if (!val) return feature.configHint;
|
||||
switch (feature.valueKey) {
|
||||
case 'audit_trail':
|
||||
return (val.enabled ? 'Active' : 'Disabled') + ' | ' + (val.algorithm || 'SHA-256') + ' | ' + (val.entry_count || 0) + ' entries logged';
|
||||
case 'taint_tracking':
|
||||
var labels = val.tracked_labels || [];
|
||||
return (val.enabled ? 'Active' : 'Disabled') + ' | Tracking: ' + labels.join(', ');
|
||||
case 'manifest_signing':
|
||||
return 'Algorithm: ' + (val.algorithm || 'Ed25519') + ' | ' + (val.available ? 'Available' : 'Not available');
|
||||
default:
|
||||
return feature.configHint;
|
||||
}
|
||||
},
|
||||
|
||||
async verifyAuditChain() {
|
||||
this.verifyingChain = true;
|
||||
this.chainResult = null;
|
||||
try {
|
||||
var res = await OpenFangAPI.get('/api/audit/verify');
|
||||
this.chainResult = res;
|
||||
} catch(e) {
|
||||
this.chainResult = { valid: false, error: e.message };
|
||||
}
|
||||
this.verifyingChain = false;
|
||||
},
|
||||
|
||||
// -- Peers methods --
|
||||
async loadPeers() {
|
||||
this.peersLoading = true;
|
||||
this.peersLoadError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/peers');
|
||||
this.peers = (data.peers || []).map(function(p) {
|
||||
return {
|
||||
node_id: p.node_id,
|
||||
node_name: p.node_name,
|
||||
address: p.address,
|
||||
state: p.state,
|
||||
agent_count: (p.agents || []).length,
|
||||
protocol_version: p.protocol_version || 1
|
||||
};
|
||||
});
|
||||
} catch(e) {
|
||||
this.peers = [];
|
||||
this.peersLoadError = e.message || 'Could not load peers.';
|
||||
}
|
||||
this.peersLoading = false;
|
||||
},
|
||||
|
||||
startPeerPolling() {
|
||||
var self = this;
|
||||
this.stopPeerPolling();
|
||||
this._peerPollTimer = setInterval(async function() {
|
||||
if (self.tab !== 'network') { self.stopPeerPolling(); return; }
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/peers');
|
||||
self.peers = (data.peers || []).map(function(p) {
|
||||
return {
|
||||
node_id: p.node_id,
|
||||
node_name: p.node_name,
|
||||
address: p.address,
|
||||
state: p.state,
|
||||
agent_count: (p.agents || []).length,
|
||||
protocol_version: p.protocol_version || 1
|
||||
};
|
||||
});
|
||||
} catch(e) { /* silent */ }
|
||||
}, 15000);
|
||||
},
|
||||
|
||||
stopPeerPolling() {
|
||||
if (this._peerPollTimer) { clearInterval(this._peerPollTimer); this._peerPollTimer = null; }
|
||||
},
|
||||
|
||||
// -- Migration methods --
|
||||
async autoDetect() {
|
||||
this.detecting = true;
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/migrate/detect');
|
||||
if (data.detected && data.scan) {
|
||||
this.sourcePath = data.path;
|
||||
this.scanResult = data.scan;
|
||||
this.migStep = 'preview';
|
||||
} else {
|
||||
this.migStep = 'not_found';
|
||||
}
|
||||
} catch(e) {
|
||||
this.migStep = 'not_found';
|
||||
}
|
||||
this.detecting = false;
|
||||
},
|
||||
|
||||
async scanPath() {
|
||||
if (!this.sourcePath) return;
|
||||
this.scanning = true;
|
||||
try {
|
||||
var data = await OpenFangAPI.post('/api/migrate/scan', { path: this.sourcePath });
|
||||
if (data.error) {
|
||||
OpenFangToast.error('Scan error: ' + data.error);
|
||||
this.scanning = false;
|
||||
return;
|
||||
}
|
||||
this.scanResult = data;
|
||||
this.migStep = 'preview';
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Scan failed: ' + e.message);
|
||||
}
|
||||
this.scanning = false;
|
||||
},
|
||||
|
||||
async runMigration(dryRun) {
|
||||
this.migrating = true;
|
||||
try {
|
||||
var target = this.targetPath;
|
||||
if (!target) target = '';
|
||||
var data = await OpenFangAPI.post('/api/migrate', {
|
||||
source: 'openclaw',
|
||||
source_dir: this.sourcePath || (this.scanResult ? this.scanResult.path : ''),
|
||||
target_dir: target,
|
||||
dry_run: dryRun
|
||||
});
|
||||
this.migResult = data;
|
||||
this.migStep = 'result';
|
||||
} catch(e) {
|
||||
this.migResult = { status: 'failed', error: e.message };
|
||||
this.migStep = 'result';
|
||||
}
|
||||
this.migrating = false;
|
||||
},
|
||||
|
||||
destroy() {
|
||||
this.stopPeerPolling();
|
||||
}
|
||||
};
|
||||
}
|
||||
299
crates/openfang-api/static/js/pages/skills.js
Normal file
299
crates/openfang-api/static/js/pages/skills.js
Normal file
@@ -0,0 +1,299 @@
|
||||
// OpenFang Skills Page — OpenClaw/ClawHub ecosystem + local skills + MCP servers
|
||||
'use strict';
|
||||
|
||||
function skillsPage() {
|
||||
return {
|
||||
tab: 'installed',
|
||||
skills: [],
|
||||
loading: true,
|
||||
loadError: '',
|
||||
|
||||
// ClawHub state
|
||||
clawhubSearch: '',
|
||||
clawhubResults: [],
|
||||
clawhubBrowseResults: [],
|
||||
clawhubLoading: false,
|
||||
clawhubError: '',
|
||||
clawhubSort: 'trending',
|
||||
clawhubNextCursor: null,
|
||||
installingSlug: null,
|
||||
installResult: null,
|
||||
_searchTimer: null,
|
||||
|
||||
// Skill detail modal
|
||||
skillDetail: null,
|
||||
detailLoading: false,
|
||||
|
||||
// MCP servers
|
||||
mcpServers: [],
|
||||
mcpLoading: false,
|
||||
|
||||
// Category definitions from the OpenClaw ecosystem
|
||||
categories: [
|
||||
{ id: 'coding', name: 'Coding & IDEs' },
|
||||
{ id: 'git', name: 'Git & GitHub' },
|
||||
{ id: 'web', name: 'Web & Frontend' },
|
||||
{ id: 'devops', name: 'DevOps & Cloud' },
|
||||
{ id: 'browser', name: 'Browser & Automation' },
|
||||
{ id: 'search', name: 'Search & Research' },
|
||||
{ id: 'ai', name: 'AI & LLMs' },
|
||||
{ id: 'data', name: 'Data & Analytics' },
|
||||
{ id: 'productivity', name: 'Productivity' },
|
||||
{ id: 'communication', name: 'Communication' },
|
||||
{ id: 'media', name: 'Media & Streaming' },
|
||||
{ id: 'notes', name: 'Notes & PKM' },
|
||||
{ id: 'security', name: 'Security' },
|
||||
{ id: 'cli', name: 'CLI Utilities' },
|
||||
{ id: 'marketing', name: 'Marketing & Sales' },
|
||||
{ id: 'finance', name: 'Finance' },
|
||||
{ id: 'smart-home', name: 'Smart Home & IoT' },
|
||||
{ id: 'docs', name: 'PDF & Documents' },
|
||||
],
|
||||
|
||||
runtimeBadge: function(rt) {
|
||||
var r = (rt || '').toLowerCase();
|
||||
if (r === 'python' || r === 'py') return { text: 'PY', cls: 'runtime-badge-py' };
|
||||
if (r === 'node' || r === 'nodejs' || r === 'js' || r === 'javascript') return { text: 'JS', cls: 'runtime-badge-js' };
|
||||
if (r === 'wasm' || r === 'webassembly') return { text: 'WASM', cls: 'runtime-badge-wasm' };
|
||||
if (r === 'prompt_only' || r === 'prompt' || r === 'promptonly') return { text: 'PROMPT', cls: 'runtime-badge-prompt' };
|
||||
return { text: r.toUpperCase().substring(0, 4), cls: 'runtime-badge-prompt' };
|
||||
},
|
||||
|
||||
sourceBadge: function(source) {
|
||||
if (!source) return { text: 'Local', cls: 'badge-dim' };
|
||||
switch (source.type) {
|
||||
case 'clawhub': return { text: 'ClawHub', cls: 'badge-info' };
|
||||
case 'openclaw': return { text: 'OpenClaw', cls: 'badge-info' };
|
||||
case 'bundled': return { text: 'Built-in', cls: 'badge-success' };
|
||||
default: return { text: 'Local', cls: 'badge-dim' };
|
||||
}
|
||||
},
|
||||
|
||||
formatDownloads: function(n) {
|
||||
if (!n) return '0';
|
||||
if (n >= 1000000) return (n / 1000000).toFixed(1) + 'M';
|
||||
if (n >= 1000) return (n / 1000).toFixed(1) + 'K';
|
||||
return n.toString();
|
||||
},
|
||||
|
||||
async loadSkills() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/skills');
|
||||
this.skills = (data.skills || []).map(function(s) {
|
||||
return {
|
||||
name: s.name,
|
||||
description: s.description || '',
|
||||
version: s.version || '',
|
||||
author: s.author || '',
|
||||
runtime: s.runtime || 'unknown',
|
||||
tools_count: s.tools_count || 0,
|
||||
tags: s.tags || [],
|
||||
enabled: s.enabled !== false,
|
||||
source: s.source || { type: 'local' },
|
||||
has_prompt_context: !!s.has_prompt_context
|
||||
};
|
||||
});
|
||||
} catch(e) {
|
||||
this.skills = [];
|
||||
this.loadError = e.message || 'Could not load skills.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadData() {
|
||||
await this.loadSkills();
|
||||
},
|
||||
|
||||
// Debounced search — fires 350ms after user stops typing
|
||||
onSearchInput() {
|
||||
if (this._searchTimer) clearTimeout(this._searchTimer);
|
||||
var q = this.clawhubSearch.trim();
|
||||
if (!q) {
|
||||
this.clawhubResults = [];
|
||||
this.clawhubError = '';
|
||||
return;
|
||||
}
|
||||
var self = this;
|
||||
this._searchTimer = setTimeout(function() { self.searchClawHub(); }, 350);
|
||||
},
|
||||
|
||||
// ClawHub search
|
||||
async searchClawHub() {
|
||||
if (!this.clawhubSearch.trim()) {
|
||||
this.clawhubResults = [];
|
||||
return;
|
||||
}
|
||||
this.clawhubLoading = true;
|
||||
this.clawhubError = '';
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/clawhub/search?q=' + encodeURIComponent(this.clawhubSearch.trim()) + '&limit=20');
|
||||
this.clawhubResults = data.items || [];
|
||||
if (data.error) this.clawhubError = data.error;
|
||||
} catch(e) {
|
||||
this.clawhubResults = [];
|
||||
this.clawhubError = e.message || 'Search failed';
|
||||
}
|
||||
this.clawhubLoading = false;
|
||||
},
|
||||
|
||||
// Clear search and go back to browse
|
||||
clearSearch() {
|
||||
this.clawhubSearch = '';
|
||||
this.clawhubResults = [];
|
||||
this.clawhubError = '';
|
||||
if (this._searchTimer) clearTimeout(this._searchTimer);
|
||||
},
|
||||
|
||||
// ClawHub browse by sort
|
||||
async browseClawHub(sort) {
|
||||
this.clawhubSort = sort || 'trending';
|
||||
this.clawhubLoading = true;
|
||||
this.clawhubError = '';
|
||||
this.clawhubNextCursor = null;
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/clawhub/browse?sort=' + this.clawhubSort + '&limit=20');
|
||||
this.clawhubBrowseResults = data.items || [];
|
||||
this.clawhubNextCursor = data.next_cursor || null;
|
||||
if (data.error) this.clawhubError = data.error;
|
||||
} catch(e) {
|
||||
this.clawhubBrowseResults = [];
|
||||
this.clawhubError = e.message || 'Browse failed';
|
||||
}
|
||||
this.clawhubLoading = false;
|
||||
},
|
||||
|
||||
// ClawHub load more results
|
||||
async loadMoreClawHub() {
|
||||
if (!this.clawhubNextCursor || this.clawhubLoading) return;
|
||||
this.clawhubLoading = true;
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/clawhub/browse?sort=' + this.clawhubSort + '&limit=20&cursor=' + encodeURIComponent(this.clawhubNextCursor));
|
||||
this.clawhubBrowseResults = this.clawhubBrowseResults.concat(data.items || []);
|
||||
this.clawhubNextCursor = data.next_cursor || null;
|
||||
} catch(e) {
|
||||
// silently fail on load more
|
||||
}
|
||||
this.clawhubLoading = false;
|
||||
},
|
||||
|
||||
// Show skill detail
|
||||
async showSkillDetail(slug) {
|
||||
this.detailLoading = true;
|
||||
this.skillDetail = null;
|
||||
this.installResult = null;
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/clawhub/skill/' + encodeURIComponent(slug));
|
||||
this.skillDetail = data;
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to load skill details');
|
||||
}
|
||||
this.detailLoading = false;
|
||||
},
|
||||
|
||||
closeDetail() {
|
||||
this.skillDetail = null;
|
||||
this.installResult = null;
|
||||
},
|
||||
|
||||
// Install from ClawHub
|
||||
async installFromClawHub(slug) {
|
||||
this.installingSlug = slug;
|
||||
this.installResult = null;
|
||||
try {
|
||||
var data = await OpenFangAPI.post('/api/clawhub/install', { slug: slug });
|
||||
this.installResult = data;
|
||||
if (data.warnings && data.warnings.length > 0) {
|
||||
OpenFangToast.success('Skill "' + data.name + '" installed with ' + data.warnings.length + ' warning(s)');
|
||||
} else {
|
||||
OpenFangToast.success('Skill "' + data.name + '" installed successfully');
|
||||
}
|
||||
// Update installed state in detail modal if open
|
||||
if (this.skillDetail && this.skillDetail.slug === slug) {
|
||||
this.skillDetail.installed = true;
|
||||
}
|
||||
await this.loadSkills();
|
||||
} catch(e) {
|
||||
var msg = e.message || 'Install failed';
|
||||
if (msg.includes('already_installed')) {
|
||||
OpenFangToast.error('Skill is already installed');
|
||||
} else if (msg.includes('SecurityBlocked')) {
|
||||
OpenFangToast.error('Skill blocked by security scan');
|
||||
} else {
|
||||
OpenFangToast.error('Install failed: ' + msg);
|
||||
}
|
||||
}
|
||||
this.installingSlug = null;
|
||||
},
|
||||
|
||||
// Uninstall
|
||||
uninstallSkill: function(name) {
|
||||
var self = this;
|
||||
OpenFangToast.confirm('Uninstall Skill', 'Uninstall skill "' + name + '"? This cannot be undone.', async function() {
|
||||
try {
|
||||
await OpenFangAPI.post('/api/skills/uninstall', { name: name });
|
||||
OpenFangToast.success('Skill "' + name + '" uninstalled');
|
||||
await self.loadSkills();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to uninstall skill: ' + e.message);
|
||||
}
|
||||
});
|
||||
},
|
||||
|
||||
// Create prompt-only skill
|
||||
async createDemoSkill(skill) {
|
||||
try {
|
||||
await OpenFangAPI.post('/api/skills/create', {
|
||||
name: skill.name,
|
||||
description: skill.description,
|
||||
runtime: 'prompt_only',
|
||||
prompt_context: skill.prompt_context || skill.description
|
||||
});
|
||||
OpenFangToast.success('Skill "' + skill.name + '" created');
|
||||
this.tab = 'installed';
|
||||
await this.loadSkills();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to create skill: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
// Load MCP servers
|
||||
async loadMcpServers() {
|
||||
this.mcpLoading = true;
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/mcp/servers');
|
||||
this.mcpServers = data;
|
||||
} catch(e) {
|
||||
this.mcpServers = { configured: [], connected: [], total_configured: 0, total_connected: 0 };
|
||||
}
|
||||
this.mcpLoading = false;
|
||||
},
|
||||
|
||||
// Category search on ClawHub
|
||||
searchCategory: function(cat) {
|
||||
this.clawhubSearch = cat.name;
|
||||
this.searchClawHub();
|
||||
},
|
||||
|
||||
// Quick start skills (prompt-only, zero deps)
|
||||
quickStartSkills: [
|
||||
{ name: 'code-review-guide', description: 'Adds code review best practices and checklist to agent context.', prompt_context: 'You are an expert code reviewer. When reviewing code:\n1. Check for bugs and logic errors\n2. Evaluate code style and readability\n3. Look for security vulnerabilities\n4. Suggest performance improvements\n5. Verify error handling\n6. Check test coverage' },
|
||||
{ name: 'writing-style', description: 'Configurable writing style guide for content generation.', prompt_context: 'Follow these writing guidelines:\n- Use clear, concise language\n- Prefer active voice over passive voice\n- Keep paragraphs short (3-4 sentences)\n- Use bullet points for lists\n- Maintain consistent tone throughout' },
|
||||
{ name: 'api-design', description: 'REST API design patterns and conventions.', prompt_context: 'When designing REST APIs:\n- Use nouns for resources, not verbs\n- Use HTTP methods correctly (GET, POST, PUT, DELETE)\n- Return appropriate status codes\n- Use pagination for list endpoints\n- Version your API\n- Document all endpoints' },
|
||||
{ name: 'security-checklist', description: 'OWASP-aligned security review checklist.', prompt_context: 'Security review checklist (OWASP aligned):\n- Input validation on all user inputs\n- Output encoding to prevent XSS\n- Parameterized queries to prevent SQL injection\n- Authentication and session management\n- Access control checks\n- CSRF protection\n- Security headers\n- Error handling without information leakage' },
|
||||
],
|
||||
|
||||
// Check if skill is installed by slug
|
||||
isSkillInstalled: function(slug) {
|
||||
return this.skills.some(function(s) {
|
||||
return s.source && s.source.type === 'clawhub' && s.source.slug === slug;
|
||||
});
|
||||
},
|
||||
|
||||
// Check if skill is installed by name
|
||||
isSkillInstalledByName: function(name) {
|
||||
return this.skills.some(function(s) { return s.name === name; });
|
||||
},
|
||||
};
|
||||
}
|
||||
251
crates/openfang-api/static/js/pages/usage.js
Normal file
251
crates/openfang-api/static/js/pages/usage.js
Normal file
@@ -0,0 +1,251 @@
|
||||
// OpenFang Analytics Page — Full usage analytics with per-model and per-agent breakdowns
|
||||
// Includes Cost Dashboard with donut chart, bar chart, projections, and provider breakdown.
|
||||
'use strict';
|
||||
|
||||
function analyticsPage() {
|
||||
return {
|
||||
tab: 'summary',
|
||||
summary: {},
|
||||
byModel: [],
|
||||
byAgent: [],
|
||||
loading: true,
|
||||
loadError: '',
|
||||
|
||||
// Cost tab state
|
||||
dailyCosts: [],
|
||||
todayCost: 0,
|
||||
firstEventDate: null,
|
||||
|
||||
// Chart colors for providers (stable palette)
|
||||
_chartColors: [
|
||||
'#FF5C00', '#3B82F6', '#10B981', '#F59E0B', '#8B5CF6',
|
||||
'#EC4899', '#06B6D4', '#EF4444', '#84CC16', '#F97316',
|
||||
'#6366F1', '#14B8A6', '#E11D48', '#A855F7', '#22D3EE'
|
||||
],
|
||||
|
||||
async loadUsage() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
await Promise.all([
|
||||
this.loadSummary(),
|
||||
this.loadByModel(),
|
||||
this.loadByAgent(),
|
||||
this.loadDailyCosts()
|
||||
]);
|
||||
} catch(e) {
|
||||
this.loadError = e.message || 'Could not load usage data.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadData() { return this.loadUsage(); },
|
||||
|
||||
async loadSummary() {
|
||||
try {
|
||||
this.summary = await OpenFangAPI.get('/api/usage/summary');
|
||||
} catch(e) {
|
||||
this.summary = { total_input_tokens: 0, total_output_tokens: 0, total_cost_usd: 0, call_count: 0, total_tool_calls: 0 };
|
||||
throw e;
|
||||
}
|
||||
},
|
||||
|
||||
async loadByModel() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/usage/by-model');
|
||||
this.byModel = data.models || [];
|
||||
} catch(e) { this.byModel = []; }
|
||||
},
|
||||
|
||||
async loadByAgent() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/usage');
|
||||
this.byAgent = data.agents || [];
|
||||
} catch(e) { this.byAgent = []; }
|
||||
},
|
||||
|
||||
async loadDailyCosts() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/usage/daily');
|
||||
this.dailyCosts = data.days || [];
|
||||
this.todayCost = data.today_cost_usd || 0;
|
||||
this.firstEventDate = data.first_event_date || null;
|
||||
} catch(e) {
|
||||
this.dailyCosts = [];
|
||||
this.todayCost = 0;
|
||||
this.firstEventDate = null;
|
||||
}
|
||||
},
|
||||
|
||||
formatTokens(n) {
|
||||
if (!n) return '0';
|
||||
if (n >= 1000000) return (n / 1000000).toFixed(2) + 'M';
|
||||
if (n >= 1000) return (n / 1000).toFixed(1) + 'K';
|
||||
return String(n);
|
||||
},
|
||||
|
||||
formatCost(c) {
|
||||
if (!c) return '$0.00';
|
||||
if (c < 0.01) return '$' + c.toFixed(4);
|
||||
return '$' + c.toFixed(2);
|
||||
},
|
||||
|
||||
maxTokens() {
|
||||
var max = 0;
|
||||
this.byModel.forEach(function(m) {
|
||||
var t = (m.total_input_tokens || 0) + (m.total_output_tokens || 0);
|
||||
if (t > max) max = t;
|
||||
});
|
||||
return max || 1;
|
||||
},
|
||||
|
||||
barWidth(m) {
|
||||
var t = (m.total_input_tokens || 0) + (m.total_output_tokens || 0);
|
||||
return Math.max(2, Math.round((t / this.maxTokens()) * 100)) + '%';
|
||||
},
|
||||
|
||||
// ── Cost tab helpers ──
|
||||
|
||||
avgCostPerMessage() {
|
||||
var count = this.summary.call_count || 0;
|
||||
if (count === 0) return 0;
|
||||
return (this.summary.total_cost_usd || 0) / count;
|
||||
},
|
||||
|
||||
projectedMonthlyCost() {
|
||||
if (!this.firstEventDate || !this.summary.total_cost_usd) return 0;
|
||||
var first = new Date(this.firstEventDate);
|
||||
var now = new Date();
|
||||
var diffMs = now.getTime() - first.getTime();
|
||||
var diffDays = diffMs / (1000 * 60 * 60 * 24);
|
||||
if (diffDays < 1) diffDays = 1;
|
||||
return (this.summary.total_cost_usd / diffDays) * 30;
|
||||
},
|
||||
|
||||
// ── Provider aggregation from byModel data ──
|
||||
|
||||
costByProvider() {
|
||||
var providerMap = {};
|
||||
var self = this;
|
||||
this.byModel.forEach(function(m) {
|
||||
var provider = self._extractProvider(m.model);
|
||||
if (!providerMap[provider]) {
|
||||
providerMap[provider] = { provider: provider, cost: 0, tokens: 0, calls: 0 };
|
||||
}
|
||||
providerMap[provider].cost += (m.total_cost_usd || 0);
|
||||
providerMap[provider].tokens += (m.total_input_tokens || 0) + (m.total_output_tokens || 0);
|
||||
providerMap[provider].calls += (m.call_count || 0);
|
||||
});
|
||||
var result = [];
|
||||
for (var key in providerMap) {
|
||||
if (providerMap.hasOwnProperty(key)) {
|
||||
result.push(providerMap[key]);
|
||||
}
|
||||
}
|
||||
result.sort(function(a, b) { return b.cost - a.cost; });
|
||||
return result;
|
||||
},
|
||||
|
||||
_extractProvider(modelName) {
|
||||
if (!modelName) return 'Unknown';
|
||||
var lower = modelName.toLowerCase();
|
||||
if (lower.indexOf('claude') !== -1 || lower.indexOf('haiku') !== -1 || lower.indexOf('sonnet') !== -1 || lower.indexOf('opus') !== -1) return 'Anthropic';
|
||||
if (lower.indexOf('gemini') !== -1 || lower.indexOf('gemma') !== -1) return 'Google';
|
||||
if (lower.indexOf('gpt') !== -1 || lower.indexOf('o1') !== -1 || lower.indexOf('o3') !== -1 || lower.indexOf('o4') !== -1) return 'OpenAI';
|
||||
if (lower.indexOf('llama') !== -1 || lower.indexOf('mixtral') !== -1 || lower.indexOf('groq') !== -1) return 'Groq';
|
||||
if (lower.indexOf('deepseek') !== -1) return 'DeepSeek';
|
||||
if (lower.indexOf('mistral') !== -1) return 'Mistral';
|
||||
if (lower.indexOf('command') !== -1 || lower.indexOf('cohere') !== -1) return 'Cohere';
|
||||
if (lower.indexOf('grok') !== -1) return 'xAI';
|
||||
if (lower.indexOf('jamba') !== -1) return 'AI21';
|
||||
if (lower.indexOf('qwen') !== -1) return 'Together';
|
||||
return 'Other';
|
||||
},
|
||||
|
||||
// ── Donut chart (stroke-dasharray on circles) ──
|
||||
|
||||
donutSegments() {
|
||||
var providers = this.costByProvider();
|
||||
var total = 0;
|
||||
var colors = this._chartColors;
|
||||
providers.forEach(function(p) { total += p.cost; });
|
||||
if (total === 0) return [];
|
||||
|
||||
var segments = [];
|
||||
var offset = 0;
|
||||
var circumference = 2 * Math.PI * 60; // r=60
|
||||
for (var i = 0; i < providers.length; i++) {
|
||||
var pct = providers[i].cost / total;
|
||||
var dashLen = pct * circumference;
|
||||
segments.push({
|
||||
provider: providers[i].provider,
|
||||
cost: providers[i].cost,
|
||||
percent: Math.round(pct * 100),
|
||||
color: colors[i % colors.length],
|
||||
dasharray: dashLen + ' ' + (circumference - dashLen),
|
||||
dashoffset: -offset,
|
||||
circumference: circumference
|
||||
});
|
||||
offset += dashLen;
|
||||
}
|
||||
return segments;
|
||||
},
|
||||
|
||||
// ── Bar chart (last 7 days) ──
|
||||
|
||||
barChartData() {
|
||||
var days = this.dailyCosts;
|
||||
if (!days || days.length === 0) return [];
|
||||
var maxCost = 0;
|
||||
days.forEach(function(d) { if (d.cost_usd > maxCost) maxCost = d.cost_usd; });
|
||||
if (maxCost === 0) maxCost = 1;
|
||||
|
||||
var dayNames = ['Sun', 'Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat'];
|
||||
var result = [];
|
||||
for (var i = 0; i < days.length; i++) {
|
||||
var d = new Date(days[i].date + 'T12:00:00');
|
||||
var dayName = dayNames[d.getDay()] || '?';
|
||||
var heightPct = Math.max(2, Math.round((days[i].cost_usd / maxCost) * 120));
|
||||
result.push({
|
||||
date: days[i].date,
|
||||
dayName: dayName,
|
||||
cost: days[i].cost_usd,
|
||||
tokens: days[i].tokens,
|
||||
calls: days[i].calls,
|
||||
barHeight: heightPct
|
||||
});
|
||||
}
|
||||
return result;
|
||||
},
|
||||
|
||||
// ── Cost by model table (sorted by cost descending) ──
|
||||
|
||||
costByModelSorted() {
|
||||
var models = this.byModel.slice();
|
||||
models.sort(function(a, b) { return (b.total_cost_usd || 0) - (a.total_cost_usd || 0); });
|
||||
return models;
|
||||
},
|
||||
|
||||
maxModelCost() {
|
||||
var max = 0;
|
||||
this.byModel.forEach(function(m) {
|
||||
if ((m.total_cost_usd || 0) > max) max = m.total_cost_usd;
|
||||
});
|
||||
return max || 1;
|
||||
},
|
||||
|
||||
costBarWidth(m) {
|
||||
return Math.max(2, Math.round(((m.total_cost_usd || 0) / this.maxModelCost()) * 100)) + '%';
|
||||
},
|
||||
|
||||
modelTier(modelName) {
|
||||
if (!modelName) return 'unknown';
|
||||
var lower = modelName.toLowerCase();
|
||||
if (lower.indexOf('opus') !== -1 || lower.indexOf('o1') !== -1 || lower.indexOf('o3') !== -1 || lower.indexOf('deepseek-r1') !== -1) return 'frontier';
|
||||
if (lower.indexOf('sonnet') !== -1 || lower.indexOf('gpt-4') !== -1 || lower.indexOf('gemini-2.5') !== -1 || lower.indexOf('gemini-1.5-pro') !== -1) return 'smart';
|
||||
if (lower.indexOf('haiku') !== -1 || lower.indexOf('gpt-3.5') !== -1 || lower.indexOf('flash') !== -1 || lower.indexOf('mixtral') !== -1) return 'balanced';
|
||||
if (lower.indexOf('llama') !== -1 || lower.indexOf('groq') !== -1 || lower.indexOf('gemma') !== -1) return 'fast';
|
||||
return 'balanced';
|
||||
}
|
||||
};
|
||||
}
|
||||
544
crates/openfang-api/static/js/pages/wizard.js
Normal file
544
crates/openfang-api/static/js/pages/wizard.js
Normal file
@@ -0,0 +1,544 @@
|
||||
// OpenFang Setup Wizard — First-run guided setup (Provider + Agent + Channel)
|
||||
'use strict';
|
||||
|
||||
function wizardPage() {
|
||||
return {
|
||||
step: 1,
|
||||
totalSteps: 6,
|
||||
loading: false,
|
||||
error: '',
|
||||
|
||||
// Step 2: Provider setup
|
||||
providers: [],
|
||||
selectedProvider: '',
|
||||
apiKeyInput: '',
|
||||
testingProvider: false,
|
||||
testResult: null,
|
||||
savingKey: false,
|
||||
keySaved: false,
|
||||
|
||||
// Step 3: Agent creation
|
||||
templates: [
|
||||
{
|
||||
id: 'assistant',
|
||||
name: 'General Assistant',
|
||||
description: 'A versatile helper for everyday tasks, answering questions, and providing recommendations.',
|
||||
icon: 'GA',
|
||||
category: 'General',
|
||||
provider: 'deepseek',
|
||||
model: 'deepseek-chat',
|
||||
profile: 'balanced',
|
||||
system_prompt: 'You are a helpful, friendly assistant. Provide clear, accurate, and concise responses. Ask clarifying questions when needed.'
|
||||
},
|
||||
{
|
||||
id: 'coder',
|
||||
name: 'Code Helper',
|
||||
description: 'A programming-focused agent that writes, reviews, and debugs code across multiple languages.',
|
||||
icon: 'CH',
|
||||
category: 'Development',
|
||||
provider: 'deepseek',
|
||||
model: 'deepseek-chat',
|
||||
profile: 'precise',
|
||||
system_prompt: 'You are an expert programmer. Help users write clean, efficient code. Explain your reasoning. Follow best practices and conventions for the language being used.'
|
||||
},
|
||||
{
|
||||
id: 'researcher',
|
||||
name: 'Researcher',
|
||||
description: 'An analytical agent that breaks down complex topics, synthesizes information, and provides cited summaries.',
|
||||
icon: 'RS',
|
||||
category: 'Research',
|
||||
provider: 'gemini',
|
||||
model: 'gemini-2.5-flash',
|
||||
profile: 'balanced',
|
||||
system_prompt: 'You are a research analyst. Break down complex topics into clear explanations. Provide structured analysis with key findings. Cite sources when available.'
|
||||
},
|
||||
{
|
||||
id: 'writer',
|
||||
name: 'Writer',
|
||||
description: 'A creative writing agent that helps with drafting, editing, and improving written content of all kinds.',
|
||||
icon: 'WR',
|
||||
category: 'Writing',
|
||||
provider: 'deepseek',
|
||||
model: 'deepseek-chat',
|
||||
profile: 'creative',
|
||||
system_prompt: 'You are a skilled writer and editor. Help users create polished content. Adapt your tone and style to match the intended audience. Offer constructive suggestions for improvement.'
|
||||
},
|
||||
{
|
||||
id: 'data-analyst',
|
||||
name: 'Data Analyst',
|
||||
description: 'A data-focused agent that helps analyze datasets, create queries, and interpret statistical results.',
|
||||
icon: 'DA',
|
||||
category: 'Development',
|
||||
provider: 'gemini',
|
||||
model: 'gemini-2.5-flash',
|
||||
profile: 'precise',
|
||||
system_prompt: 'You are a data analysis expert. Help users understand their data, write SQL/Python queries, and interpret results. Present findings clearly with actionable insights.'
|
||||
},
|
||||
{
|
||||
id: 'devops',
|
||||
name: 'DevOps Engineer',
|
||||
description: 'A systems-focused agent for CI/CD, infrastructure, Docker, and deployment troubleshooting.',
|
||||
icon: 'DO',
|
||||
category: 'Development',
|
||||
provider: 'deepseek',
|
||||
model: 'deepseek-chat',
|
||||
profile: 'precise',
|
||||
system_prompt: 'You are a DevOps engineer. Help with CI/CD pipelines, Docker, Kubernetes, infrastructure as code, and deployment. Prioritize reliability and security.'
|
||||
},
|
||||
{
|
||||
id: 'support',
|
||||
name: 'Customer Support',
|
||||
description: 'A professional, empathetic agent for handling customer inquiries and resolving issues.',
|
||||
icon: 'CS',
|
||||
category: 'Business',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'balanced',
|
||||
system_prompt: 'You are a professional customer support representative. Be empathetic, patient, and solution-oriented. Acknowledge concerns before offering solutions. Escalate complex issues appropriately.'
|
||||
},
|
||||
{
|
||||
id: 'tutor',
|
||||
name: 'Tutor',
|
||||
description: 'A patient educational agent that explains concepts step-by-step and adapts to the learner\'s level.',
|
||||
icon: 'TU',
|
||||
category: 'General',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'balanced',
|
||||
system_prompt: 'You are a patient and encouraging tutor. Explain concepts step by step, starting from fundamentals. Use analogies and examples. Check understanding before moving on. Adapt to the learner\'s pace.'
|
||||
},
|
||||
{
|
||||
id: 'api-designer',
|
||||
name: 'API Designer',
|
||||
description: 'An agent specialized in RESTful API design, OpenAPI specs, and integration architecture.',
|
||||
icon: 'AD',
|
||||
category: 'Development',
|
||||
provider: 'deepseek',
|
||||
model: 'deepseek-chat',
|
||||
profile: 'precise',
|
||||
system_prompt: 'You are an API design expert. Help users design clean, consistent RESTful APIs following best practices. Cover endpoint naming, request/response schemas, error handling, and versioning.'
|
||||
},
|
||||
{
|
||||
id: 'meeting-notes',
|
||||
name: 'Meeting Notes',
|
||||
description: 'Summarizes meeting transcripts into structured notes with action items and key decisions.',
|
||||
icon: 'MN',
|
||||
category: 'Business',
|
||||
provider: 'groq',
|
||||
model: 'llama-3.3-70b-versatile',
|
||||
profile: 'precise',
|
||||
system_prompt: 'You are a meeting summarizer. When given a meeting transcript or notes, produce a structured summary with: key decisions, action items (with owners), discussion highlights, and follow-up questions.'
|
||||
}
|
||||
],
|
||||
selectedTemplate: 0,
|
||||
agentName: 'my-assistant',
|
||||
creatingAgent: false,
|
||||
createdAgent: null,
|
||||
|
||||
// Step 3: Category filtering
|
||||
templateCategory: 'All',
|
||||
get templateCategories() {
|
||||
var cats = { 'All': true };
|
||||
this.templates.forEach(function(t) { if (t.category) cats[t.category] = true; });
|
||||
return Object.keys(cats);
|
||||
},
|
||||
get filteredTemplates() {
|
||||
var cat = this.templateCategory;
|
||||
if (cat === 'All') return this.templates;
|
||||
return this.templates.filter(function(t) { return t.category === cat; });
|
||||
},
|
||||
|
||||
// Step 3: Profile/tool descriptions
|
||||
profileDescriptions: {
|
||||
minimal: { label: 'Minimal', desc: 'Read-only file access' },
|
||||
coding: { label: 'Coding', desc: 'Files + shell + web fetch' },
|
||||
research: { label: 'Research', desc: 'Web search + file read/write' },
|
||||
balanced: { label: 'Balanced', desc: 'General-purpose tool set' },
|
||||
precise: { label: 'Precise', desc: 'Focused tool set for accuracy' },
|
||||
creative: { label: 'Creative', desc: 'Full tools with creative emphasis' },
|
||||
full: { label: 'Full', desc: 'All 35+ tools' }
|
||||
},
|
||||
profileInfo: function(name) { return this.profileDescriptions[name] || { label: name, desc: '' }; },
|
||||
|
||||
// Step 4: Try It chat
|
||||
tryItMessages: [],
|
||||
tryItInput: '',
|
||||
tryItSending: false,
|
||||
suggestedMessages: {
|
||||
'General': ['What can you help me with?', 'Tell me a fun fact', 'Summarize the latest AI news'],
|
||||
'Development': ['Write a Python hello world', 'Explain async/await', 'Review this code snippet'],
|
||||
'Research': ['Explain quantum computing simply', 'Compare React vs Vue', 'What are the latest trends in AI?'],
|
||||
'Writing': ['Help me write a professional email', 'Improve this paragraph', 'Write a blog intro about AI'],
|
||||
'Business': ['Draft a meeting agenda', 'How do I handle a complaint?', 'Create a project status update']
|
||||
},
|
||||
get currentSuggestions() {
|
||||
var tpl = this.templates[this.selectedTemplate];
|
||||
var cat = tpl ? tpl.category : 'General';
|
||||
return this.suggestedMessages[cat] || this.suggestedMessages['General'];
|
||||
},
|
||||
async sendTryItMessage(text) {
|
||||
if (!text || !text.trim() || !this.createdAgent || this.tryItSending) return;
|
||||
text = text.trim();
|
||||
this.tryItInput = '';
|
||||
this.tryItMessages.push({ role: 'user', text: text });
|
||||
this.tryItSending = true;
|
||||
try {
|
||||
var res = await OpenFangAPI.post('/api/agents/' + this.createdAgent.id + '/message', { message: text });
|
||||
this.tryItMessages.push({ role: 'agent', text: res.response || '(no response)' });
|
||||
localStorage.setItem('of-first-msg', 'true');
|
||||
} catch(e) {
|
||||
this.tryItMessages.push({ role: 'agent', text: 'Error: ' + (e.message || 'Could not reach agent') });
|
||||
}
|
||||
this.tryItSending = false;
|
||||
},
|
||||
|
||||
// Step 5: Channel setup (optional)
|
||||
channelType: '',
|
||||
channelOptions: [
|
||||
{
|
||||
name: 'telegram',
|
||||
display_name: 'Telegram',
|
||||
icon: 'TG',
|
||||
description: 'Connect your agent to a Telegram bot for messaging.',
|
||||
token_label: 'Bot Token',
|
||||
token_placeholder: '123456:ABC-DEF1234ghIkl-zyx57W2v1u123ew11',
|
||||
token_env: 'TELEGRAM_BOT_TOKEN',
|
||||
help: 'Create a bot via @BotFather on Telegram to get your token.'
|
||||
},
|
||||
{
|
||||
name: 'discord',
|
||||
display_name: 'Discord',
|
||||
icon: 'DC',
|
||||
description: 'Connect your agent to a Discord server via bot token.',
|
||||
token_label: 'Bot Token',
|
||||
token_placeholder: 'MTIz...abc',
|
||||
token_env: 'DISCORD_BOT_TOKEN',
|
||||
help: 'Create a Discord application at discord.com/developers and add a bot.'
|
||||
},
|
||||
{
|
||||
name: 'slack',
|
||||
display_name: 'Slack',
|
||||
icon: 'SL',
|
||||
description: 'Connect your agent to a Slack workspace.',
|
||||
token_label: 'Bot Token',
|
||||
token_placeholder: 'xoxb-...',
|
||||
token_env: 'SLACK_BOT_TOKEN',
|
||||
help: 'Create a Slack app at api.slack.com/apps and install it to your workspace.'
|
||||
}
|
||||
],
|
||||
channelToken: '',
|
||||
configuringChannel: false,
|
||||
channelConfigured: false,
|
||||
|
||||
// Step 5: Summary
|
||||
setupSummary: {
|
||||
provider: '',
|
||||
agent: '',
|
||||
channel: ''
|
||||
},
|
||||
|
||||
// ── Lifecycle ──
|
||||
|
||||
async loadData() {
|
||||
this.loading = true;
|
||||
this.error = '';
|
||||
try {
|
||||
await this.loadProviders();
|
||||
} catch(e) {
|
||||
this.error = e.message || 'Could not load setup data.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
// ── Navigation ──
|
||||
|
||||
nextStep() {
|
||||
if (this.step === 3 && !this.createdAgent) {
|
||||
// Skip "Try It" if no agent was created
|
||||
this.step = 5;
|
||||
} else if (this.step < this.totalSteps) {
|
||||
this.step++;
|
||||
}
|
||||
},
|
||||
|
||||
prevStep() {
|
||||
if (this.step === 5 && !this.createdAgent) {
|
||||
// Skip back past "Try It" if no agent was created
|
||||
this.step = 3;
|
||||
} else if (this.step > 1) {
|
||||
this.step--;
|
||||
}
|
||||
},
|
||||
|
||||
goToStep(n) {
|
||||
if (n >= 1 && n <= this.totalSteps) {
|
||||
if (n === 4 && !this.createdAgent) return; // Can't go to Try It without agent
|
||||
this.step = n;
|
||||
}
|
||||
},
|
||||
|
||||
stepLabel(n) {
|
||||
var labels = ['Welcome', 'Provider', 'Agent', 'Try It', 'Channel', 'Done'];
|
||||
return labels[n - 1] || '';
|
||||
},
|
||||
|
||||
get canGoNext() {
|
||||
if (this.step === 2) return this.keySaved || this.hasConfiguredProvider;
|
||||
if (this.step === 3) return this.agentName.trim().length > 0;
|
||||
return true;
|
||||
},
|
||||
|
||||
get hasConfiguredProvider() {
|
||||
var self = this;
|
||||
return this.providers.some(function(p) {
|
||||
return p.auth_status === 'configured';
|
||||
});
|
||||
},
|
||||
|
||||
// ── Step 2: Providers ──
|
||||
|
||||
async loadProviders() {
|
||||
try {
|
||||
var data = await OpenFangAPI.get('/api/providers');
|
||||
this.providers = data.providers || [];
|
||||
// Pre-select first unconfigured provider, or first one
|
||||
var unconfigured = this.providers.filter(function(p) {
|
||||
return p.auth_status !== 'configured' && p.api_key_env;
|
||||
});
|
||||
if (unconfigured.length > 0) {
|
||||
this.selectedProvider = unconfigured[0].id;
|
||||
} else if (this.providers.length > 0) {
|
||||
this.selectedProvider = this.providers[0].id;
|
||||
}
|
||||
} catch(e) { this.providers = []; }
|
||||
},
|
||||
|
||||
get selectedProviderObj() {
|
||||
var self = this;
|
||||
var match = this.providers.filter(function(p) { return p.id === self.selectedProvider; });
|
||||
return match.length > 0 ? match[0] : null;
|
||||
},
|
||||
|
||||
get popularProviders() {
|
||||
var popular = ['anthropic', 'openai', 'gemini', 'groq', 'deepseek', 'openrouter'];
|
||||
return this.providers.filter(function(p) {
|
||||
return popular.indexOf(p.id) >= 0;
|
||||
}).sort(function(a, b) {
|
||||
return popular.indexOf(a.id) - popular.indexOf(b.id);
|
||||
});
|
||||
},
|
||||
|
||||
get otherProviders() {
|
||||
var popular = ['anthropic', 'openai', 'gemini', 'groq', 'deepseek', 'openrouter'];
|
||||
return this.providers.filter(function(p) {
|
||||
return popular.indexOf(p.id) < 0;
|
||||
});
|
||||
},
|
||||
|
||||
selectProvider(id) {
|
||||
this.selectedProvider = id;
|
||||
this.apiKeyInput = '';
|
||||
this.testResult = null;
|
||||
this.keySaved = false;
|
||||
},
|
||||
|
||||
providerHelp: function(id) {
|
||||
var help = {
|
||||
anthropic: { url: 'https://console.anthropic.com/settings/keys', text: 'Get your key from the Anthropic Console' },
|
||||
openai: { url: 'https://platform.openai.com/api-keys', text: 'Get your key from the OpenAI Platform' },
|
||||
gemini: { url: 'https://aistudio.google.com/apikey', text: 'Get your key from Google AI Studio' },
|
||||
groq: { url: 'https://console.groq.com/keys', text: 'Get your key from the Groq Console (free tier available)' },
|
||||
deepseek: { url: 'https://platform.deepseek.com/api_keys', text: 'Get your key from the DeepSeek Platform (very affordable)' },
|
||||
openrouter: { url: 'https://openrouter.ai/keys', text: 'Get your key from OpenRouter (access 100+ models with one key)' },
|
||||
mistral: { url: 'https://console.mistral.ai/api-keys', text: 'Get your key from the Mistral Console' },
|
||||
together: { url: 'https://api.together.xyz/settings/api-keys', text: 'Get your key from Together AI' },
|
||||
fireworks: { url: 'https://fireworks.ai/account/api-keys', text: 'Get your key from Fireworks AI' },
|
||||
perplexity: { url: 'https://www.perplexity.ai/settings/api', text: 'Get your key from Perplexity Settings' },
|
||||
cohere: { url: 'https://dashboard.cohere.com/api-keys', text: 'Get your key from the Cohere Dashboard' },
|
||||
xai: { url: 'https://console.x.ai/', text: 'Get your key from the xAI Console' }
|
||||
};
|
||||
return help[id] || null;
|
||||
},
|
||||
|
||||
providerIsConfigured(p) {
|
||||
return p && p.auth_status === 'configured';
|
||||
},
|
||||
|
||||
async saveKey() {
|
||||
var provider = this.selectedProviderObj;
|
||||
if (!provider) return;
|
||||
var key = this.apiKeyInput.trim();
|
||||
if (!key) {
|
||||
OpenFangToast.error('Please enter an API key');
|
||||
return;
|
||||
}
|
||||
this.savingKey = true;
|
||||
try {
|
||||
await OpenFangAPI.post('/api/providers/' + encodeURIComponent(provider.id) + '/key', { key: key });
|
||||
this.apiKeyInput = '';
|
||||
this.keySaved = true;
|
||||
this.setupSummary.provider = provider.display_name;
|
||||
OpenFangToast.success('API key saved for ' + provider.display_name);
|
||||
await this.loadProviders();
|
||||
// Auto-test after saving
|
||||
await this.testKey();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to save key: ' + e.message);
|
||||
}
|
||||
this.savingKey = false;
|
||||
},
|
||||
|
||||
async testKey() {
|
||||
var provider = this.selectedProviderObj;
|
||||
if (!provider) return;
|
||||
this.testingProvider = true;
|
||||
this.testResult = null;
|
||||
try {
|
||||
var result = await OpenFangAPI.post('/api/providers/' + encodeURIComponent(provider.id) + '/test', {});
|
||||
this.testResult = result;
|
||||
if (result.status === 'ok') {
|
||||
OpenFangToast.success(provider.display_name + ' connected (' + (result.latency_ms || '?') + 'ms)');
|
||||
} else {
|
||||
OpenFangToast.error(provider.display_name + ': ' + (result.error || 'Connection failed'));
|
||||
}
|
||||
} catch(e) {
|
||||
this.testResult = { status: 'error', error: e.message };
|
||||
OpenFangToast.error('Test failed: ' + e.message);
|
||||
}
|
||||
this.testingProvider = false;
|
||||
},
|
||||
|
||||
// ── Step 3: Agent creation ──
|
||||
|
||||
selectTemplate(index) {
|
||||
this.selectedTemplate = index;
|
||||
var tpl = this.templates[index];
|
||||
if (tpl) {
|
||||
this.agentName = tpl.name.toLowerCase().replace(/\s+/g, '-');
|
||||
}
|
||||
},
|
||||
|
||||
async createAgent() {
|
||||
var tpl = this.templates[this.selectedTemplate];
|
||||
if (!tpl) return;
|
||||
var name = this.agentName.trim();
|
||||
if (!name) {
|
||||
OpenFangToast.error('Please enter a name for your agent');
|
||||
return;
|
||||
}
|
||||
|
||||
// Use the provider the user just configured, or the template default
|
||||
var provider = tpl.provider;
|
||||
var model = tpl.model;
|
||||
if (this.selectedProviderObj && this.providerIsConfigured(this.selectedProviderObj)) {
|
||||
provider = this.selectedProviderObj.id;
|
||||
// Use a sensible default model for the provider
|
||||
model = this.defaultModelForProvider(provider) || tpl.model;
|
||||
}
|
||||
|
||||
var toml = '[agent]\n';
|
||||
toml += 'name = "' + name.replace(/"/g, '\\"') + '"\n';
|
||||
toml += 'description = "' + tpl.description.replace(/"/g, '\\"') + '"\n';
|
||||
toml += 'profile = "' + tpl.profile + '"\n\n';
|
||||
toml += '[model]\nprovider = "' + provider + '"\n';
|
||||
toml += 'name = "' + model + '"\n\n';
|
||||
toml += '[prompt]\nsystem = """\n' + tpl.system_prompt + '\n"""\n';
|
||||
|
||||
this.creatingAgent = true;
|
||||
try {
|
||||
var res = await OpenFangAPI.post('/api/agents', { manifest_toml: toml });
|
||||
if (res.agent_id) {
|
||||
this.createdAgent = { id: res.agent_id, name: res.name || name };
|
||||
this.setupSummary.agent = res.name || name;
|
||||
OpenFangToast.success('Agent "' + (res.name || name) + '" created');
|
||||
await Alpine.store('app').refreshAgents();
|
||||
} else {
|
||||
OpenFangToast.error('Failed: ' + (res.error || 'Unknown error'));
|
||||
}
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to create agent: ' + e.message);
|
||||
}
|
||||
this.creatingAgent = false;
|
||||
},
|
||||
|
||||
defaultModelForProvider(providerId) {
|
||||
var defaults = {
|
||||
anthropic: 'claude-sonnet-4-20250514',
|
||||
openai: 'gpt-4o',
|
||||
gemini: 'gemini-2.5-flash',
|
||||
groq: 'llama-3.3-70b-versatile',
|
||||
deepseek: 'deepseek-chat',
|
||||
openrouter: 'openrouter/auto',
|
||||
mistral: 'mistral-large-latest',
|
||||
together: 'meta-llama/Llama-3-70b-chat-hf',
|
||||
fireworks: 'accounts/fireworks/models/llama-v3p1-70b-instruct',
|
||||
perplexity: 'llama-3.1-sonar-large-128k-online',
|
||||
cohere: 'command-r-plus',
|
||||
xai: 'grok-2'
|
||||
};
|
||||
return defaults[providerId] || '';
|
||||
},
|
||||
|
||||
// ── Step 5: Channel setup ──
|
||||
|
||||
selectChannel(name) {
|
||||
if (this.channelType === name) {
|
||||
this.channelType = '';
|
||||
this.channelToken = '';
|
||||
} else {
|
||||
this.channelType = name;
|
||||
this.channelToken = '';
|
||||
}
|
||||
},
|
||||
|
||||
get selectedChannelObj() {
|
||||
var self = this;
|
||||
var match = this.channelOptions.filter(function(ch) { return ch.name === self.channelType; });
|
||||
return match.length > 0 ? match[0] : null;
|
||||
},
|
||||
|
||||
async configureChannel() {
|
||||
var ch = this.selectedChannelObj;
|
||||
if (!ch) return;
|
||||
var token = this.channelToken.trim();
|
||||
if (!token) {
|
||||
OpenFangToast.error('Please enter the ' + ch.token_label);
|
||||
return;
|
||||
}
|
||||
this.configuringChannel = true;
|
||||
try {
|
||||
var fields = {};
|
||||
fields[ch.token_env.toLowerCase()] = token;
|
||||
fields.token = token;
|
||||
await OpenFangAPI.post('/api/channels/' + ch.name + '/configure', { fields: fields });
|
||||
this.channelConfigured = true;
|
||||
this.setupSummary.channel = ch.display_name;
|
||||
OpenFangToast.success(ch.display_name + ' configured and activated.');
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed: ' + (e.message || 'Unknown error'));
|
||||
}
|
||||
this.configuringChannel = false;
|
||||
},
|
||||
|
||||
// ── Step 6: Finish ──
|
||||
|
||||
finish() {
|
||||
localStorage.setItem('openfang-onboarded', 'true');
|
||||
Alpine.store('app').showOnboarding = false;
|
||||
// Navigate to agents with chat if an agent was created, otherwise overview
|
||||
if (this.createdAgent) {
|
||||
var agent = this.createdAgent;
|
||||
Alpine.store('app').pendingAgent = { id: agent.id, name: agent.name, model_provider: '?', model_name: '?' };
|
||||
window.location.hash = 'agents';
|
||||
} else {
|
||||
window.location.hash = 'overview';
|
||||
}
|
||||
},
|
||||
|
||||
finishAndDismiss() {
|
||||
localStorage.setItem('openfang-onboarded', 'true');
|
||||
Alpine.store('app').showOnboarding = false;
|
||||
window.location.hash = 'overview';
|
||||
}
|
||||
};
|
||||
}
|
||||
435
crates/openfang-api/static/js/pages/workflow-builder.js
Normal file
435
crates/openfang-api/static/js/pages/workflow-builder.js
Normal file
@@ -0,0 +1,435 @@
|
||||
// OpenFang Visual Workflow Builder — Drag-and-drop workflow designer
|
||||
'use strict';
|
||||
|
||||
function workflowBuilder() {
|
||||
return {
|
||||
// -- Canvas state --
|
||||
nodes: [],
|
||||
connections: [],
|
||||
selectedNode: null,
|
||||
selectedConnection: null,
|
||||
dragging: null,
|
||||
dragOffset: { x: 0, y: 0 },
|
||||
connecting: null, // { fromId, fromPort }
|
||||
connectPreview: null, // { x, y } mouse position during connect drag
|
||||
canvasOffset: { x: 0, y: 0 },
|
||||
canvasDragging: false,
|
||||
canvasDragStart: { x: 0, y: 0 },
|
||||
zoom: 1,
|
||||
nextId: 1,
|
||||
workflowName: '',
|
||||
workflowDescription: '',
|
||||
showSaveModal: false,
|
||||
showNodeEditor: false,
|
||||
showTomlPreview: false,
|
||||
tomlOutput: '',
|
||||
agents: [],
|
||||
_canvasEl: null,
|
||||
|
||||
// Node types with their configs
|
||||
nodeTypes: [
|
||||
{ type: 'agent', label: 'Agent Step', color: '#6366f1', icon: 'A', ports: { in: 1, out: 1 } },
|
||||
{ type: 'parallel', label: 'Parallel Fan-out', color: '#f59e0b', icon: 'P', ports: { in: 1, out: 3 } },
|
||||
{ type: 'condition', label: 'Condition', color: '#10b981', icon: '?', ports: { in: 1, out: 2 } },
|
||||
{ type: 'loop', label: 'Loop', color: '#ef4444', icon: 'L', ports: { in: 1, out: 1 } },
|
||||
{ type: 'collect', label: 'Collect', color: '#8b5cf6', icon: 'C', ports: { in: 3, out: 1 } },
|
||||
{ type: 'start', label: 'Start', color: '#22c55e', icon: 'S', ports: { in: 0, out: 1 } },
|
||||
{ type: 'end', label: 'End', color: '#ef4444', icon: 'E', ports: { in: 1, out: 0 } }
|
||||
],
|
||||
|
||||
async init() {
|
||||
var self = this;
|
||||
// Load agents for the agent step dropdown
|
||||
try {
|
||||
var list = await OpenFangAPI.get('/api/agents');
|
||||
self.agents = Array.isArray(list) ? list : [];
|
||||
} catch(_) {
|
||||
self.agents = [];
|
||||
}
|
||||
// Add default start node
|
||||
self.addNode('start', 60, 200);
|
||||
},
|
||||
|
||||
// ── Node Management ──────────────────────────────────
|
||||
|
||||
addNode: function(type, x, y) {
|
||||
var def = null;
|
||||
for (var i = 0; i < this.nodeTypes.length; i++) {
|
||||
if (this.nodeTypes[i].type === type) { def = this.nodeTypes[i]; break; }
|
||||
}
|
||||
if (!def) return;
|
||||
var node = {
|
||||
id: 'node-' + this.nextId++,
|
||||
type: type,
|
||||
label: def.label,
|
||||
color: def.color,
|
||||
icon: def.icon,
|
||||
x: x || 200,
|
||||
y: y || 200,
|
||||
width: 180,
|
||||
height: 70,
|
||||
ports: { in: def.ports.in, out: def.ports.out },
|
||||
config: {}
|
||||
};
|
||||
if (type === 'agent') {
|
||||
node.config = { agent_name: '', prompt: '{{input}}', model: '' };
|
||||
} else if (type === 'condition') {
|
||||
node.config = { expression: '', true_label: 'Yes', false_label: 'No' };
|
||||
} else if (type === 'loop') {
|
||||
node.config = { max_iterations: 5, until: '' };
|
||||
} else if (type === 'parallel') {
|
||||
node.config = { fan_count: 3 };
|
||||
} else if (type === 'collect') {
|
||||
node.config = { strategy: 'all' };
|
||||
}
|
||||
this.nodes.push(node);
|
||||
return node;
|
||||
},
|
||||
|
||||
deleteNode: function(nodeId) {
|
||||
this.connections = this.connections.filter(function(c) {
|
||||
return c.from !== nodeId && c.to !== nodeId;
|
||||
});
|
||||
this.nodes = this.nodes.filter(function(n) { return n.id !== nodeId; });
|
||||
if (this.selectedNode && this.selectedNode.id === nodeId) {
|
||||
this.selectedNode = null;
|
||||
this.showNodeEditor = false;
|
||||
}
|
||||
},
|
||||
|
||||
duplicateNode: function(node) {
|
||||
var newNode = this.addNode(node.type, node.x + 30, node.y + 30);
|
||||
if (newNode) {
|
||||
newNode.config = JSON.parse(JSON.stringify(node.config));
|
||||
newNode.label = node.label + ' copy';
|
||||
}
|
||||
},
|
||||
|
||||
getNode: function(id) {
|
||||
for (var i = 0; i < this.nodes.length; i++) {
|
||||
if (this.nodes[i].id === id) return this.nodes[i];
|
||||
}
|
||||
return null;
|
||||
},
|
||||
|
||||
// ── Port Positions ───────────────────────────────────
|
||||
|
||||
getInputPortPos: function(node, portIndex) {
|
||||
var total = node.ports.in;
|
||||
var spacing = node.width / (total + 1);
|
||||
return { x: node.x + spacing * (portIndex + 1), y: node.y };
|
||||
},
|
||||
|
||||
getOutputPortPos: function(node, portIndex) {
|
||||
var total = node.ports.out;
|
||||
var spacing = node.width / (total + 1);
|
||||
return { x: node.x + spacing * (portIndex + 1), y: node.y + node.height };
|
||||
},
|
||||
|
||||
// ── Connection Management ────────────────────────────
|
||||
|
||||
startConnect: function(nodeId, portIndex, e) {
|
||||
e.stopPropagation();
|
||||
this.connecting = { fromId: nodeId, fromPort: portIndex };
|
||||
var node = this.getNode(nodeId);
|
||||
var pos = this.getOutputPortPos(node, portIndex);
|
||||
this.connectPreview = { x: pos.x, y: pos.y };
|
||||
},
|
||||
|
||||
endConnect: function(nodeId, portIndex, e) {
|
||||
e.stopPropagation();
|
||||
if (!this.connecting) return;
|
||||
if (this.connecting.fromId === nodeId) {
|
||||
this.connecting = null;
|
||||
this.connectPreview = null;
|
||||
return;
|
||||
}
|
||||
// Check for duplicate
|
||||
var fromId = this.connecting.fromId;
|
||||
var fromPort = this.connecting.fromPort;
|
||||
var dup = false;
|
||||
for (var i = 0; i < this.connections.length; i++) {
|
||||
var c = this.connections[i];
|
||||
if (c.from === fromId && c.fromPort === fromPort && c.to === nodeId && c.toPort === portIndex) {
|
||||
dup = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!dup) {
|
||||
this.connections.push({
|
||||
id: 'conn-' + this.nextId++,
|
||||
from: fromId,
|
||||
fromPort: fromPort,
|
||||
to: nodeId,
|
||||
toPort: portIndex
|
||||
});
|
||||
}
|
||||
this.connecting = null;
|
||||
this.connectPreview = null;
|
||||
},
|
||||
|
||||
deleteConnection: function(connId) {
|
||||
this.connections = this.connections.filter(function(c) { return c.id !== connId; });
|
||||
this.selectedConnection = null;
|
||||
},
|
||||
|
||||
// ── Drag Handling ────────────────────────────────────
|
||||
|
||||
onNodeMouseDown: function(node, e) {
|
||||
e.stopPropagation();
|
||||
this.selectedNode = node;
|
||||
this.selectedConnection = null;
|
||||
this.dragging = node.id;
|
||||
var rect = this._getCanvasRect();
|
||||
this.dragOffset = {
|
||||
x: (e.clientX - rect.left) / this.zoom - this.canvasOffset.x - node.x,
|
||||
y: (e.clientY - rect.top) / this.zoom - this.canvasOffset.y - node.y
|
||||
};
|
||||
},
|
||||
|
||||
onCanvasMouseDown: function(e) {
|
||||
if (e.target.closest('.wf-node') || e.target.closest('.wf-port')) return;
|
||||
this.selectedNode = null;
|
||||
this.selectedConnection = null;
|
||||
this.showNodeEditor = false;
|
||||
// Start canvas pan
|
||||
this.canvasDragging = true;
|
||||
this.canvasDragStart = { x: e.clientX - this.canvasOffset.x * this.zoom, y: e.clientY - this.canvasOffset.y * this.zoom };
|
||||
},
|
||||
|
||||
onCanvasMouseMove: function(e) {
|
||||
var rect = this._getCanvasRect();
|
||||
if (this.dragging) {
|
||||
var node = this.getNode(this.dragging);
|
||||
if (node) {
|
||||
node.x = Math.max(0, (e.clientX - rect.left) / this.zoom - this.canvasOffset.x - this.dragOffset.x);
|
||||
node.y = Math.max(0, (e.clientY - rect.top) / this.zoom - this.canvasOffset.y - this.dragOffset.y);
|
||||
}
|
||||
} else if (this.connecting) {
|
||||
this.connectPreview = {
|
||||
x: (e.clientX - rect.left) / this.zoom - this.canvasOffset.x,
|
||||
y: (e.clientY - rect.top) / this.zoom - this.canvasOffset.y
|
||||
};
|
||||
} else if (this.canvasDragging) {
|
||||
this.canvasOffset = {
|
||||
x: (e.clientX - this.canvasDragStart.x) / this.zoom,
|
||||
y: (e.clientY - this.canvasDragStart.y) / this.zoom
|
||||
};
|
||||
}
|
||||
},
|
||||
|
||||
onCanvasMouseUp: function() {
|
||||
this.dragging = null;
|
||||
this.connecting = null;
|
||||
this.connectPreview = null;
|
||||
this.canvasDragging = false;
|
||||
},
|
||||
|
||||
onCanvasWheel: function(e) {
|
||||
e.preventDefault();
|
||||
var delta = e.deltaY > 0 ? -0.05 : 0.05;
|
||||
this.zoom = Math.max(0.3, Math.min(2, this.zoom + delta));
|
||||
},
|
||||
|
||||
_getCanvasRect: function() {
|
||||
if (!this._canvasEl) {
|
||||
this._canvasEl = document.getElementById('wf-canvas');
|
||||
}
|
||||
return this._canvasEl ? this._canvasEl.getBoundingClientRect() : { left: 0, top: 0 };
|
||||
},
|
||||
|
||||
// ── Connection Path ──────────────────────────────────
|
||||
|
||||
getConnectionPath: function(conn) {
|
||||
var fromNode = this.getNode(conn.from);
|
||||
var toNode = this.getNode(conn.to);
|
||||
if (!fromNode || !toNode) return '';
|
||||
var from = this.getOutputPortPos(fromNode, conn.fromPort);
|
||||
var to = this.getInputPortPos(toNode, conn.toPort);
|
||||
var dy = Math.abs(to.y - from.y);
|
||||
var cp = Math.max(40, dy * 0.5);
|
||||
return 'M ' + from.x + ' ' + from.y + ' C ' + from.x + ' ' + (from.y + cp) + ' ' + to.x + ' ' + (to.y - cp) + ' ' + to.x + ' ' + to.y;
|
||||
},
|
||||
|
||||
getPreviewPath: function() {
|
||||
if (!this.connecting || !this.connectPreview) return '';
|
||||
var fromNode = this.getNode(this.connecting.fromId);
|
||||
if (!fromNode) return '';
|
||||
var from = this.getOutputPortPos(fromNode, this.connecting.fromPort);
|
||||
var to = this.connectPreview;
|
||||
var dy = Math.abs(to.y - from.y);
|
||||
var cp = Math.max(40, dy * 0.5);
|
||||
return 'M ' + from.x + ' ' + from.y + ' C ' + from.x + ' ' + (from.y + cp) + ' ' + to.x + ' ' + (to.y - cp) + ' ' + to.x + ' ' + to.y;
|
||||
},
|
||||
|
||||
// ── Node editor ──────────────────────────────────────
|
||||
|
||||
editNode: function(node) {
|
||||
this.selectedNode = node;
|
||||
this.showNodeEditor = true;
|
||||
},
|
||||
|
||||
// ── TOML Generation ──────────────────────────────────
|
||||
|
||||
generateToml: function() {
|
||||
var self = this;
|
||||
var lines = [];
|
||||
lines.push('[workflow]');
|
||||
lines.push('name = "' + (this.workflowName || 'untitled') + '"');
|
||||
lines.push('description = "' + (this.workflowDescription || '') + '"');
|
||||
lines.push('');
|
||||
|
||||
// Topological sort the nodes (skip start/end for step generation)
|
||||
var stepNodes = this.nodes.filter(function(n) {
|
||||
return n.type !== 'start' && n.type !== 'end';
|
||||
});
|
||||
|
||||
for (var i = 0; i < stepNodes.length; i++) {
|
||||
var node = stepNodes[i];
|
||||
lines.push('[[workflow.steps]]');
|
||||
lines.push('name = "' + (node.label || 'step-' + (i + 1)) + '"');
|
||||
|
||||
if (node.type === 'agent') {
|
||||
lines.push('type = "agent"');
|
||||
if (node.config.agent_name) lines.push('agent_name = "' + node.config.agent_name + '"');
|
||||
lines.push('prompt = "' + (node.config.prompt || '{{input}}') + '"');
|
||||
if (node.config.model) lines.push('model = "' + node.config.model + '"');
|
||||
} else if (node.type === 'parallel') {
|
||||
lines.push('type = "fan_out"');
|
||||
lines.push('fan_count = ' + (node.config.fan_count || 3));
|
||||
} else if (node.type === 'condition') {
|
||||
lines.push('type = "conditional"');
|
||||
lines.push('expression = "' + (node.config.expression || '') + '"');
|
||||
} else if (node.type === 'loop') {
|
||||
lines.push('type = "loop"');
|
||||
lines.push('max_iterations = ' + (node.config.max_iterations || 5));
|
||||
if (node.config.until) lines.push('until = "' + node.config.until + '"');
|
||||
} else if (node.type === 'collect') {
|
||||
lines.push('type = "collect"');
|
||||
lines.push('strategy = "' + (node.config.strategy || 'all') + '"');
|
||||
}
|
||||
|
||||
// Find what this node connects to
|
||||
var outConns = self.connections.filter(function(c) { return c.from === node.id; });
|
||||
if (outConns.length === 1) {
|
||||
var target = self.getNode(outConns[0].to);
|
||||
if (target && target.type !== 'end') {
|
||||
lines.push('next = "' + target.label + '"');
|
||||
}
|
||||
} else if (outConns.length > 1 && node.type === 'condition') {
|
||||
for (var j = 0; j < outConns.length; j++) {
|
||||
var t2 = self.getNode(outConns[j].to);
|
||||
if (t2 && t2.type !== 'end') {
|
||||
var branchLabel = j === 0 ? 'true' : 'false';
|
||||
lines.push('next_' + branchLabel + ' = "' + t2.label + '"');
|
||||
}
|
||||
}
|
||||
} else if (outConns.length > 1 && node.type === 'parallel') {
|
||||
var targets = [];
|
||||
for (var k = 0; k < outConns.length; k++) {
|
||||
var t3 = self.getNode(outConns[k].to);
|
||||
if (t3 && t3.type !== 'end') targets.push('"' + t3.label + '"');
|
||||
}
|
||||
if (targets.length) lines.push('fan_targets = [' + targets.join(', ') + ']');
|
||||
}
|
||||
|
||||
lines.push('');
|
||||
}
|
||||
|
||||
this.tomlOutput = lines.join('\n');
|
||||
this.showTomlPreview = true;
|
||||
},
|
||||
|
||||
// ── Save Workflow ────────────────────────────────────
|
||||
|
||||
async saveWorkflow() {
|
||||
var steps = [];
|
||||
var stepNodes = this.nodes.filter(function(n) {
|
||||
return n.type !== 'start' && n.type !== 'end';
|
||||
});
|
||||
for (var i = 0; i < stepNodes.length; i++) {
|
||||
var node = stepNodes[i];
|
||||
var step = {
|
||||
name: node.label || 'step-' + (i + 1),
|
||||
mode: node.type === 'parallel' ? 'fan_out' : node.type === 'loop' ? 'loop' : 'sequential'
|
||||
};
|
||||
if (node.type === 'agent') {
|
||||
step.agent_name = node.config.agent_name || '';
|
||||
step.prompt = node.config.prompt || '{{input}}';
|
||||
}
|
||||
steps.push(step);
|
||||
}
|
||||
try {
|
||||
await OpenFangAPI.post('/api/workflows', {
|
||||
name: this.workflowName || 'untitled',
|
||||
description: this.workflowDescription || '',
|
||||
steps: steps
|
||||
});
|
||||
OpenFangToast.success('Workflow saved!');
|
||||
this.showSaveModal = false;
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to save: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
// ── Palette drop ─────────────────────────────────────
|
||||
|
||||
onPaletteDragStart: function(type, e) {
|
||||
e.dataTransfer.setData('text/plain', type);
|
||||
e.dataTransfer.effectAllowed = 'copy';
|
||||
},
|
||||
|
||||
onCanvasDrop: function(e) {
|
||||
e.preventDefault();
|
||||
var type = e.dataTransfer.getData('text/plain');
|
||||
if (!type) return;
|
||||
var rect = this._getCanvasRect();
|
||||
var x = (e.clientX - rect.left) / this.zoom - this.canvasOffset.x;
|
||||
var y = (e.clientY - rect.top) / this.zoom - this.canvasOffset.y;
|
||||
this.addNode(type, x - 90, y - 35);
|
||||
},
|
||||
|
||||
onCanvasDragOver: function(e) {
|
||||
e.preventDefault();
|
||||
e.dataTransfer.dropEffect = 'copy';
|
||||
},
|
||||
|
||||
// ── Auto Layout ──────────────────────────────────────
|
||||
|
||||
autoLayout: function() {
|
||||
// Simple top-to-bottom layout
|
||||
var y = 40;
|
||||
var x = 200;
|
||||
for (var i = 0; i < this.nodes.length; i++) {
|
||||
this.nodes[i].x = x;
|
||||
this.nodes[i].y = y;
|
||||
y += 120;
|
||||
}
|
||||
},
|
||||
|
||||
// ── Clear ────────────────────────────────────────────
|
||||
|
||||
clearCanvas: function() {
|
||||
this.nodes = [];
|
||||
this.connections = [];
|
||||
this.selectedNode = null;
|
||||
this.nextId = 1;
|
||||
this.addNode('start', 60, 200);
|
||||
},
|
||||
|
||||
// ── Zoom controls ────────────────────────────────────
|
||||
|
||||
zoomIn: function() {
|
||||
this.zoom = Math.min(2, this.zoom + 0.1);
|
||||
},
|
||||
|
||||
zoomOut: function() {
|
||||
this.zoom = Math.max(0.3, this.zoom - 0.1);
|
||||
},
|
||||
|
||||
zoomReset: function() {
|
||||
this.zoom = 1;
|
||||
this.canvasOffset = { x: 0, y: 0 };
|
||||
}
|
||||
};
|
||||
}
|
||||
79
crates/openfang-api/static/js/pages/workflows.js
Normal file
79
crates/openfang-api/static/js/pages/workflows.js
Normal file
@@ -0,0 +1,79 @@
|
||||
// OpenFang Workflows Page — Workflow builder + run history
|
||||
'use strict';
|
||||
|
||||
function workflowsPage() {
|
||||
return {
|
||||
// -- Workflows state --
|
||||
workflows: [],
|
||||
showCreateModal: false,
|
||||
runModal: null,
|
||||
runInput: '',
|
||||
runResult: '',
|
||||
running: false,
|
||||
loading: true,
|
||||
loadError: '',
|
||||
newWf: { name: '', description: '', steps: [{ name: '', agent_name: '', mode: 'sequential', prompt: '{{input}}' }] },
|
||||
|
||||
// -- Workflows methods --
|
||||
async loadWorkflows() {
|
||||
this.loading = true;
|
||||
this.loadError = '';
|
||||
try {
|
||||
this.workflows = await OpenFangAPI.get('/api/workflows');
|
||||
} catch(e) {
|
||||
this.workflows = [];
|
||||
this.loadError = e.message || 'Could not load workflows.';
|
||||
}
|
||||
this.loading = false;
|
||||
},
|
||||
|
||||
async loadData() { return this.loadWorkflows(); },
|
||||
|
||||
async createWorkflow() {
|
||||
var steps = this.newWf.steps.map(function(s) {
|
||||
return { name: s.name || 'step', agent_name: s.agent_name, mode: s.mode, prompt: s.prompt || '{{input}}' };
|
||||
});
|
||||
try {
|
||||
var wfName = this.newWf.name;
|
||||
await OpenFangAPI.post('/api/workflows', { name: wfName, description: this.newWf.description, steps: steps });
|
||||
this.showCreateModal = false;
|
||||
this.newWf = { name: '', description: '', steps: [{ name: '', agent_name: '', mode: 'sequential', prompt: '{{input}}' }] };
|
||||
OpenFangToast.success('Workflow "' + wfName + '" created');
|
||||
await this.loadWorkflows();
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to create workflow: ' + e.message);
|
||||
}
|
||||
},
|
||||
|
||||
showRunModal(wf) {
|
||||
this.runModal = wf;
|
||||
this.runInput = '';
|
||||
this.runResult = '';
|
||||
},
|
||||
|
||||
async executeWorkflow() {
|
||||
if (!this.runModal) return;
|
||||
this.running = true;
|
||||
this.runResult = '';
|
||||
try {
|
||||
var res = await OpenFangAPI.post('/api/workflows/' + this.runModal.id + '/run', { input: this.runInput });
|
||||
this.runResult = res.output || JSON.stringify(res, null, 2);
|
||||
OpenFangToast.success('Workflow completed');
|
||||
} catch(e) {
|
||||
this.runResult = 'Error: ' + e.message;
|
||||
OpenFangToast.error('Workflow failed: ' + e.message);
|
||||
}
|
||||
this.running = false;
|
||||
},
|
||||
|
||||
async viewRuns(wf) {
|
||||
try {
|
||||
var runs = await OpenFangAPI.get('/api/workflows/' + wf.id + '/runs');
|
||||
this.runResult = JSON.stringify(runs, null, 2);
|
||||
this.runModal = wf;
|
||||
} catch(e) {
|
||||
OpenFangToast.error('Failed to load run history: ' + e.message);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
BIN
crates/openfang-api/static/logo.png
Normal file
BIN
crates/openfang-api/static/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 9.8 KiB |
5
crates/openfang-api/static/vendor/alpine.min.js
vendored
Normal file
5
crates/openfang-api/static/vendor/alpine.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
10
crates/openfang-api/static/vendor/github-dark.min.css
vendored
Normal file
10
crates/openfang-api/static/vendor/github-dark.min.css
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
pre code.hljs{display:block;overflow-x:auto;padding:1em}code.hljs{padding:3px 5px}/*!
|
||||
Theme: GitHub Dark
|
||||
Description: Dark theme as seen on github.com
|
||||
Author: github.com
|
||||
Maintainer: @Hirse
|
||||
Updated: 2021-05-15
|
||||
|
||||
Outdated base version: https://github.com/primer/github-syntax-dark
|
||||
Current colors taken from GitHub's CSS
|
||||
*/.hljs{color:#c9d1d9;background:#0d1117}.hljs-doctag,.hljs-keyword,.hljs-meta .hljs-keyword,.hljs-template-tag,.hljs-template-variable,.hljs-type,.hljs-variable.language_{color:#ff7b72}.hljs-title,.hljs-title.class_,.hljs-title.class_.inherited__,.hljs-title.function_{color:#d2a8ff}.hljs-attr,.hljs-attribute,.hljs-literal,.hljs-meta,.hljs-number,.hljs-operator,.hljs-selector-attr,.hljs-selector-class,.hljs-selector-id,.hljs-variable{color:#79c0ff}.hljs-meta .hljs-string,.hljs-regexp,.hljs-string{color:#a5d6ff}.hljs-built_in,.hljs-symbol{color:#ffa657}.hljs-code,.hljs-comment,.hljs-formula{color:#8b949e}.hljs-name,.hljs-quote,.hljs-selector-pseudo,.hljs-selector-tag{color:#7ee787}.hljs-subst{color:#c9d1d9}.hljs-section{color:#1f6feb;font-weight:700}.hljs-bullet{color:#f2cc60}.hljs-emphasis{color:#c9d1d9;font-style:italic}.hljs-strong{color:#c9d1d9;font-weight:700}.hljs-addition{color:#aff5b4;background-color:#033a16}.hljs-deletion{color:#ffdcd7;background-color:#67060c}
|
||||
1244
crates/openfang-api/static/vendor/highlight.min.js
vendored
Normal file
1244
crates/openfang-api/static/vendor/highlight.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
69
crates/openfang-api/static/vendor/marked.min.js
vendored
Normal file
69
crates/openfang-api/static/vendor/marked.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
854
crates/openfang-api/tests/api_integration_test.rs
Normal file
854
crates/openfang-api/tests/api_integration_test.rs
Normal file
@@ -0,0 +1,854 @@
|
||||
//! Real HTTP integration tests for the OpenFang API.
|
||||
//!
|
||||
//! These tests boot a real kernel, start a real axum HTTP server on a random
|
||||
//! port, and hit actual endpoints with reqwest. No mocking.
|
||||
//!
|
||||
//! Tests that require an LLM API call are gated behind GROQ_API_KEY.
|
||||
//!
|
||||
//! Run: cargo test -p openfang-api --test api_integration_test -- --nocapture
|
||||
|
||||
use axum::Router;
|
||||
use openfang_api::middleware;
|
||||
use openfang_api::routes::{self, AppState};
|
||||
use openfang_api::ws;
|
||||
use openfang_kernel::OpenFangKernel;
|
||||
use openfang_types::config::{DefaultModelConfig, KernelConfig};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test infrastructure
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct TestServer {
|
||||
base_url: String,
|
||||
state: Arc<AppState>,
|
||||
_tmp: tempfile::TempDir,
|
||||
}
|
||||
|
||||
impl Drop for TestServer {
|
||||
fn drop(&mut self) {
|
||||
self.state.kernel.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a test server using ollama as default provider (no API key needed).
|
||||
/// This lets the kernel boot without any real LLM credentials.
|
||||
/// Tests that need actual LLM calls should use `start_test_server_with_llm()`.
|
||||
async fn start_test_server() -> TestServer {
|
||||
start_test_server_with_provider("ollama", "test-model", "OLLAMA_API_KEY").await
|
||||
}
|
||||
|
||||
/// Start a test server with Groq as the LLM provider (requires GROQ_API_KEY).
|
||||
async fn start_test_server_with_llm() -> TestServer {
|
||||
start_test_server_with_provider("groq", "llama-3.3-70b-versatile", "GROQ_API_KEY").await
|
||||
}
|
||||
|
||||
async fn start_test_server_with_provider(
|
||||
provider: &str,
|
||||
model: &str,
|
||||
api_key_env: &str,
|
||||
) -> TestServer {
|
||||
let tmp = tempfile::tempdir().expect("Failed to create temp dir");
|
||||
|
||||
let config = KernelConfig {
|
||||
home_dir: tmp.path().to_path_buf(),
|
||||
data_dir: tmp.path().join("data"),
|
||||
default_model: DefaultModelConfig {
|
||||
provider: provider.to_string(),
|
||||
model: model.to_string(),
|
||||
api_key_env: api_key_env.to_string(),
|
||||
base_url: None,
|
||||
},
|
||||
..KernelConfig::default()
|
||||
};
|
||||
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
let kernel = Arc::new(kernel);
|
||||
kernel.set_self_handle();
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
kernel,
|
||||
started_at: Instant::now(),
|
||||
peer_registry: None,
|
||||
bridge_manager: tokio::sync::Mutex::new(None),
|
||||
channels_config: tokio::sync::RwLock::new(Default::default()),
|
||||
shutdown_notify: Arc::new(tokio::sync::Notify::new()),
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/api/health", axum::routing::get(routes::health))
|
||||
.route("/api/status", axum::routing::get(routes::status))
|
||||
.route(
|
||||
"/api/agents",
|
||||
axum::routing::get(routes::list_agents).post(routes::spawn_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/message",
|
||||
axum::routing::post(routes::send_message),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/session",
|
||||
axum::routing::get(routes::get_agent_session),
|
||||
)
|
||||
.route("/api/agents/{id}/ws", axum::routing::get(ws::agent_ws))
|
||||
.route(
|
||||
"/api/agents/{id}",
|
||||
axum::routing::delete(routes::kill_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/triggers",
|
||||
axum::routing::get(routes::list_triggers).post(routes::create_trigger),
|
||||
)
|
||||
.route(
|
||||
"/api/triggers/{id}",
|
||||
axum::routing::delete(routes::delete_trigger),
|
||||
)
|
||||
.route(
|
||||
"/api/workflows",
|
||||
axum::routing::get(routes::list_workflows).post(routes::create_workflow),
|
||||
)
|
||||
.route(
|
||||
"/api/workflows/{id}/run",
|
||||
axum::routing::post(routes::run_workflow),
|
||||
)
|
||||
.route(
|
||||
"/api/workflows/{id}/runs",
|
||||
axum::routing::get(routes::list_workflow_runs),
|
||||
)
|
||||
.route("/api/shutdown", axum::routing::post(routes::shutdown))
|
||||
.layer(axum::middleware::from_fn(middleware::request_logging))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(CorsLayer::permissive())
|
||||
.with_state(state.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("Failed to bind test server");
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
TestServer {
|
||||
base_url: format!("http://{}", addr),
|
||||
state,
|
||||
_tmp: tmp,
|
||||
}
|
||||
}
|
||||
|
||||
/// Manifest that uses ollama (no API key required, won't make real LLM calls).
|
||||
const TEST_MANIFEST: &str = r#"
|
||||
name = "test-agent"
|
||||
version = "0.1.0"
|
||||
description = "Integration test agent"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "ollama"
|
||||
model = "test-model"
|
||||
system_prompt = "You are a test agent. Reply concisely."
|
||||
|
||||
[capabilities]
|
||||
tools = ["file_read"]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#;
|
||||
|
||||
/// Manifest that uses Groq for real LLM tests.
|
||||
const LLM_MANIFEST: &str = r#"
|
||||
name = "test-agent"
|
||||
version = "0.1.0"
|
||||
description = "Integration test agent"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "groq"
|
||||
model = "llama-3.3-70b-versatile"
|
||||
system_prompt = "You are a test agent. Reply concisely."
|
||||
|
||||
[capabilities]
|
||||
tools = ["file_read"]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_endpoint() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let resp = client
|
||||
.get(format!("{}/api/health", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), 200);
|
||||
|
||||
// Middleware injects x-request-id
|
||||
assert!(resp.headers().contains_key("x-request-id"));
|
||||
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
// Public health endpoint returns minimal info (redacted for security)
|
||||
assert_eq!(body["status"], "ok");
|
||||
assert!(body["version"].is_string());
|
||||
// Detailed fields should NOT appear in public health endpoint
|
||||
assert!(body["database"].is_null());
|
||||
assert!(body["agent_count"].is_null());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_status_endpoint() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let resp = client
|
||||
.get(format!("{}/api/status", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["status"], "running");
|
||||
assert_eq!(body["agent_count"], 0);
|
||||
assert!(body["uptime_seconds"].is_number());
|
||||
assert_eq!(body["default_provider"], "ollama");
|
||||
assert_eq!(body["agents"].as_array().unwrap().len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_spawn_list_kill_agent() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// --- Spawn ---
|
||||
let resp = client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": TEST_MANIFEST}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(resp.status(), 201);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["name"], "test-agent");
|
||||
let agent_id = body["agent_id"].as_str().unwrap().to_string();
|
||||
assert!(!agent_id.is_empty());
|
||||
|
||||
// --- List (1 agent) ---
|
||||
let resp = client
|
||||
.get(format!("{}/api/agents", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let agents: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert_eq!(agents.len(), 1);
|
||||
assert_eq!(agents[0]["name"], "test-agent");
|
||||
assert_eq!(agents[0]["id"], agent_id);
|
||||
assert_eq!(agents[0]["model_provider"], "ollama");
|
||||
|
||||
// --- Kill ---
|
||||
let resp = client
|
||||
.delete(format!("{}/api/agents/{}", server.base_url, agent_id))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["status"], "killed");
|
||||
|
||||
// --- List (empty) ---
|
||||
let resp = client
|
||||
.get(format!("{}/api/agents", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let agents: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert_eq!(agents.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_agent_session_empty() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Spawn agent
|
||||
let resp = client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": TEST_MANIFEST}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
let agent_id = body["agent_id"].as_str().unwrap();
|
||||
|
||||
// Session should be empty — no messages sent yet
|
||||
let resp = client
|
||||
.get(format!(
|
||||
"{}/api/agents/{}/session",
|
||||
server.base_url, agent_id
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["message_count"], 0);
|
||||
assert_eq!(body["messages"].as_array().unwrap().len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_send_message_with_llm() {
|
||||
if std::env::var("GROQ_API_KEY").is_err() {
|
||||
eprintln!("GROQ_API_KEY not set, skipping LLM integration test");
|
||||
return;
|
||||
}
|
||||
|
||||
let server = start_test_server_with_llm().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Spawn
|
||||
let resp = client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": LLM_MANIFEST}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
let agent_id = body["agent_id"].as_str().unwrap().to_string();
|
||||
|
||||
// Send message through the real HTTP endpoint → kernel → Groq LLM
|
||||
let resp = client
|
||||
.post(format!(
|
||||
"{}/api/agents/{}/message",
|
||||
server.base_url, agent_id
|
||||
))
|
||||
.json(&serde_json::json!({"message": "Say hello in exactly 3 words."}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
let response_text = body["response"].as_str().unwrap();
|
||||
assert!(
|
||||
!response_text.is_empty(),
|
||||
"LLM response should not be empty"
|
||||
);
|
||||
assert!(body["input_tokens"].as_u64().unwrap() > 0);
|
||||
assert!(body["output_tokens"].as_u64().unwrap() > 0);
|
||||
|
||||
// Session should now have messages
|
||||
let resp = client
|
||||
.get(format!(
|
||||
"{}/api/agents/{}/session",
|
||||
server.base_url, agent_id
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let session: serde_json::Value = resp.json().await.unwrap();
|
||||
assert!(session["message_count"].as_u64().unwrap() > 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_workflow_crud() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Spawn agent for workflow
|
||||
let resp = client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": TEST_MANIFEST}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
let agent_name = body["name"].as_str().unwrap().to_string();
|
||||
|
||||
// Create workflow
|
||||
let resp = client
|
||||
.post(format!("{}/api/workflows", server.base_url))
|
||||
.json(&serde_json::json!({
|
||||
"name": "test-workflow",
|
||||
"description": "Integration test workflow",
|
||||
"steps": [
|
||||
{
|
||||
"name": "step1",
|
||||
"agent_name": agent_name,
|
||||
"prompt": "Echo: {{input}}",
|
||||
"mode": "sequential",
|
||||
"timeout_secs": 30
|
||||
}
|
||||
]
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 201);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
let workflow_id = body["workflow_id"].as_str().unwrap().to_string();
|
||||
assert!(!workflow_id.is_empty());
|
||||
|
||||
// List workflows
|
||||
let resp = client
|
||||
.get(format!("{}/api/workflows", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let workflows: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert_eq!(workflows.len(), 1);
|
||||
assert_eq!(workflows[0]["name"], "test-workflow");
|
||||
assert_eq!(workflows[0]["steps"], 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_trigger_crud() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Spawn agent for trigger
|
||||
let resp = client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": TEST_MANIFEST}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
let agent_id = body["agent_id"].as_str().unwrap().to_string();
|
||||
|
||||
// Create trigger (Lifecycle pattern — simplest variant)
|
||||
let resp = client
|
||||
.post(format!("{}/api/triggers", server.base_url))
|
||||
.json(&serde_json::json!({
|
||||
"agent_id": agent_id,
|
||||
"pattern": "lifecycle",
|
||||
"prompt_template": "Handle: {{event}}",
|
||||
"max_fires": 5
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 201);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
let trigger_id = body["trigger_id"].as_str().unwrap().to_string();
|
||||
assert_eq!(body["agent_id"], agent_id);
|
||||
|
||||
// List triggers (unfiltered)
|
||||
let resp = client
|
||||
.get(format!("{}/api/triggers", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let triggers: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert_eq!(triggers.len(), 1);
|
||||
assert_eq!(triggers[0]["agent_id"], agent_id);
|
||||
assert_eq!(triggers[0]["enabled"], true);
|
||||
assert_eq!(triggers[0]["max_fires"], 5);
|
||||
|
||||
// List triggers (filtered by agent_id)
|
||||
let resp = client
|
||||
.get(format!(
|
||||
"{}/api/triggers?agent_id={}",
|
||||
server.base_url, agent_id
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let triggers: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert_eq!(triggers.len(), 1);
|
||||
|
||||
// Delete trigger
|
||||
let resp = client
|
||||
.delete(format!("{}/api/triggers/{}", server.base_url, trigger_id))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
|
||||
// List triggers (should be empty)
|
||||
let resp = client
|
||||
.get(format!("{}/api/triggers", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let triggers: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert_eq!(triggers.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_agent_id_returns_400() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Send message to invalid ID
|
||||
let resp = client
|
||||
.post(format!("{}/api/agents/not-a-uuid/message", server.base_url))
|
||||
.json(&serde_json::json!({"message": "hello"}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 400);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert!(body["error"].as_str().unwrap().contains("Invalid"));
|
||||
|
||||
// Kill invalid ID
|
||||
let resp = client
|
||||
.delete(format!("{}/api/agents/not-a-uuid", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 400);
|
||||
|
||||
// Session for invalid ID
|
||||
let resp = client
|
||||
.get(format!("{}/api/agents/not-a-uuid/session", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 400);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kill_nonexistent_agent_returns_404() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let fake_id = uuid::Uuid::new_v4();
|
||||
let resp = client
|
||||
.delete(format!("{}/api/agents/{}", server.base_url, fake_id))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 404);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_spawn_invalid_manifest_returns_400() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let resp = client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": "this is {{ not valid toml"}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 400);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert!(body["error"].as_str().unwrap().contains("Invalid manifest"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_request_id_header_is_uuid() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let resp = client
|
||||
.get(format!("{}/api/health", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let request_id = resp
|
||||
.headers()
|
||||
.get("x-request-id")
|
||||
.expect("x-request-id header should be present");
|
||||
let id_str = request_id.to_str().unwrap();
|
||||
assert!(
|
||||
uuid::Uuid::parse_str(id_str).is_ok(),
|
||||
"x-request-id should be a valid UUID, got: {}",
|
||||
id_str
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_agents_lifecycle() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Spawn 3 agents
|
||||
let mut ids = Vec::new();
|
||||
for i in 0..3 {
|
||||
let manifest = format!(
|
||||
r#"
|
||||
name = "agent-{i}"
|
||||
version = "0.1.0"
|
||||
description = "Multi-agent test {i}"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "ollama"
|
||||
model = "test-model"
|
||||
system_prompt = "Agent {i}."
|
||||
|
||||
[capabilities]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": manifest}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 201);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
ids.push(body["agent_id"].as_str().unwrap().to_string());
|
||||
}
|
||||
|
||||
// List should show 3
|
||||
let resp = client
|
||||
.get(format!("{}/api/agents", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let agents: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert_eq!(agents.len(), 3);
|
||||
|
||||
// Status should agree
|
||||
let resp = client
|
||||
.get(format!("{}/api/status", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let status: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(status["agent_count"], 3);
|
||||
|
||||
// Kill one
|
||||
let resp = client
|
||||
.delete(format!("{}/api/agents/{}", server.base_url, ids[1]))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
|
||||
// List should show 2
|
||||
let resp = client
|
||||
.get(format!("{}/api/agents", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let agents: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert_eq!(agents.len(), 2);
|
||||
|
||||
// Kill the rest
|
||||
for id in [&ids[0], &ids[2]] {
|
||||
client
|
||||
.delete(format!("{}/api/agents/{}", server.base_url, id))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// List should be empty
|
||||
let resp = client
|
||||
.get(format!("{}/api/agents", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let agents: Vec<serde_json::Value> = resp.json().await.unwrap();
|
||||
assert_eq!(agents.len(), 0);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Auth integration tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Start a test server with Bearer-token authentication enabled.
|
||||
async fn start_test_server_with_auth(api_key: &str) -> TestServer {
|
||||
let tmp = tempfile::tempdir().expect("Failed to create temp dir");
|
||||
|
||||
let config = KernelConfig {
|
||||
home_dir: tmp.path().to_path_buf(),
|
||||
data_dir: tmp.path().join("data"),
|
||||
api_key: api_key.to_string(),
|
||||
default_model: DefaultModelConfig {
|
||||
provider: "ollama".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
api_key_env: "OLLAMA_API_KEY".to_string(),
|
||||
base_url: None,
|
||||
},
|
||||
..KernelConfig::default()
|
||||
};
|
||||
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
let kernel = Arc::new(kernel);
|
||||
kernel.set_self_handle();
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
kernel,
|
||||
started_at: Instant::now(),
|
||||
peer_registry: None,
|
||||
bridge_manager: tokio::sync::Mutex::new(None),
|
||||
channels_config: tokio::sync::RwLock::new(Default::default()),
|
||||
shutdown_notify: Arc::new(tokio::sync::Notify::new()),
|
||||
});
|
||||
|
||||
let api_key_state = state.kernel.config.api_key.clone();
|
||||
|
||||
let app = Router::new()
|
||||
.route("/api/health", axum::routing::get(routes::health))
|
||||
.route("/api/status", axum::routing::get(routes::status))
|
||||
.route(
|
||||
"/api/agents",
|
||||
axum::routing::get(routes::list_agents).post(routes::spawn_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/message",
|
||||
axum::routing::post(routes::send_message),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/session",
|
||||
axum::routing::get(routes::get_agent_session),
|
||||
)
|
||||
.route("/api/agents/{id}/ws", axum::routing::get(ws::agent_ws))
|
||||
.route(
|
||||
"/api/agents/{id}",
|
||||
axum::routing::delete(routes::kill_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/triggers",
|
||||
axum::routing::get(routes::list_triggers).post(routes::create_trigger),
|
||||
)
|
||||
.route(
|
||||
"/api/triggers/{id}",
|
||||
axum::routing::delete(routes::delete_trigger),
|
||||
)
|
||||
.route(
|
||||
"/api/workflows",
|
||||
axum::routing::get(routes::list_workflows).post(routes::create_workflow),
|
||||
)
|
||||
.route(
|
||||
"/api/workflows/{id}/run",
|
||||
axum::routing::post(routes::run_workflow),
|
||||
)
|
||||
.route(
|
||||
"/api/workflows/{id}/runs",
|
||||
axum::routing::get(routes::list_workflow_runs),
|
||||
)
|
||||
.route("/api/shutdown", axum::routing::post(routes::shutdown))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
api_key_state,
|
||||
middleware::auth,
|
||||
))
|
||||
.layer(axum::middleware::from_fn(middleware::request_logging))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(CorsLayer::permissive())
|
||||
.with_state(state.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("Failed to bind test server");
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
TestServer {
|
||||
base_url: format!("http://{}", addr),
|
||||
state,
|
||||
_tmp: tmp,
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_health_is_public() {
|
||||
let server = start_test_server_with_auth("secret-key-123").await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// /api/health should be accessible without auth
|
||||
let resp = client
|
||||
.get(format!("{}/api/health", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_rejects_no_token() {
|
||||
let server = start_test_server_with_auth("secret-key-123").await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Protected endpoint without auth header → 401
|
||||
// Note: /api/status is public (dashboard needs it), so use a protected endpoint
|
||||
let resp = client
|
||||
.get(format!("{}/api/commands", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 401);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert!(body["error"].as_str().unwrap().contains("Missing"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_rejects_wrong_token() {
|
||||
let server = start_test_server_with_auth("secret-key-123").await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Wrong bearer token → 401
|
||||
// Note: /api/status is public (dashboard needs it), so use a protected endpoint
|
||||
let resp = client
|
||||
.get(format!("{}/api/commands", server.base_url))
|
||||
.header("authorization", "Bearer wrong-key")
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 401);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert!(body["error"].as_str().unwrap().contains("Invalid"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_accepts_correct_token() {
|
||||
let server = start_test_server_with_auth("secret-key-123").await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Correct bearer token → 200
|
||||
let resp = client
|
||||
.get(format!("{}/api/status", server.base_url))
|
||||
.header("authorization", "Bearer secret-key-123")
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["status"], "running");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_auth_disabled_when_no_key() {
|
||||
// Empty API key = auth disabled
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Protected endpoint accessible without auth when no key is configured
|
||||
let resp = client
|
||||
.get(format!("{}/api/status", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
}
|
||||
270
crates/openfang-api/tests/daemon_lifecycle_test.rs
Normal file
270
crates/openfang-api/tests/daemon_lifecycle_test.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
//! Daemon lifecycle integration tests.
|
||||
//!
|
||||
//! Tests the real daemon startup, PID file management, health serving,
|
||||
//! and graceful shutdown sequence.
|
||||
|
||||
use axum::Router;
|
||||
use openfang_api::middleware;
|
||||
use openfang_api::routes::{self, AppState};
|
||||
use openfang_api::server::{read_daemon_info, DaemonInfo};
|
||||
use openfang_kernel::OpenFangKernel;
|
||||
use openfang_types::config::{DefaultModelConfig, KernelConfig};
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Test DaemonInfo serialization and deserialization round-trip.
|
||||
#[test]
|
||||
fn test_daemon_info_serde_roundtrip() {
|
||||
let info = DaemonInfo {
|
||||
pid: 12345,
|
||||
listen_addr: "127.0.0.1:4200".to_string(),
|
||||
started_at: "2024-01-01T00:00:00Z".to_string(),
|
||||
version: "0.1.0".to_string(),
|
||||
platform: "linux".to_string(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string_pretty(&info).unwrap();
|
||||
let parsed: DaemonInfo = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(parsed.pid, 12345);
|
||||
assert_eq!(parsed.listen_addr, "127.0.0.1:4200");
|
||||
assert_eq!(parsed.version, "0.1.0");
|
||||
assert_eq!(parsed.platform, "linux");
|
||||
}
|
||||
|
||||
/// Test read_daemon_info from a file on disk.
|
||||
#[test]
|
||||
fn test_read_daemon_info_from_file() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
|
||||
// Write a daemon.json
|
||||
let info = DaemonInfo {
|
||||
pid: std::process::id(),
|
||||
listen_addr: "127.0.0.1:9999".to_string(),
|
||||
started_at: chrono::Utc::now().to_rfc3339(),
|
||||
version: "0.1.0".to_string(),
|
||||
platform: "test".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string_pretty(&info).unwrap();
|
||||
std::fs::write(tmp.path().join("daemon.json"), json).unwrap();
|
||||
|
||||
// Read it back
|
||||
let loaded = read_daemon_info(tmp.path());
|
||||
assert!(loaded.is_some());
|
||||
let loaded = loaded.unwrap();
|
||||
assert_eq!(loaded.pid, std::process::id());
|
||||
assert_eq!(loaded.listen_addr, "127.0.0.1:9999");
|
||||
}
|
||||
|
||||
/// Test read_daemon_info returns None when file doesn't exist.
|
||||
#[test]
|
||||
fn test_read_daemon_info_missing_file() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let loaded = read_daemon_info(tmp.path());
|
||||
assert!(loaded.is_none());
|
||||
}
|
||||
|
||||
/// Test read_daemon_info returns None for corrupt JSON.
|
||||
#[test]
|
||||
fn test_read_daemon_info_corrupt_json() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
std::fs::write(tmp.path().join("daemon.json"), "not json at all").unwrap();
|
||||
let loaded = read_daemon_info(tmp.path());
|
||||
assert!(loaded.is_none());
|
||||
}
|
||||
|
||||
/// Test the full daemon lifecycle:
|
||||
/// 1. Boot kernel + start server on random port
|
||||
/// 2. Write daemon info file
|
||||
/// 3. Verify health endpoint
|
||||
/// 4. Verify daemon info file contents match
|
||||
/// 5. Shut down and verify cleanup
|
||||
#[tokio::test]
|
||||
async fn test_full_daemon_lifecycle() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let daemon_info_path = tmp.path().join("daemon.json");
|
||||
|
||||
let config = KernelConfig {
|
||||
home_dir: tmp.path().to_path_buf(),
|
||||
data_dir: tmp.path().join("data"),
|
||||
default_model: DefaultModelConfig {
|
||||
provider: "ollama".to_string(),
|
||||
model: "test".to_string(),
|
||||
api_key_env: "OLLAMA_API_KEY".to_string(),
|
||||
base_url: None,
|
||||
},
|
||||
..KernelConfig::default()
|
||||
};
|
||||
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
let kernel = Arc::new(kernel);
|
||||
kernel.set_self_handle();
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
kernel: kernel.clone(),
|
||||
started_at: Instant::now(),
|
||||
peer_registry: None,
|
||||
bridge_manager: tokio::sync::Mutex::new(None),
|
||||
channels_config: tokio::sync::RwLock::new(Default::default()),
|
||||
shutdown_notify: Arc::new(tokio::sync::Notify::new()),
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/api/health", axum::routing::get(routes::health))
|
||||
.route("/api/status", axum::routing::get(routes::status))
|
||||
.route("/api/shutdown", axum::routing::post(routes::shutdown))
|
||||
.layer(axum::middleware::from_fn(middleware::request_logging))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(CorsLayer::permissive())
|
||||
.with_state(state.clone());
|
||||
|
||||
// Bind to random port
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
// Spawn server
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
// Write daemon info file (like run_daemon does)
|
||||
let daemon_info = DaemonInfo {
|
||||
pid: std::process::id(),
|
||||
listen_addr: addr.to_string(),
|
||||
started_at: chrono::Utc::now().to_rfc3339(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
platform: std::env::consts::OS.to_string(),
|
||||
};
|
||||
let json = serde_json::to_string_pretty(&daemon_info).unwrap();
|
||||
std::fs::write(&daemon_info_path, &json).unwrap();
|
||||
|
||||
// --- Verify daemon info file ---
|
||||
assert!(daemon_info_path.exists());
|
||||
let loaded = read_daemon_info(tmp.path()).unwrap();
|
||||
assert_eq!(loaded.pid, std::process::id());
|
||||
assert_eq!(loaded.listen_addr, addr.to_string());
|
||||
|
||||
// --- Verify health endpoint ---
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.get(format!("http://{}/api/health", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["status"], "ok");
|
||||
|
||||
// --- Verify status endpoint ---
|
||||
let resp = client
|
||||
.get(format!("http://{}/api/status", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
let body: serde_json::Value = resp.json().await.unwrap();
|
||||
assert_eq!(body["status"], "running");
|
||||
|
||||
// --- Shutdown ---
|
||||
let resp = client
|
||||
.post(format!("http://{}/api/shutdown", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(resp.status(), 200);
|
||||
|
||||
// Clean up daemon info file (like run_daemon does)
|
||||
let _ = std::fs::remove_file(&daemon_info_path);
|
||||
assert!(!daemon_info_path.exists());
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
|
||||
/// Test that stale daemon info is detected when no process is running at that PID.
|
||||
#[test]
|
||||
fn test_stale_daemon_info_detection() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
|
||||
// Write daemon.json with a PID that almost certainly doesn't exist
|
||||
// (using a very high PID number)
|
||||
let info = DaemonInfo {
|
||||
pid: 99999999, // unlikely to be running
|
||||
listen_addr: "127.0.0.1:9999".to_string(),
|
||||
started_at: "2024-01-01T00:00:00Z".to_string(),
|
||||
version: "0.1.0".to_string(),
|
||||
platform: "test".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string_pretty(&info).unwrap();
|
||||
std::fs::write(tmp.path().join("daemon.json"), json).unwrap();
|
||||
|
||||
// read_daemon_info just reads the file — it doesn't check if the PID is alive
|
||||
// (that check happens in run_daemon). So the file is readable:
|
||||
let loaded = read_daemon_info(tmp.path());
|
||||
assert!(loaded.is_some());
|
||||
assert_eq!(loaded.unwrap().pid, 99999999);
|
||||
}
|
||||
|
||||
/// Test that the server starts and immediately responds to requests.
|
||||
#[tokio::test]
|
||||
async fn test_server_immediate_responsiveness() {
|
||||
let tmp = tempfile::tempdir().unwrap();
|
||||
let config = KernelConfig {
|
||||
home_dir: tmp.path().to_path_buf(),
|
||||
data_dir: tmp.path().join("data"),
|
||||
default_model: DefaultModelConfig {
|
||||
provider: "ollama".to_string(),
|
||||
model: "test".to_string(),
|
||||
api_key_env: "OLLAMA_API_KEY".to_string(),
|
||||
base_url: None,
|
||||
},
|
||||
..KernelConfig::default()
|
||||
};
|
||||
|
||||
let kernel = OpenFangKernel::boot_with_config(config).unwrap();
|
||||
let kernel = Arc::new(kernel);
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
kernel: kernel.clone(),
|
||||
started_at: Instant::now(),
|
||||
peer_registry: None,
|
||||
bridge_manager: tokio::sync::Mutex::new(None),
|
||||
channels_config: tokio::sync::RwLock::new(Default::default()),
|
||||
shutdown_notify: Arc::new(tokio::sync::Notify::new()),
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/api/health", axum::routing::get(routes::health))
|
||||
.with_state(state);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
// Hit health endpoint immediately — should respond fast
|
||||
let client = reqwest::Client::new();
|
||||
let start = Instant::now();
|
||||
let resp = client
|
||||
.get(format!("http://{}/api/health", addr))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
let latency = start.elapsed();
|
||||
|
||||
assert_eq!(resp.status(), 200);
|
||||
assert!(
|
||||
latency.as_millis() < 1000,
|
||||
"Health endpoint should respond in <1s, took {}ms",
|
||||
latency.as_millis()
|
||||
);
|
||||
|
||||
kernel.shutdown();
|
||||
}
|
||||
584
crates/openfang-api/tests/load_test.rs
Normal file
584
crates/openfang-api/tests/load_test.rs
Normal file
@@ -0,0 +1,584 @@
|
||||
//! Load & performance tests for the OpenFang API.
|
||||
//!
|
||||
//! Measures throughput under concurrent access: agent spawning, API endpoint
|
||||
//! latency, session management, and memory usage.
|
||||
//!
|
||||
//! Run: cargo test -p openfang-api --test load_test -- --nocapture
|
||||
|
||||
use axum::Router;
|
||||
use openfang_api::middleware;
|
||||
use openfang_api::routes::{self, AppState};
|
||||
use openfang_kernel::OpenFangKernel;
|
||||
use openfang_types::config::{DefaultModelConfig, KernelConfig};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tower_http::trace::TraceLayer;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Test infrastructure (mirrors api_integration_test.rs)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct TestServer {
|
||||
base_url: String,
|
||||
state: Arc<AppState>,
|
||||
_tmp: tempfile::TempDir,
|
||||
}
|
||||
|
||||
impl Drop for TestServer {
|
||||
fn drop(&mut self) {
|
||||
self.state.kernel.shutdown();
|
||||
}
|
||||
}
|
||||
|
||||
async fn start_test_server() -> TestServer {
|
||||
let tmp = tempfile::tempdir().expect("Failed to create temp dir");
|
||||
|
||||
let config = KernelConfig {
|
||||
home_dir: tmp.path().to_path_buf(),
|
||||
data_dir: tmp.path().join("data"),
|
||||
default_model: DefaultModelConfig {
|
||||
provider: "ollama".to_string(),
|
||||
model: "test-model".to_string(),
|
||||
api_key_env: "OLLAMA_API_KEY".to_string(),
|
||||
base_url: None,
|
||||
},
|
||||
..KernelConfig::default()
|
||||
};
|
||||
|
||||
let kernel = OpenFangKernel::boot_with_config(config).expect("Kernel should boot");
|
||||
let kernel = Arc::new(kernel);
|
||||
kernel.set_self_handle();
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
kernel,
|
||||
started_at: Instant::now(),
|
||||
peer_registry: None,
|
||||
bridge_manager: tokio::sync::Mutex::new(None),
|
||||
channels_config: tokio::sync::RwLock::new(Default::default()),
|
||||
shutdown_notify: Arc::new(tokio::sync::Notify::new()),
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/api/health", axum::routing::get(routes::health))
|
||||
.route("/api/status", axum::routing::get(routes::status))
|
||||
.route("/api/version", axum::routing::get(routes::version))
|
||||
.route(
|
||||
"/api/metrics",
|
||||
axum::routing::get(routes::prometheus_metrics),
|
||||
)
|
||||
.route(
|
||||
"/api/agents",
|
||||
axum::routing::get(routes::list_agents).post(routes::spawn_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}",
|
||||
axum::routing::get(routes::get_agent).delete(routes::kill_agent),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/session",
|
||||
axum::routing::get(routes::get_agent_session),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/session/reset",
|
||||
axum::routing::post(routes::reset_session),
|
||||
)
|
||||
.route(
|
||||
"/api/agents/{id}/sessions",
|
||||
axum::routing::get(routes::list_agent_sessions).post(routes::create_agent_session),
|
||||
)
|
||||
.route("/api/tools", axum::routing::get(routes::list_tools))
|
||||
.route("/api/models", axum::routing::get(routes::list_models))
|
||||
.route("/api/providers", axum::routing::get(routes::list_providers))
|
||||
.route("/api/usage", axum::routing::get(routes::usage_stats))
|
||||
.route(
|
||||
"/api/workflows",
|
||||
axum::routing::get(routes::list_workflows).post(routes::create_workflow),
|
||||
)
|
||||
.route(
|
||||
"/api/workflows/{id}/run",
|
||||
axum::routing::post(routes::run_workflow),
|
||||
)
|
||||
.route("/api/config", axum::routing::get(routes::get_config))
|
||||
.layer(axum::middleware::from_fn(middleware::request_logging))
|
||||
.layer(TraceLayer::new_for_http())
|
||||
.layer(CorsLayer::permissive())
|
||||
.with_state(state.clone());
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.expect("Failed to bind test server");
|
||||
let addr = listener.local_addr().unwrap();
|
||||
|
||||
tokio::spawn(async move {
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
});
|
||||
|
||||
TestServer {
|
||||
base_url: format!("http://{}", addr),
|
||||
state,
|
||||
_tmp: tmp,
|
||||
}
|
||||
}
|
||||
|
||||
const TEST_MANIFEST: &str = r#"
|
||||
name = "load-test-agent"
|
||||
version = "0.1.0"
|
||||
description = "Load test agent"
|
||||
author = "test"
|
||||
module = "builtin:chat"
|
||||
|
||||
[model]
|
||||
provider = "ollama"
|
||||
model = "test-model"
|
||||
system_prompt = "You are a test agent."
|
||||
|
||||
[capabilities]
|
||||
tools = ["file_read"]
|
||||
memory_read = ["*"]
|
||||
memory_write = ["self.*"]
|
||||
"#;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Load tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Test: Concurrent agent spawns — verify kernel handles parallel agent creation.
|
||||
#[tokio::test]
|
||||
async fn load_concurrent_agent_spawns() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
let n = 20; // 20 concurrent spawns
|
||||
|
||||
let start = Instant::now();
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for i in 0..n {
|
||||
let c = client.clone();
|
||||
let url = format!("{}/api/agents", server.base_url);
|
||||
let manifest = TEST_MANIFEST.replace("load-test-agent", &format!("load-agent-{i}"));
|
||||
handles.push(tokio::spawn(async move {
|
||||
let res = c
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({"manifest_toml": manifest}))
|
||||
.send()
|
||||
.await
|
||||
.expect("request failed");
|
||||
(res.status().as_u16(), i)
|
||||
}));
|
||||
}
|
||||
|
||||
let mut success = 0;
|
||||
for h in handles {
|
||||
let (status, _i) = h.await.unwrap();
|
||||
if status == 200 || status == 201 {
|
||||
success += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
eprintln!(
|
||||
" [LOAD] Concurrent spawns: {success}/{n} succeeded in {:.0}ms ({:.0} spawns/sec)",
|
||||
elapsed.as_millis(),
|
||||
n as f64 / elapsed.as_secs_f64()
|
||||
);
|
||||
assert!(success >= n - 2, "Most agents should spawn successfully");
|
||||
|
||||
// Verify via list
|
||||
let agents: serde_json::Value = client
|
||||
.get(format!("{}/api/agents", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
let count = agents.as_array().map(|a| a.len()).unwrap_or(0);
|
||||
eprintln!(" [LOAD] Total agents after spawn: {count}");
|
||||
assert!(count >= success);
|
||||
}
|
||||
|
||||
/// Test: API endpoint latency — measure p50/p95/p99 for health, status, list agents.
|
||||
#[tokio::test]
|
||||
async fn load_endpoint_latency() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Spawn a few agents for the list endpoint to return
|
||||
for i in 0..5 {
|
||||
let manifest = TEST_MANIFEST.replace("load-test-agent", &format!("latency-agent-{i}"));
|
||||
client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": manifest}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let endpoints = vec![
|
||||
("GET", "/api/health"),
|
||||
("GET", "/api/status"),
|
||||
("GET", "/api/agents"),
|
||||
("GET", "/api/tools"),
|
||||
("GET", "/api/models"),
|
||||
("GET", "/api/metrics"),
|
||||
("GET", "/api/config"),
|
||||
("GET", "/api/usage"),
|
||||
];
|
||||
|
||||
for (method, path) in &endpoints {
|
||||
let mut latencies = Vec::new();
|
||||
let n = 100;
|
||||
|
||||
for _ in 0..n {
|
||||
let start = Instant::now();
|
||||
let url = format!("{}{}", server.base_url, path);
|
||||
let res = match *method {
|
||||
"GET" => client.get(&url).send().await,
|
||||
_ => client.post(&url).send().await,
|
||||
};
|
||||
let elapsed = start.elapsed();
|
||||
assert!(res.is_ok(), "{method} {path} failed");
|
||||
latencies.push(elapsed);
|
||||
}
|
||||
|
||||
latencies.sort();
|
||||
let p50 = latencies[n / 2];
|
||||
let p95 = latencies[(n as f64 * 0.95) as usize];
|
||||
let p99 = latencies[(n as f64 * 0.99) as usize];
|
||||
|
||||
eprintln!(
|
||||
" [LOAD] {method} {path:30} p50={:>5.1}ms p95={:>5.1}ms p99={:>5.1}ms",
|
||||
p50.as_secs_f64() * 1000.0,
|
||||
p95.as_secs_f64() * 1000.0,
|
||||
p99.as_secs_f64() * 1000.0,
|
||||
);
|
||||
|
||||
// p99 should be under 100ms for read endpoints
|
||||
assert!(
|
||||
p99 < Duration::from_millis(500),
|
||||
"{method} {path} p99 too high: {p99:?}"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test: Concurrent reads — many clients hitting the same endpoints simultaneously.
|
||||
#[tokio::test]
|
||||
async fn load_concurrent_reads() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Spawn some agents first
|
||||
for i in 0..3 {
|
||||
let manifest = TEST_MANIFEST.replace("load-test-agent", &format!("concurrent-agent-{i}"));
|
||||
client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": manifest}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let n = 50;
|
||||
let start = Instant::now();
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for i in 0..n {
|
||||
let c = client.clone();
|
||||
let base = server.base_url.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
// Cycle through different endpoints
|
||||
let path = match i % 4 {
|
||||
0 => "/api/health",
|
||||
1 => "/api/agents",
|
||||
2 => "/api/status",
|
||||
_ => "/api/metrics",
|
||||
};
|
||||
let res = c
|
||||
.get(format!("{base}{path}"))
|
||||
.send()
|
||||
.await
|
||||
.expect("request failed");
|
||||
res.status().as_u16()
|
||||
}));
|
||||
}
|
||||
|
||||
let mut success = 0;
|
||||
for h in handles {
|
||||
let status = h.await.unwrap();
|
||||
if status == 200 {
|
||||
success += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
eprintln!(
|
||||
" [LOAD] Concurrent reads: {success}/{n} succeeded in {:.0}ms ({:.0} req/sec)",
|
||||
elapsed.as_millis(),
|
||||
n as f64 / elapsed.as_secs_f64()
|
||||
);
|
||||
assert_eq!(success, n, "All concurrent reads should succeed");
|
||||
}
|
||||
|
||||
/// Test: Session management under load — create, list, and switch sessions.
|
||||
#[tokio::test]
|
||||
async fn load_session_management() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Spawn an agent
|
||||
let res: serde_json::Value = client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": TEST_MANIFEST}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
let agent_id = res["agent_id"].as_str().unwrap().to_string();
|
||||
|
||||
// Create multiple sessions
|
||||
let n = 10;
|
||||
let start = Instant::now();
|
||||
let mut session_ids = Vec::new();
|
||||
|
||||
for i in 0..n {
|
||||
let res: serde_json::Value = client
|
||||
.post(format!(
|
||||
"{}/api/agents/{}/sessions",
|
||||
server.base_url, agent_id
|
||||
))
|
||||
.json(&serde_json::json!({"label": format!("session-{i}")}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
if let Some(id) = res.get("session_id").and_then(|v| v.as_str()) {
|
||||
session_ids.push(id.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
eprintln!(
|
||||
" [LOAD] Created {n} sessions in {:.0}ms",
|
||||
elapsed.as_millis()
|
||||
);
|
||||
|
||||
// List sessions
|
||||
let start = Instant::now();
|
||||
let sessions_resp: serde_json::Value = client
|
||||
.get(format!(
|
||||
"{}/api/agents/{}/sessions",
|
||||
server.base_url, agent_id
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
// Response is {"sessions": [...]} — extract the array
|
||||
let session_count = sessions_resp
|
||||
.get("sessions")
|
||||
.and_then(|v| v.as_array())
|
||||
.map(|a| a.len())
|
||||
.unwrap_or_else(|| {
|
||||
// Fallback: maybe it's a direct array
|
||||
sessions_resp.as_array().map(|a| a.len()).unwrap_or(0)
|
||||
});
|
||||
eprintln!(
|
||||
" [LOAD] Listed {session_count} sessions in {:.1}ms",
|
||||
start.elapsed().as_secs_f64() * 1000.0
|
||||
);
|
||||
|
||||
// We expect at least some sessions (the original + our new ones)
|
||||
// Note: create_session might fail silently for some if agent was spawned without session
|
||||
eprintln!(" [LOAD] Session IDs collected: {}", session_ids.len());
|
||||
assert!(
|
||||
!session_ids.is_empty() || session_count > 0,
|
||||
"Should have created some sessions"
|
||||
);
|
||||
|
||||
// Switch between sessions rapidly
|
||||
let start = Instant::now();
|
||||
for sid in &session_ids {
|
||||
client
|
||||
.post(format!(
|
||||
"{}/api/agents/{}/sessions/{}/switch",
|
||||
server.base_url, agent_id, sid
|
||||
))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
eprintln!(
|
||||
" [LOAD] Switched through {} sessions in {:.0}ms",
|
||||
session_ids.len(),
|
||||
start.elapsed().as_millis()
|
||||
);
|
||||
}
|
||||
|
||||
/// Test: Workflow creation and listing under load.
|
||||
#[tokio::test]
|
||||
async fn load_workflow_operations() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let n = 15;
|
||||
let start = Instant::now();
|
||||
|
||||
// Create workflows concurrently
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..n {
|
||||
let c = client.clone();
|
||||
let url = format!("{}/api/workflows", server.base_url);
|
||||
handles.push(tokio::spawn(async move {
|
||||
let res = c
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({
|
||||
"name": format!("wf-{i}"),
|
||||
"description": format!("Load test workflow {i}"),
|
||||
"steps": [{
|
||||
"name": "step1",
|
||||
"agent_name": "test-agent",
|
||||
"mode": "sequential",
|
||||
"prompt": "{{input}}"
|
||||
}]
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.expect("request failed");
|
||||
res.status().as_u16()
|
||||
}));
|
||||
}
|
||||
|
||||
let mut created = 0;
|
||||
for h in handles {
|
||||
let status = h.await.unwrap();
|
||||
if status == 200 || status == 201 {
|
||||
created += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
eprintln!(
|
||||
" [LOAD] Created {created}/{n} workflows in {:.0}ms",
|
||||
elapsed.as_millis()
|
||||
);
|
||||
|
||||
// List all workflows
|
||||
let start = Instant::now();
|
||||
let workflows: serde_json::Value = client
|
||||
.get(format!("{}/api/workflows", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
let wf_count = workflows.as_array().map(|a| a.len()).unwrap_or(0);
|
||||
eprintln!(
|
||||
" [LOAD] Listed {wf_count} workflows in {:.1}ms",
|
||||
start.elapsed().as_secs_f64() * 1000.0
|
||||
);
|
||||
assert!(wf_count >= created);
|
||||
}
|
||||
|
||||
/// Test: Agent spawn + kill cycle — stress the registry.
|
||||
#[tokio::test]
|
||||
async fn load_spawn_kill_cycle() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
let cycles = 10;
|
||||
let start = Instant::now();
|
||||
let mut ids = Vec::new();
|
||||
|
||||
// Spawn
|
||||
for i in 0..cycles {
|
||||
let manifest = TEST_MANIFEST.replace("load-test-agent", &format!("cycle-agent-{i}"));
|
||||
let res: serde_json::Value = client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": manifest}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
if let Some(id) = res.get("agent_id").and_then(|v| v.as_str()) {
|
||||
ids.push(id.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// Kill
|
||||
for id in &ids {
|
||||
client
|
||||
.delete(format!("{}/api/agents/{}", server.base_url, id))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
eprintln!(
|
||||
" [LOAD] Spawn+kill {cycles} agents in {:.0}ms ({:.0}ms per cycle)",
|
||||
elapsed.as_millis(),
|
||||
elapsed.as_millis() as f64 / cycles as f64
|
||||
);
|
||||
|
||||
// Verify all cleaned up
|
||||
let agents: serde_json::Value = client
|
||||
.get(format!("{}/api/agents", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap()
|
||||
.json()
|
||||
.await
|
||||
.unwrap();
|
||||
let remaining = agents.as_array().map(|a| a.len()).unwrap_or(0);
|
||||
assert_eq!(remaining, 0, "All agents should be killed");
|
||||
}
|
||||
|
||||
/// Test: Prometheus metrics endpoint under sustained load.
|
||||
#[tokio::test]
|
||||
async fn load_metrics_sustained() {
|
||||
let server = start_test_server().await;
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Spawn a few agents first so metrics have data
|
||||
for i in 0..3 {
|
||||
let manifest = TEST_MANIFEST.replace("load-test-agent", &format!("metrics-agent-{i}"));
|
||||
client
|
||||
.post(format!("{}/api/agents", server.base_url))
|
||||
.json(&serde_json::json!({"manifest_toml": manifest}))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Hit metrics endpoint 200 times
|
||||
let n = 200;
|
||||
let start = Instant::now();
|
||||
for _ in 0..n {
|
||||
let res = client
|
||||
.get(format!("{}/api/metrics", server.base_url))
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(res.status().as_u16(), 200);
|
||||
let body = res.text().await.unwrap();
|
||||
assert!(body.contains("openfang_agents_active"));
|
||||
}
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
eprintln!(
|
||||
" [LOAD] Metrics {n} requests in {:.0}ms ({:.0} req/sec, {:.1}ms avg)",
|
||||
elapsed.as_millis(),
|
||||
n as f64 / elapsed.as_secs_f64(),
|
||||
elapsed.as_secs_f64() * 1000.0 / n as f64
|
||||
);
|
||||
}
|
||||
36
crates/openfang-channels/Cargo.toml
Normal file
36
crates/openfang-channels/Cargo.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[package]
|
||||
name = "openfang-channels"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Channel Bridge Layer — pluggable messaging integrations for OpenFang"
|
||||
|
||||
[dependencies]
|
||||
openfang-types = { path = "../openfang-types" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
dashmap = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
url = { workspace = true }
|
||||
zeroize = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
hmac = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
hex = { workspace = true }
|
||||
|
||||
lettre = { workspace = true }
|
||||
imap = { workspace = true }
|
||||
native-tls = { workspace = true }
|
||||
mailparse = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
694
crates/openfang-channels/src/bluesky.rs
Normal file
694
crates/openfang-channels/src/bluesky.rs
Normal file
@@ -0,0 +1,694 @@
|
||||
//! AT Protocol (Bluesky) channel adapter.
|
||||
//!
|
||||
//! Uses the AT Protocol (atproto) XRPC API for authentication, posting, and
|
||||
//! polling notifications. Session creation uses `com.atproto.server.createSession`
|
||||
//! with identifier + app password. Posts are created via
|
||||
//! `com.atproto.repo.createRecord` with the `app.bsky.feed.post` lexicon.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Default Bluesky PDS service URL.
|
||||
const DEFAULT_SERVICE_URL: &str = "https://bsky.social";
|
||||
|
||||
/// Maximum Bluesky post length (grapheme clusters).
|
||||
const MAX_MESSAGE_LEN: usize = 300;
|
||||
|
||||
/// Notification poll interval in seconds.
|
||||
const POLL_INTERVAL_SECS: u64 = 5;
|
||||
|
||||
/// Session refresh buffer — refresh 5 minutes before actual expiry.
|
||||
const SESSION_REFRESH_BUFFER_SECS: u64 = 300;
|
||||
|
||||
/// AT Protocol (Bluesky) adapter.
|
||||
///
|
||||
/// Inbound mentions are received by polling the `app.bsky.notification.listNotifications`
|
||||
/// endpoint. Outbound posts are created via `com.atproto.repo.createRecord` with
|
||||
/// the `app.bsky.feed.post` record type. Session tokens are cached and refreshed
|
||||
/// automatically.
|
||||
pub struct BlueskyAdapter {
|
||||
/// AT Protocol identifier (handle or DID, e.g., "alice.bsky.social").
|
||||
identifier: String,
|
||||
/// SECURITY: App password for session creation, zeroized on drop.
|
||||
app_password: Zeroizing<String>,
|
||||
/// PDS service URL (default: `"https://bsky.social"`).
|
||||
service_url: String,
|
||||
/// HTTP client for API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Cached session (access_jwt, refresh_jwt, did, expiry).
|
||||
session: Arc<RwLock<Option<BlueskySession>>>,
|
||||
}
|
||||
|
||||
/// Cached Bluesky session data.
|
||||
struct BlueskySession {
|
||||
/// JWT access token for authenticated requests.
|
||||
access_jwt: String,
|
||||
/// JWT refresh token for session renewal.
|
||||
refresh_jwt: String,
|
||||
/// The DID of the authenticated account.
|
||||
did: String,
|
||||
/// When this session was created (for expiry tracking).
|
||||
created_at: Instant,
|
||||
}
|
||||
|
||||
impl BlueskyAdapter {
|
||||
/// Create a new Bluesky adapter with the default service URL.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `identifier` - AT Protocol handle (e.g., "alice.bsky.social") or DID.
|
||||
/// * `app_password` - App password (not the main account password).
|
||||
pub fn new(identifier: String, app_password: String) -> Self {
|
||||
Self::with_service_url(identifier, app_password, DEFAULT_SERVICE_URL.to_string())
|
||||
}
|
||||
|
||||
/// Create a new Bluesky adapter with a custom PDS service URL.
|
||||
pub fn with_service_url(identifier: String, app_password: String, service_url: String) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let service_url = service_url.trim_end_matches('/').to_string();
|
||||
Self {
|
||||
identifier,
|
||||
app_password: Zeroizing::new(app_password),
|
||||
service_url,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
session: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new session via `com.atproto.server.createSession`.
|
||||
async fn create_session(&self) -> Result<BlueskySession, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/xrpc/com.atproto.server.createSession", self.service_url);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"identifier": self.identifier,
|
||||
"password": self.app_password.as_str(),
|
||||
});
|
||||
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Bluesky createSession failed {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
let resp_body: serde_json::Value = resp.json().await?;
|
||||
let access_jwt = resp_body["accessJwt"]
|
||||
.as_str()
|
||||
.ok_or("Missing accessJwt")?
|
||||
.to_string();
|
||||
let refresh_jwt = resp_body["refreshJwt"]
|
||||
.as_str()
|
||||
.ok_or("Missing refreshJwt")?
|
||||
.to_string();
|
||||
let did = resp_body["did"].as_str().ok_or("Missing did")?.to_string();
|
||||
|
||||
Ok(BlueskySession {
|
||||
access_jwt,
|
||||
refresh_jwt,
|
||||
did,
|
||||
created_at: Instant::now(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Refresh an existing session via `com.atproto.server.refreshSession`.
|
||||
async fn refresh_session(
|
||||
&self,
|
||||
refresh_jwt: &str,
|
||||
) -> Result<BlueskySession, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/xrpc/com.atproto.server.refreshSession",
|
||||
self.service_url
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(refresh_jwt)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
// Refresh failed, create new session
|
||||
return self.create_session().await;
|
||||
}
|
||||
|
||||
let resp_body: serde_json::Value = resp.json().await?;
|
||||
let access_jwt = resp_body["accessJwt"]
|
||||
.as_str()
|
||||
.ok_or("Missing accessJwt")?
|
||||
.to_string();
|
||||
let new_refresh_jwt = resp_body["refreshJwt"]
|
||||
.as_str()
|
||||
.ok_or("Missing refreshJwt")?
|
||||
.to_string();
|
||||
let did = resp_body["did"].as_str().ok_or("Missing did")?.to_string();
|
||||
|
||||
Ok(BlueskySession {
|
||||
access_jwt,
|
||||
refresh_jwt: new_refresh_jwt,
|
||||
did,
|
||||
created_at: Instant::now(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid access JWT, creating or refreshing the session as needed.
|
||||
async fn get_token(&self) -> Result<(String, String), Box<dyn std::error::Error>> {
|
||||
let guard = self.session.read().await;
|
||||
if let Some(ref session) = *guard {
|
||||
// Sessions last ~2 hours; refresh if older than 90 minutes
|
||||
if session.created_at.elapsed()
|
||||
< Duration::from_secs(5400 - SESSION_REFRESH_BUFFER_SECS)
|
||||
{
|
||||
return Ok((session.access_jwt.clone(), session.did.clone()));
|
||||
}
|
||||
let refresh_jwt = session.refresh_jwt.clone();
|
||||
drop(guard);
|
||||
|
||||
let new_session = self.refresh_session(&refresh_jwt).await?;
|
||||
let token = new_session.access_jwt.clone();
|
||||
let did = new_session.did.clone();
|
||||
*self.session.write().await = Some(new_session);
|
||||
return Ok((token, did));
|
||||
}
|
||||
drop(guard);
|
||||
|
||||
let session = self.create_session().await?;
|
||||
let token = session.access_jwt.clone();
|
||||
let did = session.did.clone();
|
||||
*self.session.write().await = Some(session);
|
||||
Ok((token, did))
|
||||
}
|
||||
|
||||
/// Validate credentials by creating a session.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let session = self.create_session().await?;
|
||||
let did = session.did.clone();
|
||||
*self.session.write().await = Some(session);
|
||||
Ok(did)
|
||||
}
|
||||
|
||||
/// Create a post (skeet) via `com.atproto.repo.createRecord`.
|
||||
async fn api_create_post(
|
||||
&self,
|
||||
text: &str,
|
||||
reply_ref: Option<&serde_json::Value>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (token, did) = self.get_token().await?;
|
||||
let url = format!("{}/xrpc/com.atproto.repo.createRecord", self.service_url);
|
||||
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let now = Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string();
|
||||
|
||||
let mut record = serde_json::json!({
|
||||
"$type": "app.bsky.feed.post",
|
||||
"text": chunk,
|
||||
"createdAt": now,
|
||||
});
|
||||
|
||||
if let Some(reply) = reply_ref {
|
||||
record["reply"] = reply.clone();
|
||||
}
|
||||
|
||||
let body = serde_json::json!({
|
||||
"repo": did,
|
||||
"collection": "app.bsky.feed.post",
|
||||
"record": record,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Bluesky createRecord error {status}: {resp_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Bluesky notification into a `ChannelMessage`.
|
||||
fn parse_bluesky_notification(
|
||||
notification: &serde_json::Value,
|
||||
own_did: &str,
|
||||
) -> Option<ChannelMessage> {
|
||||
let reason = notification["reason"].as_str().unwrap_or("");
|
||||
// We care about mentions and replies
|
||||
if reason != "mention" && reason != "reply" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let author = notification.get("author")?;
|
||||
let author_did = author["did"].as_str().unwrap_or("");
|
||||
// Skip own notifications
|
||||
if author_did == own_did {
|
||||
return None;
|
||||
}
|
||||
|
||||
let record = notification.get("record")?;
|
||||
let text = record["text"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let uri = notification["uri"].as_str().unwrap_or("").to_string();
|
||||
let cid = notification["cid"].as_str().unwrap_or("").to_string();
|
||||
let handle = author["handle"].as_str().unwrap_or("").to_string();
|
||||
let display_name = author["displayName"]
|
||||
.as_str()
|
||||
.unwrap_or(&handle)
|
||||
.to_string();
|
||||
let indexed_at = notification["indexedAt"].as_str().unwrap_or("").to_string();
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd_name = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("uri".to_string(), serde_json::Value::String(uri.clone()));
|
||||
metadata.insert("cid".to_string(), serde_json::Value::String(cid));
|
||||
metadata.insert("handle".to_string(), serde_json::Value::String(handle));
|
||||
metadata.insert(
|
||||
"reason".to_string(),
|
||||
serde_json::Value::String(reason.to_string()),
|
||||
);
|
||||
metadata.insert(
|
||||
"indexed_at".to_string(),
|
||||
serde_json::Value::String(indexed_at),
|
||||
);
|
||||
|
||||
// Extract reply reference if present
|
||||
if let Some(reply) = record.get("reply") {
|
||||
metadata.insert("reply_ref".to_string(), reply.clone());
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("bluesky".to_string()),
|
||||
platform_message_id: uri,
|
||||
sender: ChannelUser {
|
||||
platform_id: author_did.to_string(),
|
||||
display_name,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: false, // Bluesky mentions are treated as direct interactions
|
||||
thread_id: None,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for BlueskyAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"bluesky"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("bluesky".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let did = self.validate().await?;
|
||||
info!("Bluesky adapter authenticated as {did}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let service_url = self.service_url.clone();
|
||||
let session = Arc::clone(&self.session);
|
||||
let own_did = did;
|
||||
let client = self.client.clone();
|
||||
let identifier = self.identifier.clone();
|
||||
let app_password = self.app_password.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let poll_interval = Duration::from_secs(POLL_INTERVAL_SECS);
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
let mut last_seen_at: Option<String> = None;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Bluesky adapter shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(poll_interval) => {}
|
||||
}
|
||||
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Get current access token
|
||||
let token = {
|
||||
let guard = session.read().await;
|
||||
match &*guard {
|
||||
Some(s) => s.access_jwt.clone(),
|
||||
None => {
|
||||
// Re-create session
|
||||
drop(guard);
|
||||
let url =
|
||||
format!("{}/xrpc/com.atproto.server.createSession", service_url);
|
||||
let body = serde_json::json!({
|
||||
"identifier": identifier,
|
||||
"password": app_password.as_str(),
|
||||
});
|
||||
match client.post(&url).json(&body).send().await {
|
||||
Ok(resp) => {
|
||||
let resp_body: serde_json::Value =
|
||||
resp.json().await.unwrap_or_default();
|
||||
let tok =
|
||||
resp_body["accessJwt"].as_str().unwrap_or("").to_string();
|
||||
if tok.is_empty() {
|
||||
warn!("Bluesky: failed to create session");
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
tokio::time::sleep(backoff).await;
|
||||
continue;
|
||||
}
|
||||
let new_session = BlueskySession {
|
||||
access_jwt: tok.clone(),
|
||||
refresh_jwt: resp_body["refreshJwt"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
did: resp_body["did"].as_str().unwrap_or("").to_string(),
|
||||
created_at: Instant::now(),
|
||||
};
|
||||
*session.write().await = Some(new_session);
|
||||
tok
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Bluesky: session create error: {e}");
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
tokio::time::sleep(backoff).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Poll notifications
|
||||
let mut url = format!(
|
||||
"{}/xrpc/app.bsky.notification.listNotifications?limit=25",
|
||||
service_url
|
||||
);
|
||||
if let Some(ref seen) = last_seen_at {
|
||||
url.push_str(&format!("&seenAt={}", seen));
|
||||
}
|
||||
|
||||
let resp = match client.get(&url).bearer_auth(&token).send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Bluesky: notification fetch error: {e}");
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
warn!("Bluesky: notification fetch returned {}", resp.status());
|
||||
if resp.status().as_u16() == 401 {
|
||||
// Session expired, clear it so next iteration re-creates
|
||||
*session.write().await = None;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let body: serde_json::Value = match resp.json().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("Bluesky: failed to parse notifications: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let notifications = match body["notifications"].as_array() {
|
||||
Some(arr) => arr,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
for notif in notifications {
|
||||
// Track latest indexed_at
|
||||
if let Some(indexed) = notif["indexedAt"].as_str() {
|
||||
if last_seen_at
|
||||
.as_ref()
|
||||
.map(|s| indexed > s.as_str())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
last_seen_at = Some(indexed.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(msg) = parse_bluesky_notification(notif, &own_did) {
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update seen marker
|
||||
if last_seen_at.is_some() {
|
||||
let mark_url = format!("{}/xrpc/app.bsky.notification.updateSeen", service_url);
|
||||
let mark_body = serde_json::json!({
|
||||
"seenAt": Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string(),
|
||||
});
|
||||
let _ = client
|
||||
.post(&mark_url)
|
||||
.bearer_auth(&token)
|
||||
.json(&mark_body)
|
||||
.send()
|
||||
.await;
|
||||
}
|
||||
|
||||
backoff = Duration::from_secs(1);
|
||||
}
|
||||
|
||||
info!("Bluesky polling loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_create_post(&text, None).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_create_post("(Unsupported content type)", None)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Bluesky/AT Protocol does not support typing indicators
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_bluesky_adapter_creation() {
|
||||
let adapter = BlueskyAdapter::new(
|
||||
"alice.bsky.social".to_string(),
|
||||
"app-password-123".to_string(),
|
||||
);
|
||||
assert_eq!(adapter.name(), "bluesky");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("bluesky".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bluesky_default_service_url() {
|
||||
let adapter = BlueskyAdapter::new("alice.bsky.social".to_string(), "pwd".to_string());
|
||||
assert_eq!(adapter.service_url, "https://bsky.social");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bluesky_custom_service_url() {
|
||||
let adapter = BlueskyAdapter::with_service_url(
|
||||
"alice.example.com".to_string(),
|
||||
"pwd".to_string(),
|
||||
"https://pds.example.com/".to_string(),
|
||||
);
|
||||
assert_eq!(adapter.service_url, "https://pds.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bluesky_identifier_stored() {
|
||||
let adapter = BlueskyAdapter::new("did:plc:abc123".to_string(), "pwd".to_string());
|
||||
assert_eq!(adapter.identifier, "did:plc:abc123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_bluesky_notification_mention() {
|
||||
let notif = serde_json::json!({
|
||||
"uri": "at://did:plc:sender/app.bsky.feed.post/abc123",
|
||||
"cid": "bafyrei...",
|
||||
"author": {
|
||||
"did": "did:plc:sender",
|
||||
"handle": "alice.bsky.social",
|
||||
"displayName": "Alice"
|
||||
},
|
||||
"reason": "mention",
|
||||
"record": {
|
||||
"text": "@bot hello there!",
|
||||
"createdAt": "2024-01-01T00:00:00.000Z"
|
||||
},
|
||||
"indexedAt": "2024-01-01T00:00:01.000Z"
|
||||
});
|
||||
|
||||
let msg = parse_bluesky_notification(¬if, "did:plc:bot").unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Custom("bluesky".to_string()));
|
||||
assert_eq!(msg.sender.display_name, "Alice");
|
||||
assert_eq!(msg.sender.platform_id, "did:plc:sender");
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "@bot hello there!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_bluesky_notification_reply() {
|
||||
let notif = serde_json::json!({
|
||||
"uri": "at://did:plc:sender/app.bsky.feed.post/def456",
|
||||
"cid": "bafyrei...",
|
||||
"author": {
|
||||
"did": "did:plc:sender",
|
||||
"handle": "bob.bsky.social",
|
||||
"displayName": "Bob"
|
||||
},
|
||||
"reason": "reply",
|
||||
"record": {
|
||||
"text": "Nice post!",
|
||||
"createdAt": "2024-01-01T00:00:00.000Z",
|
||||
"reply": {
|
||||
"root": { "uri": "at://...", "cid": "..." },
|
||||
"parent": { "uri": "at://...", "cid": "..." }
|
||||
}
|
||||
},
|
||||
"indexedAt": "2024-01-01T00:00:01.000Z"
|
||||
});
|
||||
|
||||
let msg = parse_bluesky_notification(¬if, "did:plc:bot").unwrap();
|
||||
assert!(msg.metadata.contains_key("reply_ref"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_bluesky_notification_skips_own() {
|
||||
let notif = serde_json::json!({
|
||||
"uri": "at://did:plc:bot/app.bsky.feed.post/abc",
|
||||
"cid": "...",
|
||||
"author": {
|
||||
"did": "did:plc:bot",
|
||||
"handle": "bot.bsky.social"
|
||||
},
|
||||
"reason": "mention",
|
||||
"record": {
|
||||
"text": "self mention"
|
||||
},
|
||||
"indexedAt": "2024-01-01T00:00:00.000Z"
|
||||
});
|
||||
|
||||
assert!(parse_bluesky_notification(¬if, "did:plc:bot").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_bluesky_notification_skips_like() {
|
||||
let notif = serde_json::json!({
|
||||
"uri": "at://...",
|
||||
"cid": "...",
|
||||
"author": {
|
||||
"did": "did:plc:other",
|
||||
"handle": "other.bsky.social"
|
||||
},
|
||||
"reason": "like",
|
||||
"record": {},
|
||||
"indexedAt": "2024-01-01T00:00:00.000Z"
|
||||
});
|
||||
|
||||
assert!(parse_bluesky_notification(¬if, "did:plc:bot").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_bluesky_notification_command() {
|
||||
let notif = serde_json::json!({
|
||||
"uri": "at://did:plc:sender/app.bsky.feed.post/cmd1",
|
||||
"cid": "...",
|
||||
"author": {
|
||||
"did": "did:plc:sender",
|
||||
"handle": "alice.bsky.social",
|
||||
"displayName": "Alice"
|
||||
},
|
||||
"reason": "mention",
|
||||
"record": {
|
||||
"text": "/status check"
|
||||
},
|
||||
"indexedAt": "2024-01-01T00:00:00.000Z"
|
||||
});
|
||||
|
||||
let msg = parse_bluesky_notification(¬if, "did:plc:bot").unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "status");
|
||||
assert_eq!(args, &["check"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
1091
crates/openfang-channels/src/bridge.rs
Normal file
1091
crates/openfang-channels/src/bridge.rs
Normal file
File diff suppressed because it is too large
Load Diff
425
crates/openfang-channels/src/dingtalk.rs
Normal file
425
crates/openfang-channels/src/dingtalk.rs
Normal file
@@ -0,0 +1,425 @@
|
||||
//! DingTalk Robot channel adapter.
|
||||
//!
|
||||
//! Integrates with the DingTalk (Alibaba) custom robot API. Incoming messages
|
||||
//! are received via an HTTP webhook callback server, and outbound messages are
|
||||
//! posted to the robot send endpoint with HMAC-SHA256 signature verification.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const MAX_MESSAGE_LEN: usize = 20000;
|
||||
const DINGTALK_SEND_URL: &str = "https://oapi.dingtalk.com/robot/send";
|
||||
|
||||
/// DingTalk Robot channel adapter.
|
||||
///
|
||||
/// Uses a webhook listener to receive incoming messages from DingTalk
|
||||
/// conversations and posts replies via the signed Robot Send API.
|
||||
pub struct DingTalkAdapter {
|
||||
/// SECURITY: Robot access token is zeroized on drop.
|
||||
access_token: Zeroizing<String>,
|
||||
/// SECURITY: Signing secret for HMAC-SHA256 verification.
|
||||
secret: Zeroizing<String>,
|
||||
/// Port for the incoming webhook HTTP server.
|
||||
webhook_port: u16,
|
||||
/// HTTP client for outbound requests.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl DingTalkAdapter {
|
||||
/// Create a new DingTalk Robot adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `access_token` - Robot access token from DingTalk.
|
||||
/// * `secret` - Signing secret for request verification.
|
||||
/// * `webhook_port` - Local port to listen for DingTalk callbacks.
|
||||
pub fn new(access_token: String, secret: String, webhook_port: u16) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
access_token: Zeroizing::new(access_token),
|
||||
secret: Zeroizing::new(secret),
|
||||
webhook_port,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute the HMAC-SHA256 signature for a DingTalk request.
|
||||
///
|
||||
/// DingTalk signature = Base64(HMAC-SHA256(secret, timestamp + "\n" + secret))
|
||||
fn compute_signature(secret: &str, timestamp: i64) -> String {
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
let string_to_sign = format!("{}\n{}", timestamp, secret);
|
||||
let mut mac =
|
||||
Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key size");
|
||||
mac.update(string_to_sign.as_bytes());
|
||||
let result = mac.finalize();
|
||||
use base64::Engine;
|
||||
base64::engine::general_purpose::STANDARD.encode(result.into_bytes())
|
||||
}
|
||||
|
||||
/// Verify an incoming DingTalk callback signature.
|
||||
fn verify_signature(secret: &str, timestamp: i64, signature: &str) -> bool {
|
||||
let expected = Self::compute_signature(secret, timestamp);
|
||||
// Constant-time comparison
|
||||
if expected.len() != signature.len() {
|
||||
return false;
|
||||
}
|
||||
let mut diff = 0u8;
|
||||
for (a, b) in expected.bytes().zip(signature.bytes()) {
|
||||
diff |= a ^ b;
|
||||
}
|
||||
diff == 0
|
||||
}
|
||||
|
||||
/// Build the signed send URL with access_token, timestamp, and signature.
|
||||
fn build_send_url(&self) -> String {
|
||||
let timestamp = Utc::now().timestamp_millis();
|
||||
let sign = Self::compute_signature(&self.secret, timestamp);
|
||||
let encoded_sign = url::form_urlencoded::Serializer::new(String::new())
|
||||
.append_pair("sign", &sign)
|
||||
.finish();
|
||||
format!(
|
||||
"{}?access_token={}×tamp={}&{}",
|
||||
DINGTALK_SEND_URL,
|
||||
self.access_token.as_str(),
|
||||
timestamp,
|
||||
encoded_sign
|
||||
)
|
||||
}
|
||||
|
||||
/// Parse a DingTalk webhook JSON body into extracted fields.
|
||||
fn parse_callback(body: &serde_json::Value) -> Option<(String, String, String, String, bool)> {
|
||||
let msg_type = body["msgtype"].as_str()?;
|
||||
let text = match msg_type {
|
||||
"text" => body["text"]["content"].as_str()?.trim().to_string(),
|
||||
_ => return None,
|
||||
};
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let sender_id = body["senderId"].as_str().unwrap_or("unknown").to_string();
|
||||
let sender_nick = body["senderNick"].as_str().unwrap_or("Unknown").to_string();
|
||||
let conversation_id = body["conversationId"].as_str().unwrap_or("").to_string();
|
||||
let is_group = body["conversationType"].as_str() == Some("2");
|
||||
|
||||
Some((text, sender_id, sender_nick, conversation_id, is_group))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for DingTalkAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"dingtalk"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("dingtalk".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let secret = self.secret.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
info!("DingTalk adapter starting webhook server on port {port}");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let tx_shared = Arc::new(tx);
|
||||
let secret_shared = Arc::new(secret);
|
||||
|
||||
let app = axum::Router::new().route(
|
||||
"/",
|
||||
axum::routing::post({
|
||||
let tx = Arc::clone(&tx_shared);
|
||||
let secret = Arc::clone(&secret_shared);
|
||||
move |headers: axum::http::HeaderMap,
|
||||
body: axum::extract::Json<serde_json::Value>| {
|
||||
let tx = Arc::clone(&tx);
|
||||
let secret = Arc::clone(&secret);
|
||||
async move {
|
||||
// Extract timestamp and sign from headers
|
||||
let timestamp_str = headers
|
||||
.get("timestamp")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("0");
|
||||
let signature = headers
|
||||
.get("sign")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
// Verify signature
|
||||
if let Ok(ts) = timestamp_str.parse::<i64>() {
|
||||
if !DingTalkAdapter::verify_signature(&secret, ts, signature) {
|
||||
warn!("DingTalk: invalid signature");
|
||||
return axum::http::StatusCode::FORBIDDEN;
|
||||
}
|
||||
|
||||
// Check timestamp freshness (1 hour window)
|
||||
let now = Utc::now().timestamp_millis();
|
||||
if (now - ts).unsigned_abs() > 3_600_000 {
|
||||
warn!("DingTalk: stale timestamp");
|
||||
return axum::http::StatusCode::FORBIDDEN;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((text, sender_id, sender_nick, conv_id, is_group)) =
|
||||
DingTalkAdapter::parse_callback(&body)
|
||||
{
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text)
|
||||
};
|
||||
|
||||
let msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("dingtalk".to_string()),
|
||||
platform_message_id: format!(
|
||||
"dt-{}",
|
||||
Utc::now().timestamp_millis()
|
||||
),
|
||||
sender: ChannelUser {
|
||||
platform_id: sender_id,
|
||||
display_name: sender_nick,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: None,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"conversation_id".to_string(),
|
||||
serde_json::Value::String(conv_id),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
|
||||
axum::http::StatusCode::OK
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
info!("DingTalk webhook server listening on {addr}");
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("DingTalk: failed to bind port {port}: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let server = axum::serve(listener, app);
|
||||
|
||||
tokio::select! {
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
warn!("DingTalk webhook server error: {e}");
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("DingTalk adapter shutting down");
|
||||
}
|
||||
}
|
||||
|
||||
info!("DingTalk webhook server stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
let chunks = split_message(&text, MAX_MESSAGE_LEN);
|
||||
let num_chunks = chunks.len();
|
||||
|
||||
for chunk in chunks {
|
||||
let url = self.build_send_url();
|
||||
let body = serde_json::json!({
|
||||
"msgtype": "text",
|
||||
"text": {
|
||||
"content": chunk,
|
||||
}
|
||||
});
|
||||
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("DingTalk API error {status}: {err_body}").into());
|
||||
}
|
||||
|
||||
// DingTalk returns {"errcode": 0, "errmsg": "ok"} on success
|
||||
let result: serde_json::Value = resp.json().await?;
|
||||
if result["errcode"].as_i64() != Some(0) {
|
||||
return Err(format!(
|
||||
"DingTalk error: {}",
|
||||
result["errmsg"].as_str().unwrap_or("unknown")
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
// Rate limit: small delay between chunks
|
||||
if num_chunks > 1 {
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// DingTalk Robot API does not support typing indicators.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_dingtalk_adapter_creation() {
|
||||
let adapter =
|
||||
DingTalkAdapter::new("test-token".to_string(), "test-secret".to_string(), 8080);
|
||||
assert_eq!(adapter.name(), "dingtalk");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("dingtalk".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dingtalk_signature_computation() {
|
||||
let timestamp: i64 = 1700000000000;
|
||||
let secret = "my-secret";
|
||||
let sig = DingTalkAdapter::compute_signature(secret, timestamp);
|
||||
assert!(!sig.is_empty());
|
||||
// Verify deterministic output
|
||||
let sig2 = DingTalkAdapter::compute_signature(secret, timestamp);
|
||||
assert_eq!(sig, sig2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dingtalk_signature_verification() {
|
||||
let secret = "test-secret-123";
|
||||
let timestamp: i64 = 1700000000000;
|
||||
let sig = DingTalkAdapter::compute_signature(secret, timestamp);
|
||||
assert!(DingTalkAdapter::verify_signature(secret, timestamp, &sig));
|
||||
assert!(!DingTalkAdapter::verify_signature(
|
||||
secret, timestamp, "bad-sig"
|
||||
));
|
||||
assert!(!DingTalkAdapter::verify_signature(
|
||||
"wrong-secret",
|
||||
timestamp,
|
||||
&sig
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dingtalk_parse_callback_text() {
|
||||
let body = serde_json::json!({
|
||||
"msgtype": "text",
|
||||
"text": { "content": "Hello bot" },
|
||||
"senderId": "user123",
|
||||
"senderNick": "Alice",
|
||||
"conversationId": "conv456",
|
||||
"conversationType": "2",
|
||||
});
|
||||
let result = DingTalkAdapter::parse_callback(&body);
|
||||
assert!(result.is_some());
|
||||
let (text, sender_id, sender_nick, conv_id, is_group) = result.unwrap();
|
||||
assert_eq!(text, "Hello bot");
|
||||
assert_eq!(sender_id, "user123");
|
||||
assert_eq!(sender_nick, "Alice");
|
||||
assert_eq!(conv_id, "conv456");
|
||||
assert!(is_group);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dingtalk_parse_callback_unsupported_type() {
|
||||
let body = serde_json::json!({
|
||||
"msgtype": "image",
|
||||
"image": { "downloadCode": "abc" },
|
||||
});
|
||||
assert!(DingTalkAdapter::parse_callback(&body).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dingtalk_parse_callback_dm() {
|
||||
let body = serde_json::json!({
|
||||
"msgtype": "text",
|
||||
"text": { "content": "DM message" },
|
||||
"senderId": "u1",
|
||||
"senderNick": "Bob",
|
||||
"conversationId": "c1",
|
||||
"conversationType": "1",
|
||||
});
|
||||
let result = DingTalkAdapter::parse_callback(&body);
|
||||
assert!(result.is_some());
|
||||
let (_, _, _, _, is_group) = result.unwrap();
|
||||
assert!(!is_group);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dingtalk_send_url_contains_token_and_sign() {
|
||||
let adapter = DingTalkAdapter::new("my-token".to_string(), "my-secret".to_string(), 8080);
|
||||
let url = adapter.build_send_url();
|
||||
assert!(url.contains("access_token=my-token"));
|
||||
assert!(url.contains("timestamp="));
|
||||
assert!(url.contains("sign="));
|
||||
}
|
||||
}
|
||||
692
crates/openfang-channels/src/discord.rs
Normal file
692
crates/openfang-channels/src/discord.rs
Normal file
@@ -0,0 +1,692 @@
|
||||
//! Discord Gateway adapter for the OpenFang channel bridge.
|
||||
//!
|
||||
//! Uses Discord Gateway WebSocket (v10) for receiving messages and the REST API
|
||||
//! for sending responses. No external Discord crate — just `tokio-tungstenite` + `reqwest`.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::{SinkExt, Stream, StreamExt};
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const DISCORD_API_BASE: &str = "https://discord.com/api/v10";
|
||||
const MAX_BACKOFF: Duration = Duration::from_secs(60);
|
||||
const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
|
||||
const DISCORD_MSG_LIMIT: usize = 2000;
|
||||
|
||||
/// Discord Gateway opcodes.
|
||||
mod opcode {
|
||||
pub const DISPATCH: u64 = 0;
|
||||
pub const HEARTBEAT: u64 = 1;
|
||||
pub const IDENTIFY: u64 = 2;
|
||||
pub const RESUME: u64 = 6;
|
||||
pub const RECONNECT: u64 = 7;
|
||||
pub const INVALID_SESSION: u64 = 9;
|
||||
pub const HELLO: u64 = 10;
|
||||
pub const HEARTBEAT_ACK: u64 = 11;
|
||||
}
|
||||
|
||||
/// Discord Gateway adapter using WebSocket.
|
||||
pub struct DiscordAdapter {
|
||||
/// SECURITY: Bot token is zeroized on drop to prevent memory disclosure.
|
||||
token: Zeroizing<String>,
|
||||
client: reqwest::Client,
|
||||
allowed_guilds: Vec<u64>,
|
||||
intents: u64,
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Bot's own user ID (populated after READY event).
|
||||
bot_user_id: Arc<RwLock<Option<String>>>,
|
||||
/// Session ID for resume (populated after READY event).
|
||||
session_id: Arc<RwLock<Option<String>>>,
|
||||
/// Resume gateway URL.
|
||||
resume_gateway_url: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl DiscordAdapter {
|
||||
pub fn new(token: String, allowed_guilds: Vec<u64>, intents: u64) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
token: Zeroizing::new(token),
|
||||
client: reqwest::Client::new(),
|
||||
allowed_guilds,
|
||||
intents,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
bot_user_id: Arc::new(RwLock::new(None)),
|
||||
session_id: Arc::new(RwLock::new(None)),
|
||||
resume_gateway_url: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the WebSocket gateway URL from the Discord API.
|
||||
async fn get_gateway_url(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{DISCORD_API_BASE}/gateway/bot");
|
||||
let resp: serde_json::Value = self
|
||||
.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bot {}", self.token.as_str()))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let ws_url = resp["url"]
|
||||
.as_str()
|
||||
.ok_or("Missing 'url' in gateway response")?;
|
||||
|
||||
Ok(format!("{ws_url}/?v=10&encoding=json"))
|
||||
}
|
||||
|
||||
/// Send a message to a Discord channel via REST API.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{DISCORD_API_BASE}/channels/{channel_id}/messages");
|
||||
let chunks = split_message(text, DISCORD_MSG_LIMIT);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({ "content": chunk });
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bot {}", self.token.as_str()))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let body_text = resp.text().await.unwrap_or_default();
|
||||
warn!("Discord sendMessage failed: {body_text}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send typing indicator to a Discord channel.
|
||||
async fn api_send_typing(&self, channel_id: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{DISCORD_API_BASE}/channels/{channel_id}/typing");
|
||||
let _ = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bot {}", self.token.as_str()))
|
||||
.send()
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for DiscordAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"discord"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Discord
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let gateway_url = self.get_gateway_url().await?;
|
||||
info!("Discord gateway URL obtained");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
|
||||
let token = self.token.clone();
|
||||
let intents = self.intents;
|
||||
let allowed_guilds = self.allowed_guilds.clone();
|
||||
let bot_user_id = self.bot_user_id.clone();
|
||||
let session_id_store = self.session_id.clone();
|
||||
let resume_url_store = self.resume_gateway_url.clone();
|
||||
let mut shutdown = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = INITIAL_BACKOFF;
|
||||
let mut connect_url = gateway_url;
|
||||
// Sequence persists across reconnections for RESUME
|
||||
let sequence: Arc<RwLock<Option<u64>>> = Arc::new(RwLock::new(None));
|
||||
|
||||
loop {
|
||||
if *shutdown.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
info!("Connecting to Discord gateway...");
|
||||
|
||||
let ws_result = tokio_tungstenite::connect_async(&connect_url).await;
|
||||
let ws_stream = match ws_result {
|
||||
Ok((stream, _)) => stream,
|
||||
Err(e) => {
|
||||
warn!("Discord gateway connection failed: {e}, retrying in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
backoff = INITIAL_BACKOFF;
|
||||
info!("Discord gateway connected");
|
||||
|
||||
let (mut ws_tx, mut ws_rx) = ws_stream.split();
|
||||
let mut _heartbeat_interval: Option<u64> = None;
|
||||
|
||||
// Inner message loop — returns true if we should reconnect
|
||||
let should_reconnect = 'inner: loop {
|
||||
let msg = tokio::select! {
|
||||
msg = ws_rx.next() => msg,
|
||||
_ = shutdown.changed() => {
|
||||
if *shutdown.borrow() {
|
||||
info!("Discord shutdown requested");
|
||||
let _ = ws_tx.close().await;
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let msg = match msg {
|
||||
Some(Ok(m)) => m,
|
||||
Some(Err(e)) => {
|
||||
warn!("Discord WebSocket error: {e}");
|
||||
break 'inner true;
|
||||
}
|
||||
None => {
|
||||
info!("Discord WebSocket closed");
|
||||
break 'inner true;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match msg {
|
||||
tokio_tungstenite::tungstenite::Message::Text(t) => t,
|
||||
tokio_tungstenite::tungstenite::Message::Close(_) => {
|
||||
info!("Discord gateway closed by server");
|
||||
break 'inner true;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let payload: serde_json::Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
warn!("Discord: failed to parse gateway message: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let op = payload["op"].as_u64().unwrap_or(999);
|
||||
|
||||
// Update sequence number
|
||||
if let Some(s) = payload["s"].as_u64() {
|
||||
*sequence.write().await = Some(s);
|
||||
}
|
||||
|
||||
match op {
|
||||
opcode::HELLO => {
|
||||
let interval =
|
||||
payload["d"]["heartbeat_interval"].as_u64().unwrap_or(45000);
|
||||
_heartbeat_interval = Some(interval);
|
||||
debug!("Discord HELLO: heartbeat_interval={interval}ms");
|
||||
|
||||
// Try RESUME if we have a session, otherwise IDENTIFY
|
||||
let has_session = session_id_store.read().await.is_some();
|
||||
let has_seq = sequence.read().await.is_some();
|
||||
|
||||
let gateway_msg = if has_session && has_seq {
|
||||
let sid = session_id_store.read().await.clone().unwrap();
|
||||
let seq = *sequence.read().await;
|
||||
info!("Discord: sending RESUME (session={sid})");
|
||||
serde_json::json!({
|
||||
"op": opcode::RESUME,
|
||||
"d": {
|
||||
"token": token.as_str(),
|
||||
"session_id": sid,
|
||||
"seq": seq
|
||||
}
|
||||
})
|
||||
} else {
|
||||
info!("Discord: sending IDENTIFY");
|
||||
serde_json::json!({
|
||||
"op": opcode::IDENTIFY,
|
||||
"d": {
|
||||
"token": token.as_str(),
|
||||
"intents": intents,
|
||||
"properties": {
|
||||
"os": "linux",
|
||||
"browser": "openfang",
|
||||
"device": "openfang"
|
||||
}
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
if let Err(e) = ws_tx
|
||||
.send(tokio_tungstenite::tungstenite::Message::Text(
|
||||
serde_json::to_string(&gateway_msg).unwrap(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
error!("Discord: failed to send IDENTIFY/RESUME: {e}");
|
||||
break 'inner true;
|
||||
}
|
||||
}
|
||||
|
||||
opcode::DISPATCH => {
|
||||
let event_name = payload["t"].as_str().unwrap_or("");
|
||||
let d = &payload["d"];
|
||||
|
||||
match event_name {
|
||||
"READY" => {
|
||||
let user_id =
|
||||
d["user"]["id"].as_str().unwrap_or("").to_string();
|
||||
let username =
|
||||
d["user"]["username"].as_str().unwrap_or("unknown");
|
||||
let sid = d["session_id"].as_str().unwrap_or("").to_string();
|
||||
let resume_url =
|
||||
d["resume_gateway_url"].as_str().unwrap_or("").to_string();
|
||||
|
||||
*bot_user_id.write().await = Some(user_id.clone());
|
||||
*session_id_store.write().await = Some(sid);
|
||||
if !resume_url.is_empty() {
|
||||
*resume_url_store.write().await = Some(resume_url);
|
||||
}
|
||||
|
||||
info!("Discord bot ready: {username} ({user_id})");
|
||||
}
|
||||
|
||||
"MESSAGE_CREATE" | "MESSAGE_UPDATE" => {
|
||||
if let Some(msg) =
|
||||
parse_discord_message(d, &bot_user_id, &allowed_guilds)
|
||||
.await
|
||||
{
|
||||
debug!(
|
||||
"Discord {event_name} from {}: {:?}",
|
||||
msg.sender.display_name, msg.content
|
||||
);
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"RESUMED" => {
|
||||
info!("Discord session resumed successfully");
|
||||
}
|
||||
|
||||
_ => {
|
||||
debug!("Discord event: {event_name}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
opcode::HEARTBEAT => {
|
||||
// Server requests immediate heartbeat
|
||||
let seq = *sequence.read().await;
|
||||
let hb = serde_json::json!({ "op": opcode::HEARTBEAT, "d": seq });
|
||||
let _ = ws_tx
|
||||
.send(tokio_tungstenite::tungstenite::Message::Text(
|
||||
serde_json::to_string(&hb).unwrap(),
|
||||
))
|
||||
.await;
|
||||
}
|
||||
|
||||
opcode::HEARTBEAT_ACK => {
|
||||
debug!("Discord heartbeat ACK received");
|
||||
}
|
||||
|
||||
opcode::RECONNECT => {
|
||||
info!("Discord: server requested reconnect");
|
||||
break 'inner true;
|
||||
}
|
||||
|
||||
opcode::INVALID_SESSION => {
|
||||
let resumable = payload["d"].as_bool().unwrap_or(false);
|
||||
if resumable {
|
||||
info!("Discord: invalid session (resumable)");
|
||||
} else {
|
||||
info!("Discord: invalid session (not resumable), clearing session");
|
||||
*session_id_store.write().await = None;
|
||||
*sequence.write().await = None;
|
||||
}
|
||||
break 'inner true;
|
||||
}
|
||||
|
||||
_ => {
|
||||
debug!("Discord: unknown opcode {op}");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !should_reconnect || *shutdown.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Try resume URL if available
|
||||
if let Some(ref url) = *resume_url_store.read().await {
|
||||
connect_url = format!("{url}/?v=10&encoding=json");
|
||||
}
|
||||
|
||||
warn!("Discord: reconnecting in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
}
|
||||
|
||||
info!("Discord gateway loop stopped");
|
||||
});
|
||||
|
||||
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// platform_id is the channel_id for Discord
|
||||
let channel_id = &user.platform_id;
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(channel_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(channel_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.api_send_typing(&user.platform_id).await
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Discord MESSAGE_CREATE or MESSAGE_UPDATE payload into a `ChannelMessage`.
|
||||
async fn parse_discord_message(
|
||||
d: &serde_json::Value,
|
||||
bot_user_id: &Arc<RwLock<Option<String>>>,
|
||||
allowed_guilds: &[u64],
|
||||
) -> Option<ChannelMessage> {
|
||||
let author = d.get("author")?;
|
||||
let author_id = author["id"].as_str()?;
|
||||
|
||||
// Filter out bot's own messages
|
||||
if let Some(ref bid) = *bot_user_id.read().await {
|
||||
if author_id == bid {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out other bots
|
||||
if author["bot"].as_bool() == Some(true) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Filter by allowed guilds
|
||||
if !allowed_guilds.is_empty() {
|
||||
if let Some(guild_id) = d["guild_id"].as_str() {
|
||||
let gid: u64 = guild_id.parse().unwrap_or(0);
|
||||
if !allowed_guilds.contains(&gid) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let content_text = d["content"].as_str().unwrap_or("");
|
||||
if content_text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let channel_id = d["channel_id"].as_str()?;
|
||||
let message_id = d["id"].as_str().unwrap_or("0");
|
||||
let username = author["username"].as_str().unwrap_or("Unknown");
|
||||
let discriminator = author["discriminator"].as_str().unwrap_or("0000");
|
||||
let display_name = if discriminator == "0" {
|
||||
username.to_string()
|
||||
} else {
|
||||
format!("{username}#{discriminator}")
|
||||
};
|
||||
|
||||
let timestamp = d["timestamp"]
|
||||
.as_str()
|
||||
.and_then(|ts| chrono::DateTime::parse_from_rfc3339(ts).ok())
|
||||
.map(|dt| dt.with_timezone(&chrono::Utc))
|
||||
.unwrap_or_else(chrono::Utc::now);
|
||||
|
||||
// Parse commands (messages starting with /)
|
||||
let content = if content_text.starts_with('/') {
|
||||
let parts: Vec<&str> = content_text.splitn(2, ' ').collect();
|
||||
let cmd_name = &parts[0][1..];
|
||||
let args = if parts.len() > 1 {
|
||||
parts[1].split_whitespace().map(String::from).collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(content_text.to_string())
|
||||
};
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Discord,
|
||||
platform_message_id: message_id.to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: channel_id.to_string(),
|
||||
display_name,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp,
|
||||
is_group: true,
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_discord_message_basic() {
|
||||
let bot_id = Arc::new(RwLock::new(Some("bot123".to_string())));
|
||||
let d = serde_json::json!({
|
||||
"id": "msg1",
|
||||
"channel_id": "ch1",
|
||||
"content": "Hello agent!",
|
||||
"author": {
|
||||
"id": "user456",
|
||||
"username": "alice",
|
||||
"discriminator": "0",
|
||||
"bot": false
|
||||
},
|
||||
"timestamp": "2024-01-01T00:00:00+00:00"
|
||||
});
|
||||
|
||||
let msg = parse_discord_message(&d, &bot_id, &[]).await.unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Discord);
|
||||
assert_eq!(msg.sender.display_name, "alice");
|
||||
assert_eq!(msg.sender.platform_id, "ch1");
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello agent!"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_discord_message_filters_bot() {
|
||||
let bot_id = Arc::new(RwLock::new(Some("bot123".to_string())));
|
||||
let d = serde_json::json!({
|
||||
"id": "msg1",
|
||||
"channel_id": "ch1",
|
||||
"content": "My own message",
|
||||
"author": {
|
||||
"id": "bot123",
|
||||
"username": "openfang",
|
||||
"discriminator": "0"
|
||||
},
|
||||
"timestamp": "2024-01-01T00:00:00+00:00"
|
||||
});
|
||||
|
||||
let msg = parse_discord_message(&d, &bot_id, &[]).await;
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_discord_message_filters_other_bots() {
|
||||
let bot_id = Arc::new(RwLock::new(Some("bot123".to_string())));
|
||||
let d = serde_json::json!({
|
||||
"id": "msg1",
|
||||
"channel_id": "ch1",
|
||||
"content": "Bot message",
|
||||
"author": {
|
||||
"id": "other_bot",
|
||||
"username": "somebot",
|
||||
"discriminator": "0",
|
||||
"bot": true
|
||||
},
|
||||
"timestamp": "2024-01-01T00:00:00+00:00"
|
||||
});
|
||||
|
||||
let msg = parse_discord_message(&d, &bot_id, &[]).await;
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_discord_message_guild_filter() {
|
||||
let bot_id = Arc::new(RwLock::new(Some("bot123".to_string())));
|
||||
let d = serde_json::json!({
|
||||
"id": "msg1",
|
||||
"channel_id": "ch1",
|
||||
"guild_id": "999",
|
||||
"content": "Hello",
|
||||
"author": {
|
||||
"id": "user1",
|
||||
"username": "bob",
|
||||
"discriminator": "0"
|
||||
},
|
||||
"timestamp": "2024-01-01T00:00:00+00:00"
|
||||
});
|
||||
|
||||
// Not in allowed guilds
|
||||
let msg = parse_discord_message(&d, &bot_id, &[111, 222]).await;
|
||||
assert!(msg.is_none());
|
||||
|
||||
// In allowed guilds
|
||||
let msg = parse_discord_message(&d, &bot_id, &[999]).await;
|
||||
assert!(msg.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_discord_command() {
|
||||
let bot_id = Arc::new(RwLock::new(None));
|
||||
let d = serde_json::json!({
|
||||
"id": "msg1",
|
||||
"channel_id": "ch1",
|
||||
"content": "/agent hello-world",
|
||||
"author": {
|
||||
"id": "user1",
|
||||
"username": "alice",
|
||||
"discriminator": "0"
|
||||
},
|
||||
"timestamp": "2024-01-01T00:00:00+00:00"
|
||||
});
|
||||
|
||||
let msg = parse_discord_message(&d, &bot_id, &[]).await.unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "agent");
|
||||
assert_eq!(args, &["hello-world"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_discord_empty_content() {
|
||||
let bot_id = Arc::new(RwLock::new(None));
|
||||
let d = serde_json::json!({
|
||||
"id": "msg1",
|
||||
"channel_id": "ch1",
|
||||
"content": "",
|
||||
"author": {
|
||||
"id": "user1",
|
||||
"username": "alice",
|
||||
"discriminator": "0"
|
||||
},
|
||||
"timestamp": "2024-01-01T00:00:00+00:00"
|
||||
});
|
||||
|
||||
let msg = parse_discord_message(&d, &bot_id, &[]).await;
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_discord_discriminator() {
|
||||
let bot_id = Arc::new(RwLock::new(None));
|
||||
let d = serde_json::json!({
|
||||
"id": "msg1",
|
||||
"channel_id": "ch1",
|
||||
"content": "Hi",
|
||||
"author": {
|
||||
"id": "user1",
|
||||
"username": "alice",
|
||||
"discriminator": "1234"
|
||||
},
|
||||
"timestamp": "2024-01-01T00:00:00+00:00"
|
||||
});
|
||||
|
||||
let msg = parse_discord_message(&d, &bot_id, &[]).await.unwrap();
|
||||
assert_eq!(msg.sender.display_name, "alice#1234");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_discord_message_update() {
|
||||
let bot_id = Arc::new(RwLock::new(Some("bot123".to_string())));
|
||||
let d = serde_json::json!({
|
||||
"id": "msg1",
|
||||
"channel_id": "ch1",
|
||||
"content": "Edited message content",
|
||||
"author": {
|
||||
"id": "user456",
|
||||
"username": "alice",
|
||||
"discriminator": "0",
|
||||
"bot": false
|
||||
},
|
||||
"timestamp": "2024-01-01T00:00:00+00:00",
|
||||
"edited_timestamp": "2024-01-01T00:01:00+00:00"
|
||||
});
|
||||
|
||||
// MESSAGE_UPDATE uses the same parse function as MESSAGE_CREATE
|
||||
let msg = parse_discord_message(&d, &bot_id, &[]).await.unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Discord);
|
||||
assert!(
|
||||
matches!(msg.content, ChannelContent::Text(ref t) if t == "Edited message content")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_discord_adapter_creation() {
|
||||
let adapter = DiscordAdapter::new("test-token".to_string(), vec![123, 456], 33280);
|
||||
assert_eq!(adapter.name(), "discord");
|
||||
assert_eq!(adapter.channel_type(), ChannelType::Discord);
|
||||
}
|
||||
}
|
||||
469
crates/openfang-channels/src/discourse.rs
Normal file
469
crates/openfang-channels/src/discourse.rs
Normal file
@@ -0,0 +1,469 @@
|
||||
//! Discourse channel adapter.
|
||||
//!
|
||||
//! Integrates with the Discourse forum REST API. Uses long-polling on
|
||||
//! `posts.json` to receive new posts and creates replies via the same API.
|
||||
//! Authentication uses the `Api-Key` and `Api-Username` headers.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const POLL_INTERVAL_SECS: u64 = 10;
|
||||
const MAX_MESSAGE_LEN: usize = 32000;
|
||||
|
||||
/// Discourse forum channel adapter.
|
||||
///
|
||||
/// Polls the Discourse `/posts.json` endpoint for new posts and creates
|
||||
/// replies via `POST /posts.json`. Filters posts by category if configured.
|
||||
pub struct DiscourseAdapter {
|
||||
/// Base URL of the Discourse instance (e.g., `"https://forum.example.com"`).
|
||||
base_url: String,
|
||||
/// SECURITY: API key is zeroized on drop.
|
||||
api_key: Zeroizing<String>,
|
||||
/// Username associated with the API key.
|
||||
api_username: String,
|
||||
/// Category slugs to filter (empty = all categories).
|
||||
categories: Vec<String>,
|
||||
/// HTTP client.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Last seen post ID (for incremental polling).
|
||||
last_post_id: Arc<RwLock<u64>>,
|
||||
}
|
||||
|
||||
impl DiscourseAdapter {
|
||||
/// Create a new Discourse adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `base_url` - Base URL of the Discourse instance.
|
||||
/// * `api_key` - Discourse API key (admin or user-scoped).
|
||||
/// * `api_username` - Username for the API key (usually "system" or a bot account).
|
||||
/// * `categories` - Category slugs to listen to (empty = all).
|
||||
pub fn new(
|
||||
base_url: String,
|
||||
api_key: String,
|
||||
api_username: String,
|
||||
categories: Vec<String>,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let base_url = base_url.trim_end_matches('/').to_string();
|
||||
Self {
|
||||
base_url,
|
||||
api_key: Zeroizing::new(api_key),
|
||||
api_username,
|
||||
categories,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
last_post_id: Arc::new(RwLock::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add Discourse API auth headers to a request builder.
|
||||
fn auth_headers(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
builder
|
||||
.header("Api-Key", self.api_key.as_str())
|
||||
.header("Api-Username", &self.api_username)
|
||||
}
|
||||
|
||||
/// Validate credentials by calling `/session/current.json`.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/session/current.json", self.base_url);
|
||||
let resp = self.auth_headers(self.client.get(&url)).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("Discourse auth failed (HTTP {})", resp.status()).into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let username = body["current_user"]["username"]
|
||||
.as_str()
|
||||
.unwrap_or(&self.api_username)
|
||||
.to_string();
|
||||
Ok(username)
|
||||
}
|
||||
|
||||
/// Fetch the latest posts since `before_id`.
|
||||
async fn fetch_latest_posts(
|
||||
client: &reqwest::Client,
|
||||
base_url: &str,
|
||||
api_key: &str,
|
||||
api_username: &str,
|
||||
before_id: u64,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
|
||||
let url = if before_id > 0 {
|
||||
format!("{}/posts.json?before={}", base_url, before_id)
|
||||
} else {
|
||||
format!("{}/posts.json", base_url)
|
||||
};
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.header("Api-Key", api_key)
|
||||
.header("Api-Username", api_username)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("Discourse: HTTP {}", resp.status()).into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let posts = body["latest_posts"].as_array().cloned().unwrap_or_default();
|
||||
Ok(posts)
|
||||
}
|
||||
|
||||
/// Create a reply to a topic.
|
||||
async fn create_post(
|
||||
&self,
|
||||
topic_id: u64,
|
||||
raw: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/posts.json", self.base_url);
|
||||
let chunks = split_message(raw, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"topic_id": topic_id,
|
||||
"raw": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.auth_headers(self.client.post(&url))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Discourse API error {status}: {err_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a category slug matches the filter.
|
||||
#[allow(dead_code)]
|
||||
fn matches_category(&self, category_slug: &str) -> bool {
|
||||
self.categories.is_empty() || self.categories.iter().any(|c| c == category_slug)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for DiscourseAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"discourse"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("discourse".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let own_username = self.validate().await?;
|
||||
info!("Discourse adapter authenticated as {own_username}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let base_url = self.base_url.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let api_username = self.api_username.clone();
|
||||
let categories = self.categories.clone();
|
||||
let client = self.client.clone();
|
||||
let last_post_id = Arc::clone(&self.last_post_id);
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
// Initialize last_post_id to skip historical posts
|
||||
{
|
||||
let posts = Self::fetch_latest_posts(&client, &base_url, &api_key, &api_username, 0)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
if let Some(latest) = posts.first() {
|
||||
let id = latest["id"].as_u64().unwrap_or(0);
|
||||
*last_post_id.write().await = id;
|
||||
}
|
||||
}
|
||||
|
||||
let poll_interval = Duration::from_secs(POLL_INTERVAL_SECS);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
if *shutdown_rx.borrow() {
|
||||
info!("Discourse adapter shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(poll_interval) => {}
|
||||
}
|
||||
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
let current_last = *last_post_id.read().await;
|
||||
|
||||
let poll_result =
|
||||
Self::fetch_latest_posts(&client, &base_url, &api_key, &api_username, 0)
|
||||
.await
|
||||
.map_err(|e| e.to_string());
|
||||
|
||||
let posts = match poll_result {
|
||||
Ok(p) => {
|
||||
backoff = Duration::from_secs(1);
|
||||
p
|
||||
}
|
||||
Err(msg) => {
|
||||
warn!("Discourse: poll error: {msg}, backing off {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(120));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut max_id = current_last;
|
||||
|
||||
// Process posts in chronological order (API returns newest first)
|
||||
for post in posts.iter().rev() {
|
||||
let post_id = post["id"].as_u64().unwrap_or(0);
|
||||
if post_id <= current_last {
|
||||
continue;
|
||||
}
|
||||
|
||||
let username = post["username"].as_str().unwrap_or("unknown");
|
||||
// Skip own posts
|
||||
if username == own_username || username == api_username {
|
||||
continue;
|
||||
}
|
||||
|
||||
let raw = post["raw"].as_str().unwrap_or("");
|
||||
if raw.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Category filter
|
||||
let category_slug = post["category_slug"].as_str().unwrap_or("");
|
||||
if !categories.is_empty() && !categories.iter().any(|c| c == category_slug) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let topic_id = post["topic_id"].as_u64().unwrap_or(0);
|
||||
let topic_slug = post["topic_slug"].as_str().unwrap_or("").to_string();
|
||||
let post_number = post["post_number"].as_u64().unwrap_or(0);
|
||||
let display_name = post["display_username"]
|
||||
.as_str()
|
||||
.unwrap_or(username)
|
||||
.to_string();
|
||||
|
||||
if post_id > max_id {
|
||||
max_id = post_id;
|
||||
}
|
||||
|
||||
let content = if raw.starts_with('/') {
|
||||
let parts: Vec<&str> = raw.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(raw.to_string())
|
||||
};
|
||||
|
||||
let msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("discourse".to_string()),
|
||||
platform_message_id: format!("discourse-post-{}", post_id),
|
||||
sender: ChannelUser {
|
||||
platform_id: username.to_string(),
|
||||
display_name,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true,
|
||||
thread_id: Some(format!("topic-{}", topic_id)),
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"topic_id".to_string(),
|
||||
serde_json::Value::Number(topic_id.into()),
|
||||
);
|
||||
m.insert(
|
||||
"topic_slug".to_string(),
|
||||
serde_json::Value::String(topic_slug),
|
||||
);
|
||||
m.insert(
|
||||
"post_number".to_string(),
|
||||
serde_json::Value::Number(post_number.into()),
|
||||
);
|
||||
m.insert(
|
||||
"category".to_string(),
|
||||
serde_json::Value::String(category_slug.to_string()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if max_id > current_last {
|
||||
*last_post_id.write().await = max_id;
|
||||
}
|
||||
}
|
||||
|
||||
info!("Discourse polling loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
// Extract topic_id from user.platform_id or metadata
|
||||
// Convention: platform_id holds the topic_id for replies
|
||||
let topic_id: u64 = user.platform_id.parse().unwrap_or(0);
|
||||
|
||||
if topic_id == 0 {
|
||||
return Err("Discourse: cannot send without topic_id in platform_id".into());
|
||||
}
|
||||
|
||||
self.create_post(topic_id, &text).await
|
||||
}
|
||||
|
||||
async fn send_in_thread(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
thread_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
// thread_id format: "topic-{id}"
|
||||
let topic_id: u64 = thread_id
|
||||
.strip_prefix("topic-")
|
||||
.unwrap_or(thread_id)
|
||||
.parse()
|
||||
.map_err(|_| "Discourse: invalid thread_id format")?;
|
||||
|
||||
self.create_post(topic_id, &text).await
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Discourse does not have typing indicators.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_discourse_adapter_creation() {
|
||||
let adapter = DiscourseAdapter::new(
|
||||
"https://forum.example.com".to_string(),
|
||||
"api-key-123".to_string(),
|
||||
"system".to_string(),
|
||||
vec!["general".to_string()],
|
||||
);
|
||||
assert_eq!(adapter.name(), "discourse");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("discourse".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_discourse_url_normalization() {
|
||||
let adapter = DiscourseAdapter::new(
|
||||
"https://forum.example.com/".to_string(),
|
||||
"key".to_string(),
|
||||
"bot".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.base_url, "https://forum.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_discourse_category_filter() {
|
||||
let adapter = DiscourseAdapter::new(
|
||||
"https://forum.example.com".to_string(),
|
||||
"key".to_string(),
|
||||
"bot".to_string(),
|
||||
vec!["dev".to_string(), "support".to_string()],
|
||||
);
|
||||
assert!(adapter.matches_category("dev"));
|
||||
assert!(adapter.matches_category("support"));
|
||||
assert!(!adapter.matches_category("random"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_discourse_category_filter_empty_allows_all() {
|
||||
let adapter = DiscourseAdapter::new(
|
||||
"https://forum.example.com".to_string(),
|
||||
"key".to_string(),
|
||||
"bot".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert!(adapter.matches_category("anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_discourse_auth_headers() {
|
||||
let adapter = DiscourseAdapter::new(
|
||||
"https://forum.example.com".to_string(),
|
||||
"my-api-key".to_string(),
|
||||
"bot-user".to_string(),
|
||||
vec![],
|
||||
);
|
||||
let builder = adapter.client.get("https://example.com");
|
||||
let builder = adapter.auth_headers(builder);
|
||||
let request = builder.build().unwrap();
|
||||
assert_eq!(request.headers().get("Api-Key").unwrap(), "my-api-key");
|
||||
assert_eq!(request.headers().get("Api-Username").unwrap(), "bot-user");
|
||||
}
|
||||
}
|
||||
601
crates/openfang-channels/src/email.rs
Normal file
601
crates/openfang-channels/src/email.rs
Normal file
@@ -0,0 +1,601 @@
|
||||
//! Email channel adapter (IMAP + SMTP).
|
||||
//!
|
||||
//! Polls IMAP for new emails and sends responses via SMTP using `lettre`.
|
||||
//! Uses the subject line for agent routing (e.g., "\[coder\] Fix this bug").
|
||||
|
||||
use crate::types::{ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use dashmap::DashMap;
|
||||
use futures::Stream;
|
||||
use lettre::message::Mailbox;
|
||||
use lettre::transport::smtp::authentication::Credentials;
|
||||
use lettre::AsyncSmtpTransport;
|
||||
use lettre::AsyncTransport;
|
||||
use lettre::Tokio1Executor;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Reply context for email threading (In-Reply-To / Subject continuity).
|
||||
#[derive(Debug, Clone)]
|
||||
struct ReplyCtx {
|
||||
subject: String,
|
||||
message_id: String,
|
||||
}
|
||||
|
||||
/// Email channel adapter using IMAP for receiving and SMTP for sending.
|
||||
pub struct EmailAdapter {
|
||||
/// IMAP server host.
|
||||
imap_host: String,
|
||||
/// IMAP port (993 for TLS).
|
||||
imap_port: u16,
|
||||
/// SMTP server host.
|
||||
smtp_host: String,
|
||||
/// SMTP port (587 for STARTTLS, 465 for implicit TLS).
|
||||
smtp_port: u16,
|
||||
/// Email address (used for both IMAP and SMTP).
|
||||
username: String,
|
||||
/// SECURITY: Password is zeroized on drop.
|
||||
password: Zeroizing<String>,
|
||||
/// How often to check for new emails.
|
||||
poll_interval: Duration,
|
||||
/// Which IMAP folders to monitor.
|
||||
folders: Vec<String>,
|
||||
/// Only process emails from these senders (empty = all).
|
||||
allowed_senders: Vec<String>,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Tracks reply context per sender for email threading.
|
||||
reply_ctx: Arc<DashMap<String, ReplyCtx>>,
|
||||
}
|
||||
|
||||
impl EmailAdapter {
|
||||
/// Create a new email adapter.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
imap_host: String,
|
||||
imap_port: u16,
|
||||
smtp_host: String,
|
||||
smtp_port: u16,
|
||||
username: String,
|
||||
password: String,
|
||||
poll_interval_secs: u64,
|
||||
folders: Vec<String>,
|
||||
allowed_senders: Vec<String>,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
imap_host,
|
||||
imap_port,
|
||||
smtp_host,
|
||||
smtp_port,
|
||||
username,
|
||||
password: Zeroizing::new(password),
|
||||
poll_interval: Duration::from_secs(poll_interval_secs),
|
||||
folders: if folders.is_empty() {
|
||||
vec!["INBOX".to_string()]
|
||||
} else {
|
||||
folders
|
||||
},
|
||||
allowed_senders,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
reply_ctx: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a sender is in the allowlist (empty = allow all). Used in tests.
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_sender(&self, sender: &str) -> bool {
|
||||
self.allowed_senders.is_empty() || self.allowed_senders.iter().any(|s| sender.contains(s))
|
||||
}
|
||||
|
||||
/// Extract agent name from subject line brackets, e.g., "[coder] Fix the bug" -> Some("coder")
|
||||
fn extract_agent_from_subject(subject: &str) -> Option<String> {
|
||||
let subject = subject.trim();
|
||||
if subject.starts_with('[') {
|
||||
if let Some(end) = subject.find(']') {
|
||||
let agent = &subject[1..end];
|
||||
if !agent.is_empty() {
|
||||
return Some(agent.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Strip the agent tag from a subject line.
|
||||
fn strip_agent_tag(subject: &str) -> String {
|
||||
let subject = subject.trim();
|
||||
if subject.starts_with('[') {
|
||||
if let Some(end) = subject.find(']') {
|
||||
return subject[end + 1..].trim().to_string();
|
||||
}
|
||||
}
|
||||
subject.to_string()
|
||||
}
|
||||
|
||||
/// Build an async SMTP transport for sending emails.
|
||||
async fn build_smtp_transport(
|
||||
&self,
|
||||
) -> Result<AsyncSmtpTransport<Tokio1Executor>, Box<dyn std::error::Error>> {
|
||||
let creds =
|
||||
Credentials::new(self.username.clone(), self.password.as_str().to_string());
|
||||
|
||||
let transport = if self.smtp_port == 465 {
|
||||
// Implicit TLS (port 465)
|
||||
AsyncSmtpTransport::<Tokio1Executor>::relay(&self.smtp_host)?
|
||||
.port(self.smtp_port)
|
||||
.credentials(creds)
|
||||
.build()
|
||||
} else {
|
||||
// STARTTLS (port 587 or other)
|
||||
AsyncSmtpTransport::<Tokio1Executor>::starttls_relay(&self.smtp_host)?
|
||||
.port(self.smtp_port)
|
||||
.credentials(creds)
|
||||
.build()
|
||||
};
|
||||
|
||||
Ok(transport)
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract `user@domain` from a potentially formatted email string like `"Name <user@domain>"`.
|
||||
fn extract_email_addr(raw: &str) -> String {
|
||||
let raw = raw.trim();
|
||||
if let Some(start) = raw.find('<') {
|
||||
if let Some(end) = raw.find('>') {
|
||||
if end > start {
|
||||
return raw[start + 1..end].trim().to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
raw.to_string()
|
||||
}
|
||||
|
||||
/// Get a specific header value from a parsed email.
|
||||
fn get_header(parsed: &mailparse::ParsedMail<'_>, name: &str) -> Option<String> {
|
||||
parsed
|
||||
.headers
|
||||
.iter()
|
||||
.find(|h| h.get_key().eq_ignore_ascii_case(name))
|
||||
.map(|h| h.get_value())
|
||||
}
|
||||
|
||||
/// Extract the text/plain body from a parsed email (handles multipart).
|
||||
fn extract_text_body(parsed: &mailparse::ParsedMail<'_>) -> String {
|
||||
if parsed.subparts.is_empty() {
|
||||
return parsed.get_body().unwrap_or_default();
|
||||
}
|
||||
// Walk subparts looking for text/plain
|
||||
for part in &parsed.subparts {
|
||||
let ct = part.ctype.mimetype.to_lowercase();
|
||||
if ct == "text/plain" {
|
||||
return part.get_body().unwrap_or_default();
|
||||
}
|
||||
}
|
||||
// Fallback: first subpart body
|
||||
parsed
|
||||
.subparts
|
||||
.first()
|
||||
.and_then(|p| p.get_body().ok())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Fetch unseen emails from IMAP using blocking I/O.
|
||||
/// Returns a Vec of (from_addr, subject, message_id, body).
|
||||
fn fetch_unseen_emails(
|
||||
host: &str,
|
||||
port: u16,
|
||||
username: &str,
|
||||
password: &str,
|
||||
folders: &[String],
|
||||
) -> Result<Vec<(String, String, String, String)>, String> {
|
||||
let tls = native_tls::TlsConnector::builder()
|
||||
.build()
|
||||
.map_err(|e| format!("TLS connector error: {e}"))?;
|
||||
|
||||
let client = imap::connect((host, port), host, &tls)
|
||||
.map_err(|e| format!("IMAP connect failed: {e}"))?;
|
||||
|
||||
let mut session = client
|
||||
.login(username, password)
|
||||
.map_err(|(e, _)| format!("IMAP login failed: {e}"))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
for folder in folders {
|
||||
if let Err(e) = session.select(folder) {
|
||||
warn!(folder, error = %e, "IMAP SELECT failed, skipping folder");
|
||||
continue;
|
||||
}
|
||||
|
||||
let uids = match session.uid_search("UNSEEN") {
|
||||
Ok(uids) => uids,
|
||||
Err(e) => {
|
||||
warn!(folder, error = %e, "IMAP SEARCH UNSEEN failed");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if uids.is_empty() {
|
||||
debug!(folder, "No unseen emails");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fetch in batches of up to 50 to avoid huge responses
|
||||
let uid_list: Vec<u32> = uids.into_iter().take(50).collect();
|
||||
let uid_set: String = uid_list
|
||||
.iter()
|
||||
.map(|u| u.to_string())
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
|
||||
let fetches = match session.uid_fetch(&uid_set, "RFC822") {
|
||||
Ok(f) => f,
|
||||
Err(e) => {
|
||||
warn!(folder, error = %e, "IMAP FETCH failed");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for fetch in fetches.iter() {
|
||||
let body_bytes = match fetch.body() {
|
||||
Some(b) => b,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let parsed = match mailparse::parse_mail(body_bytes) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Failed to parse email");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let from = get_header(&parsed, "From").unwrap_or_default();
|
||||
let subject = get_header(&parsed, "Subject").unwrap_or_default();
|
||||
let message_id = get_header(&parsed, "Message-ID").unwrap_or_default();
|
||||
let text_body = extract_text_body(&parsed);
|
||||
|
||||
let from_addr = extract_email_addr(&from);
|
||||
results.push((from_addr, subject, message_id, text_body));
|
||||
}
|
||||
|
||||
// Mark fetched messages as Seen
|
||||
if let Err(e) = session.uid_store(&uid_set, "+FLAGS (\\Seen)") {
|
||||
warn!(error = %e, "Failed to mark emails as Seen");
|
||||
}
|
||||
}
|
||||
|
||||
let _ = session.logout();
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for EmailAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"email"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Email
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let poll_interval = self.poll_interval;
|
||||
let imap_host = self.imap_host.clone();
|
||||
let imap_port = self.imap_port;
|
||||
let username = self.username.clone();
|
||||
let password = self.password.clone();
|
||||
let folders = self.folders.clone();
|
||||
let allowed_senders = self.allowed_senders.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
let reply_ctx = self.reply_ctx.clone();
|
||||
|
||||
info!(
|
||||
"Starting email adapter (IMAP: {}:{}, SMTP: {}:{}, polling every {:?})",
|
||||
imap_host, imap_port, self.smtp_host, self.smtp_port, poll_interval
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Email adapter shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(poll_interval) => {}
|
||||
}
|
||||
|
||||
// IMAP operations are blocking I/O — run in spawn_blocking
|
||||
let host = imap_host.clone();
|
||||
let port = imap_port;
|
||||
let user = username.clone();
|
||||
let pass = password.clone();
|
||||
let fldrs = folders.clone();
|
||||
|
||||
let emails = tokio::task::spawn_blocking(move || {
|
||||
fetch_unseen_emails(&host, port, &user, pass.as_str(), &fldrs)
|
||||
})
|
||||
.await;
|
||||
|
||||
let emails = match emails {
|
||||
Ok(Ok(emails)) => emails,
|
||||
Ok(Err(e)) => {
|
||||
error!("IMAP poll error: {e}");
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("IMAP spawn_blocking panic: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for (from_addr, subject, message_id, body) in emails {
|
||||
// Check allowed senders
|
||||
if !allowed_senders.is_empty()
|
||||
&& !allowed_senders.iter().any(|s| from_addr.contains(s))
|
||||
{
|
||||
debug!(from = %from_addr, "Email from non-allowed sender, skipping");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Store reply context for threading
|
||||
if !message_id.is_empty() {
|
||||
reply_ctx.insert(
|
||||
from_addr.clone(),
|
||||
ReplyCtx {
|
||||
subject: subject.clone(),
|
||||
message_id: message_id.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// Extract target agent from subject brackets (stored in metadata for router)
|
||||
let _target_agent =
|
||||
EmailAdapter::extract_agent_from_subject(&subject);
|
||||
let clean_subject = EmailAdapter::strip_agent_tag(&subject);
|
||||
|
||||
// Build the message body: prepend subject context
|
||||
let text = if clean_subject.is_empty() {
|
||||
body.trim().to_string()
|
||||
} else {
|
||||
format!("Subject: {clean_subject}\n\n{}", body.trim())
|
||||
};
|
||||
|
||||
let msg = ChannelMessage {
|
||||
channel: ChannelType::Email,
|
||||
platform_message_id: message_id.clone(),
|
||||
sender: ChannelUser {
|
||||
platform_id: from_addr.clone(),
|
||||
display_name: from_addr.clone(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: ChannelContent::Text(text),
|
||||
target_agent: None, // Routing handled by bridge AgentRouter
|
||||
timestamp: Utc::now(),
|
||||
is_group: false,
|
||||
thread_id: None,
|
||||
metadata: std::collections::HashMap::new(),
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
info!("Email channel receiver dropped, stopping poll");
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
// Parse recipient address
|
||||
let to_addr = extract_email_addr(&user.platform_id);
|
||||
let to_mailbox: Mailbox = to_addr
|
||||
.parse()
|
||||
.map_err(|e| format!("Invalid recipient email '{}': {}", to_addr, e))?;
|
||||
|
||||
let from_mailbox: Mailbox = self
|
||||
.username
|
||||
.parse()
|
||||
.map_err(|e| format!("Invalid sender email '{}': {}", self.username, e))?;
|
||||
|
||||
// Extract subject from text body convention: "Subject: ...\n\n..."
|
||||
let (subject, body) = if text.starts_with("Subject: ") {
|
||||
if let Some(pos) = text.find("\n\n") {
|
||||
let subj = text[9..pos].trim().to_string();
|
||||
let body = text[pos + 2..].to_string();
|
||||
(subj, body)
|
||||
} else {
|
||||
("OpenFang Reply".to_string(), text)
|
||||
}
|
||||
} else {
|
||||
// Check reply context for subject continuity
|
||||
let subj = self
|
||||
.reply_ctx
|
||||
.get(&to_addr)
|
||||
.map(|ctx| format!("Re: {}", ctx.subject))
|
||||
.unwrap_or_else(|| "OpenFang Reply".to_string());
|
||||
(subj, text)
|
||||
};
|
||||
|
||||
// Build email message
|
||||
let mut builder = lettre::Message::builder()
|
||||
.from(from_mailbox)
|
||||
.to(to_mailbox)
|
||||
.subject(&subject);
|
||||
|
||||
// Add In-Reply-To header for threading
|
||||
if let Some(ctx) = self.reply_ctx.get(&to_addr) {
|
||||
if !ctx.message_id.is_empty() {
|
||||
builder = builder.in_reply_to(ctx.message_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let email = builder
|
||||
.body(body)
|
||||
.map_err(|e| format!("Failed to build email: {e}"))?;
|
||||
|
||||
// Send via SMTP
|
||||
let transport = self.build_smtp_transport().await?;
|
||||
transport
|
||||
.send(email)
|
||||
.await
|
||||
.map_err(|e| format!("SMTP send failed: {e}"))?;
|
||||
|
||||
info!(
|
||||
to = %to_addr,
|
||||
subject = %subject,
|
||||
"Email sent successfully via SMTP"
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
warn!(
|
||||
"Unsupported email content type for {}, only text is supported",
|
||||
user.platform_id
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_email_adapter_creation() {
|
||||
let adapter = EmailAdapter::new(
|
||||
"imap.gmail.com".to_string(),
|
||||
993,
|
||||
"smtp.gmail.com".to_string(),
|
||||
587,
|
||||
"user@gmail.com".to_string(),
|
||||
"password".to_string(),
|
||||
30,
|
||||
vec![],
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.name(), "email");
|
||||
assert_eq!(adapter.folders, vec!["INBOX".to_string()]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allowed_senders() {
|
||||
let adapter = EmailAdapter::new(
|
||||
"imap.example.com".to_string(),
|
||||
993,
|
||||
"smtp.example.com".to_string(),
|
||||
587,
|
||||
"bot@example.com".to_string(),
|
||||
"pass".to_string(),
|
||||
30,
|
||||
vec![],
|
||||
vec!["boss@company.com".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_sender("boss@company.com"));
|
||||
assert!(!adapter.is_allowed_sender("random@other.com"));
|
||||
|
||||
let open = EmailAdapter::new(
|
||||
"imap.example.com".to_string(),
|
||||
993,
|
||||
"smtp.example.com".to_string(),
|
||||
587,
|
||||
"bot@example.com".to_string(),
|
||||
"pass".to_string(),
|
||||
30,
|
||||
vec![],
|
||||
vec![],
|
||||
);
|
||||
assert!(open.is_allowed_sender("anyone@anywhere.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_agent_from_subject() {
|
||||
assert_eq!(
|
||||
EmailAdapter::extract_agent_from_subject("[coder] Fix the bug"),
|
||||
Some("coder".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
EmailAdapter::extract_agent_from_subject("[researcher] Find papers on AI"),
|
||||
Some("researcher".to_string())
|
||||
);
|
||||
assert_eq!(
|
||||
EmailAdapter::extract_agent_from_subject("No brackets here"),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
EmailAdapter::extract_agent_from_subject("[] Empty brackets"),
|
||||
None
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_agent_tag() {
|
||||
assert_eq!(
|
||||
EmailAdapter::strip_agent_tag("[coder] Fix the bug"),
|
||||
"Fix the bug"
|
||||
);
|
||||
assert_eq!(EmailAdapter::strip_agent_tag("No brackets"), "No brackets");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_email_addr() {
|
||||
assert_eq!(
|
||||
extract_email_addr("John Doe <john@example.com>"),
|
||||
"john@example.com"
|
||||
);
|
||||
assert_eq!(extract_email_addr("user@example.com"), "user@example.com");
|
||||
assert_eq!(extract_email_addr("<user@test.com>"), "user@test.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subject_extraction_from_body() {
|
||||
let text = "Subject: Test Subject\n\nThis is the body.";
|
||||
assert!(text.starts_with("Subject: "));
|
||||
let pos = text.find("\n\n").unwrap();
|
||||
let subject = &text[9..pos];
|
||||
let body = &text[pos + 2..];
|
||||
assert_eq!(subject, "Test Subject");
|
||||
assert_eq!(body, "This is the body.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reply_ctx_threading() {
|
||||
let ctx_map: DashMap<String, ReplyCtx> = DashMap::new();
|
||||
ctx_map.insert(
|
||||
"user@test.com".to_string(),
|
||||
ReplyCtx {
|
||||
subject: "Original Subject".to_string(),
|
||||
message_id: "<msg-123@test.com>".to_string(),
|
||||
},
|
||||
);
|
||||
let ctx = ctx_map.get("user@test.com").unwrap();
|
||||
assert_eq!(ctx.subject, "Original Subject");
|
||||
assert_eq!(ctx.message_id, "<msg-123@test.com>");
|
||||
}
|
||||
}
|
||||
799
crates/openfang-channels/src/feishu.rs
Normal file
799
crates/openfang-channels/src/feishu.rs
Normal file
@@ -0,0 +1,799 @@
|
||||
//! Feishu/Lark Open Platform channel adapter.
|
||||
//!
|
||||
//! Uses the Feishu Open API for sending messages and a webhook HTTP server for
|
||||
//! receiving inbound events. Authentication is performed via a tenant access token
|
||||
//! obtained from `https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal`.
|
||||
//! The token is cached and refreshed automatically (2-hour expiry).
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Feishu tenant access token endpoint.
|
||||
const FEISHU_TOKEN_URL: &str =
|
||||
"https://open.feishu.cn/open-apis/auth/v3/tenant_access_token/internal";
|
||||
|
||||
/// Feishu send message endpoint.
|
||||
const FEISHU_SEND_URL: &str = "https://open.feishu.cn/open-apis/im/v1/messages";
|
||||
|
||||
/// Feishu bot info endpoint.
|
||||
const FEISHU_BOT_INFO_URL: &str = "https://open.feishu.cn/open-apis/bot/v3/info";
|
||||
|
||||
/// Maximum Feishu message text length (characters).
|
||||
const MAX_MESSAGE_LEN: usize = 4096;
|
||||
|
||||
/// Token refresh buffer — refresh 5 minutes before actual expiry.
|
||||
const TOKEN_REFRESH_BUFFER_SECS: u64 = 300;
|
||||
|
||||
/// Feishu/Lark Open Platform adapter.
|
||||
///
|
||||
/// Inbound messages arrive via a webhook HTTP server that receives event
|
||||
/// callbacks from the Feishu platform. Outbound messages are sent via the
|
||||
/// Feishu IM API with a tenant access token for authentication.
|
||||
pub struct FeishuAdapter {
|
||||
/// Feishu app ID.
|
||||
app_id: String,
|
||||
/// SECURITY: Feishu app secret, zeroized on drop.
|
||||
app_secret: Zeroizing<String>,
|
||||
/// Port on which the inbound webhook HTTP server listens.
|
||||
webhook_port: u16,
|
||||
/// Optional verification token for webhook event validation.
|
||||
verification_token: Option<String>,
|
||||
/// Optional encrypt key for webhook event decryption.
|
||||
encrypt_key: Option<String>,
|
||||
/// HTTP client for API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Cached tenant access token and its expiry instant.
|
||||
cached_token: Arc<RwLock<Option<(String, Instant)>>>,
|
||||
}
|
||||
|
||||
impl FeishuAdapter {
|
||||
/// Create a new Feishu adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `app_id` - Feishu application ID.
|
||||
/// * `app_secret` - Feishu application secret.
|
||||
/// * `webhook_port` - Local port for the inbound webhook HTTP server.
|
||||
pub fn new(app_id: String, app_secret: String, webhook_port: u16) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
app_id,
|
||||
app_secret: Zeroizing::new(app_secret),
|
||||
webhook_port,
|
||||
verification_token: None,
|
||||
encrypt_key: None,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
cached_token: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Feishu adapter with webhook verification.
|
||||
pub fn with_verification(
|
||||
app_id: String,
|
||||
app_secret: String,
|
||||
webhook_port: u16,
|
||||
verification_token: Option<String>,
|
||||
encrypt_key: Option<String>,
|
||||
) -> Self {
|
||||
let mut adapter = Self::new(app_id, app_secret, webhook_port);
|
||||
adapter.verification_token = verification_token;
|
||||
adapter.encrypt_key = encrypt_key;
|
||||
adapter
|
||||
}
|
||||
|
||||
/// Obtain a valid tenant access token, refreshing if expired or missing.
|
||||
async fn get_token(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
// Check cache first
|
||||
{
|
||||
let guard = self.cached_token.read().await;
|
||||
if let Some((ref token, expiry)) = *guard {
|
||||
if Instant::now() < expiry {
|
||||
return Ok(token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch a new tenant access token
|
||||
let body = serde_json::json!({
|
||||
"app_id": self.app_id,
|
||||
"app_secret": self.app_secret.as_str(),
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(FEISHU_TOKEN_URL)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Feishu token request failed {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
let resp_body: serde_json::Value = resp.json().await?;
|
||||
let code = resp_body["code"].as_i64().unwrap_or(-1);
|
||||
if code != 0 {
|
||||
let msg = resp_body["msg"].as_str().unwrap_or("unknown error");
|
||||
return Err(format!("Feishu token error: {msg}").into());
|
||||
}
|
||||
|
||||
let tenant_access_token = resp_body["tenant_access_token"]
|
||||
.as_str()
|
||||
.ok_or("Missing tenant_access_token")?
|
||||
.to_string();
|
||||
let expire = resp_body["expire"].as_u64().unwrap_or(7200);
|
||||
|
||||
// Cache with safety buffer
|
||||
let expiry =
|
||||
Instant::now() + Duration::from_secs(expire.saturating_sub(TOKEN_REFRESH_BUFFER_SECS));
|
||||
*self.cached_token.write().await = Some((tenant_access_token.clone(), expiry));
|
||||
|
||||
Ok(tenant_access_token)
|
||||
}
|
||||
|
||||
/// Validate credentials by fetching bot info.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let token = self.get_token().await?;
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(FEISHU_BOT_INFO_URL)
|
||||
.bearer_auth(&token)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Feishu authentication failed {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let code = body["code"].as_i64().unwrap_or(-1);
|
||||
if code != 0 {
|
||||
let msg = body["msg"].as_str().unwrap_or("unknown error");
|
||||
return Err(format!("Feishu bot info error: {msg}").into());
|
||||
}
|
||||
|
||||
let bot_name = body["bot"]["app_name"]
|
||||
.as_str()
|
||||
.unwrap_or("Feishu Bot")
|
||||
.to_string();
|
||||
Ok(bot_name)
|
||||
}
|
||||
|
||||
/// Send a text message to a Feishu chat.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
receive_id: &str,
|
||||
receive_id_type: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let token = self.get_token().await?;
|
||||
let url = format!("{}?receive_id_type={}", FEISHU_SEND_URL, receive_id_type);
|
||||
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let content = serde_json::json!({
|
||||
"text": chunk,
|
||||
});
|
||||
|
||||
let body = serde_json::json!({
|
||||
"receive_id": receive_id,
|
||||
"msg_type": "text",
|
||||
"content": content.to_string(),
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Feishu send message error {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
let resp_body: serde_json::Value = resp.json().await?;
|
||||
let code = resp_body["code"].as_i64().unwrap_or(-1);
|
||||
if code != 0 {
|
||||
let msg = resp_body["msg"].as_str().unwrap_or("unknown error");
|
||||
warn!("Feishu send message API error: {msg}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Reply to a message in a thread.
|
||||
#[allow(dead_code)]
|
||||
async fn api_reply_message(
|
||||
&self,
|
||||
message_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let token = self.get_token().await?;
|
||||
let url = format!(
|
||||
"https://open.feishu.cn/open-apis/im/v1/messages/{}/reply",
|
||||
message_id
|
||||
);
|
||||
|
||||
let content = serde_json::json!({
|
||||
"text": text,
|
||||
});
|
||||
|
||||
let body = serde_json::json!({
|
||||
"msg_type": "text",
|
||||
"content": content.to_string(),
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Feishu reply message error {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Feishu webhook event into a `ChannelMessage`.
|
||||
///
|
||||
/// Handles `im.message.receive_v1` events with text message type.
|
||||
fn parse_feishu_event(event: &serde_json::Value) -> Option<ChannelMessage> {
|
||||
// Feishu v2 event schema
|
||||
let header = event.get("header")?;
|
||||
let event_type = header["event_type"].as_str().unwrap_or("");
|
||||
|
||||
if event_type != "im.message.receive_v1" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let event_data = event.get("event")?;
|
||||
let message = event_data.get("message")?;
|
||||
let sender = event_data.get("sender")?;
|
||||
|
||||
let msg_type = message["message_type"].as_str().unwrap_or("");
|
||||
if msg_type != "text" {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Parse the content JSON string
|
||||
let content_str = message["content"].as_str().unwrap_or("{}");
|
||||
let content_json: serde_json::Value = serde_json::from_str(content_str).unwrap_or_default();
|
||||
let text = content_json["text"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let message_id = message["message_id"].as_str().unwrap_or("").to_string();
|
||||
let chat_id = message["chat_id"].as_str().unwrap_or("").to_string();
|
||||
let chat_type = message["chat_type"].as_str().unwrap_or("p2p");
|
||||
let root_id = message["root_id"].as_str().map(|s| s.to_string());
|
||||
|
||||
let sender_id = sender
|
||||
.get("sender_id")
|
||||
.and_then(|s| s.get("open_id"))
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let sender_type = sender["sender_type"].as_str().unwrap_or("user");
|
||||
|
||||
// Skip bot messages
|
||||
if sender_type == "bot" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let is_group = chat_type == "group";
|
||||
|
||||
let msg_content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd_name = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"chat_id".to_string(),
|
||||
serde_json::Value::String(chat_id.clone()),
|
||||
);
|
||||
metadata.insert(
|
||||
"message_id".to_string(),
|
||||
serde_json::Value::String(message_id.clone()),
|
||||
);
|
||||
metadata.insert(
|
||||
"chat_type".to_string(),
|
||||
serde_json::Value::String(chat_type.to_string()),
|
||||
);
|
||||
metadata.insert(
|
||||
"sender_id".to_string(),
|
||||
serde_json::Value::String(sender_id.clone()),
|
||||
);
|
||||
if let Some(mentions) = message.get("mentions") {
|
||||
metadata.insert("mentions".to_string(), mentions.clone());
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("feishu".to_string()),
|
||||
platform_message_id: message_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: chat_id,
|
||||
display_name: sender_id,
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: root_id,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for FeishuAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"feishu"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("feishu".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let bot_name = self.validate().await?;
|
||||
info!("Feishu adapter authenticated as {bot_name}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let verification_token = self.verification_token.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let verification_token = Arc::new(verification_token);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
let app = axum::Router::new().route(
|
||||
"/feishu/webhook",
|
||||
axum::routing::post({
|
||||
let vt = Arc::clone(&verification_token);
|
||||
let tx = Arc::clone(&tx);
|
||||
move |body: axum::extract::Json<serde_json::Value>| {
|
||||
let vt = Arc::clone(&vt);
|
||||
let tx = Arc::clone(&tx);
|
||||
async move {
|
||||
// Handle URL verification challenge
|
||||
if let Some(challenge) = body.0.get("challenge") {
|
||||
// Verify token if configured
|
||||
if let Some(ref expected_token) = *vt {
|
||||
let token = body.0["token"].as_str().unwrap_or("");
|
||||
if token != expected_token {
|
||||
warn!("Feishu: invalid verification token");
|
||||
return (
|
||||
axum::http::StatusCode::FORBIDDEN,
|
||||
axum::Json(serde_json::json!({})),
|
||||
);
|
||||
}
|
||||
}
|
||||
return (
|
||||
axum::http::StatusCode::OK,
|
||||
axum::Json(serde_json::json!({
|
||||
"challenge": challenge,
|
||||
})),
|
||||
);
|
||||
}
|
||||
|
||||
// Handle event callback
|
||||
if let Some(schema) = body.0["schema"].as_str() {
|
||||
if schema == "2.0" {
|
||||
// V2 event format
|
||||
if let Some(msg) = parse_feishu_event(&body.0) {
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// V1 event format (legacy)
|
||||
let event_type = body.0["event"]["type"].as_str().unwrap_or("");
|
||||
if event_type == "message" {
|
||||
// Legacy format handling
|
||||
let event = &body.0["event"];
|
||||
let text = event["text"].as_str().unwrap_or("");
|
||||
if !text.is_empty() {
|
||||
let open_id =
|
||||
event["open_id"].as_str().unwrap_or("").to_string();
|
||||
let chat_id = event["open_chat_id"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let msg_id = event["open_message_id"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let is_group =
|
||||
event["chat_type"].as_str().unwrap_or("") == "group";
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| {
|
||||
a.split_whitespace().map(String::from).collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("feishu".to_string()),
|
||||
platform_message_id: msg_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: chat_id,
|
||||
display_name: open_id,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
let _ = tx.send(channel_msg).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(
|
||||
axum::http::StatusCode::OK,
|
||||
axum::Json(serde_json::json!({})),
|
||||
)
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
info!("Feishu webhook server listening on {addr}");
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("Feishu webhook bind failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let server = axum::serve(listener, app);
|
||||
|
||||
tokio::select! {
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
warn!("Feishu webhook server error: {e}");
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Feishu adapter shutting down");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
// Use chat_id as receive_id with chat_id type
|
||||
self.api_send_message(&user.platform_id, "chat_id", &text)
|
||||
.await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "chat_id", "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Feishu does not support typing indicators via REST API
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_feishu_adapter_creation() {
|
||||
let adapter =
|
||||
FeishuAdapter::new("cli_abc123".to_string(), "app-secret-456".to_string(), 9000);
|
||||
assert_eq!(adapter.name(), "feishu");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("feishu".to_string())
|
||||
);
|
||||
assert_eq!(adapter.webhook_port, 9000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feishu_with_verification() {
|
||||
let adapter = FeishuAdapter::with_verification(
|
||||
"cli_abc123".to_string(),
|
||||
"secret".to_string(),
|
||||
9000,
|
||||
Some("verify-token".to_string()),
|
||||
Some("encrypt-key".to_string()),
|
||||
);
|
||||
assert_eq!(adapter.verification_token, Some("verify-token".to_string()));
|
||||
assert_eq!(adapter.encrypt_key, Some("encrypt-key".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_feishu_app_id_stored() {
|
||||
let adapter = FeishuAdapter::new("cli_test".to_string(), "secret".to_string(), 8080);
|
||||
assert_eq!(adapter.app_id, "cli_test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_feishu_event_v2_text() {
|
||||
let event = serde_json::json!({
|
||||
"schema": "2.0",
|
||||
"header": {
|
||||
"event_id": "evt-001",
|
||||
"event_type": "im.message.receive_v1",
|
||||
"create_time": "1234567890000",
|
||||
"token": "verify-token",
|
||||
"app_id": "cli_abc123",
|
||||
"tenant_key": "tenant-key-1"
|
||||
},
|
||||
"event": {
|
||||
"sender": {
|
||||
"sender_id": {
|
||||
"open_id": "ou_abc123",
|
||||
"user_id": "user-1"
|
||||
},
|
||||
"sender_type": "user"
|
||||
},
|
||||
"message": {
|
||||
"message_id": "om_abc123",
|
||||
"root_id": null,
|
||||
"chat_id": "oc_chat123",
|
||||
"chat_type": "p2p",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"Hello from Feishu!\"}"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_feishu_event(&event).unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Custom("feishu".to_string()));
|
||||
assert_eq!(msg.platform_message_id, "om_abc123");
|
||||
assert!(!msg.is_group);
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from Feishu!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_feishu_event_group_message() {
|
||||
let event = serde_json::json!({
|
||||
"schema": "2.0",
|
||||
"header": {
|
||||
"event_id": "evt-002",
|
||||
"event_type": "im.message.receive_v1"
|
||||
},
|
||||
"event": {
|
||||
"sender": {
|
||||
"sender_id": {
|
||||
"open_id": "ou_abc123"
|
||||
},
|
||||
"sender_type": "user"
|
||||
},
|
||||
"message": {
|
||||
"message_id": "om_grp1",
|
||||
"chat_id": "oc_grp123",
|
||||
"chat_type": "group",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"Group message\"}"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_feishu_event(&event).unwrap();
|
||||
assert!(msg.is_group);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_feishu_event_command() {
|
||||
let event = serde_json::json!({
|
||||
"schema": "2.0",
|
||||
"header": {
|
||||
"event_id": "evt-003",
|
||||
"event_type": "im.message.receive_v1"
|
||||
},
|
||||
"event": {
|
||||
"sender": {
|
||||
"sender_id": {
|
||||
"open_id": "ou_abc123"
|
||||
},
|
||||
"sender_type": "user"
|
||||
},
|
||||
"message": {
|
||||
"message_id": "om_cmd1",
|
||||
"chat_id": "oc_chat1",
|
||||
"chat_type": "p2p",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"/help all\"}"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_feishu_event(&event).unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "help");
|
||||
assert_eq!(args, &["all"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_feishu_event_skips_bot() {
|
||||
let event = serde_json::json!({
|
||||
"schema": "2.0",
|
||||
"header": {
|
||||
"event_id": "evt-004",
|
||||
"event_type": "im.message.receive_v1"
|
||||
},
|
||||
"event": {
|
||||
"sender": {
|
||||
"sender_id": {
|
||||
"open_id": "ou_bot"
|
||||
},
|
||||
"sender_type": "bot"
|
||||
},
|
||||
"message": {
|
||||
"message_id": "om_bot1",
|
||||
"chat_id": "oc_chat1",
|
||||
"chat_type": "p2p",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"Bot message\"}"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_feishu_event(&event).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_feishu_event_non_text() {
|
||||
let event = serde_json::json!({
|
||||
"schema": "2.0",
|
||||
"header": {
|
||||
"event_id": "evt-005",
|
||||
"event_type": "im.message.receive_v1"
|
||||
},
|
||||
"event": {
|
||||
"sender": {
|
||||
"sender_id": {
|
||||
"open_id": "ou_user1"
|
||||
},
|
||||
"sender_type": "user"
|
||||
},
|
||||
"message": {
|
||||
"message_id": "om_img1",
|
||||
"chat_id": "oc_chat1",
|
||||
"chat_type": "p2p",
|
||||
"message_type": "image",
|
||||
"content": "{\"image_key\":\"img_v2_abc123\"}"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_feishu_event(&event).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_feishu_event_wrong_type() {
|
||||
let event = serde_json::json!({
|
||||
"schema": "2.0",
|
||||
"header": {
|
||||
"event_id": "evt-006",
|
||||
"event_type": "im.chat.member_bot.added_v1"
|
||||
},
|
||||
"event": {}
|
||||
});
|
||||
|
||||
assert!(parse_feishu_event(&event).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_feishu_event_thread_id() {
|
||||
let event = serde_json::json!({
|
||||
"schema": "2.0",
|
||||
"header": {
|
||||
"event_id": "evt-007",
|
||||
"event_type": "im.message.receive_v1"
|
||||
},
|
||||
"event": {
|
||||
"sender": {
|
||||
"sender_id": {
|
||||
"open_id": "ou_user1"
|
||||
},
|
||||
"sender_type": "user"
|
||||
},
|
||||
"message": {
|
||||
"message_id": "om_thread1",
|
||||
"root_id": "om_root1",
|
||||
"chat_id": "oc_chat1",
|
||||
"chat_type": "group",
|
||||
"message_type": "text",
|
||||
"content": "{\"text\":\"Thread reply\"}"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_feishu_event(&event).unwrap();
|
||||
assert_eq!(msg.thread_id, Some("om_root1".to_string()));
|
||||
}
|
||||
}
|
||||
465
crates/openfang-channels/src/flock.rs
Normal file
465
crates/openfang-channels/src/flock.rs
Normal file
@@ -0,0 +1,465 @@
|
||||
//! Flock Bot channel adapter.
|
||||
//!
|
||||
//! Uses the Flock Messaging API with a local webhook HTTP server for receiving
|
||||
//! inbound event callbacks and the REST API for sending messages. Authentication
|
||||
//! is performed via a Bot token parameter. Flock delivers events as JSON POST
|
||||
//! requests to the configured webhook endpoint.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Flock REST API base URL.
|
||||
const FLOCK_API_BASE: &str = "https://api.flock.com/v2";
|
||||
|
||||
/// Maximum message length for Flock messages.
|
||||
const MAX_MESSAGE_LEN: usize = 4096;
|
||||
|
||||
/// Flock Bot channel adapter using webhook for receiving and REST API for sending.
|
||||
///
|
||||
/// Listens for inbound event callbacks via a configurable HTTP webhook server
|
||||
/// and sends outbound messages via the Flock `chat.sendMessage` endpoint.
|
||||
/// Supports channel-receive and app-install event types.
|
||||
pub struct FlockAdapter {
|
||||
/// SECURITY: Bot token is zeroized on drop.
|
||||
bot_token: Zeroizing<String>,
|
||||
/// Port for the inbound webhook HTTP listener.
|
||||
webhook_port: u16,
|
||||
/// HTTP client for outbound API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl FlockAdapter {
|
||||
/// Create a new Flock adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `bot_token` - Flock Bot token for API authentication.
|
||||
/// * `webhook_port` - Local port to bind the webhook listener on.
|
||||
pub fn new(bot_token: String, webhook_port: u16) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
bot_token: Zeroizing::new(bot_token),
|
||||
webhook_port,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate credentials by fetching bot/app info.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/users.getInfo?token={}",
|
||||
FLOCK_API_BASE,
|
||||
self.bot_token.as_str()
|
||||
);
|
||||
let resp = self.client.get(&url).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Flock authentication failed".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let user_id = body["userId"]
|
||||
.as_str()
|
||||
.or_else(|| body["id"].as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
/// Send a text message to a Flock channel or user.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
to: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/chat.sendMessage", FLOCK_API_BASE);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"token": self.bot_token.as_str(),
|
||||
"to": to,
|
||||
"text": chunk,
|
||||
});
|
||||
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Flock API error {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
// Check for API-level errors in response body
|
||||
let result: serde_json::Value = match resp.json().await {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if let Some(error) = result.get("error") {
|
||||
return Err(format!("Flock API error: {error}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a rich message with attachments to a Flock channel.
|
||||
#[allow(dead_code)]
|
||||
async fn api_send_rich_message(
|
||||
&self,
|
||||
to: &str,
|
||||
text: &str,
|
||||
attachment_title: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/chat.sendMessage", FLOCK_API_BASE);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"token": self.bot_token.as_str(),
|
||||
"to": to,
|
||||
"text": text,
|
||||
"attachments": [{
|
||||
"title": attachment_title,
|
||||
"description": text,
|
||||
"color": "#4CAF50",
|
||||
}]
|
||||
});
|
||||
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Flock rich message error {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an inbound Flock event callback into a `ChannelMessage`.
|
||||
///
|
||||
/// Flock delivers various event types; we only process `chat.receiveMessage`
|
||||
/// events (incoming messages sent to the bot).
|
||||
fn parse_flock_event(event: &serde_json::Value, own_user_id: &str) -> Option<ChannelMessage> {
|
||||
let event_name = event["name"].as_str().unwrap_or("");
|
||||
|
||||
// Handle app.install and client.slashCommand events by ignoring them
|
||||
match event_name {
|
||||
"chat.receiveMessage" => {}
|
||||
"client.messageAction" => {}
|
||||
_ => return None,
|
||||
}
|
||||
|
||||
let message = &event["message"];
|
||||
|
||||
let text = message["text"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let from = message["from"].as_str().unwrap_or("");
|
||||
let to = message["to"].as_str().unwrap_or("");
|
||||
|
||||
// Skip messages from the bot itself
|
||||
if from == own_user_id {
|
||||
return None;
|
||||
}
|
||||
|
||||
let msg_id = message["uid"]
|
||||
.as_str()
|
||||
.or_else(|| message["id"].as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let sender_name = message["fromName"].as_str().unwrap_or(from);
|
||||
|
||||
// Determine if group or DM
|
||||
// In Flock, channels start with 'g:' for groups, user IDs for DMs
|
||||
let is_group = to.starts_with("g:");
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"from".to_string(),
|
||||
serde_json::Value::String(from.to_string()),
|
||||
);
|
||||
metadata.insert("to".to_string(), serde_json::Value::String(to.to_string()));
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("flock".to_string()),
|
||||
platform_message_id: msg_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: to.to_string(),
|
||||
display_name: sender_name.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: None,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for FlockAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"flock"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("flock".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let bot_user_id = self.validate().await?;
|
||||
info!("Flock adapter authenticated (user_id: {bot_user_id})");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let own_user_id = bot_user_id;
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let user_id_shared = Arc::new(own_user_id);
|
||||
let tx_shared = Arc::new(tx);
|
||||
|
||||
let app = axum::Router::new().route(
|
||||
"/flock/events",
|
||||
axum::routing::post({
|
||||
let user_id = Arc::clone(&user_id_shared);
|
||||
let tx = Arc::clone(&tx_shared);
|
||||
move |body: axum::extract::Json<serde_json::Value>| {
|
||||
let user_id = Arc::clone(&user_id);
|
||||
let tx = Arc::clone(&tx);
|
||||
async move {
|
||||
// Handle Flock's event verification
|
||||
if body["name"].as_str() == Some("app.install") {
|
||||
return axum::http::StatusCode::OK;
|
||||
}
|
||||
|
||||
if let Some(msg) = parse_flock_event(&body, &user_id) {
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
|
||||
axum::http::StatusCode::OK
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
info!("Flock webhook server listening on {addr}");
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("Flock webhook bind failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let server = axum::serve(listener, app);
|
||||
|
||||
tokio::select! {
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
warn!("Flock webhook server error: {e}");
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Flock adapter shutting down");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Flock does not expose a typing indicator API for bots
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_flock_adapter_creation() {
|
||||
let adapter = FlockAdapter::new("test-bot-token".to_string(), 8181);
|
||||
assert_eq!(adapter.name(), "flock");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("flock".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flock_token_zeroized() {
|
||||
let adapter = FlockAdapter::new("secret-flock-token".to_string(), 8181);
|
||||
assert_eq!(adapter.bot_token.as_str(), "secret-flock-token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flock_webhook_port() {
|
||||
let adapter = FlockAdapter::new("token".to_string(), 7777);
|
||||
assert_eq!(adapter.webhook_port, 7777);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_flock_event_message() {
|
||||
let event = serde_json::json!({
|
||||
"name": "chat.receiveMessage",
|
||||
"message": {
|
||||
"text": "Hello from Flock!",
|
||||
"from": "u:user123",
|
||||
"to": "g:channel456",
|
||||
"uid": "msg-001",
|
||||
"fromName": "Alice"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_flock_event(&event, "u:bot001").unwrap();
|
||||
assert_eq!(msg.sender.display_name, "Alice");
|
||||
assert_eq!(msg.sender.platform_id, "g:channel456");
|
||||
assert!(msg.is_group);
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from Flock!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_flock_event_command() {
|
||||
let event = serde_json::json!({
|
||||
"name": "chat.receiveMessage",
|
||||
"message": {
|
||||
"text": "/status check",
|
||||
"from": "u:user123",
|
||||
"to": "u:bot001",
|
||||
"uid": "msg-002"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_flock_event(&event, "u:bot001-different").unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "status");
|
||||
assert_eq!(args, &["check"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_flock_event_skip_bot() {
|
||||
let event = serde_json::json!({
|
||||
"name": "chat.receiveMessage",
|
||||
"message": {
|
||||
"text": "Bot response",
|
||||
"from": "u:bot001",
|
||||
"to": "g:channel456"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_flock_event(&event, "u:bot001");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_flock_event_dm() {
|
||||
let event = serde_json::json!({
|
||||
"name": "chat.receiveMessage",
|
||||
"message": {
|
||||
"text": "Direct msg",
|
||||
"from": "u:user123",
|
||||
"to": "u:bot001",
|
||||
"uid": "msg-003",
|
||||
"fromName": "Bob"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_flock_event(&event, "u:bot001-different").unwrap();
|
||||
assert!(!msg.is_group); // "to" doesn't start with "g:"
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_flock_event_unknown_type() {
|
||||
let event = serde_json::json!({
|
||||
"name": "app.install",
|
||||
"userId": "u:user123"
|
||||
});
|
||||
|
||||
let msg = parse_flock_event(&event, "u:bot001");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_flock_event_empty_text() {
|
||||
let event = serde_json::json!({
|
||||
"name": "chat.receiveMessage",
|
||||
"message": {
|
||||
"text": "",
|
||||
"from": "u:user123",
|
||||
"to": "g:channel456"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_flock_event(&event, "u:bot001");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
}
|
||||
252
crates/openfang-channels/src/formatter.rs
Normal file
252
crates/openfang-channels/src/formatter.rs
Normal file
@@ -0,0 +1,252 @@
|
||||
//! Channel-specific message formatting.
|
||||
//!
|
||||
//! Converts standard Markdown into platform-specific markup:
|
||||
//! - Telegram HTML: `**bold**` → `<b>bold</b>`
|
||||
//! - Slack mrkdwn: `**bold**` → `*bold*`, `[text](url)` → `<url|text>`
|
||||
//! - Plain text: strips all formatting
|
||||
|
||||
use openfang_types::config::OutputFormat;
|
||||
|
||||
/// Format a message for a specific channel output format.
|
||||
pub fn format_for_channel(text: &str, format: OutputFormat) -> String {
|
||||
match format {
|
||||
OutputFormat::Markdown => text.to_string(),
|
||||
OutputFormat::TelegramHtml => markdown_to_telegram_html(text),
|
||||
OutputFormat::SlackMrkdwn => markdown_to_slack_mrkdwn(text),
|
||||
OutputFormat::PlainText => markdown_to_plain(text),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Markdown to Telegram HTML subset.
|
||||
///
|
||||
/// Supported tags: `<b>`, `<i>`, `<code>`, `<pre>`, `<a href="">`.
|
||||
fn markdown_to_telegram_html(text: &str) -> String {
|
||||
let mut result = text.to_string();
|
||||
|
||||
// Bold: **text** → <b>text</b>
|
||||
while let Some(start) = result.find("**") {
|
||||
if let Some(end) = result[start + 2..].find("**") {
|
||||
let end = start + 2 + end;
|
||||
let inner = result[start + 2..end].to_string();
|
||||
result = format!("{}<b>{}</b>{}", &result[..start], inner, &result[end + 2..]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Italic: *text* → <i>text</i> (but not inside bold tags)
|
||||
// Simple heuristic: match single * not preceded/followed by *
|
||||
let mut out = String::with_capacity(result.len());
|
||||
let chars: Vec<char> = result.chars().collect();
|
||||
let mut i = 0;
|
||||
let mut in_italic = false;
|
||||
while i < chars.len() {
|
||||
if chars[i] == '*'
|
||||
&& (i == 0 || chars[i - 1] != '*')
|
||||
&& (i + 1 >= chars.len() || chars[i + 1] != '*')
|
||||
{
|
||||
if in_italic {
|
||||
out.push_str("</i>");
|
||||
} else {
|
||||
out.push_str("<i>");
|
||||
}
|
||||
in_italic = !in_italic;
|
||||
} else {
|
||||
out.push(chars[i]);
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
result = out;
|
||||
|
||||
// Inline code: `text` → <code>text</code>
|
||||
while let Some(start) = result.find('`') {
|
||||
if let Some(end) = result[start + 1..].find('`') {
|
||||
let end = start + 1 + end;
|
||||
let inner = result[start + 1..end].to_string();
|
||||
result = format!(
|
||||
"{}<code>{}</code>{}",
|
||||
&result[..start],
|
||||
inner,
|
||||
&result[end + 1..]
|
||||
);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Links: [text](url) → <a href="url">text</a>
|
||||
while let Some(bracket_start) = result.find('[') {
|
||||
if let Some(bracket_end) = result[bracket_start..].find("](") {
|
||||
let bracket_end = bracket_start + bracket_end;
|
||||
if let Some(paren_end) = result[bracket_end + 2..].find(')') {
|
||||
let paren_end = bracket_end + 2 + paren_end;
|
||||
let link_text = &result[bracket_start + 1..bracket_end];
|
||||
let url = &result[bracket_end + 2..paren_end];
|
||||
result = format!(
|
||||
"{}<a href=\"{}\">{}</a>{}",
|
||||
&result[..bracket_start],
|
||||
url,
|
||||
link_text,
|
||||
&result[paren_end + 1..]
|
||||
);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Convert Markdown to Slack mrkdwn format.
|
||||
fn markdown_to_slack_mrkdwn(text: &str) -> String {
|
||||
let mut result = text.to_string();
|
||||
|
||||
// Bold: **text** → *text*
|
||||
while let Some(start) = result.find("**") {
|
||||
if let Some(end) = result[start + 2..].find("**") {
|
||||
let end = start + 2 + end;
|
||||
let inner = result[start + 2..end].to_string();
|
||||
result = format!("{}*{}*{}", &result[..start], inner, &result[end + 2..]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Links: [text](url) → <url|text>
|
||||
while let Some(bracket_start) = result.find('[') {
|
||||
if let Some(bracket_end) = result[bracket_start..].find("](") {
|
||||
let bracket_end = bracket_start + bracket_end;
|
||||
if let Some(paren_end) = result[bracket_end + 2..].find(')') {
|
||||
let paren_end = bracket_end + 2 + paren_end;
|
||||
let link_text = &result[bracket_start + 1..bracket_end];
|
||||
let url = &result[bracket_end + 2..paren_end];
|
||||
result = format!(
|
||||
"{}<{}|{}>{}",
|
||||
&result[..bracket_start],
|
||||
url,
|
||||
link_text,
|
||||
&result[paren_end + 1..]
|
||||
);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Strip all Markdown formatting, producing plain text.
|
||||
fn markdown_to_plain(text: &str) -> String {
|
||||
let mut result = text.to_string();
|
||||
|
||||
// Remove bold markers
|
||||
result = result.replace("**", "");
|
||||
|
||||
// Remove italic markers (single *)
|
||||
// Simple approach: remove isolated *
|
||||
let mut out = String::with_capacity(result.len());
|
||||
let chars: Vec<char> = result.chars().collect();
|
||||
for (i, &ch) in chars.iter().enumerate() {
|
||||
if ch == '*'
|
||||
&& (i == 0 || chars[i - 1] != '*')
|
||||
&& (i + 1 >= chars.len() || chars[i + 1] != '*')
|
||||
{
|
||||
continue;
|
||||
}
|
||||
out.push(ch);
|
||||
}
|
||||
result = out;
|
||||
|
||||
// Remove inline code markers
|
||||
result = result.replace('`', "");
|
||||
|
||||
// Convert links: [text](url) → text (url)
|
||||
while let Some(bracket_start) = result.find('[') {
|
||||
if let Some(bracket_end) = result[bracket_start..].find("](") {
|
||||
let bracket_end = bracket_start + bracket_end;
|
||||
if let Some(paren_end) = result[bracket_end + 2..].find(')') {
|
||||
let paren_end = bracket_end + 2 + paren_end;
|
||||
let link_text = &result[bracket_start + 1..bracket_end];
|
||||
let url = &result[bracket_end + 2..paren_end];
|
||||
result = format!(
|
||||
"{}{} ({}){}",
|
||||
&result[..bracket_start],
|
||||
link_text,
|
||||
url,
|
||||
&result[paren_end + 1..]
|
||||
);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_format_markdown_passthrough() {
|
||||
let text = "**bold** and *italic*";
|
||||
assert_eq!(format_for_channel(text, OutputFormat::Markdown), text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_telegram_html_bold() {
|
||||
let result = markdown_to_telegram_html("Hello **world**!");
|
||||
assert_eq!(result, "Hello <b>world</b>!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_telegram_html_italic() {
|
||||
let result = markdown_to_telegram_html("Hello *world*!");
|
||||
assert_eq!(result, "Hello <i>world</i>!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_telegram_html_code() {
|
||||
let result = markdown_to_telegram_html("Use `println!`");
|
||||
assert_eq!(result, "Use <code>println!</code>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_telegram_html_link() {
|
||||
let result = markdown_to_telegram_html("[click here](https://example.com)");
|
||||
assert_eq!(result, "<a href=\"https://example.com\">click here</a>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slack_mrkdwn_bold() {
|
||||
let result = markdown_to_slack_mrkdwn("Hello **world**!");
|
||||
assert_eq!(result, "Hello *world*!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slack_mrkdwn_link() {
|
||||
let result = markdown_to_slack_mrkdwn("[click](https://example.com)");
|
||||
assert_eq!(result, "<https://example.com|click>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plain_text_strips_formatting() {
|
||||
let result = markdown_to_plain("**bold** and `code` and *italic*");
|
||||
assert_eq!(result, "bold and code and italic");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_plain_text_converts_links() {
|
||||
let result = markdown_to_plain("[click](https://example.com)");
|
||||
assert_eq!(result, "click (https://example.com)");
|
||||
}
|
||||
}
|
||||
413
crates/openfang-channels/src/gitter.rs
Normal file
413
crates/openfang-channels/src/gitter.rs
Normal file
@@ -0,0 +1,413 @@
|
||||
//! Gitter channel adapter.
|
||||
//!
|
||||
//! Connects to the Gitter Streaming API for real-time messages and posts
|
||||
//! replies via the REST API. Uses Bearer token authentication and
|
||||
//! newline-delimited JSON streaming.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const MAX_MESSAGE_LEN: usize = 4096;
|
||||
const GITTER_STREAM_URL: &str = "https://stream.gitter.im/v1/rooms";
|
||||
const GITTER_API_URL: &str = "https://api.gitter.im/v1/rooms";
|
||||
|
||||
/// Gitter streaming channel adapter.
|
||||
///
|
||||
/// Receives messages via the Gitter Streaming API (newline-delimited JSON)
|
||||
/// and sends replies via the REST API.
|
||||
pub struct GitterAdapter {
|
||||
/// SECURITY: Bearer token is zeroized on drop.
|
||||
token: Zeroizing<String>,
|
||||
/// Gitter room ID to listen on.
|
||||
room_id: String,
|
||||
/// HTTP client.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl GitterAdapter {
|
||||
/// Create a new Gitter adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token` - Gitter personal access token.
|
||||
/// * `room_id` - Gitter room ID to listen on and send to.
|
||||
pub fn new(token: String, room_id: String) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
token: Zeroizing::new(token),
|
||||
room_id,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate token by fetching the authenticated user.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = "https://api.gitter.im/v1/user";
|
||||
let resp = self
|
||||
.client
|
||||
.get(url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("Gitter auth failed (HTTP {})", resp.status()).into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
// /v1/user returns an array with a single user object
|
||||
let username = body
|
||||
.as_array()
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|u| u["username"].as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
Ok(username)
|
||||
}
|
||||
|
||||
/// Fetch room info to resolve display name.
|
||||
async fn get_room_name(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/{}", GITTER_API_URL, self.room_id);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("Gitter: failed to fetch room (HTTP {})", resp.status()).into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let name = body["name"].as_str().unwrap_or("unknown-room").to_string();
|
||||
Ok(name)
|
||||
}
|
||||
|
||||
/// Send a text message to the room via REST API.
|
||||
async fn api_send_message(&self, text: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/{}/chatMessages", GITTER_API_URL, self.room_id);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"text": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Gitter API error {status}: {err_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse a newline-delimited JSON message from the streaming API.
|
||||
fn parse_stream_message(line: &str) -> Option<(String, String, String, String)> {
|
||||
let val: serde_json::Value = serde_json::from_str(line).ok()?;
|
||||
let id = val["id"].as_str()?.to_string();
|
||||
let text = val["text"].as_str()?.to_string();
|
||||
let username = val["fromUser"]["username"].as_str()?.to_string();
|
||||
let display_name = val["fromUser"]["displayName"]
|
||||
.as_str()
|
||||
.unwrap_or(&username)
|
||||
.to_string();
|
||||
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some((id, text, username, display_name))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for GitterAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"gitter"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("gitter".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let own_username = self.validate().await?;
|
||||
let room_name = self.get_room_name().await.unwrap_or_default();
|
||||
info!("Gitter adapter authenticated as {own_username} in room {room_name}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let room_id = self.room_id.clone();
|
||||
let token = self.token.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let stream_client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(0)) // No timeout for streaming
|
||||
.build()
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
let url = format!("{}/{}/chatMessages", GITTER_STREAM_URL, room_id);
|
||||
|
||||
let response = match stream_client
|
||||
.get(&url)
|
||||
.bearer_auth(token.as_str())
|
||||
.header("Accept", "application/json")
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => {
|
||||
if !r.status().is_success() {
|
||||
warn!("Gitter: stream returned HTTP {}", r.status());
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(120));
|
||||
continue;
|
||||
}
|
||||
backoff = Duration::from_secs(1);
|
||||
r
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Gitter: stream connection error: {e}, backing off {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(120));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Gitter: streaming connection established for room {room_id}");
|
||||
|
||||
// Read the streaming response as bytes, splitting on newlines
|
||||
let mut stream = response.bytes_stream();
|
||||
use futures::StreamExt;
|
||||
|
||||
let mut line_buffer = String::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
if *shutdown_rx.borrow() {
|
||||
info!("Gitter adapter shutting down");
|
||||
return;
|
||||
}
|
||||
}
|
||||
chunk = stream.next() => {
|
||||
match chunk {
|
||||
Some(Ok(bytes)) => {
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
line_buffer.push_str(&text);
|
||||
|
||||
// Process complete lines
|
||||
while let Some(newline_pos) = line_buffer.find('\n') {
|
||||
let line = line_buffer[..newline_pos].trim().to_string();
|
||||
line_buffer = line_buffer[newline_pos + 1..].to_string();
|
||||
|
||||
// Skip heartbeat (empty lines / whitespace-only)
|
||||
if line.is_empty() || line.chars().all(|c| c.is_whitespace()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((id, text, username, display_name)) =
|
||||
Self::parse_stream_message(&line)
|
||||
{
|
||||
// Skip own messages
|
||||
if username == own_username {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| {
|
||||
a.split_whitespace()
|
||||
.map(String::from)
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text)
|
||||
};
|
||||
|
||||
let msg = ChannelMessage {
|
||||
channel: ChannelType::Custom(
|
||||
"gitter".to_string(),
|
||||
),
|
||||
platform_message_id: id,
|
||||
sender: ChannelUser {
|
||||
platform_id: username.clone(),
|
||||
display_name,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true,
|
||||
thread_id: None,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"room_id".to_string(),
|
||||
serde_json::Value::String(
|
||||
room_id.clone(),
|
||||
),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("Gitter: stream read error: {e}");
|
||||
break; // Reconnect
|
||||
}
|
||||
None => {
|
||||
info!("Gitter: stream ended, reconnecting...");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exponential backoff before reconnect
|
||||
if !*shutdown_rx.borrow() {
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
}
|
||||
}
|
||||
|
||||
info!("Gitter streaming loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
self.api_send_message(&text).await
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Gitter does not have a typing indicator API.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gitter_adapter_creation() {
|
||||
let adapter = GitterAdapter::new("test-token".to_string(), "abc123room".to_string());
|
||||
assert_eq!(adapter.name(), "gitter");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("gitter".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gitter_room_id() {
|
||||
let adapter = GitterAdapter::new("tok".to_string(), "my-room-id".to_string());
|
||||
assert_eq!(adapter.room_id, "my-room-id");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gitter_parse_stream_message() {
|
||||
let json = r#"{"id":"msg1","text":"Hello world","fromUser":{"username":"alice","displayName":"Alice B"}}"#;
|
||||
let result = GitterAdapter::parse_stream_message(json);
|
||||
assert!(result.is_some());
|
||||
let (id, text, username, display_name) = result.unwrap();
|
||||
assert_eq!(id, "msg1");
|
||||
assert_eq!(text, "Hello world");
|
||||
assert_eq!(username, "alice");
|
||||
assert_eq!(display_name, "Alice B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gitter_parse_stream_message_missing_fields() {
|
||||
let json = r#"{"id":"msg1"}"#;
|
||||
assert!(GitterAdapter::parse_stream_message(json).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gitter_parse_stream_message_empty_text() {
|
||||
let json =
|
||||
r#"{"id":"msg1","text":"","fromUser":{"username":"alice","displayName":"Alice"}}"#;
|
||||
assert!(GitterAdapter::parse_stream_message(json).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gitter_parse_stream_message_no_display_name() {
|
||||
let json = r#"{"id":"msg1","text":"hi","fromUser":{"username":"bob"}}"#;
|
||||
let result = GitterAdapter::parse_stream_message(json);
|
||||
assert!(result.is_some());
|
||||
let (_, _, username, display_name) = result.unwrap();
|
||||
assert_eq!(username, "bob");
|
||||
assert_eq!(display_name, "bob"); // Falls back to username
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gitter_parse_invalid_json() {
|
||||
assert!(GitterAdapter::parse_stream_message("not json").is_none());
|
||||
assert!(GitterAdapter::parse_stream_message("").is_none());
|
||||
}
|
||||
}
|
||||
412
crates/openfang-channels/src/google_chat.rs
Normal file
412
crates/openfang-channels/src/google_chat.rs
Normal file
@@ -0,0 +1,412 @@
|
||||
//! Google Chat channel adapter.
|
||||
//!
|
||||
//! Uses Google Chat REST API with service account JWT authentication for sending
|
||||
//! messages and a webhook listener for receiving inbound messages from Google Chat
|
||||
//! spaces.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const MAX_MESSAGE_LEN: usize = 4096;
|
||||
const TOKEN_REFRESH_MARGIN_SECS: u64 = 300;
|
||||
|
||||
/// Google Chat channel adapter using service account authentication and REST API.
|
||||
///
|
||||
/// Inbound messages arrive via a configurable webhook HTTP listener.
|
||||
/// Outbound messages are sent via the Google Chat REST API using an OAuth2 access
|
||||
/// token obtained from a service account JWT.
|
||||
pub struct GoogleChatAdapter {
|
||||
/// SECURITY: Service account key JSON is zeroized on drop.
|
||||
service_account_key: Zeroizing<String>,
|
||||
/// Space IDs to listen to (e.g., "spaces/AAAA").
|
||||
space_ids: Vec<String>,
|
||||
/// Port for the inbound webhook HTTP listener.
|
||||
webhook_port: u16,
|
||||
/// HTTP client for outbound API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Cached OAuth2 access token with expiry instant.
|
||||
cached_token: Arc<RwLock<Option<(String, Instant)>>>,
|
||||
}
|
||||
|
||||
impl GoogleChatAdapter {
|
||||
/// Create a new Google Chat adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `service_account_key` - JSON content of the Google service account key file.
|
||||
/// * `space_ids` - Google Chat space IDs to interact with.
|
||||
/// * `webhook_port` - Local port to bind the inbound webhook listener on.
|
||||
pub fn new(service_account_key: String, space_ids: Vec<String>, webhook_port: u16) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
service_account_key: Zeroizing::new(service_account_key),
|
||||
space_ids,
|
||||
webhook_port,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
cached_token: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a valid access token, refreshing if expired or missing.
|
||||
///
|
||||
/// In a full implementation this would perform JWT signing and exchange with
|
||||
/// Google's OAuth2 token endpoint. For now it parses a pre-supplied token
|
||||
/// from the service account key JSON (field "access_token") or returns an
|
||||
/// error indicating that full JWT auth is not yet wired.
|
||||
async fn get_access_token(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
// Check cache first
|
||||
{
|
||||
let cache = self.cached_token.read().await;
|
||||
if let Some((ref token, expiry)) = *cache {
|
||||
if Instant::now() + Duration::from_secs(TOKEN_REFRESH_MARGIN_SECS) < expiry {
|
||||
return Ok(token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Parse the service account key to extract project/client info
|
||||
let key_json: serde_json::Value = serde_json::from_str(&self.service_account_key)
|
||||
.map_err(|e| format!("Invalid service account key JSON: {e}"))?;
|
||||
|
||||
// For a real implementation: build a JWT, sign with the private key,
|
||||
// exchange at https://oauth2.googleapis.com/token for an access token.
|
||||
// This adapter currently expects an "access_token" field for testing or
|
||||
// a pre-authorized token workflow.
|
||||
let token = key_json["access_token"]
|
||||
.as_str()
|
||||
.ok_or("Service account key missing 'access_token' field; full JWT auth not yet implemented")?
|
||||
.to_string();
|
||||
|
||||
let expiry = Instant::now() + Duration::from_secs(3600);
|
||||
*self.cached_token.write().await = Some((token.clone(), expiry));
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Send a text message to a Google Chat space.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
space_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let token = self.get_access_token().await?;
|
||||
let url = format!("https://chat.googleapis.com/v1/{}/messages", space_id);
|
||||
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"text": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Google Chat API error {status}: {body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a space ID is in the allowed list.
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_space(&self, space_id: &str) -> bool {
|
||||
self.space_ids.is_empty() || self.space_ids.iter().any(|s| s == space_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for GoogleChatAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"google_chat"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("google_chat".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate we can parse the service account key
|
||||
let _key: serde_json::Value = serde_json::from_str(&self.service_account_key)
|
||||
.map_err(|e| format!("Invalid service account key: {e}"))?;
|
||||
|
||||
info!(
|
||||
"Google Chat adapter starting webhook listener on port {}",
|
||||
self.webhook_port
|
||||
);
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let space_ids = self.space_ids.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Bind a minimal HTTP listener for inbound webhooks
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("Google Chat: failed to bind webhook on port {port}: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Google Chat webhook listener bound on {addr}");
|
||||
|
||||
loop {
|
||||
let (stream, _peer) = tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Google Chat adapter shutting down");
|
||||
break;
|
||||
}
|
||||
result = listener.accept() => {
|
||||
match result {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
warn!("Google Chat: accept error: {e}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let tx = tx.clone();
|
||||
let space_ids = space_ids.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Read HTTP request from the TCP stream
|
||||
let mut reader = tokio::io::BufReader::new(stream);
|
||||
let mut request_line = String::new();
|
||||
if tokio::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Read headers to find Content-Length
|
||||
let mut content_length: usize = 0;
|
||||
loop {
|
||||
let mut header_line = String::new();
|
||||
if tokio::io::AsyncBufReadExt::read_line(&mut reader, &mut header_line)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
let trimmed = header_line.trim();
|
||||
if trimmed.is_empty() {
|
||||
break;
|
||||
}
|
||||
if let Some(val) = trimmed.strip_prefix("Content-Length:") {
|
||||
if let Ok(len) = val.trim().parse::<usize>() {
|
||||
content_length = len;
|
||||
}
|
||||
}
|
||||
if let Some(val) = trimmed.strip_prefix("content-length:") {
|
||||
if let Ok(len) = val.trim().parse::<usize>() {
|
||||
content_length = len;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Read body
|
||||
let mut body_buf = vec![0u8; content_length.min(65536)];
|
||||
use tokio::io::AsyncReadExt;
|
||||
if content_length > 0
|
||||
&& reader
|
||||
.read_exact(&mut body_buf[..content_length.min(65536)])
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Send 200 OK response
|
||||
use tokio::io::AsyncWriteExt;
|
||||
let resp = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
|
||||
let _ = reader.get_mut().write_all(resp).await;
|
||||
|
||||
// Parse the Google Chat event payload
|
||||
let payload: serde_json::Value =
|
||||
match serde_json::from_slice(&body_buf[..content_length.min(65536)]) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
let event_type = payload["type"].as_str().unwrap_or("");
|
||||
if event_type != "MESSAGE" {
|
||||
return;
|
||||
}
|
||||
|
||||
let message = &payload["message"];
|
||||
let text = message["text"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let space_name = payload["space"]["name"].as_str().unwrap_or("");
|
||||
if !space_ids.is_empty() && !space_ids.iter().any(|s| s == space_name) {
|
||||
return;
|
||||
}
|
||||
|
||||
let sender_name = message["sender"]["displayName"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown");
|
||||
let sender_id = message["sender"]["name"].as_str().unwrap_or("unknown");
|
||||
let message_name = message["name"].as_str().unwrap_or("").to_string();
|
||||
let thread_name = message["thread"]["name"].as_str().map(String::from);
|
||||
let space_type = payload["space"]["type"].as_str().unwrap_or("ROOM");
|
||||
let is_group = space_type != "DM";
|
||||
|
||||
let msg_content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("google_chat".to_string()),
|
||||
platform_message_id: message_name,
|
||||
sender: ChannelUser {
|
||||
platform_id: space_name.to_string(),
|
||||
display_name: sender_name.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: thread_name,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"sender_id".to_string(),
|
||||
serde_json::Value::String(sender_id.to_string()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
let _ = tx.send(channel_msg).await;
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_google_chat_adapter_creation() {
|
||||
let adapter = GoogleChatAdapter::new(
|
||||
r#"{"access_token":"test-token","project_id":"test"}"#.to_string(),
|
||||
vec!["spaces/AAAA".to_string()],
|
||||
8090,
|
||||
);
|
||||
assert_eq!(adapter.name(), "google_chat");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("google_chat".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_google_chat_allowed_spaces() {
|
||||
let adapter = GoogleChatAdapter::new(
|
||||
r#"{"access_token":"tok"}"#.to_string(),
|
||||
vec!["spaces/AAAA".to_string()],
|
||||
8090,
|
||||
);
|
||||
assert!(adapter.is_allowed_space("spaces/AAAA"));
|
||||
assert!(!adapter.is_allowed_space("spaces/BBBB"));
|
||||
|
||||
let open = GoogleChatAdapter::new(r#"{"access_token":"tok"}"#.to_string(), vec![], 8090);
|
||||
assert!(open.is_allowed_space("spaces/anything"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_google_chat_token_caching() {
|
||||
let adapter = GoogleChatAdapter::new(
|
||||
r#"{"access_token":"cached-tok","project_id":"p"}"#.to_string(),
|
||||
vec![],
|
||||
8091,
|
||||
);
|
||||
|
||||
// First call should parse and cache
|
||||
let token = adapter.get_access_token().await.unwrap();
|
||||
assert_eq!(token, "cached-tok");
|
||||
|
||||
// Second call should return from cache
|
||||
let token2 = adapter.get_access_token().await.unwrap();
|
||||
assert_eq!(token2, "cached-tok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_google_chat_invalid_key() {
|
||||
let adapter = GoogleChatAdapter::new("not-json".to_string(), vec![], 8092);
|
||||
// Can't call async get_access_token in sync test, but verify construction works
|
||||
assert_eq!(adapter.webhook_port, 8092);
|
||||
}
|
||||
}
|
||||
418
crates/openfang-channels/src/gotify.rs
Normal file
418
crates/openfang-channels/src/gotify.rs
Normal file
@@ -0,0 +1,418 @@
|
||||
//! Gotify channel adapter.
|
||||
//!
|
||||
//! Connects to a Gotify server via WebSocket for receiving push notifications
|
||||
//! and sends messages via the REST API. Uses separate app and client tokens
|
||||
//! for publishing and subscribing respectively.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const MAX_MESSAGE_LEN: usize = 65535;
|
||||
|
||||
/// Gotify push notification channel adapter.
|
||||
///
|
||||
/// Receives messages via the Gotify WebSocket stream (`/stream`) using a
|
||||
/// client token and sends messages via the REST API (`/message`) using an
|
||||
/// app token.
|
||||
pub struct GotifyAdapter {
|
||||
/// Gotify server URL (e.g., `"https://gotify.example.com"`).
|
||||
server_url: String,
|
||||
/// SECURITY: App token for sending messages (zeroized on drop).
|
||||
app_token: Zeroizing<String>,
|
||||
/// SECURITY: Client token for receiving messages (zeroized on drop).
|
||||
client_token: Zeroizing<String>,
|
||||
/// HTTP client for REST API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl GotifyAdapter {
|
||||
/// Create a new Gotify adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `server_url` - Base URL of the Gotify server.
|
||||
/// * `app_token` - Token for an application (used to send messages).
|
||||
/// * `client_token` - Token for a client (used to receive messages via WebSocket).
|
||||
pub fn new(server_url: String, app_token: String, client_token: String) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let server_url = server_url.trim_end_matches('/').to_string();
|
||||
Self {
|
||||
server_url,
|
||||
app_token: Zeroizing::new(app_token),
|
||||
client_token: Zeroizing::new(client_token),
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate the app token by checking the application info.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/current/user?token={}",
|
||||
self.server_url,
|
||||
self.client_token.as_str()
|
||||
);
|
||||
let resp = self.client.get(&url).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("Gotify auth failed (HTTP {})", resp.status()).into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let name = body["name"].as_str().unwrap_or("gotify-user").to_string();
|
||||
Ok(name)
|
||||
}
|
||||
|
||||
/// Build the WebSocket URL for the stream endpoint.
|
||||
fn build_ws_url(&self) -> String {
|
||||
let base = self
|
||||
.server_url
|
||||
.replace("https://", "wss://")
|
||||
.replace("http://", "ws://");
|
||||
format!("{}/stream?token={}", base, self.client_token.as_str())
|
||||
}
|
||||
|
||||
/// Send a message via the Gotify REST API.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
title: &str,
|
||||
message: &str,
|
||||
priority: u8,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/message?token={}",
|
||||
self.server_url,
|
||||
self.app_token.as_str()
|
||||
);
|
||||
let chunks = split_message(message, MAX_MESSAGE_LEN);
|
||||
|
||||
for (i, chunk) in chunks.iter().enumerate() {
|
||||
let chunk_title = if chunks.len() > 1 {
|
||||
format!("{} ({}/{})", title, i + 1, chunks.len())
|
||||
} else {
|
||||
title.to_string()
|
||||
};
|
||||
|
||||
let body = serde_json::json!({
|
||||
"title": chunk_title,
|
||||
"message": chunk,
|
||||
"priority": priority,
|
||||
});
|
||||
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Gotify API error {status}: {err_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse a Gotify WebSocket message (JSON).
|
||||
fn parse_ws_message(text: &str) -> Option<(u64, String, String, u8, u64)> {
|
||||
let val: serde_json::Value = serde_json::from_str(text).ok()?;
|
||||
let id = val["id"].as_u64()?;
|
||||
let message = val["message"].as_str()?.to_string();
|
||||
let title = val["title"].as_str().unwrap_or("").to_string();
|
||||
let priority = val["priority"].as_u64().unwrap_or(0) as u8;
|
||||
let app_id = val["appid"].as_u64().unwrap_or(0);
|
||||
|
||||
if message.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some((id, message, title, priority, app_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for GotifyAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"gotify"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("gotify".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let user_name = self.validate().await?;
|
||||
info!("Gotify adapter authenticated as {user_name}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let ws_url = self.build_ws_url();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
info!("Gotify: connecting WebSocket...");
|
||||
|
||||
let ws_connect = match tokio_tungstenite::connect_async(&ws_url).await {
|
||||
Ok((ws_stream, _)) => {
|
||||
backoff = Duration::from_secs(1);
|
||||
ws_stream
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Gotify: WebSocket connection failed: {e}, backing off {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(120));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Gotify: WebSocket connected");
|
||||
|
||||
use futures::StreamExt;
|
||||
let (mut _ws_write, mut ws_read) = ws_connect.split();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
if *shutdown_rx.borrow() {
|
||||
info!("Gotify adapter shutting down");
|
||||
return;
|
||||
}
|
||||
}
|
||||
msg = ws_read.next() => {
|
||||
match msg {
|
||||
Some(Ok(ws_msg)) => {
|
||||
let text = match ws_msg {
|
||||
tokio_tungstenite::tungstenite::Message::Text(t) => t,
|
||||
tokio_tungstenite::tungstenite::Message::Ping(_) => continue,
|
||||
tokio_tungstenite::tungstenite::Message::Pong(_) => continue,
|
||||
tokio_tungstenite::tungstenite::Message::Close(_) => {
|
||||
info!("Gotify: WebSocket closed by server");
|
||||
break;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
if let Some((id, message, title, priority, app_id)) =
|
||||
Self::parse_ws_message(&text)
|
||||
{
|
||||
let content = if message.starts_with('/') {
|
||||
let parts: Vec<&str> =
|
||||
message.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| {
|
||||
a.split_whitespace()
|
||||
.map(String::from)
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(message)
|
||||
};
|
||||
|
||||
let msg = ChannelMessage {
|
||||
channel: ChannelType::Custom(
|
||||
"gotify".to_string(),
|
||||
),
|
||||
platform_message_id: format!("gotify-{id}"),
|
||||
sender: ChannelUser {
|
||||
platform_id: format!("app-{app_id}"),
|
||||
display_name: if title.is_empty() {
|
||||
format!("app-{app_id}")
|
||||
} else {
|
||||
title.clone()
|
||||
},
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: false,
|
||||
thread_id: None,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"title".to_string(),
|
||||
serde_json::Value::String(title),
|
||||
);
|
||||
m.insert(
|
||||
"priority".to_string(),
|
||||
serde_json::Value::Number(priority.into()),
|
||||
);
|
||||
m.insert(
|
||||
"app_id".to_string(),
|
||||
serde_json::Value::Number(app_id.into()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("Gotify: WebSocket read error: {e}");
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
info!("Gotify: WebSocket stream ended");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exponential backoff before reconnect
|
||||
if !*shutdown_rx.borrow() {
|
||||
warn!("Gotify: reconnecting in {backoff:?}...");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
}
|
||||
}
|
||||
|
||||
info!("Gotify WebSocket loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
self.api_send_message("OpenFang", &text, 5).await
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Gotify has no typing indicator.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_gotify_adapter_creation() {
|
||||
let adapter = GotifyAdapter::new(
|
||||
"https://gotify.example.com".to_string(),
|
||||
"app-token".to_string(),
|
||||
"client-token".to_string(),
|
||||
);
|
||||
assert_eq!(adapter.name(), "gotify");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("gotify".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gotify_url_normalization() {
|
||||
let adapter = GotifyAdapter::new(
|
||||
"https://gotify.example.com/".to_string(),
|
||||
"app".to_string(),
|
||||
"client".to_string(),
|
||||
);
|
||||
assert_eq!(adapter.server_url, "https://gotify.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gotify_ws_url_https() {
|
||||
let adapter = GotifyAdapter::new(
|
||||
"https://gotify.example.com".to_string(),
|
||||
"app".to_string(),
|
||||
"client-tok".to_string(),
|
||||
);
|
||||
let ws_url = adapter.build_ws_url();
|
||||
assert!(ws_url.starts_with("wss://"));
|
||||
assert!(ws_url.contains("/stream?token=client-tok"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gotify_ws_url_http() {
|
||||
let adapter = GotifyAdapter::new(
|
||||
"http://localhost:8080".to_string(),
|
||||
"app".to_string(),
|
||||
"client-tok".to_string(),
|
||||
);
|
||||
let ws_url = adapter.build_ws_url();
|
||||
assert!(ws_url.starts_with("ws://"));
|
||||
assert!(ws_url.contains("/stream?token=client-tok"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gotify_parse_ws_message() {
|
||||
let json = r#"{"id":42,"appid":7,"message":"Hello Gotify","title":"Test App","priority":5,"date":"2024-01-01T00:00:00Z"}"#;
|
||||
let result = GotifyAdapter::parse_ws_message(json);
|
||||
assert!(result.is_some());
|
||||
let (id, message, title, priority, app_id) = result.unwrap();
|
||||
assert_eq!(id, 42);
|
||||
assert_eq!(message, "Hello Gotify");
|
||||
assert_eq!(title, "Test App");
|
||||
assert_eq!(priority, 5);
|
||||
assert_eq!(app_id, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gotify_parse_ws_message_empty() {
|
||||
let json = r#"{"id":1,"appid":1,"message":"","title":"","priority":0}"#;
|
||||
assert!(GotifyAdapter::parse_ws_message(json).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gotify_parse_ws_message_minimal() {
|
||||
let json = r#"{"id":1,"message":"hi"}"#;
|
||||
let result = GotifyAdapter::parse_ws_message(json);
|
||||
assert!(result.is_some());
|
||||
let (_, msg, title, priority, app_id) = result.unwrap();
|
||||
assert_eq!(msg, "hi");
|
||||
assert_eq!(title, "");
|
||||
assert_eq!(priority, 0);
|
||||
assert_eq!(app_id, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gotify_parse_invalid_json() {
|
||||
assert!(GotifyAdapter::parse_ws_message("not json").is_none());
|
||||
}
|
||||
}
|
||||
390
crates/openfang-channels/src/guilded.rs
Normal file
390
crates/openfang-channels/src/guilded.rs
Normal file
@@ -0,0 +1,390 @@
|
||||
//! Guilded Bot channel adapter.
|
||||
//!
|
||||
//! Connects to the Guilded Bot API via WebSocket for receiving real-time events
|
||||
//! and uses the REST API for sending messages. Authentication is performed via
|
||||
//! Bearer token. The WebSocket gateway at `wss://www.guilded.gg/websocket/v1`
|
||||
//! delivers `ChatMessageCreated` events for incoming messages.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Guilded REST API base URL.
|
||||
const GUILDED_API_BASE: &str = "https://www.guilded.gg/api/v1";
|
||||
|
||||
/// Guilded WebSocket gateway URL.
|
||||
const GUILDED_WS_URL: &str = "wss://www.guilded.gg/websocket/v1";
|
||||
|
||||
/// Maximum message length for Guilded messages.
|
||||
const MAX_MESSAGE_LEN: usize = 4000;
|
||||
|
||||
/// Guilded Bot API channel adapter using WebSocket for events and REST for sending.
|
||||
///
|
||||
/// Connects to the Guilded WebSocket gateway for real-time message events and
|
||||
/// sends replies via the REST API. Supports filtering by server (guild) IDs.
|
||||
pub struct GuildedAdapter {
|
||||
/// SECURITY: Bot token is zeroized on drop.
|
||||
bot_token: Zeroizing<String>,
|
||||
/// Server (guild) IDs to listen on (empty = all servers the bot is in).
|
||||
server_ids: Vec<String>,
|
||||
/// HTTP client for REST API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl GuildedAdapter {
|
||||
/// Create a new Guilded adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `bot_token` - Guilded bot authentication token.
|
||||
/// * `server_ids` - Server IDs to filter events for (empty = all).
|
||||
pub fn new(bot_token: String, server_ids: Vec<String>) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
bot_token: Zeroizing::new(bot_token),
|
||||
server_ids,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate credentials by fetching the bot's own user info.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/users/@me", GUILDED_API_BASE);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.bot_token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Guilded authentication failed".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let bot_id = body["user"]["id"].as_str().unwrap_or("unknown").to_string();
|
||||
Ok(bot_id)
|
||||
}
|
||||
|
||||
/// Send a text message to a Guilded channel via REST API.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/channels/{}/messages", GUILDED_API_BASE, channel_id);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"content": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.bot_token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Guilded API error {status}: {resp_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a server ID is in the allowed list.
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_server(&self, server_id: &str) -> bool {
|
||||
self.server_ids.is_empty() || self.server_ids.iter().any(|s| s == server_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for GuildedAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"guilded"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("guilded".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let bot_id = self.validate().await?;
|
||||
info!("Guilded adapter authenticated as bot {bot_id}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let bot_token = self.bot_token.clone();
|
||||
let server_ids = self.server_ids.clone();
|
||||
let own_bot_id = bot_id;
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Build WebSocket request with auth header
|
||||
let mut request =
|
||||
match tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(GUILDED_WS_URL) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Guilded: failed to build WS request: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
request.headers_mut().insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", bot_token.as_str()).parse().unwrap(),
|
||||
);
|
||||
|
||||
// Connect to WebSocket
|
||||
let ws_stream = match tokio_tungstenite::connect_async(request).await {
|
||||
Ok((stream, _resp)) => stream,
|
||||
Err(e) => {
|
||||
warn!("Guilded: WebSocket connection failed: {e}, retrying in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Guilded WebSocket connected");
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
use futures::StreamExt;
|
||||
let (mut _write, mut read) = ws_stream.split();
|
||||
|
||||
// Read events from WebSocket
|
||||
let should_reconnect = loop {
|
||||
let msg = tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Guilded adapter shutting down");
|
||||
return;
|
||||
}
|
||||
msg = read.next() => msg,
|
||||
};
|
||||
|
||||
let msg = match msg {
|
||||
Some(Ok(m)) => m,
|
||||
Some(Err(e)) => {
|
||||
warn!("Guilded WS read error: {e}");
|
||||
break true;
|
||||
}
|
||||
None => {
|
||||
info!("Guilded WS stream ended");
|
||||
break true;
|
||||
}
|
||||
};
|
||||
|
||||
// Only process text messages
|
||||
let text = match msg {
|
||||
tokio_tungstenite::tungstenite::Message::Text(t) => t,
|
||||
tokio_tungstenite::tungstenite::Message::Ping(_) => continue,
|
||||
tokio_tungstenite::tungstenite::Message::Close(_) => {
|
||||
info!("Guilded WS received close frame");
|
||||
break true;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let event: serde_json::Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let event_type = event["t"].as_str().unwrap_or("");
|
||||
|
||||
// Handle welcome event (op 1) — contains heartbeat interval
|
||||
let op = event["op"].as_i64().unwrap_or(0);
|
||||
if op == 1 {
|
||||
info!("Guilded: received welcome event");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Only process ChatMessageCreated events
|
||||
if event_type != "ChatMessageCreated" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let message = &event["d"]["message"];
|
||||
let msg_server_id = event["d"]["serverId"].as_str().unwrap_or("");
|
||||
|
||||
// Filter by server ID if configured
|
||||
if !server_ids.is_empty() && !server_ids.iter().any(|s| s == msg_server_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let created_by = message["createdBy"].as_str().unwrap_or("");
|
||||
// Skip messages from the bot itself
|
||||
if created_by == own_bot_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = message["content"].as_str().unwrap_or("");
|
||||
if content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg_id = message["id"].as_str().unwrap_or("").to_string();
|
||||
let channel_id = message["channelId"].as_str().unwrap_or("").to_string();
|
||||
|
||||
let msg_content = if content.starts_with('/') {
|
||||
let parts: Vec<&str> = content.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(content.to_string())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("guilded".to_string()),
|
||||
platform_message_id: msg_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: channel_id,
|
||||
display_name: created_by.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true,
|
||||
thread_id: None,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"server_id".to_string(),
|
||||
serde_json::Value::String(msg_server_id.to_string()),
|
||||
);
|
||||
m.insert(
|
||||
"created_by".to_string(),
|
||||
serde_json::Value::String(created_by.to_string()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !should_reconnect || *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
warn!("Guilded: reconnecting in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
}
|
||||
|
||||
info!("Guilded WebSocket loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Guilded does not expose a public typing indicator API for bots
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_guilded_adapter_creation() {
|
||||
let adapter =
|
||||
GuildedAdapter::new("test-bot-token".to_string(), vec!["server1".to_string()]);
|
||||
assert_eq!(adapter.name(), "guilded");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("guilded".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_guilded_allowed_servers() {
|
||||
let adapter = GuildedAdapter::new(
|
||||
"tok".to_string(),
|
||||
vec!["srv-1".to_string(), "srv-2".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_server("srv-1"));
|
||||
assert!(adapter.is_allowed_server("srv-2"));
|
||||
assert!(!adapter.is_allowed_server("srv-3"));
|
||||
|
||||
let open = GuildedAdapter::new("tok".to_string(), vec![]);
|
||||
assert!(open.is_allowed_server("any-server"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_guilded_token_zeroized() {
|
||||
let adapter = GuildedAdapter::new("secret-bot-token".to_string(), vec![]);
|
||||
assert_eq!(adapter.bot_token.as_str(), "secret-bot-token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_guilded_constants() {
|
||||
assert_eq!(MAX_MESSAGE_LEN, 4000);
|
||||
assert_eq!(GUILDED_WS_URL, "wss://www.guilded.gg/websocket/v1");
|
||||
}
|
||||
}
|
||||
653
crates/openfang-channels/src/irc.rs
Normal file
653
crates/openfang-channels/src/irc.rs
Normal file
@@ -0,0 +1,653 @@
|
||||
//! IRC channel adapter for the OpenFang channel bridge.
|
||||
//!
|
||||
//! Uses raw TCP via `tokio::net::TcpStream` with `tokio::io` buffered I/O for
|
||||
//! plaintext IRC connections. Implements the core IRC protocol: NICK, USER, JOIN,
|
||||
//! PRIVMSG, PING/PONG. A `use_tls: bool` field is reserved for future TLS support
|
||||
//! (would require a `tokio-native-tls` dependency).
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{debug, info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Maximum IRC message length per RFC 2812 (including CRLF).
|
||||
/// We use 510 for the payload (512 minus CRLF).
|
||||
const MAX_MESSAGE_LEN: usize = 510;
|
||||
|
||||
/// Maximum length for a single PRIVMSG payload, accounting for the
|
||||
/// `:nick!user@host PRIVMSG #channel :` prefix overhead (~80 chars conservative).
|
||||
const MAX_PRIVMSG_PAYLOAD: usize = 400;
|
||||
|
||||
const MAX_BACKOFF: Duration = Duration::from_secs(60);
|
||||
const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
|
||||
|
||||
/// IRC channel adapter using raw TCP and the IRC text protocol.
|
||||
///
|
||||
/// Connects to an IRC server, authenticates with NICK/USER (and optional PASS),
|
||||
/// joins configured channels, and listens for PRIVMSG events.
|
||||
pub struct IrcAdapter {
|
||||
/// IRC server hostname (e.g., "irc.libera.chat").
|
||||
server: String,
|
||||
/// IRC server port (typically 6667 for plaintext, 6697 for TLS).
|
||||
port: u16,
|
||||
/// Bot's IRC nickname.
|
||||
nick: String,
|
||||
/// SECURITY: Optional server password, zeroized on drop.
|
||||
password: Option<Zeroizing<String>>,
|
||||
/// IRC channels to join (e.g., ["#openfang", "#bots"]).
|
||||
channels: Vec<String>,
|
||||
/// Reserved for future TLS support. Currently only plaintext is implemented.
|
||||
#[allow(dead_code)]
|
||||
use_tls: bool,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Shared write handle for sending messages from the `send()` method.
|
||||
/// Populated after `start()` connects to the server.
|
||||
write_tx: Arc<RwLock<Option<mpsc::Sender<String>>>>,
|
||||
}
|
||||
|
||||
impl IrcAdapter {
|
||||
/// Create a new IRC adapter.
|
||||
///
|
||||
/// * `server` — IRC server hostname.
|
||||
/// * `port` — IRC server port (6667 for plaintext).
|
||||
/// * `nick` — Bot's IRC nickname.
|
||||
/// * `password` — Optional server password (PASS command).
|
||||
/// * `channels` — IRC channels to join (must start with `#`).
|
||||
/// * `use_tls` — Reserved for future TLS support (currently ignored).
|
||||
pub fn new(
|
||||
server: String,
|
||||
port: u16,
|
||||
nick: String,
|
||||
password: Option<String>,
|
||||
channels: Vec<String>,
|
||||
use_tls: bool,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
server,
|
||||
port,
|
||||
nick,
|
||||
password: password.map(Zeroizing::new),
|
||||
channels,
|
||||
use_tls,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
write_tx: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Format the server address as `host:port`.
|
||||
fn addr(&self) -> String {
|
||||
format!("{}:{}", self.server, self.port)
|
||||
}
|
||||
}
|
||||
|
||||
/// An IRC protocol line parsed into its components.
|
||||
#[derive(Debug)]
|
||||
struct IrcLine {
|
||||
/// Optional prefix (e.g., ":nick!user@host").
|
||||
prefix: Option<String>,
|
||||
/// The IRC command (e.g., "PRIVMSG", "PING", "001").
|
||||
command: String,
|
||||
/// Parameters following the command.
|
||||
params: Vec<String>,
|
||||
/// Trailing parameter (after `:` in the params).
|
||||
trailing: Option<String>,
|
||||
}
|
||||
|
||||
/// Parse a raw IRC line into structured components.
|
||||
///
|
||||
/// IRC line format: `[:prefix] COMMAND [params...] [:trailing]`
|
||||
fn parse_irc_line(line: &str) -> Option<IrcLine> {
|
||||
let line = line.trim();
|
||||
if line.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut remaining = line;
|
||||
let prefix = if remaining.starts_with(':') {
|
||||
let space = remaining.find(' ')?;
|
||||
let pfx = remaining[1..space].to_string();
|
||||
remaining = &remaining[space + 1..];
|
||||
Some(pfx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Split off the trailing parameter (after " :")
|
||||
let (main_part, trailing) = if let Some(idx) = remaining.find(" :") {
|
||||
let trail = remaining[idx + 2..].to_string();
|
||||
(&remaining[..idx], Some(trail))
|
||||
} else {
|
||||
(remaining, None)
|
||||
};
|
||||
|
||||
let mut parts = main_part.split_whitespace();
|
||||
let command = parts.next()?.to_string();
|
||||
let params: Vec<String> = parts.map(String::from).collect();
|
||||
|
||||
Some(IrcLine {
|
||||
prefix,
|
||||
command,
|
||||
params,
|
||||
trailing,
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract the nickname from an IRC prefix like "nick!user@host".
|
||||
fn nick_from_prefix(prefix: &str) -> &str {
|
||||
prefix.split('!').next().unwrap_or(prefix)
|
||||
}
|
||||
|
||||
/// Parse a PRIVMSG IRC line into a `ChannelMessage`.
|
||||
fn parse_privmsg(line: &IrcLine, bot_nick: &str) -> Option<ChannelMessage> {
|
||||
if line.command != "PRIVMSG" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let prefix = line.prefix.as_deref()?;
|
||||
let sender_nick = nick_from_prefix(prefix);
|
||||
|
||||
// Skip messages from the bot itself
|
||||
if sender_nick.eq_ignore_ascii_case(bot_nick) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let target = line.params.first()?;
|
||||
let text = line.trailing.as_deref().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Determine if this is a channel message (group) or a DM
|
||||
let is_group = target.starts_with('#') || target.starts_with('&');
|
||||
|
||||
// The "platform_id" is the channel name for group messages, or the
|
||||
// sender's nick for DMs (so replies go back to the right place).
|
||||
let platform_id = if is_group {
|
||||
target.to_string()
|
||||
} else {
|
||||
sender_nick.to_string()
|
||||
};
|
||||
|
||||
// Parse commands (messages starting with /)
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd_name = &parts[0][1..];
|
||||
let args = if parts.len() > 1 {
|
||||
parts[1].split_whitespace().map(String::from).collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("irc".to_string()),
|
||||
platform_message_id: String::new(), // IRC has no message IDs
|
||||
sender: ChannelUser {
|
||||
platform_id,
|
||||
display_name: sender_nick.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for IrcAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"irc"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("irc".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let (write_cmd_tx, mut write_cmd_rx) = mpsc::channel::<String>(64);
|
||||
|
||||
// Store the write channel so `send()` can use it
|
||||
*self.write_tx.write().await = Some(write_cmd_tx.clone());
|
||||
|
||||
let addr = self.addr();
|
||||
let nick = self.nick.clone();
|
||||
let password = self.password.clone();
|
||||
let channels = self.channels.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = INITIAL_BACKOFF;
|
||||
|
||||
loop {
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
info!("Connecting to IRC server at {addr}...");
|
||||
|
||||
let stream = match TcpStream::connect(&addr).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!("IRC connection failed: {e}, retrying in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
backoff = INITIAL_BACKOFF;
|
||||
info!("IRC connected to {addr}");
|
||||
|
||||
let (reader, mut writer) = stream.into_split();
|
||||
let mut lines = BufReader::new(reader).lines();
|
||||
|
||||
// Send PASS (if configured), NICK, and USER
|
||||
let mut registration = String::new();
|
||||
if let Some(ref pass) = password {
|
||||
registration.push_str(&format!("PASS {}\r\n", pass.as_str()));
|
||||
}
|
||||
registration.push_str(&format!("NICK {nick}\r\n"));
|
||||
registration.push_str(&format!("USER {nick} 0 * :OpenFang Bot\r\n"));
|
||||
|
||||
if let Err(e) = writer.write_all(registration.as_bytes()).await {
|
||||
warn!("IRC registration send failed: {e}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
continue;
|
||||
}
|
||||
|
||||
let nick_clone = nick.clone();
|
||||
let channels_clone = channels.clone();
|
||||
let mut joined = false;
|
||||
|
||||
// Inner message loop — returns true if we should reconnect
|
||||
let should_reconnect = 'inner: loop {
|
||||
tokio::select! {
|
||||
line_result = lines.next_line() => {
|
||||
let line = match line_result {
|
||||
Ok(Some(l)) => l,
|
||||
Ok(None) => {
|
||||
info!("IRC connection closed");
|
||||
break 'inner true;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("IRC read error: {e}");
|
||||
break 'inner true;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("IRC < {line}");
|
||||
|
||||
let parsed = match parse_irc_line(&line) {
|
||||
Some(p) => p,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
match parsed.command.as_str() {
|
||||
// PING/PONG keepalive
|
||||
"PING" => {
|
||||
let pong_param = parsed.trailing
|
||||
.as_deref()
|
||||
.or(parsed.params.first().map(|s| s.as_str()))
|
||||
.unwrap_or("");
|
||||
let pong = format!("PONG :{pong_param}\r\n");
|
||||
if let Err(e) = writer.write_all(pong.as_bytes()).await {
|
||||
warn!("IRC PONG send failed: {e}");
|
||||
break 'inner true;
|
||||
}
|
||||
}
|
||||
|
||||
// RPL_WELCOME (001) — registration complete, join channels
|
||||
"001" => {
|
||||
if !joined {
|
||||
info!("IRC registered as {nick_clone}");
|
||||
for ch in &channels_clone {
|
||||
let join_cmd = format!("JOIN {ch}\r\n");
|
||||
if let Err(e) = writer.write_all(join_cmd.as_bytes()).await {
|
||||
warn!("IRC JOIN send failed: {e}");
|
||||
break 'inner true;
|
||||
}
|
||||
info!("IRC joining {ch}");
|
||||
}
|
||||
joined = true;
|
||||
}
|
||||
}
|
||||
|
||||
// PRIVMSG — incoming message
|
||||
"PRIVMSG" => {
|
||||
if let Some(msg) = parse_privmsg(&parsed, &nick_clone) {
|
||||
debug!(
|
||||
"IRC message from {}: {:?}",
|
||||
msg.sender.display_name, msg.content
|
||||
);
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ERR_NICKNAMEINUSE (433) — nickname taken
|
||||
"433" => {
|
||||
warn!("IRC: nickname '{nick_clone}' is already in use");
|
||||
let alt_nick = format!("{nick_clone}_");
|
||||
let cmd = format!("NICK {alt_nick}\r\n");
|
||||
let _ = writer.write_all(cmd.as_bytes()).await;
|
||||
}
|
||||
|
||||
// JOIN confirmation
|
||||
"JOIN" => {
|
||||
if let Some(ref prefix) = parsed.prefix {
|
||||
let joiner = nick_from_prefix(prefix);
|
||||
let channel = parsed.trailing
|
||||
.as_deref()
|
||||
.or(parsed.params.first().map(|s| s.as_str()))
|
||||
.unwrap_or("?");
|
||||
if joiner.eq_ignore_ascii_case(&nick_clone) {
|
||||
info!("IRC joined {channel}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
// Ignore other commands
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Outbound message requests from `send()`
|
||||
Some(raw_cmd) = write_cmd_rx.recv() => {
|
||||
if let Err(e) = writer.write_all(raw_cmd.as_bytes()).await {
|
||||
warn!("IRC write failed: {e}");
|
||||
break 'inner true;
|
||||
}
|
||||
}
|
||||
|
||||
_ = shutdown_rx.changed() => {
|
||||
if *shutdown_rx.borrow() {
|
||||
info!("IRC adapter shutting down");
|
||||
let _ = writer.write_all(b"QUIT :OpenFang shutting down\r\n").await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !should_reconnect || *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
warn!("IRC: reconnecting in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
}
|
||||
|
||||
info!("IRC connection loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let write_tx = self.write_tx.read().await;
|
||||
let write_tx = write_tx
|
||||
.as_ref()
|
||||
.ok_or("IRC adapter not started — call start() first")?;
|
||||
|
||||
let target = &user.platform_id; // channel name or nick
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
let chunks = split_message(&text, MAX_PRIVMSG_PAYLOAD);
|
||||
for chunk in chunks {
|
||||
let raw = format!("PRIVMSG {target} :{chunk}\r\n");
|
||||
if raw.len() > MAX_MESSAGE_LEN + 2 {
|
||||
// Shouldn't happen with MAX_PRIVMSG_PAYLOAD, but be safe
|
||||
warn!("IRC message exceeds 512 bytes, truncating");
|
||||
}
|
||||
write_tx.send(raw).await.map_err(|e| {
|
||||
Box::<dyn std::error::Error>::from(format!("IRC write channel closed: {e}"))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_irc_adapter_creation() {
|
||||
let adapter = IrcAdapter::new(
|
||||
"irc.libera.chat".to_string(),
|
||||
6667,
|
||||
"openfang".to_string(),
|
||||
None,
|
||||
vec!["#openfang".to_string()],
|
||||
false,
|
||||
);
|
||||
assert_eq!(adapter.name(), "irc");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("irc".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_irc_addr() {
|
||||
let adapter = IrcAdapter::new(
|
||||
"irc.libera.chat".to_string(),
|
||||
6667,
|
||||
"bot".to_string(),
|
||||
None,
|
||||
vec![],
|
||||
false,
|
||||
);
|
||||
assert_eq!(adapter.addr(), "irc.libera.chat:6667");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_irc_addr_custom_port() {
|
||||
let adapter = IrcAdapter::new(
|
||||
"localhost".to_string(),
|
||||
6697,
|
||||
"bot".to_string(),
|
||||
Some("secret".to_string()),
|
||||
vec!["#test".to_string()],
|
||||
true,
|
||||
);
|
||||
assert_eq!(adapter.addr(), "localhost:6697");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_irc_line_ping() {
|
||||
let line = parse_irc_line("PING :server.example.com").unwrap();
|
||||
assert!(line.prefix.is_none());
|
||||
assert_eq!(line.command, "PING");
|
||||
assert_eq!(line.trailing.as_deref(), Some("server.example.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_irc_line_privmsg() {
|
||||
let line = parse_irc_line(":alice!alice@host PRIVMSG #openfang :Hello everyone!").unwrap();
|
||||
assert_eq!(line.prefix.as_deref(), Some("alice!alice@host"));
|
||||
assert_eq!(line.command, "PRIVMSG");
|
||||
assert_eq!(line.params, vec!["#openfang"]);
|
||||
assert_eq!(line.trailing.as_deref(), Some("Hello everyone!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_irc_line_numeric() {
|
||||
let line = parse_irc_line(":server 001 botnick :Welcome to the IRC network").unwrap();
|
||||
assert_eq!(line.prefix.as_deref(), Some("server"));
|
||||
assert_eq!(line.command, "001");
|
||||
assert_eq!(line.params, vec!["botnick"]);
|
||||
assert_eq!(line.trailing.as_deref(), Some("Welcome to the IRC network"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_irc_line_no_trailing() {
|
||||
let line = parse_irc_line(":alice!alice@host JOIN #openfang").unwrap();
|
||||
assert_eq!(line.command, "JOIN");
|
||||
assert_eq!(line.params, vec!["#openfang"]);
|
||||
assert!(line.trailing.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_irc_line_empty() {
|
||||
assert!(parse_irc_line("").is_none());
|
||||
assert!(parse_irc_line(" ").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nick_from_prefix_full() {
|
||||
assert_eq!(nick_from_prefix("alice!alice@host.example.com"), "alice");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nick_from_prefix_nick_only() {
|
||||
assert_eq!(nick_from_prefix("alice"), "alice");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_channel() {
|
||||
let line = IrcLine {
|
||||
prefix: Some("alice!alice@host".to_string()),
|
||||
command: "PRIVMSG".to_string(),
|
||||
params: vec!["#openfang".to_string()],
|
||||
trailing: Some("Hello from IRC!".to_string()),
|
||||
};
|
||||
|
||||
let msg = parse_privmsg(&line, "openfang-bot").unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Custom("irc".to_string()));
|
||||
assert_eq!(msg.sender.display_name, "alice");
|
||||
assert_eq!(msg.sender.platform_id, "#openfang");
|
||||
assert!(msg.is_group);
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from IRC!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_dm() {
|
||||
let line = IrcLine {
|
||||
prefix: Some("bob!bob@host".to_string()),
|
||||
command: "PRIVMSG".to_string(),
|
||||
params: vec!["openfang-bot".to_string()],
|
||||
trailing: Some("Private message".to_string()),
|
||||
};
|
||||
|
||||
let msg = parse_privmsg(&line, "openfang-bot").unwrap();
|
||||
assert!(!msg.is_group);
|
||||
assert_eq!(msg.sender.platform_id, "bob"); // DM replies go to sender
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_skips_self() {
|
||||
let line = IrcLine {
|
||||
prefix: Some("openfang-bot!bot@host".to_string()),
|
||||
command: "PRIVMSG".to_string(),
|
||||
params: vec!["#openfang".to_string()],
|
||||
trailing: Some("My own message".to_string()),
|
||||
};
|
||||
|
||||
let msg = parse_privmsg(&line, "openfang-bot");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_command() {
|
||||
let line = IrcLine {
|
||||
prefix: Some("alice!alice@host".to_string()),
|
||||
command: "PRIVMSG".to_string(),
|
||||
params: vec!["#openfang".to_string()],
|
||||
trailing: Some("/agent hello-world".to_string()),
|
||||
};
|
||||
|
||||
let msg = parse_privmsg(&line, "openfang-bot").unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "agent");
|
||||
assert_eq!(args, &["hello-world"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_empty_text() {
|
||||
let line = IrcLine {
|
||||
prefix: Some("alice!alice@host".to_string()),
|
||||
command: "PRIVMSG".to_string(),
|
||||
params: vec!["#openfang".to_string()],
|
||||
trailing: Some("".to_string()),
|
||||
};
|
||||
|
||||
let msg = parse_privmsg(&line, "openfang-bot");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_no_trailing() {
|
||||
let line = IrcLine {
|
||||
prefix: Some("alice!alice@host".to_string()),
|
||||
command: "PRIVMSG".to_string(),
|
||||
params: vec!["#openfang".to_string()],
|
||||
trailing: None,
|
||||
};
|
||||
|
||||
let msg = parse_privmsg(&line, "openfang-bot");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_not_privmsg() {
|
||||
let line = IrcLine {
|
||||
prefix: Some("alice!alice@host".to_string()),
|
||||
command: "NOTICE".to_string(),
|
||||
params: vec!["#openfang".to_string()],
|
||||
trailing: Some("Notice text".to_string()),
|
||||
};
|
||||
|
||||
let msg = parse_privmsg(&line, "openfang-bot");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
}
|
||||
511
crates/openfang-channels/src/keybase.rs
Normal file
511
crates/openfang-channels/src/keybase.rs
Normal file
@@ -0,0 +1,511 @@
|
||||
//! Keybase Chat channel adapter.
|
||||
//!
|
||||
//! Uses the Keybase Chat API JSON protocol over HTTP for sending and receiving
|
||||
//! messages. Polls for new messages using the `list` + `read` API methods and
|
||||
//! sends messages via the `send` method. Authentication is performed using a
|
||||
//! Keybase username and paper key.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Maximum message length for Keybase messages.
|
||||
const MAX_MESSAGE_LEN: usize = 10000;
|
||||
|
||||
/// Polling interval in seconds for new messages.
|
||||
const POLL_INTERVAL_SECS: u64 = 3;
|
||||
|
||||
/// Keybase Chat API base URL (local daemon or remote API).
|
||||
const KEYBASE_API_URL: &str = "http://127.0.0.1:5222/api";
|
||||
|
||||
/// Keybase Chat channel adapter using JSON API protocol with polling.
|
||||
///
|
||||
/// Interfaces with the Keybase Chat API to send and receive messages. Supports
|
||||
/// filtering by team names for team-based conversations.
|
||||
pub struct KeybaseAdapter {
|
||||
/// Keybase username for authentication.
|
||||
username: String,
|
||||
/// SECURITY: Paper key is zeroized on drop.
|
||||
#[allow(dead_code)]
|
||||
paperkey: Zeroizing<String>,
|
||||
/// Team names to listen on (empty = all conversations).
|
||||
allowed_teams: Vec<String>,
|
||||
/// HTTP client for API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Last read message ID per conversation for incremental polling.
|
||||
last_msg_ids: Arc<RwLock<HashMap<String, i64>>>,
|
||||
}
|
||||
|
||||
impl KeybaseAdapter {
|
||||
/// Create a new Keybase adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `username` - Keybase username.
|
||||
/// * `paperkey` - Paper key for authentication.
|
||||
/// * `allowed_teams` - Team names to filter conversations (empty = all).
|
||||
pub fn new(username: String, paperkey: String, allowed_teams: Vec<String>) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
username,
|
||||
paperkey: Zeroizing::new(paperkey),
|
||||
allowed_teams,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
last_msg_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the authentication payload for API requests.
|
||||
#[allow(dead_code)]
|
||||
fn auth_payload(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"username": self.username,
|
||||
"paperkey": self.paperkey.as_str(),
|
||||
})
|
||||
}
|
||||
|
||||
/// List conversations from the Keybase Chat API.
|
||||
#[allow(dead_code)]
|
||||
async fn list_conversations(
|
||||
&self,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
|
||||
let payload = serde_json::json!({
|
||||
"method": "list",
|
||||
"params": {
|
||||
"options": {}
|
||||
}
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(KEYBASE_API_URL)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Keybase: failed to list conversations".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let conversations = body["result"]["conversations"]
|
||||
.as_array()
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
Ok(conversations)
|
||||
}
|
||||
|
||||
/// Read messages from a specific conversation channel.
|
||||
#[allow(dead_code)]
|
||||
async fn read_messages(
|
||||
&self,
|
||||
channel: &serde_json::Value,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
|
||||
let payload = serde_json::json!({
|
||||
"method": "read",
|
||||
"params": {
|
||||
"options": {
|
||||
"channel": channel,
|
||||
"pagination": {
|
||||
"num": 50,
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(KEYBASE_API_URL)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Keybase: failed to read messages".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let messages = body["result"]["messages"]
|
||||
.as_array()
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
/// Send a text message to a Keybase conversation.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
channel: &serde_json::Value,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let payload = serde_json::json!({
|
||||
"method": "send",
|
||||
"params": {
|
||||
"options": {
|
||||
"channel": channel,
|
||||
"message": {
|
||||
"body": chunk,
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(KEYBASE_API_URL)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Keybase API error {status}: {body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a team name is in the allowed list.
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_team(&self, team_name: &str) -> bool {
|
||||
self.allowed_teams.is_empty() || self.allowed_teams.iter().any(|t| t == team_name)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for KeybaseAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"keybase"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("keybase".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
info!("Keybase adapter starting for user {}", self.username);
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let username = self.username.clone();
|
||||
let allowed_teams = self.allowed_teams.clone();
|
||||
let client = self.client.clone();
|
||||
let last_msg_ids = Arc::clone(&self.last_msg_ids);
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let poll_interval = Duration::from_secs(POLL_INTERVAL_SECS);
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Keybase adapter shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(poll_interval) => {}
|
||||
}
|
||||
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
// List conversations
|
||||
let list_payload = serde_json::json!({
|
||||
"method": "list",
|
||||
"params": {
|
||||
"options": {}
|
||||
}
|
||||
});
|
||||
|
||||
let conversations = match client
|
||||
.post(KEYBASE_API_URL)
|
||||
.json(&list_payload)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body["result"]["conversations"]
|
||||
.as_array()
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Keybase: failed to list conversations: {e}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
for conv in &conversations {
|
||||
let channel_info = &conv["channel"];
|
||||
let members_type = channel_info["members_type"].as_str().unwrap_or("");
|
||||
let team_name = channel_info["name"].as_str().unwrap_or("");
|
||||
let topic_name = channel_info["topic_name"].as_str().unwrap_or("general");
|
||||
|
||||
// Filter by team if configured
|
||||
if !allowed_teams.is_empty()
|
||||
&& members_type == "team"
|
||||
&& !allowed_teams.iter().any(|t| t == team_name)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let conv_key = format!("{}:{}", team_name, topic_name);
|
||||
|
||||
// Read messages from this conversation
|
||||
let read_payload = serde_json::json!({
|
||||
"method": "read",
|
||||
"params": {
|
||||
"options": {
|
||||
"channel": channel_info,
|
||||
"pagination": {
|
||||
"num": 20,
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let messages = match client
|
||||
.post(KEYBASE_API_URL)
|
||||
.json(&read_payload)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body["result"]["messages"]
|
||||
.as_array()
|
||||
.cloned()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Keybase: read error for {conv_key}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let last_id = {
|
||||
let ids = last_msg_ids.read().await;
|
||||
ids.get(&conv_key).copied().unwrap_or(0)
|
||||
};
|
||||
|
||||
let mut newest_id = last_id;
|
||||
|
||||
for msg_wrapper in &messages {
|
||||
let msg = &msg_wrapper["msg"];
|
||||
let msg_id = msg["id"].as_i64().unwrap_or(0);
|
||||
|
||||
// Skip already-seen messages
|
||||
if msg_id <= last_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let sender_username = msg["sender"]["username"].as_str().unwrap_or("");
|
||||
// Skip own messages
|
||||
if sender_username == username {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content_type = msg["content"]["type"].as_str().unwrap_or("");
|
||||
if content_type != "text" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let text = msg["content"]["text"]["body"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if msg_id > newest_id {
|
||||
newest_id = msg_id;
|
||||
}
|
||||
|
||||
let sender_device = msg["sender"]["device_name"].as_str().unwrap_or("");
|
||||
let is_group = members_type == "team";
|
||||
|
||||
let msg_content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("keybase".to_string()),
|
||||
platform_message_id: msg_id.to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: conv_key.clone(),
|
||||
display_name: sender_username.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: None,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"team_name".to_string(),
|
||||
serde_json::Value::String(team_name.to_string()),
|
||||
);
|
||||
m.insert(
|
||||
"topic_name".to_string(),
|
||||
serde_json::Value::String(topic_name.to_string()),
|
||||
);
|
||||
m.insert(
|
||||
"sender_device".to_string(),
|
||||
serde_json::Value::String(sender_device.to_string()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Update last known ID
|
||||
if newest_id > last_id {
|
||||
last_msg_ids.write().await.insert(conv_key, newest_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Keybase polling loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(text) => text,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
// Parse platform_id back into channel info (format: "team:topic")
|
||||
let parts: Vec<&str> = user.platform_id.splitn(2, ':').collect();
|
||||
let (team_name, topic_name) = if parts.len() == 2 {
|
||||
(parts[0], parts[1])
|
||||
} else {
|
||||
(user.platform_id.as_str(), "general")
|
||||
};
|
||||
|
||||
let channel_info = serde_json::json!({
|
||||
"name": team_name,
|
||||
"topic_name": topic_name,
|
||||
"members_type": "team",
|
||||
});
|
||||
|
||||
self.api_send_message(&channel_info, &text).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Keybase does not expose a typing indicator via the JSON API
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_keybase_adapter_creation() {
|
||||
let adapter = KeybaseAdapter::new(
|
||||
"testuser".to_string(),
|
||||
"paper-key-phrase".to_string(),
|
||||
vec!["myteam".to_string()],
|
||||
);
|
||||
assert_eq!(adapter.name(), "keybase");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("keybase".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keybase_allowed_teams() {
|
||||
let adapter = KeybaseAdapter::new(
|
||||
"user".to_string(),
|
||||
"paperkey".to_string(),
|
||||
vec!["team-a".to_string(), "team-b".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_team("team-a"));
|
||||
assert!(adapter.is_allowed_team("team-b"));
|
||||
assert!(!adapter.is_allowed_team("team-c"));
|
||||
|
||||
let open = KeybaseAdapter::new("user".to_string(), "paperkey".to_string(), vec![]);
|
||||
assert!(open.is_allowed_team("any-team"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keybase_paperkey_zeroized() {
|
||||
let adapter = KeybaseAdapter::new(
|
||||
"user".to_string(),
|
||||
"my secret paper key".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.paperkey.as_str(), "my secret paper key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keybase_auth_payload() {
|
||||
let adapter = KeybaseAdapter::new("myuser".to_string(), "my-paper-key".to_string(), vec![]);
|
||||
let payload = adapter.auth_payload();
|
||||
assert_eq!(payload["username"], "myuser");
|
||||
assert_eq!(payload["paperkey"], "my-paper-key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keybase_username_stored() {
|
||||
let adapter = KeybaseAdapter::new("alice".to_string(), "key".to_string(), vec![]);
|
||||
assert_eq!(adapter.username, "alice");
|
||||
}
|
||||
}
|
||||
52
crates/openfang-channels/src/lib.rs
Normal file
52
crates/openfang-channels/src/lib.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
//! Channel Bridge Layer for the OpenFang Agent OS.
|
||||
//!
|
||||
//! Provides 40 pluggable messaging integrations that convert platform messages
|
||||
//! into unified `ChannelMessage` events for the kernel.
|
||||
|
||||
pub mod bridge;
|
||||
pub mod discord;
|
||||
pub mod email;
|
||||
pub mod formatter;
|
||||
pub mod google_chat;
|
||||
pub mod irc;
|
||||
pub mod matrix;
|
||||
pub mod mattermost;
|
||||
pub mod rocketchat;
|
||||
pub mod router;
|
||||
pub mod signal;
|
||||
pub mod slack;
|
||||
pub mod teams;
|
||||
pub mod telegram;
|
||||
pub mod twitch;
|
||||
pub mod types;
|
||||
pub mod whatsapp;
|
||||
pub mod xmpp;
|
||||
pub mod zulip;
|
||||
// Wave 3 — High-value channels
|
||||
pub mod bluesky;
|
||||
pub mod feishu;
|
||||
pub mod line;
|
||||
pub mod mastodon;
|
||||
pub mod messenger;
|
||||
pub mod reddit;
|
||||
pub mod revolt;
|
||||
pub mod viber;
|
||||
// Wave 4 — Enterprise & community channels
|
||||
pub mod flock;
|
||||
pub mod guilded;
|
||||
pub mod keybase;
|
||||
pub mod nextcloud;
|
||||
pub mod nostr;
|
||||
pub mod pumble;
|
||||
pub mod threema;
|
||||
pub mod twist;
|
||||
pub mod webex;
|
||||
// Wave 5 — Niche & differentiating channels
|
||||
pub mod dingtalk;
|
||||
pub mod discourse;
|
||||
pub mod gitter;
|
||||
pub mod gotify;
|
||||
pub mod linkedin;
|
||||
pub mod mumble;
|
||||
pub mod ntfy;
|
||||
pub mod webhook;
|
||||
650
crates/openfang-channels/src/line.rs
Normal file
650
crates/openfang-channels/src/line.rs
Normal file
@@ -0,0 +1,650 @@
|
||||
//! LINE Messaging API channel adapter.
|
||||
//!
|
||||
//! Uses the LINE Messaging API v2 for sending push/reply messages and a lightweight
|
||||
//! axum HTTP webhook server for receiving inbound events. Webhook signature
|
||||
//! verification uses HMAC-SHA256 with the channel secret. Authentication for
|
||||
//! outbound calls uses `Authorization: Bearer {channel_access_token}`.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// LINE push message API endpoint.
|
||||
const LINE_PUSH_URL: &str = "https://api.line.me/v2/bot/message/push";
|
||||
|
||||
/// LINE reply message API endpoint.
|
||||
const LINE_REPLY_URL: &str = "https://api.line.me/v2/bot/message/reply";
|
||||
|
||||
/// LINE profile API endpoint.
|
||||
#[allow(dead_code)]
|
||||
const LINE_PROFILE_URL: &str = "https://api.line.me/v2/bot/profile";
|
||||
|
||||
/// Maximum LINE message text length (characters).
|
||||
const MAX_MESSAGE_LEN: usize = 5000;
|
||||
|
||||
/// LINE Messaging API adapter.
|
||||
///
|
||||
/// Inbound messages arrive via an axum HTTP webhook server that accepts POST
|
||||
/// requests from the LINE Platform. Each request body is validated using
|
||||
/// HMAC-SHA256 (`X-Line-Signature` header) with the channel secret.
|
||||
///
|
||||
/// Outbound messages are sent via the push message API with a bearer token.
|
||||
pub struct LineAdapter {
|
||||
/// SECURITY: Channel secret for webhook signature verification, zeroized on drop.
|
||||
channel_secret: Zeroizing<String>,
|
||||
/// SECURITY: Channel access token for outbound API calls, zeroized on drop.
|
||||
access_token: Zeroizing<String>,
|
||||
/// Port on which the inbound webhook HTTP server listens.
|
||||
webhook_port: u16,
|
||||
/// HTTP client for outbound API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl LineAdapter {
|
||||
/// Create a new LINE adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `channel_secret` - Channel secret for HMAC-SHA256 signature verification.
|
||||
/// * `access_token` - Long-lived channel access token for sending messages.
|
||||
/// * `webhook_port` - Local port for the inbound webhook HTTP server.
|
||||
pub fn new(channel_secret: String, access_token: String, webhook_port: u16) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
channel_secret: Zeroizing::new(channel_secret),
|
||||
access_token: Zeroizing::new(access_token),
|
||||
webhook_port,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Verify the X-Line-Signature header using HMAC-SHA256.
|
||||
///
|
||||
/// The signature is computed as `Base64(HMAC-SHA256(channel_secret, body))`.
|
||||
fn verify_signature(&self, body: &[u8], signature: &str) -> bool {
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
type HmacSha256 = Hmac<Sha256>;
|
||||
|
||||
let Ok(mut mac) = HmacSha256::new_from_slice(self.channel_secret.as_bytes()) else {
|
||||
warn!("LINE: failed to create HMAC instance");
|
||||
return false;
|
||||
};
|
||||
mac.update(body);
|
||||
let result = mac.finalize().into_bytes();
|
||||
|
||||
// Compare with constant-time base64 decode + verify
|
||||
use base64::Engine;
|
||||
let Ok(expected) = base64::engine::general_purpose::STANDARD.decode(signature) else {
|
||||
warn!("LINE: invalid base64 in X-Line-Signature");
|
||||
return false;
|
||||
};
|
||||
|
||||
// Constant-time comparison to prevent timing attacks
|
||||
if result.len() != expected.len() {
|
||||
return false;
|
||||
}
|
||||
let mut diff = 0u8;
|
||||
for (a, b) in result.iter().zip(expected.iter()) {
|
||||
diff |= a ^ b;
|
||||
}
|
||||
diff == 0
|
||||
}
|
||||
|
||||
/// Validate the channel access token by fetching the bot's own profile.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
// Verify token by calling the bot info endpoint
|
||||
let resp = self
|
||||
.client
|
||||
.get("https://api.line.me/v2/bot/info")
|
||||
.bearer_auth(self.access_token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("LINE authentication failed {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let display_name = body["displayName"]
|
||||
.as_str()
|
||||
.unwrap_or("LINE Bot")
|
||||
.to_string();
|
||||
Ok(display_name)
|
||||
}
|
||||
|
||||
/// Fetch a user's display name from the LINE profile API.
|
||||
#[allow(dead_code)]
|
||||
async fn get_user_display_name(&self, user_id: &str) -> String {
|
||||
let url = format!("{}/{}", LINE_PROFILE_URL, user_id);
|
||||
match self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.access_token.as_str())
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) if resp.status().is_success() => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body["displayName"]
|
||||
.as_str()
|
||||
.unwrap_or("Unknown")
|
||||
.to_string()
|
||||
}
|
||||
_ => "Unknown".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a push message to a LINE user or group.
|
||||
async fn api_push_message(
|
||||
&self,
|
||||
to: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"to": to,
|
||||
"messages": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": chunk,
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(LINE_PUSH_URL)
|
||||
.bearer_auth(self.access_token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("LINE push API error {status}: {resp_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a reply message using a reply token (must be used within 30s).
|
||||
#[allow(dead_code)]
|
||||
async fn api_reply_message(
|
||||
&self,
|
||||
reply_token: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
// LINE reply API allows up to 5 messages per reply
|
||||
let messages: Vec<serde_json::Value> = chunks
|
||||
.into_iter()
|
||||
.take(5)
|
||||
.map(|chunk| {
|
||||
serde_json::json!({
|
||||
"type": "text",
|
||||
"text": chunk,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let body = serde_json::json!({
|
||||
"replyToken": reply_token,
|
||||
"messages": messages,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(LINE_REPLY_URL)
|
||||
.bearer_auth(self.access_token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("LINE reply API error {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a LINE webhook event into a `ChannelMessage`.
|
||||
///
|
||||
/// Handles `message` events with text type. Returns `None` for unsupported
|
||||
/// event types (follow, unfollow, postback, beacon, etc.).
|
||||
fn parse_line_event(event: &serde_json::Value) -> Option<ChannelMessage> {
|
||||
let event_type = event["type"].as_str().unwrap_or("");
|
||||
if event_type != "message" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let message = event.get("message")?;
|
||||
let msg_type = message["type"].as_str().unwrap_or("");
|
||||
|
||||
// Only handle text messages for now
|
||||
if msg_type != "text" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let text = message["text"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let source = event.get("source")?;
|
||||
let source_type = source["type"].as_str().unwrap_or("user");
|
||||
let user_id = source["userId"].as_str().unwrap_or("").to_string();
|
||||
|
||||
// Determine the target (user, group, or room) for replies
|
||||
let (reply_to, is_group) = match source_type {
|
||||
"group" => {
|
||||
let group_id = source["groupId"].as_str().unwrap_or("").to_string();
|
||||
(group_id, true)
|
||||
}
|
||||
"room" => {
|
||||
let room_id = source["roomId"].as_str().unwrap_or("").to_string();
|
||||
(room_id, true)
|
||||
}
|
||||
_ => (user_id.clone(), false),
|
||||
};
|
||||
|
||||
let msg_id = message["id"].as_str().unwrap_or("").to_string();
|
||||
let reply_token = event["replyToken"].as_str().unwrap_or("").to_string();
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd_name = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"user_id".to_string(),
|
||||
serde_json::Value::String(user_id.clone()),
|
||||
);
|
||||
metadata.insert(
|
||||
"reply_to".to_string(),
|
||||
serde_json::Value::String(reply_to.clone()),
|
||||
);
|
||||
if !reply_token.is_empty() {
|
||||
metadata.insert(
|
||||
"reply_token".to_string(),
|
||||
serde_json::Value::String(reply_token),
|
||||
);
|
||||
}
|
||||
metadata.insert(
|
||||
"source_type".to_string(),
|
||||
serde_json::Value::String(source_type.to_string()),
|
||||
);
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("line".to_string()),
|
||||
platform_message_id: msg_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: reply_to,
|
||||
display_name: user_id,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: None,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for LineAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"line"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("line".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let bot_name = self.validate().await?;
|
||||
info!("LINE adapter authenticated as {bot_name}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let channel_secret = self.channel_secret.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let channel_secret = Arc::new(channel_secret);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
let app = axum::Router::new().route(
|
||||
"/webhook",
|
||||
axum::routing::post({
|
||||
let secret = Arc::clone(&channel_secret);
|
||||
let tx = Arc::clone(&tx);
|
||||
move |headers: axum::http::HeaderMap,
|
||||
body: axum::extract::Json<serde_json::Value>| {
|
||||
let secret = Arc::clone(&secret);
|
||||
let tx = Arc::clone(&tx);
|
||||
async move {
|
||||
// Verify X-Line-Signature
|
||||
let signature = headers
|
||||
.get("x-line-signature")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
let body_bytes = serde_json::to_vec(&body.0).unwrap_or_default();
|
||||
|
||||
// Create a temporary adapter-like verifier
|
||||
let adapter = LineAdapter {
|
||||
channel_secret: secret.as_ref().clone(),
|
||||
access_token: Zeroizing::new(String::new()),
|
||||
webhook_port: 0,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(watch::channel(false).0),
|
||||
shutdown_rx: watch::channel(false).1,
|
||||
};
|
||||
|
||||
if !signature.is_empty()
|
||||
&& !adapter.verify_signature(&body_bytes, signature)
|
||||
{
|
||||
warn!("LINE: invalid webhook signature");
|
||||
return axum::http::StatusCode::UNAUTHORIZED;
|
||||
}
|
||||
|
||||
// Parse events array
|
||||
if let Some(events) = body.0["events"].as_array() {
|
||||
for event in events {
|
||||
if let Some(msg) = parse_line_event(event) {
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
axum::http::StatusCode::OK
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
info!("LINE webhook server listening on {addr}");
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("LINE webhook bind failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let server = axum::serve(listener, app);
|
||||
|
||||
tokio::select! {
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
warn!("LINE webhook server error: {e}");
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("LINE adapter shutting down");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_push_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
ChannelContent::Image { url, caption } => {
|
||||
// LINE supports image messages with a preview
|
||||
let body = serde_json::json!({
|
||||
"to": user.platform_id,
|
||||
"messages": [
|
||||
{
|
||||
"type": "image",
|
||||
"originalContentUrl": url,
|
||||
"previewImageUrl": url,
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(LINE_PUSH_URL)
|
||||
.bearer_auth(self.access_token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
warn!("LINE image push error {status}: {resp_body}");
|
||||
}
|
||||
|
||||
// Send caption as separate text if present
|
||||
if let Some(cap) = caption {
|
||||
if !cap.is_empty() {
|
||||
self.api_push_message(&user.platform_id, &cap).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
self.api_push_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// LINE does not support typing indicators via REST API
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_line_adapter_creation() {
|
||||
let adapter = LineAdapter::new(
|
||||
"channel-secret-123".to_string(),
|
||||
"access-token-456".to_string(),
|
||||
8080,
|
||||
);
|
||||
assert_eq!(adapter.name(), "line");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("line".to_string())
|
||||
);
|
||||
assert_eq!(adapter.webhook_port, 8080);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_line_adapter_both_tokens() {
|
||||
let adapter = LineAdapter::new("secret".to_string(), "token".to_string(), 9000);
|
||||
// Verify both secrets are stored as Zeroizing
|
||||
assert_eq!(adapter.channel_secret.as_str(), "secret");
|
||||
assert_eq!(adapter.access_token.as_str(), "token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_line_event_text_message() {
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"replyToken": "reply-token-123",
|
||||
"source": {
|
||||
"type": "user",
|
||||
"userId": "U1234567890"
|
||||
},
|
||||
"message": {
|
||||
"id": "msg-001",
|
||||
"type": "text",
|
||||
"text": "Hello from LINE!"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_line_event(&event).unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Custom("line".to_string()));
|
||||
assert_eq!(msg.platform_message_id, "msg-001");
|
||||
assert!(!msg.is_group);
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from LINE!"));
|
||||
assert!(msg.metadata.contains_key("reply_token"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_line_event_group_message() {
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"replyToken": "reply-token-456",
|
||||
"source": {
|
||||
"type": "group",
|
||||
"groupId": "C1234567890",
|
||||
"userId": "U1234567890"
|
||||
},
|
||||
"message": {
|
||||
"id": "msg-002",
|
||||
"type": "text",
|
||||
"text": "Group message"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_line_event(&event).unwrap();
|
||||
assert!(msg.is_group);
|
||||
assert_eq!(msg.sender.platform_id, "C1234567890");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_line_event_command() {
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"replyToken": "rt",
|
||||
"source": {
|
||||
"type": "user",
|
||||
"userId": "U123"
|
||||
},
|
||||
"message": {
|
||||
"id": "msg-003",
|
||||
"type": "text",
|
||||
"text": "/status all"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_line_event(&event).unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "status");
|
||||
assert_eq!(args, &["all"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_line_event_non_message() {
|
||||
let event = serde_json::json!({
|
||||
"type": "follow",
|
||||
"replyToken": "rt",
|
||||
"source": {
|
||||
"type": "user",
|
||||
"userId": "U123"
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_line_event(&event).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_line_event_non_text() {
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"replyToken": "rt",
|
||||
"source": {
|
||||
"type": "user",
|
||||
"userId": "U123"
|
||||
},
|
||||
"message": {
|
||||
"id": "msg-004",
|
||||
"type": "sticker",
|
||||
"packageId": "1",
|
||||
"stickerId": "1"
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_line_event(&event).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_line_event_room_source() {
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"replyToken": "rt",
|
||||
"source": {
|
||||
"type": "room",
|
||||
"roomId": "R1234567890",
|
||||
"userId": "U123"
|
||||
},
|
||||
"message": {
|
||||
"id": "msg-005",
|
||||
"type": "text",
|
||||
"text": "Room message"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_line_event(&event).unwrap();
|
||||
assert!(msg.is_group);
|
||||
assert_eq!(msg.sender.platform_id, "R1234567890");
|
||||
}
|
||||
}
|
||||
484
crates/openfang-channels/src/linkedin.rs
Normal file
484
crates/openfang-channels/src/linkedin.rs
Normal file
@@ -0,0 +1,484 @@
|
||||
//! LinkedIn Messaging channel adapter.
|
||||
//!
|
||||
//! Integrates with the LinkedIn Organization Messaging API using OAuth2
|
||||
//! Bearer token authentication. Polls for new messages and sends replies
|
||||
//! via the REST API.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const POLL_INTERVAL_SECS: u64 = 10;
|
||||
const MAX_MESSAGE_LEN: usize = 3000;
|
||||
const LINKEDIN_API_BASE: &str = "https://api.linkedin.com/v2";
|
||||
|
||||
/// LinkedIn Messaging channel adapter.
|
||||
///
|
||||
/// Polls the LinkedIn Organization Messaging API for new inbound messages
|
||||
/// and sends replies via the same API. Requires a valid OAuth2 access token
|
||||
/// with `r_organization_social` and `w_organization_social` scopes.
|
||||
pub struct LinkedInAdapter {
|
||||
/// SECURITY: OAuth2 access token is zeroized on drop.
|
||||
access_token: Zeroizing<String>,
|
||||
/// LinkedIn organization URN (e.g., "urn:li:organization:12345").
|
||||
organization_id: String,
|
||||
/// HTTP client.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Last seen message timestamp for incremental polling (epoch millis).
|
||||
last_seen_ts: Arc<RwLock<i64>>,
|
||||
}
|
||||
|
||||
impl LinkedInAdapter {
|
||||
/// Create a new LinkedIn adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `access_token` - OAuth2 Bearer token with messaging permissions.
|
||||
/// * `organization_id` - LinkedIn organization URN or numeric ID.
|
||||
pub fn new(access_token: String, organization_id: String) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
// Normalize organization_id to URN format
|
||||
let organization_id = if organization_id.starts_with("urn:") {
|
||||
organization_id
|
||||
} else {
|
||||
format!("urn:li:organization:{}", organization_id)
|
||||
};
|
||||
Self {
|
||||
access_token: Zeroizing::new(access_token),
|
||||
organization_id,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
last_seen_ts: Arc::new(RwLock::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an authenticated request builder.
|
||||
fn auth_request(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
builder
|
||||
.bearer_auth(self.access_token.as_str())
|
||||
.header("X-Restli-Protocol-Version", "2.0.0")
|
||||
.header("LinkedIn-Version", "202401")
|
||||
}
|
||||
|
||||
/// Validate credentials by fetching the organization info.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/organizations/{}",
|
||||
LINKEDIN_API_BASE,
|
||||
self.organization_id
|
||||
.strip_prefix("urn:li:organization:")
|
||||
.unwrap_or(&self.organization_id)
|
||||
);
|
||||
let resp = self.auth_request(self.client.get(&url)).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("LinkedIn auth failed (HTTP {})", resp.status()).into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let name = body["localizedName"]
|
||||
.as_str()
|
||||
.unwrap_or("LinkedIn Org")
|
||||
.to_string();
|
||||
Ok(name)
|
||||
}
|
||||
|
||||
/// Fetch new messages from the organization messaging inbox.
|
||||
async fn fetch_messages(
|
||||
client: &reqwest::Client,
|
||||
access_token: &str,
|
||||
organization_id: &str,
|
||||
after_ts: i64,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/organizationMessages?q=organization&organization={}&count=50",
|
||||
LINKEDIN_API_BASE,
|
||||
url::form_urlencoded::Serializer::new(String::new())
|
||||
.append_pair("org", organization_id)
|
||||
.finish()
|
||||
.split('=')
|
||||
.nth(1)
|
||||
.unwrap_or(organization_id)
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.bearer_auth(access_token)
|
||||
.header("X-Restli-Protocol-Version", "2.0.0")
|
||||
.header("LinkedIn-Version", "202401")
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("LinkedIn: HTTP {}", resp.status()).into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let elements = body["elements"].as_array().cloned().unwrap_or_default();
|
||||
|
||||
// Filter to messages after the given timestamp
|
||||
let filtered: Vec<serde_json::Value> = elements
|
||||
.into_iter()
|
||||
.filter(|msg| {
|
||||
let created = msg["createdAt"].as_i64().unwrap_or(0);
|
||||
created > after_ts
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(filtered)
|
||||
}
|
||||
|
||||
/// Send a message via the LinkedIn Organization Messaging API.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
recipient_urn: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/organizationMessages", LINKEDIN_API_BASE);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
let num_chunks = chunks.len();
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"recipients": [recipient_urn],
|
||||
"organization": self.organization_id,
|
||||
"body": {
|
||||
"text": chunk,
|
||||
},
|
||||
"messageType": "MEMBER_TO_MEMBER",
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.auth_request(self.client.post(&url))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("LinkedIn API error {status}: {err_body}").into());
|
||||
}
|
||||
|
||||
// LinkedIn rate limit: max 100 requests per day for messaging
|
||||
// Small delay between chunks to be respectful
|
||||
if num_chunks > 1 {
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Parse a LinkedIn message element into usable fields.
|
||||
fn parse_message_element(
|
||||
element: &serde_json::Value,
|
||||
) -> Option<(String, String, String, String, i64)> {
|
||||
let id = element["id"].as_str()?.to_string();
|
||||
let body_text = element["body"]["text"].as_str()?.to_string();
|
||||
if body_text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let sender_urn = element["from"].as_str().unwrap_or("unknown").to_string();
|
||||
let sender_name = element["fromName"]
|
||||
.as_str()
|
||||
.or_else(|| element["senderName"].as_str())
|
||||
.unwrap_or("LinkedIn User")
|
||||
.to_string();
|
||||
let created_at = element["createdAt"].as_i64().unwrap_or(0);
|
||||
|
||||
Some((id, body_text, sender_urn, sender_name, created_at))
|
||||
}
|
||||
|
||||
/// Get the numeric organization ID.
|
||||
pub fn org_numeric_id(&self) -> &str {
|
||||
self.organization_id
|
||||
.strip_prefix("urn:li:organization:")
|
||||
.unwrap_or(&self.organization_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for LinkedInAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"linkedin"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("linkedin".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let org_name = self.validate().await?;
|
||||
info!("LinkedIn adapter authenticated for org: {org_name}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let access_token = self.access_token.clone();
|
||||
let organization_id = self.organization_id.clone();
|
||||
let client = self.client.clone();
|
||||
let last_seen_ts = Arc::clone(&self.last_seen_ts);
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
// Initialize last_seen_ts to now so we only get new messages
|
||||
{
|
||||
*last_seen_ts.write().await = Utc::now().timestamp_millis();
|
||||
}
|
||||
|
||||
let poll_interval = Duration::from_secs(POLL_INTERVAL_SECS);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
if *shutdown_rx.borrow() {
|
||||
info!("LinkedIn adapter shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = tokio::time::sleep(poll_interval) => {}
|
||||
}
|
||||
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
let after_ts = *last_seen_ts.read().await;
|
||||
|
||||
let poll_result =
|
||||
Self::fetch_messages(&client, &access_token, &organization_id, after_ts)
|
||||
.await
|
||||
.map_err(|e| e.to_string());
|
||||
|
||||
let messages = match poll_result {
|
||||
Ok(m) => {
|
||||
backoff = Duration::from_secs(1);
|
||||
m
|
||||
}
|
||||
Err(msg) => {
|
||||
warn!("LinkedIn: poll error: {msg}, backing off {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(300));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let mut max_ts = after_ts;
|
||||
|
||||
for element in &messages {
|
||||
let (id, body_text, sender_urn, sender_name, created_at) =
|
||||
match Self::parse_message_element(element) {
|
||||
Some(parsed) => parsed,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
// Skip messages from own organization
|
||||
if sender_urn.contains(&organization_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if created_at > max_ts {
|
||||
max_ts = created_at;
|
||||
}
|
||||
|
||||
let thread_id = element["conversationId"]
|
||||
.as_str()
|
||||
.or_else(|| element["threadId"].as_str())
|
||||
.map(String::from);
|
||||
|
||||
let content = if body_text.starts_with('/') {
|
||||
let parts: Vec<&str> = body_text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(body_text)
|
||||
};
|
||||
|
||||
let msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("linkedin".to_string()),
|
||||
platform_message_id: id,
|
||||
sender: ChannelUser {
|
||||
platform_id: sender_urn.clone(),
|
||||
display_name: sender_name,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: false,
|
||||
thread_id,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"sender_urn".to_string(),
|
||||
serde_json::Value::String(sender_urn),
|
||||
);
|
||||
m.insert(
|
||||
"organization_id".to_string(),
|
||||
serde_json::Value::String(organization_id.clone()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if max_ts > after_ts {
|
||||
*last_seen_ts.write().await = max_ts;
|
||||
}
|
||||
}
|
||||
|
||||
info!("LinkedIn polling loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
// user.platform_id should be the recipient's LinkedIn URN
|
||||
self.api_send_message(&user.platform_id, &text).await
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// LinkedIn Messaging API does not support typing indicators.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_linkedin_adapter_creation() {
|
||||
let adapter = LinkedInAdapter::new("test-token".to_string(), "12345".to_string());
|
||||
assert_eq!(adapter.name(), "linkedin");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("linkedin".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linkedin_organization_id_normalization() {
|
||||
let adapter = LinkedInAdapter::new("tok".to_string(), "12345".to_string());
|
||||
assert_eq!(adapter.organization_id, "urn:li:organization:12345");
|
||||
|
||||
let adapter2 =
|
||||
LinkedInAdapter::new("tok".to_string(), "urn:li:organization:67890".to_string());
|
||||
assert_eq!(adapter2.organization_id, "urn:li:organization:67890");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linkedin_org_numeric_id() {
|
||||
let adapter = LinkedInAdapter::new("tok".to_string(), "12345".to_string());
|
||||
assert_eq!(adapter.org_numeric_id(), "12345");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linkedin_auth_headers() {
|
||||
let adapter = LinkedInAdapter::new("my-oauth-token".to_string(), "12345".to_string());
|
||||
let builder = adapter.client.get("https://api.linkedin.com/v2/me");
|
||||
let builder = adapter.auth_request(builder);
|
||||
let request = builder.build().unwrap();
|
||||
assert!(request.headers().contains_key("authorization"));
|
||||
assert_eq!(
|
||||
request.headers().get("X-Restli-Protocol-Version").unwrap(),
|
||||
"2.0.0"
|
||||
);
|
||||
assert_eq!(request.headers().get("LinkedIn-Version").unwrap(), "202401");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linkedin_parse_message_element() {
|
||||
let element = serde_json::json!({
|
||||
"id": "msg-001",
|
||||
"body": { "text": "Hello from LinkedIn" },
|
||||
"from": "urn:li:person:abc123",
|
||||
"fromName": "Jane Doe",
|
||||
"createdAt": 1700000000000_i64,
|
||||
});
|
||||
let result = LinkedInAdapter::parse_message_element(&element);
|
||||
assert!(result.is_some());
|
||||
let (id, body, from, name, ts) = result.unwrap();
|
||||
assert_eq!(id, "msg-001");
|
||||
assert_eq!(body, "Hello from LinkedIn");
|
||||
assert_eq!(from, "urn:li:person:abc123");
|
||||
assert_eq!(name, "Jane Doe");
|
||||
assert_eq!(ts, 1700000000000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linkedin_parse_message_empty_body() {
|
||||
let element = serde_json::json!({
|
||||
"id": "msg-002",
|
||||
"body": { "text": "" },
|
||||
"from": "urn:li:person:xyz",
|
||||
});
|
||||
assert!(LinkedInAdapter::parse_message_element(&element).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linkedin_parse_message_missing_body() {
|
||||
let element = serde_json::json!({
|
||||
"id": "msg-003",
|
||||
"from": "urn:li:person:xyz",
|
||||
});
|
||||
assert!(LinkedInAdapter::parse_message_element(&element).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_linkedin_parse_message_defaults() {
|
||||
let element = serde_json::json!({
|
||||
"id": "msg-004",
|
||||
"body": { "text": "Hi" },
|
||||
});
|
||||
let result = LinkedInAdapter::parse_message_element(&element);
|
||||
assert!(result.is_some());
|
||||
let (_, _, from, name, _) = result.unwrap();
|
||||
assert_eq!(from, "unknown");
|
||||
assert_eq!(name, "LinkedIn User");
|
||||
}
|
||||
}
|
||||
706
crates/openfang-channels/src/mastodon.rs
Normal file
706
crates/openfang-channels/src/mastodon.rs
Normal file
@@ -0,0 +1,706 @@
|
||||
//! Mastodon Streaming API channel adapter.
|
||||
//!
|
||||
//! Uses the Mastodon REST API v1 for sending statuses (toots) and the Streaming
|
||||
//! API (Server-Sent Events) for real-time notification reception. Authentication
|
||||
//! is performed via `Authorization: Bearer {access_token}` on all API calls.
|
||||
//! Mentions/notifications are received via the SSE user stream endpoint.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Maximum Mastodon status length (default server limit).
|
||||
const MAX_MESSAGE_LEN: usize = 500;
|
||||
|
||||
/// SSE reconnect delay on error.
|
||||
const SSE_RECONNECT_DELAY_SECS: u64 = 5;
|
||||
|
||||
/// Maximum backoff for SSE reconnection.
|
||||
const MAX_BACKOFF_SECS: u64 = 60;
|
||||
|
||||
/// Mastodon Streaming API adapter.
|
||||
///
|
||||
/// Inbound mentions are received via Server-Sent Events (SSE) from the
|
||||
/// Mastodon streaming user endpoint. Outbound replies are posted as new
|
||||
/// statuses with `in_reply_to_id` set to the original status ID.
|
||||
pub struct MastodonAdapter {
|
||||
/// Mastodon instance URL (e.g., `"https://mastodon.social"`).
|
||||
instance_url: String,
|
||||
/// SECURITY: Access token (OAuth2 bearer token), zeroized on drop.
|
||||
access_token: Zeroizing<String>,
|
||||
/// HTTP client for API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Bot's own account ID (populated after verification).
|
||||
own_account_id: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl MastodonAdapter {
|
||||
/// Create a new Mastodon adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `instance_url` - Base URL of the Mastodon instance (no trailing slash).
|
||||
/// * `access_token` - OAuth2 access token with `read` and `write` scopes.
|
||||
pub fn new(instance_url: String, access_token: String) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let instance_url = instance_url.trim_end_matches('/').to_string();
|
||||
Self {
|
||||
instance_url,
|
||||
access_token: Zeroizing::new(access_token),
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
own_account_id: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate the access token by calling `/api/v1/accounts/verify_credentials`.
|
||||
async fn validate(&self) -> Result<(String, String), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/api/v1/accounts/verify_credentials", self.instance_url);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.access_token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Mastodon authentication failed {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let account_id = body["id"].as_str().unwrap_or("").to_string();
|
||||
let username = body["username"].as_str().unwrap_or("unknown").to_string();
|
||||
|
||||
// Store own account ID
|
||||
*self.own_account_id.write().await = Some(account_id.clone());
|
||||
|
||||
Ok((account_id, username))
|
||||
}
|
||||
|
||||
/// Post a status (toot), optionally as a reply.
|
||||
async fn api_post_status(
|
||||
&self,
|
||||
text: &str,
|
||||
in_reply_to_id: Option<&str>,
|
||||
visibility: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/api/v1/statuses", self.instance_url);
|
||||
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
let mut reply_id = in_reply_to_id.map(|s| s.to_string());
|
||||
|
||||
for chunk in chunks {
|
||||
let mut params: HashMap<&str, &str> = HashMap::new();
|
||||
params.insert("status", chunk);
|
||||
params.insert("visibility", visibility);
|
||||
|
||||
if let Some(ref rid) = reply_id {
|
||||
params.insert("in_reply_to_id", rid);
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.access_token.as_str())
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Mastodon post status error {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
// If we're posting a thread, chain replies
|
||||
let resp_body: serde_json::Value = resp.json().await?;
|
||||
reply_id = resp_body["id"].as_str().map(|s| s.to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fetch notifications (mentions) since a given ID.
|
||||
#[allow(dead_code)]
|
||||
async fn fetch_notifications(
|
||||
&self,
|
||||
since_id: Option<&str>,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
|
||||
let mut url = format!(
|
||||
"{}/api/v1/notifications?types[]=mention&limit=30",
|
||||
self.instance_url
|
||||
);
|
||||
|
||||
if let Some(sid) = since_id {
|
||||
url.push_str(&format!("&since_id={}", sid));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.access_token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Failed to fetch Mastodon notifications".into());
|
||||
}
|
||||
|
||||
let notifications: Vec<serde_json::Value> = resp.json().await?;
|
||||
Ok(notifications)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Mastodon notification (mention) into a `ChannelMessage`.
|
||||
fn parse_mastodon_notification(
|
||||
notification: &serde_json::Value,
|
||||
own_account_id: &str,
|
||||
) -> Option<ChannelMessage> {
|
||||
let notif_type = notification["type"].as_str().unwrap_or("");
|
||||
if notif_type != "mention" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let status = notification.get("status")?;
|
||||
let account = notification.get("account")?;
|
||||
|
||||
let account_id = account["id"].as_str().unwrap_or("");
|
||||
// Skip own mentions (shouldn't happen but guard)
|
||||
if account_id == own_account_id {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Extract text content (strip HTML tags for plain text)
|
||||
let content_html = status["content"].as_str().unwrap_or("");
|
||||
let text = strip_html_tags(content_html);
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let status_id = status["id"].as_str().unwrap_or("").to_string();
|
||||
let notif_id = notification["id"].as_str().unwrap_or("").to_string();
|
||||
let username = account["username"].as_str().unwrap_or("").to_string();
|
||||
let display_name = account["display_name"]
|
||||
.as_str()
|
||||
.unwrap_or(&username)
|
||||
.to_string();
|
||||
let acct = account["acct"].as_str().unwrap_or("").to_string();
|
||||
let visibility = status["visibility"]
|
||||
.as_str()
|
||||
.unwrap_or("public")
|
||||
.to_string();
|
||||
let in_reply_to = status["in_reply_to_id"].as_str().map(|s| s.to_string());
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd_name = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text)
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"status_id".to_string(),
|
||||
serde_json::Value::String(status_id.clone()),
|
||||
);
|
||||
metadata.insert(
|
||||
"notification_id".to_string(),
|
||||
serde_json::Value::String(notif_id),
|
||||
);
|
||||
metadata.insert("acct".to_string(), serde_json::Value::String(acct));
|
||||
metadata.insert(
|
||||
"visibility".to_string(),
|
||||
serde_json::Value::String(visibility),
|
||||
);
|
||||
if let Some(ref reply_to) = in_reply_to {
|
||||
metadata.insert(
|
||||
"in_reply_to_id".to_string(),
|
||||
serde_json::Value::String(reply_to.clone()),
|
||||
);
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("mastodon".to_string()),
|
||||
platform_message_id: status_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: account_id.to_string(),
|
||||
display_name,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: false, // Mentions are treated as DM-like interactions
|
||||
thread_id: in_reply_to,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
/// Simple HTML tag stripper for Mastodon status content.
|
||||
///
|
||||
/// Mastodon returns HTML in status content. This strips tags and decodes
|
||||
/// common HTML entities. For production, consider a proper HTML sanitizer.
|
||||
fn strip_html_tags(html: &str) -> String {
|
||||
let mut result = String::with_capacity(html.len());
|
||||
let mut in_tag = false;
|
||||
let mut tag_buf = String::new();
|
||||
|
||||
for ch in html.chars() {
|
||||
match ch {
|
||||
'<' => {
|
||||
in_tag = true;
|
||||
tag_buf.clear();
|
||||
}
|
||||
'>' if in_tag => {
|
||||
in_tag = false;
|
||||
// Insert newline for block-level closing tags
|
||||
let tag_lower = tag_buf.to_lowercase();
|
||||
if tag_lower.starts_with("br")
|
||||
|| tag_lower.starts_with("/p")
|
||||
|| tag_lower.starts_with("/div")
|
||||
|| tag_lower.starts_with("/li")
|
||||
{
|
||||
result.push('\n');
|
||||
}
|
||||
tag_buf.clear();
|
||||
}
|
||||
_ if in_tag => {
|
||||
tag_buf.push(ch);
|
||||
}
|
||||
_ => {
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decode HTML entities
|
||||
let decoded = result
|
||||
.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace(""", "\"")
|
||||
.replace("'", "'")
|
||||
.replace("'", "'")
|
||||
.replace("'", "'")
|
||||
.replace(" ", " ");
|
||||
|
||||
decoded.trim().to_string()
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for MastodonAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"mastodon"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("mastodon".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let (account_id, username) = self.validate().await?;
|
||||
info!("Mastodon adapter authenticated as @{username} (id: {account_id})");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let instance_url = self.instance_url.clone();
|
||||
let access_token = self.access_token.clone();
|
||||
let own_account_id = account_id;
|
||||
let client = self.client.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let poll_interval = Duration::from_secs(SSE_RECONNECT_DELAY_SECS);
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
let mut last_notification_id: Option<String> = None;
|
||||
let mut use_streaming = true;
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Mastodon adapter shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(poll_interval) => {}
|
||||
}
|
||||
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
if use_streaming {
|
||||
// Attempt SSE connection to streaming API
|
||||
let stream_url = format!("{}/api/v1/streaming/user", instance_url);
|
||||
|
||||
match client
|
||||
.get(&stream_url)
|
||||
.bearer_auth(access_token.as_str())
|
||||
.header("Accept", "text/event-stream")
|
||||
.timeout(Duration::from_secs(5))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) if r.status().is_success() => {
|
||||
info!("Mastodon: connected to SSE stream");
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
use futures::StreamExt;
|
||||
let mut bytes_stream = r.bytes_stream();
|
||||
let mut event_type = String::new();
|
||||
|
||||
while let Some(chunk_result) = bytes_stream.next().await {
|
||||
if *shutdown_rx.borrow_and_update() {
|
||||
return;
|
||||
}
|
||||
|
||||
let chunk = match chunk_result {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
warn!("Mastodon SSE stream error: {e}");
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let text = String::from_utf8_lossy(&chunk);
|
||||
for line in text.lines() {
|
||||
if let Some(ev) = line.strip_prefix("event: ") {
|
||||
event_type = ev.trim().to_string();
|
||||
} else if let Some(data) = line.strip_prefix("data: ") {
|
||||
if event_type == "notification" {
|
||||
if let Ok(notif) =
|
||||
serde_json::from_str::<serde_json::Value>(data)
|
||||
{
|
||||
if let Some(msg) = parse_mastodon_notification(
|
||||
¬if,
|
||||
&own_account_id,
|
||||
) {
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
event_type.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream ended, will reconnect
|
||||
}
|
||||
Ok(r) => {
|
||||
warn!(
|
||||
"Mastodon SSE: non-success status {}, falling back to polling",
|
||||
r.status()
|
||||
);
|
||||
use_streaming = false;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Mastodon SSE connection failed: {e}, falling back to polling");
|
||||
use_streaming = false;
|
||||
}
|
||||
}
|
||||
|
||||
// Backoff before reconnect attempt
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(MAX_BACKOFF_SECS));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Polling fallback: fetch notifications via REST
|
||||
let mut url = format!(
|
||||
"{}/api/v1/notifications?types[]=mention&limit=30",
|
||||
instance_url
|
||||
);
|
||||
if let Some(ref sid) = last_notification_id {
|
||||
url.push_str(&format!("&since_id={}", sid));
|
||||
}
|
||||
|
||||
let poll_resp = match client
|
||||
.get(&url)
|
||||
.bearer_auth(access_token.as_str())
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Mastodon: notification poll error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if !poll_resp.status().is_success() {
|
||||
warn!(
|
||||
"Mastodon: notification poll returned {}",
|
||||
poll_resp.status()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let notifications: Vec<serde_json::Value> =
|
||||
poll_resp.json().await.unwrap_or_default();
|
||||
|
||||
for notif in ¬ifications {
|
||||
if let Some(nid) = notif["id"].as_str() {
|
||||
last_notification_id = Some(nid.to_string());
|
||||
}
|
||||
if let Some(msg) = parse_mastodon_notification(notif, &own_account_id) {
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
backoff = Duration::from_secs(1);
|
||||
}
|
||||
|
||||
info!("Mastodon polling loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
// _user.platform_id is the account_id; we use status_id from metadata for reply
|
||||
self.api_post_status(&text, None, "unlisted").await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_post_status("(Unsupported content type)", None, "unlisted")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_in_thread(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
thread_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_post_status(&text, Some(thread_id), "unlisted")
|
||||
.await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_post_status("(Unsupported content type)", Some(thread_id), "unlisted")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Mastodon does not support typing indicators
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mastodon_adapter_creation() {
|
||||
let adapter = MastodonAdapter::new(
|
||||
"https://mastodon.social".to_string(),
|
||||
"access-token-123".to_string(),
|
||||
);
|
||||
assert_eq!(adapter.name(), "mastodon");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("mastodon".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mastodon_url_normalization() {
|
||||
let adapter =
|
||||
MastodonAdapter::new("https://mastodon.social/".to_string(), "tok".to_string());
|
||||
assert_eq!(adapter.instance_url, "https://mastodon.social");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mastodon_custom_instance() {
|
||||
let adapter =
|
||||
MastodonAdapter::new("https://infosec.exchange".to_string(), "tok".to_string());
|
||||
assert_eq!(adapter.instance_url, "https://infosec.exchange");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_html_tags_basic() {
|
||||
assert_eq!(
|
||||
strip_html_tags("<p>Hello <strong>world</strong></p>"),
|
||||
"Hello world"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_html_tags_entities() {
|
||||
assert_eq!(strip_html_tags("a & b < c"), "a & b < c");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_html_tags_empty() {
|
||||
assert_eq!(strip_html_tags(""), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_html_tags_no_tags() {
|
||||
assert_eq!(strip_html_tags("plain text"), "plain text");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_html_tags_emoji() {
|
||||
assert_eq!(
|
||||
strip_html_tags("<p>Hello 🦀🔥 world</p>"),
|
||||
"Hello 🦀🔥 world"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_html_tags_cjk() {
|
||||
assert_eq!(
|
||||
strip_html_tags("<p>你好 <strong>世界</strong></p>"),
|
||||
"你好 世界"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_html_tags_numeric_entities() {
|
||||
assert_eq!(strip_html_tags("'hello'"), "'hello'");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_html_tags_div_newline() {
|
||||
assert_eq!(
|
||||
strip_html_tags("<div>one</div><div>two</div>").trim(),
|
||||
"one\ntwo"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mastodon_notification_mention() {
|
||||
let notif = serde_json::json!({
|
||||
"id": "notif-1",
|
||||
"type": "mention",
|
||||
"account": {
|
||||
"id": "acct-123",
|
||||
"username": "alice",
|
||||
"display_name": "Alice",
|
||||
"acct": "alice@mastodon.social"
|
||||
},
|
||||
"status": {
|
||||
"id": "status-456",
|
||||
"content": "<p>@bot Hello!</p>",
|
||||
"visibility": "public",
|
||||
"in_reply_to_id": null
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_mastodon_notification(¬if, "acct-999").unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Custom("mastodon".to_string()));
|
||||
assert_eq!(msg.sender.display_name, "Alice");
|
||||
assert_eq!(msg.platform_message_id, "status-456");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mastodon_notification_non_mention() {
|
||||
let notif = serde_json::json!({
|
||||
"id": "notif-1",
|
||||
"type": "favourite",
|
||||
"account": {
|
||||
"id": "acct-123",
|
||||
"username": "alice"
|
||||
},
|
||||
"status": {
|
||||
"id": "status-456",
|
||||
"content": "<p>liked</p>"
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_mastodon_notification(¬if, "acct-999").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mastodon_notification_own_mention() {
|
||||
let notif = serde_json::json!({
|
||||
"id": "notif-1",
|
||||
"type": "mention",
|
||||
"account": {
|
||||
"id": "acct-999",
|
||||
"username": "bot"
|
||||
},
|
||||
"status": {
|
||||
"id": "status-1",
|
||||
"content": "<p>self mention</p>",
|
||||
"visibility": "public"
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_mastodon_notification(¬if, "acct-999").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mastodon_notification_visibility() {
|
||||
let notif = serde_json::json!({
|
||||
"id": "notif-1",
|
||||
"type": "mention",
|
||||
"account": {
|
||||
"id": "acct-123",
|
||||
"username": "alice",
|
||||
"display_name": "Alice",
|
||||
"acct": "alice"
|
||||
},
|
||||
"status": {
|
||||
"id": "status-1",
|
||||
"content": "<p>DM to bot</p>",
|
||||
"visibility": "direct",
|
||||
"in_reply_to_id": null
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_mastodon_notification(¬if, "acct-999").unwrap();
|
||||
assert_eq!(
|
||||
msg.metadata.get("visibility").and_then(|v| v.as_str()),
|
||||
Some("direct")
|
||||
);
|
||||
}
|
||||
}
|
||||
356
crates/openfang-channels/src/matrix.rs
Normal file
356
crates/openfang-channels/src/matrix.rs
Normal file
@@ -0,0 +1,356 @@
|
||||
//! Matrix channel adapter.
|
||||
//!
|
||||
//! Uses the Matrix Client-Server API (via reqwest) for sending and receiving messages.
|
||||
//! Implements /sync long-polling for real-time message reception.
|
||||
|
||||
use crate::types::{ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const SYNC_TIMEOUT_MS: u64 = 30000;
|
||||
const MAX_MESSAGE_LEN: usize = 4096;
|
||||
|
||||
/// Matrix channel adapter using the Client-Server API.
|
||||
pub struct MatrixAdapter {
|
||||
/// Matrix homeserver URL (e.g., `"https://matrix.org"`).
|
||||
homeserver_url: String,
|
||||
/// Bot's user ID (e.g., "@openfang:matrix.org").
|
||||
user_id: String,
|
||||
/// SECURITY: Access token is zeroized on drop.
|
||||
access_token: Zeroizing<String>,
|
||||
/// HTTP client.
|
||||
client: reqwest::Client,
|
||||
/// Allowed room IDs (empty = all joined rooms).
|
||||
allowed_rooms: Vec<String>,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Sync token for resuming /sync.
|
||||
since_token: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl MatrixAdapter {
|
||||
/// Create a new Matrix adapter.
|
||||
pub fn new(
|
||||
homeserver_url: String,
|
||||
user_id: String,
|
||||
access_token: String,
|
||||
allowed_rooms: Vec<String>,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
homeserver_url,
|
||||
user_id,
|
||||
access_token: Zeroizing::new(access_token),
|
||||
client: reqwest::Client::new(),
|
||||
allowed_rooms,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
since_token: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a text message to a Matrix room.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
room_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let txn_id = uuid::Uuid::new_v4().to_string();
|
||||
let url = format!(
|
||||
"{}/_matrix/client/v3/rooms/{}/send/m.room.message/{}",
|
||||
self.homeserver_url, room_id, txn_id
|
||||
);
|
||||
|
||||
let chunks = crate::types::split_message(text, MAX_MESSAGE_LEN);
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"msgtype": "m.text",
|
||||
"body": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.put(&url)
|
||||
.bearer_auth(&*self.access_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Matrix API error {status}: {body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate credentials by calling /whoami.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/_matrix/client/v3/account/whoami", self.homeserver_url);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(&*self.access_token)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Matrix authentication failed".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let user_id = body["user_id"].as_str().unwrap_or("unknown").to_string();
|
||||
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_room(&self, room_id: &str) -> bool {
|
||||
self.allowed_rooms.is_empty() || self.allowed_rooms.iter().any(|r| r == room_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for MatrixAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"matrix"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Matrix
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let validated_user = self.validate().await?;
|
||||
info!("Matrix adapter authenticated as {validated_user}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let homeserver = self.homeserver_url.clone();
|
||||
let access_token = self.access_token.clone();
|
||||
let user_id = self.user_id.clone();
|
||||
let allowed_rooms = self.allowed_rooms.clone();
|
||||
let client = self.client.clone();
|
||||
let since_token = Arc::clone(&self.since_token);
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
// Build /sync URL
|
||||
let since = since_token.read().await.clone();
|
||||
let mut url = format!(
|
||||
"{}/_matrix/client/v3/sync?timeout={}&filter={{\"room\":{{\"timeline\":{{\"limit\":10}}}}}}",
|
||||
homeserver, SYNC_TIMEOUT_MS
|
||||
);
|
||||
if let Some(ref token) = since {
|
||||
url.push_str(&format!("&since={token}"));
|
||||
}
|
||||
|
||||
let resp = tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Matrix adapter shutting down");
|
||||
break;
|
||||
}
|
||||
result = client.get(&url).bearer_auth(&*access_token).send() => {
|
||||
match result {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Matrix sync error: {e}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
warn!("Matrix sync returned {}", resp.status());
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
let body: serde_json::Value = match resp.json().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("Matrix sync parse error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Update since token
|
||||
if let Some(next) = body["next_batch"].as_str() {
|
||||
*since_token.write().await = Some(next.to_string());
|
||||
}
|
||||
|
||||
// Process room events
|
||||
if let Some(rooms) = body["rooms"]["join"].as_object() {
|
||||
for (room_id, room_data) in rooms {
|
||||
if !allowed_rooms.is_empty() && !allowed_rooms.iter().any(|r| r == room_id)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(events) = room_data["timeline"]["events"].as_array() {
|
||||
for event in events {
|
||||
let event_type = event["type"].as_str().unwrap_or("");
|
||||
if event_type != "m.room.message" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let sender = event["sender"].as_str().unwrap_or("");
|
||||
if sender == user_id {
|
||||
continue; // Skip own messages
|
||||
}
|
||||
|
||||
let content = event["content"]["body"].as_str().unwrap_or("");
|
||||
if content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg_content = if content.starts_with('/') {
|
||||
let parts: Vec<&str> = content.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(content.to_string())
|
||||
};
|
||||
|
||||
let event_id = event["event_id"].as_str().unwrap_or("").to_string();
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Matrix,
|
||||
platform_message_id: event_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: room_id.clone(),
|
||||
display_name: sender.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true,
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/_matrix/client/v3/rooms/{}/typing/{}",
|
||||
self.homeserver_url, user.platform_id, self.user_id
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"typing": true,
|
||||
"timeout": 5000,
|
||||
});
|
||||
|
||||
let _ = self
|
||||
.client
|
||||
.put(&url)
|
||||
.bearer_auth(&*self.access_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_matrix_adapter_creation() {
|
||||
let adapter = MatrixAdapter::new(
|
||||
"https://matrix.org".to_string(),
|
||||
"@bot:matrix.org".to_string(),
|
||||
"access_token".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.name(), "matrix");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matrix_allowed_rooms() {
|
||||
let adapter = MatrixAdapter::new(
|
||||
"https://matrix.org".to_string(),
|
||||
"@bot:matrix.org".to_string(),
|
||||
"token".to_string(),
|
||||
vec!["!room1:matrix.org".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_room("!room1:matrix.org"));
|
||||
assert!(!adapter.is_allowed_room("!room2:matrix.org"));
|
||||
|
||||
let open = MatrixAdapter::new(
|
||||
"https://matrix.org".to_string(),
|
||||
"@bot:matrix.org".to_string(),
|
||||
"token".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert!(open.is_allowed_room("!any:matrix.org"));
|
||||
}
|
||||
}
|
||||
729
crates/openfang-channels/src/mattermost.rs
Normal file
729
crates/openfang-channels/src/mattermost.rs
Normal file
@@ -0,0 +1,729 @@
|
||||
//! Mattermost channel adapter for the OpenFang channel bridge.
|
||||
//!
|
||||
//! Uses the Mattermost WebSocket API v4 for real-time message reception and the
|
||||
//! REST API v4 for sending messages. No external Mattermost crate — just
|
||||
//! `tokio-tungstenite` + `reqwest`.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::{SinkExt, Stream, StreamExt};
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{debug, info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Maximum Mattermost message length (characters). The server limit is 16383.
|
||||
const MAX_MESSAGE_LEN: usize = 16383;
|
||||
const MAX_BACKOFF: Duration = Duration::from_secs(60);
|
||||
const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
|
||||
|
||||
/// Mattermost WebSocket + REST API v4 adapter.
|
||||
///
|
||||
/// Inbound messages arrive via WebSocket events (`posted`).
|
||||
/// Outbound messages are sent via `POST /api/v4/posts`.
|
||||
pub struct MattermostAdapter {
|
||||
/// Mattermost server URL (e.g., `"https://mattermost.example.com"`).
|
||||
server_url: String,
|
||||
/// SECURITY: Auth token is zeroized on drop to prevent memory disclosure.
|
||||
token: Zeroizing<String>,
|
||||
/// Restrict to specific channel IDs (empty = all channels the bot is in).
|
||||
allowed_channels: Vec<String>,
|
||||
/// HTTP client for outbound REST API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Bot's own user ID (populated after /api/v4/users/me).
|
||||
bot_user_id: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl MattermostAdapter {
|
||||
/// Create a new Mattermost adapter.
|
||||
///
|
||||
/// * `server_url` — Base Mattermost server URL (no trailing slash).
|
||||
/// * `token` — Personal access token or bot token.
|
||||
/// * `allowed_channels` — Channel IDs to listen on (empty = all).
|
||||
pub fn new(server_url: String, token: String, allowed_channels: Vec<String>) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
server_url: server_url.trim_end_matches('/').to_string(),
|
||||
token: Zeroizing::new(token),
|
||||
allowed_channels,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
bot_user_id: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate the token by calling `GET /api/v4/users/me`.
|
||||
async fn validate_token(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/api/v4/users/me", self.server_url);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Mattermost auth failed {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let user_id = body["id"].as_str().unwrap_or("unknown").to_string();
|
||||
let username = body["username"].as_str().unwrap_or("unknown");
|
||||
info!("Mattermost authenticated as {username} ({user_id})");
|
||||
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
/// Build the WebSocket URL for the Mattermost API v4.
|
||||
fn ws_url(&self) -> String {
|
||||
let base = if self.server_url.starts_with("https://") {
|
||||
self.server_url.replacen("https://", "wss://", 1)
|
||||
} else if self.server_url.starts_with("http://") {
|
||||
self.server_url.replacen("http://", "ws://", 1)
|
||||
} else {
|
||||
format!("wss://{}", self.server_url)
|
||||
};
|
||||
format!("{base}/api/v4/websocket")
|
||||
}
|
||||
|
||||
/// Send a text message to a Mattermost channel via REST API.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/api/v4/posts", self.server_url);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"channel_id": channel_id,
|
||||
"message": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
warn!("Mattermost sendMessage failed {status}: {resp_body}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check whether a channel ID is allowed (empty list = allow all).
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_channel(&self, channel_id: &str) -> bool {
|
||||
self.allowed_channels.is_empty() || self.allowed_channels.iter().any(|c| c == channel_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Mattermost WebSocket `posted` event into a `ChannelMessage`.
|
||||
///
|
||||
/// The `data` field of a `posted` event contains a JSON string under `post`
|
||||
/// which holds the actual post payload.
|
||||
fn parse_mattermost_event(
|
||||
event: &serde_json::Value,
|
||||
bot_user_id: &Option<String>,
|
||||
allowed_channels: &[String],
|
||||
) -> Option<ChannelMessage> {
|
||||
let event_type = event["event"].as_str().unwrap_or("");
|
||||
if event_type != "posted" {
|
||||
return None;
|
||||
}
|
||||
|
||||
// The `data.post` field is a JSON string that needs a second parse
|
||||
let post_str = event["data"]["post"].as_str()?;
|
||||
let post: serde_json::Value = serde_json::from_str(post_str).ok()?;
|
||||
|
||||
let user_id = post["user_id"].as_str().unwrap_or("");
|
||||
let channel_id = post["channel_id"].as_str().unwrap_or("");
|
||||
let message = post["message"].as_str().unwrap_or("");
|
||||
let post_id = post["id"].as_str().unwrap_or("").to_string();
|
||||
|
||||
// Skip messages from the bot itself
|
||||
if let Some(ref bid) = bot_user_id {
|
||||
if user_id == bid {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
// Filter by allowed channels
|
||||
if !allowed_channels.is_empty() && !allowed_channels.iter().any(|c| c == channel_id) {
|
||||
return None;
|
||||
}
|
||||
|
||||
if message.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Determine if group conversation from channel_type in event data
|
||||
let channel_type = event["data"]["channel_type"].as_str().unwrap_or("");
|
||||
let is_group = channel_type != "D"; // "D" = direct message
|
||||
|
||||
// Extract thread root id if this is a threaded reply
|
||||
let root_id = post["root_id"].as_str().unwrap_or("");
|
||||
let thread_id = if root_id.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(root_id.to_string())
|
||||
};
|
||||
|
||||
// Extract sender display name from event data
|
||||
let sender_name = event["data"]["sender_name"].as_str().unwrap_or(user_id);
|
||||
|
||||
// Parse commands (messages starting with /)
|
||||
let content = if message.starts_with('/') {
|
||||
let parts: Vec<&str> = message.splitn(2, ' ').collect();
|
||||
let cmd_name = &parts[0][1..];
|
||||
let args = if parts.len() > 1 {
|
||||
parts[1].split_whitespace().map(String::from).collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(message.to_string())
|
||||
};
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Mattermost,
|
||||
platform_message_id: post_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: channel_id.to_string(),
|
||||
display_name: sender_name.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id,
|
||||
metadata: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for MattermostAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"mattermost"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Mattermost
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate token and get bot user ID
|
||||
let user_id = self.validate_token().await?;
|
||||
*self.bot_user_id.write().await = Some(user_id);
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let ws_url = self.ws_url();
|
||||
let token = self.token.clone();
|
||||
let bot_user_id = self.bot_user_id.clone();
|
||||
let allowed_channels = self.allowed_channels.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = INITIAL_BACKOFF;
|
||||
|
||||
loop {
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
info!("Connecting to Mattermost WebSocket at {ws_url}...");
|
||||
|
||||
let ws_result = tokio_tungstenite::connect_async(&ws_url).await;
|
||||
let ws_stream = match ws_result {
|
||||
Ok((stream, _)) => stream,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Mattermost WebSocket connection failed: {e}, retrying in {backoff:?}"
|
||||
);
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
backoff = INITIAL_BACKOFF;
|
||||
info!("Mattermost WebSocket connected");
|
||||
|
||||
let (mut ws_tx, mut ws_rx) = ws_stream.split();
|
||||
|
||||
// Authenticate over WebSocket with the token
|
||||
let auth_msg = serde_json::json!({
|
||||
"seq": 1,
|
||||
"action": "authentication_challenge",
|
||||
"data": {
|
||||
"token": token.as_str()
|
||||
}
|
||||
});
|
||||
|
||||
if let Err(e) = ws_tx
|
||||
.send(tokio_tungstenite::tungstenite::Message::Text(
|
||||
serde_json::to_string(&auth_msg).unwrap(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
warn!("Mattermost WebSocket auth send failed: {e}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Inner message loop — returns true if we should reconnect
|
||||
let should_reconnect = 'inner: loop {
|
||||
let msg = tokio::select! {
|
||||
msg = ws_rx.next() => msg,
|
||||
_ = shutdown_rx.changed() => {
|
||||
if *shutdown_rx.borrow() {
|
||||
info!("Mattermost adapter shutting down");
|
||||
let _ = ws_tx.close().await;
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let msg = match msg {
|
||||
Some(Ok(m)) => m,
|
||||
Some(Err(e)) => {
|
||||
warn!("Mattermost WebSocket error: {e}");
|
||||
break 'inner true;
|
||||
}
|
||||
None => {
|
||||
info!("Mattermost WebSocket closed");
|
||||
break 'inner true;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match msg {
|
||||
tokio_tungstenite::tungstenite::Message::Text(t) => t,
|
||||
tokio_tungstenite::tungstenite::Message::Close(_) => {
|
||||
info!("Mattermost WebSocket closed by server");
|
||||
break 'inner true;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let payload: serde_json::Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
warn!("Mattermost: failed to parse message: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Check for auth response
|
||||
if payload.get("status").is_some() {
|
||||
let status = payload["status"].as_str().unwrap_or("");
|
||||
if status == "OK" {
|
||||
debug!("Mattermost WebSocket authentication successful");
|
||||
} else {
|
||||
warn!("Mattermost WebSocket auth response: {status}");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse events
|
||||
let bot_id_guard = bot_user_id.read().await;
|
||||
if let Some(channel_msg) =
|
||||
parse_mattermost_event(&payload, &bot_id_guard, &allowed_channels)
|
||||
{
|
||||
debug!(
|
||||
"Mattermost message from {}: {:?}",
|
||||
channel_msg.sender.display_name, channel_msg.content
|
||||
);
|
||||
drop(bot_id_guard);
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !should_reconnect || *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
warn!("Mattermost: reconnecting in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
}
|
||||
|
||||
info!("Mattermost WebSocket loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let channel_id = &user.platform_id;
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(channel_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(channel_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Mattermost supports typing indicators via the WebSocket, but since we
|
||||
// only hold a WebSocket reader in the spawn loop, we use the REST API
|
||||
// userTyping action via a POST to /api/v4/users/me/typing.
|
||||
let url = format!("{}/api/v4/users/me/typing", self.server_url);
|
||||
let body = serde_json::json!({
|
||||
"channel_id": user.platform_id,
|
||||
});
|
||||
|
||||
let _ = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_in_thread(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
thread_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let channel_id = &user.platform_id;
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
let url = format!("{}/api/v4/posts", self.server_url);
|
||||
let chunks = split_message(&text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"channel_id": channel_id,
|
||||
"message": chunk,
|
||||
"root_id": thread_id,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
warn!("Mattermost send_in_thread failed {status}: {resp_body}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mattermost_adapter_creation() {
|
||||
let adapter = MattermostAdapter::new(
|
||||
"https://mattermost.example.com".to_string(),
|
||||
"test-token".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.name(), "mattermost");
|
||||
assert_eq!(adapter.channel_type(), ChannelType::Mattermost);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mattermost_ws_url_https() {
|
||||
let adapter = MattermostAdapter::new(
|
||||
"https://mm.example.com".to_string(),
|
||||
"token".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.ws_url(), "wss://mm.example.com/api/v4/websocket");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mattermost_ws_url_http() {
|
||||
let adapter = MattermostAdapter::new(
|
||||
"http://localhost:8065".to_string(),
|
||||
"token".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.ws_url(), "ws://localhost:8065/api/v4/websocket");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mattermost_ws_url_trailing_slash() {
|
||||
let adapter = MattermostAdapter::new(
|
||||
"https://mm.example.com/".to_string(),
|
||||
"token".to_string(),
|
||||
vec![],
|
||||
);
|
||||
// Constructor trims trailing slash
|
||||
assert_eq!(adapter.ws_url(), "wss://mm.example.com/api/v4/websocket");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mattermost_allowed_channels() {
|
||||
let adapter = MattermostAdapter::new(
|
||||
"https://mm.example.com".to_string(),
|
||||
"token".to_string(),
|
||||
vec!["ch1".to_string(), "ch2".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_channel("ch1"));
|
||||
assert!(adapter.is_allowed_channel("ch2"));
|
||||
assert!(!adapter.is_allowed_channel("ch3"));
|
||||
|
||||
let open = MattermostAdapter::new(
|
||||
"https://mm.example.com".to_string(),
|
||||
"token".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert!(open.is_allowed_channel("any-channel"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mattermost_event_basic() {
|
||||
let post = serde_json::json!({
|
||||
"id": "post-1",
|
||||
"user_id": "user-456",
|
||||
"channel_id": "ch-789",
|
||||
"message": "Hello from Mattermost!",
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let event = serde_json::json!({
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": serde_json::to_string(&post).unwrap(),
|
||||
"channel_type": "O",
|
||||
"sender_name": "alice"
|
||||
}
|
||||
});
|
||||
|
||||
let bot_id = Some("bot-123".to_string());
|
||||
let msg = parse_mattermost_event(&event, &bot_id, &[]).unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Mattermost);
|
||||
assert_eq!(msg.sender.display_name, "alice");
|
||||
assert_eq!(msg.sender.platform_id, "ch-789");
|
||||
assert!(msg.is_group);
|
||||
assert!(msg.thread_id.is_none());
|
||||
assert!(
|
||||
matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from Mattermost!")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mattermost_event_dm() {
|
||||
let post = serde_json::json!({
|
||||
"id": "post-1",
|
||||
"user_id": "user-456",
|
||||
"channel_id": "ch-789",
|
||||
"message": "DM message",
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let event = serde_json::json!({
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": serde_json::to_string(&post).unwrap(),
|
||||
"channel_type": "D",
|
||||
"sender_name": "bob"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_mattermost_event(&event, &None, &[]).unwrap();
|
||||
assert!(!msg.is_group);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mattermost_event_threaded() {
|
||||
let post = serde_json::json!({
|
||||
"id": "post-2",
|
||||
"user_id": "user-456",
|
||||
"channel_id": "ch-789",
|
||||
"message": "Thread reply",
|
||||
"root_id": "post-1"
|
||||
});
|
||||
|
||||
let event = serde_json::json!({
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": serde_json::to_string(&post).unwrap(),
|
||||
"channel_type": "O",
|
||||
"sender_name": "alice"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_mattermost_event(&event, &None, &[]).unwrap();
|
||||
assert_eq!(msg.thread_id, Some("post-1".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mattermost_event_skips_bot() {
|
||||
let post = serde_json::json!({
|
||||
"id": "post-1",
|
||||
"user_id": "bot-123",
|
||||
"channel_id": "ch-789",
|
||||
"message": "Bot message",
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let event = serde_json::json!({
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": serde_json::to_string(&post).unwrap(),
|
||||
"channel_type": "O",
|
||||
"sender_name": "openfang-bot"
|
||||
}
|
||||
});
|
||||
|
||||
let bot_id = Some("bot-123".to_string());
|
||||
let msg = parse_mattermost_event(&event, &bot_id, &[]);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mattermost_event_channel_filter() {
|
||||
let post = serde_json::json!({
|
||||
"id": "post-1",
|
||||
"user_id": "user-456",
|
||||
"channel_id": "ch-789",
|
||||
"message": "Hello",
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let event = serde_json::json!({
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": serde_json::to_string(&post).unwrap(),
|
||||
"channel_type": "O",
|
||||
"sender_name": "alice"
|
||||
}
|
||||
});
|
||||
|
||||
// Not in allowed channels
|
||||
let msg =
|
||||
parse_mattermost_event(&event, &None, &["ch-111".to_string(), "ch-222".to_string()]);
|
||||
assert!(msg.is_none());
|
||||
|
||||
// In allowed channels
|
||||
let msg = parse_mattermost_event(&event, &None, &["ch-789".to_string()]);
|
||||
assert!(msg.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mattermost_event_command() {
|
||||
let post = serde_json::json!({
|
||||
"id": "post-1",
|
||||
"user_id": "user-456",
|
||||
"channel_id": "ch-789",
|
||||
"message": "/agent hello-world",
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let event = serde_json::json!({
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": serde_json::to_string(&post).unwrap(),
|
||||
"channel_type": "O",
|
||||
"sender_name": "alice"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_mattermost_event(&event, &None, &[]).unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "agent");
|
||||
assert_eq!(args, &["hello-world"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mattermost_event_non_posted() {
|
||||
let event = serde_json::json!({
|
||||
"event": "typing",
|
||||
"data": {}
|
||||
});
|
||||
|
||||
let msg = parse_mattermost_event(&event, &None, &[]);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_mattermost_event_empty_message() {
|
||||
let post = serde_json::json!({
|
||||
"id": "post-1",
|
||||
"user_id": "user-456",
|
||||
"channel_id": "ch-789",
|
||||
"message": "",
|
||||
"root_id": ""
|
||||
});
|
||||
|
||||
let event = serde_json::json!({
|
||||
"event": "posted",
|
||||
"data": {
|
||||
"post": serde_json::to_string(&post).unwrap(),
|
||||
"channel_type": "O",
|
||||
"sender_name": "alice"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_mattermost_event(&event, &None, &[]);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
}
|
||||
625
crates/openfang-channels/src/messenger.rs
Normal file
625
crates/openfang-channels/src/messenger.rs
Normal file
@@ -0,0 +1,625 @@
|
||||
//! Facebook Messenger Platform channel adapter.
|
||||
//!
|
||||
//! Uses the Facebook Messenger Platform Send API (Graph API v18.0) for sending
|
||||
//! messages and a webhook HTTP server for receiving inbound events. The webhook
|
||||
//! supports both GET (verification challenge) and POST (message events).
|
||||
//! Authentication uses the page access token as a query parameter on the Send API.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Facebook Graph API base URL for sending messages.
|
||||
const GRAPH_API_BASE: &str = "https://graph.facebook.com/v18.0";
|
||||
|
||||
/// Maximum Messenger message text length (characters).
|
||||
const MAX_MESSAGE_LEN: usize = 2000;
|
||||
|
||||
/// Facebook Messenger Platform adapter.
|
||||
///
|
||||
/// Inbound messages arrive via a webhook HTTP server that supports:
|
||||
/// - GET requests for Facebook's webhook verification challenge
|
||||
/// - POST requests for incoming message events
|
||||
///
|
||||
/// Outbound messages are sent via the Messenger Send API using
|
||||
/// the page access token for authentication.
|
||||
pub struct MessengerAdapter {
|
||||
/// SECURITY: Page access token for the Send API, zeroized on drop.
|
||||
page_token: Zeroizing<String>,
|
||||
/// SECURITY: Verify token for webhook registration, zeroized on drop.
|
||||
verify_token: Zeroizing<String>,
|
||||
/// Port on which the inbound webhook HTTP server listens.
|
||||
webhook_port: u16,
|
||||
/// HTTP client for outbound API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl MessengerAdapter {
|
||||
/// Create a new Messenger adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `page_token` - Facebook page access token for the Send API.
|
||||
/// * `verify_token` - Token used to verify the webhook during Facebook's setup.
|
||||
/// * `webhook_port` - Local port for the inbound webhook HTTP server.
|
||||
pub fn new(page_token: String, verify_token: String, webhook_port: u16) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
page_token: Zeroizing::new(page_token),
|
||||
verify_token: Zeroizing::new(verify_token),
|
||||
webhook_port,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate the page token by calling the Graph API to get page info.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/me?access_token={}",
|
||||
GRAPH_API_BASE,
|
||||
self.page_token.as_str()
|
||||
);
|
||||
|
||||
let resp = self.client.get(&url).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Messenger authentication failed {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let page_name = body["name"].as_str().unwrap_or("Messenger Bot").to_string();
|
||||
Ok(page_name)
|
||||
}
|
||||
|
||||
/// Send a text message to a Messenger user via the Send API.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
recipient_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/me/messages?access_token={}",
|
||||
GRAPH_API_BASE,
|
||||
self.page_token.as_str()
|
||||
);
|
||||
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"recipient": {
|
||||
"id": recipient_id,
|
||||
},
|
||||
"message": {
|
||||
"text": chunk,
|
||||
},
|
||||
"messaging_type": "RESPONSE",
|
||||
});
|
||||
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Messenger Send API error {status}: {resp_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a typing indicator (sender action) to a Messenger user.
|
||||
async fn api_send_action(
|
||||
&self,
|
||||
recipient_id: &str,
|
||||
action: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/me/messages?access_token={}",
|
||||
GRAPH_API_BASE,
|
||||
self.page_token.as_str()
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"recipient": {
|
||||
"id": recipient_id,
|
||||
},
|
||||
"sender_action": action,
|
||||
});
|
||||
|
||||
let _ = self.client.post(&url).json(&body).send().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mark a message as seen via sender action.
|
||||
#[allow(dead_code)]
|
||||
async fn mark_seen(&self, recipient_id: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.api_send_action(recipient_id, "mark_seen").await
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse Facebook Messenger webhook entry into `ChannelMessage` values.
|
||||
///
|
||||
/// A single webhook POST can contain multiple entries, each with multiple
|
||||
/// messaging events. This function processes one entry and returns all
|
||||
/// valid messages found.
|
||||
fn parse_messenger_entry(entry: &serde_json::Value) -> Vec<ChannelMessage> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
let messaging = match entry["messaging"].as_array() {
|
||||
Some(arr) => arr,
|
||||
None => return messages,
|
||||
};
|
||||
|
||||
for event in messaging {
|
||||
// Only handle message events (not delivery, read, postback, etc.)
|
||||
let message = match event.get("message") {
|
||||
Some(m) => m,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
// Skip echo messages (sent by the page itself)
|
||||
if message["is_echo"].as_bool().unwrap_or(false) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let text = match message["text"].as_str() {
|
||||
Some(t) if !t.is_empty() => t,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let sender_id = event["sender"]["id"].as_str().unwrap_or("").to_string();
|
||||
let recipient_id = event["recipient"]["id"].as_str().unwrap_or("").to_string();
|
||||
let msg_id = message["mid"].as_str().unwrap_or("").to_string();
|
||||
let timestamp = event["timestamp"].as_u64().unwrap_or(0);
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd_name = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"sender_id".to_string(),
|
||||
serde_json::Value::String(sender_id.clone()),
|
||||
);
|
||||
metadata.insert(
|
||||
"recipient_id".to_string(),
|
||||
serde_json::Value::String(recipient_id),
|
||||
);
|
||||
metadata.insert(
|
||||
"timestamp".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from(timestamp)),
|
||||
);
|
||||
|
||||
// Check for quick reply payload
|
||||
if let Some(qr) = message.get("quick_reply") {
|
||||
if let Some(payload) = qr["payload"].as_str() {
|
||||
metadata.insert(
|
||||
"quick_reply_payload".to_string(),
|
||||
serde_json::Value::String(payload.to_string()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for NLP entities (if enabled on the page)
|
||||
if let Some(nlp) = message.get("nlp") {
|
||||
if let Some(entities) = nlp.get("entities") {
|
||||
metadata.insert("nlp_entities".to_string(), entities.clone());
|
||||
}
|
||||
}
|
||||
|
||||
messages.push(ChannelMessage {
|
||||
channel: ChannelType::Custom("messenger".to_string()),
|
||||
platform_message_id: msg_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: sender_id,
|
||||
display_name: String::new(), // Messenger doesn't include name in webhook
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: false, // Messenger Bot API is always 1:1
|
||||
thread_id: None,
|
||||
metadata,
|
||||
});
|
||||
}
|
||||
|
||||
messages
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for MessengerAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"messenger"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("messenger".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let page_name = self.validate().await?;
|
||||
info!("Messenger adapter authenticated as {page_name}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let verify_token = self.verify_token.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let verify_token = Arc::new(verify_token);
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
let app = axum::Router::new().route(
|
||||
"/webhook",
|
||||
axum::routing::get({
|
||||
// Facebook webhook verification handler
|
||||
let vt = Arc::clone(&verify_token);
|
||||
move |query: axum::extract::Query<HashMap<String, String>>| {
|
||||
let vt = Arc::clone(&vt);
|
||||
async move {
|
||||
let mode = query.get("hub.mode").map(|s| s.as_str()).unwrap_or("");
|
||||
let token = query
|
||||
.get("hub.verify_token")
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("");
|
||||
let challenge = query.get("hub.challenge").cloned().unwrap_or_default();
|
||||
|
||||
if mode == "subscribe" && token == vt.as_str() {
|
||||
info!("Messenger webhook verified");
|
||||
(axum::http::StatusCode::OK, challenge)
|
||||
} else {
|
||||
warn!("Messenger webhook verification failed");
|
||||
(axum::http::StatusCode::FORBIDDEN, String::new())
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.post({
|
||||
// Incoming message handler
|
||||
let tx = Arc::clone(&tx);
|
||||
move |body: axum::extract::Json<serde_json::Value>| {
|
||||
let tx = Arc::clone(&tx);
|
||||
async move {
|
||||
let object = body.0["object"].as_str().unwrap_or("");
|
||||
if object != "page" {
|
||||
return axum::http::StatusCode::OK;
|
||||
}
|
||||
|
||||
if let Some(entries) = body.0["entry"].as_array() {
|
||||
for entry in entries {
|
||||
let msgs = parse_messenger_entry(entry);
|
||||
for msg in msgs {
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
axum::http::StatusCode::OK
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
info!("Messenger webhook server listening on {addr}");
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("Messenger webhook bind failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let server = axum::serve(listener, app);
|
||||
|
||||
tokio::select! {
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
warn!("Messenger webhook server error: {e}");
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Messenger adapter shutting down");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
ChannelContent::Image { url, caption } => {
|
||||
// Send image attachment via Messenger
|
||||
let api_url = format!(
|
||||
"{}/me/messages?access_token={}",
|
||||
GRAPH_API_BASE,
|
||||
self.page_token.as_str()
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"recipient": {
|
||||
"id": user.platform_id,
|
||||
},
|
||||
"message": {
|
||||
"attachment": {
|
||||
"type": "image",
|
||||
"payload": {
|
||||
"url": url,
|
||||
"is_reusable": true,
|
||||
}
|
||||
}
|
||||
},
|
||||
"messaging_type": "RESPONSE",
|
||||
});
|
||||
|
||||
let resp = self.client.post(&api_url).json(&body).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
warn!("Messenger image send error {status}: {resp_body}");
|
||||
}
|
||||
|
||||
// Send caption as a separate text message
|
||||
if let Some(cap) = caption {
|
||||
if !cap.is_empty() {
|
||||
self.api_send_message(&user.platform_id, &cap).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.api_send_action(&user.platform_id, "typing_on").await
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_messenger_adapter_creation() {
|
||||
let adapter = MessengerAdapter::new(
|
||||
"page-token-123".to_string(),
|
||||
"verify-token-456".to_string(),
|
||||
8080,
|
||||
);
|
||||
assert_eq!(adapter.name(), "messenger");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("messenger".to_string())
|
||||
);
|
||||
assert_eq!(adapter.webhook_port, 8080);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_messenger_both_tokens() {
|
||||
let adapter = MessengerAdapter::new("page-tok".to_string(), "verify-tok".to_string(), 9000);
|
||||
assert_eq!(adapter.page_token.as_str(), "page-tok");
|
||||
assert_eq!(adapter.verify_token.as_str(), "verify-tok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_messenger_entry_text_message() {
|
||||
let entry = serde_json::json!({
|
||||
"id": "page-id-123",
|
||||
"time": 1458692752478_u64,
|
||||
"messaging": [
|
||||
{
|
||||
"sender": { "id": "user-123" },
|
||||
"recipient": { "id": "page-456" },
|
||||
"timestamp": 1458692752478_u64,
|
||||
"message": {
|
||||
"mid": "mid.123",
|
||||
"text": "Hello from Messenger!"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let msgs = parse_messenger_entry(&entry);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert_eq!(
|
||||
msgs[0].channel,
|
||||
ChannelType::Custom("messenger".to_string())
|
||||
);
|
||||
assert_eq!(msgs[0].sender.platform_id, "user-123");
|
||||
assert!(
|
||||
matches!(msgs[0].content, ChannelContent::Text(ref t) if t == "Hello from Messenger!")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_messenger_entry_command() {
|
||||
let entry = serde_json::json!({
|
||||
"id": "page-id",
|
||||
"messaging": [
|
||||
{
|
||||
"sender": { "id": "user-1" },
|
||||
"recipient": { "id": "page-1" },
|
||||
"timestamp": 0,
|
||||
"message": {
|
||||
"mid": "mid.456",
|
||||
"text": "/models list"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let msgs = parse_messenger_entry(&entry);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
match &msgs[0].content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "models");
|
||||
assert_eq!(args, &["list"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_messenger_entry_skips_echo() {
|
||||
let entry = serde_json::json!({
|
||||
"id": "page-id",
|
||||
"messaging": [
|
||||
{
|
||||
"sender": { "id": "page-1" },
|
||||
"recipient": { "id": "user-1" },
|
||||
"timestamp": 0,
|
||||
"message": {
|
||||
"mid": "mid.789",
|
||||
"text": "Echo message",
|
||||
"is_echo": true,
|
||||
"app_id": 12345
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let msgs = parse_messenger_entry(&entry);
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_messenger_entry_skips_delivery() {
|
||||
let entry = serde_json::json!({
|
||||
"id": "page-id",
|
||||
"messaging": [
|
||||
{
|
||||
"sender": { "id": "user-1" },
|
||||
"recipient": { "id": "page-1" },
|
||||
"timestamp": 0,
|
||||
"delivery": {
|
||||
"mids": ["mid.123"],
|
||||
"watermark": 1458668856253_u64
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let msgs = parse_messenger_entry(&entry);
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_messenger_entry_quick_reply() {
|
||||
let entry = serde_json::json!({
|
||||
"id": "page-id",
|
||||
"messaging": [
|
||||
{
|
||||
"sender": { "id": "user-1" },
|
||||
"recipient": { "id": "page-1" },
|
||||
"timestamp": 0,
|
||||
"message": {
|
||||
"mid": "mid.qr",
|
||||
"text": "Red",
|
||||
"quick_reply": {
|
||||
"payload": "DEVELOPER_DEFINED_PAYLOAD_FOR_RED"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let msgs = parse_messenger_entry(&entry);
|
||||
assert_eq!(msgs.len(), 1);
|
||||
assert!(msgs[0].metadata.contains_key("quick_reply_payload"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_messenger_entry_empty_text() {
|
||||
let entry = serde_json::json!({
|
||||
"id": "page-id",
|
||||
"messaging": [
|
||||
{
|
||||
"sender": { "id": "user-1" },
|
||||
"recipient": { "id": "page-1" },
|
||||
"timestamp": 0,
|
||||
"message": {
|
||||
"mid": "mid.empty",
|
||||
"text": ""
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let msgs = parse_messenger_entry(&entry);
|
||||
assert!(msgs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_messenger_entry_multiple_messages() {
|
||||
let entry = serde_json::json!({
|
||||
"id": "page-id",
|
||||
"messaging": [
|
||||
{
|
||||
"sender": { "id": "user-1" },
|
||||
"recipient": { "id": "page-1" },
|
||||
"timestamp": 0,
|
||||
"message": { "mid": "mid.1", "text": "First" }
|
||||
},
|
||||
{
|
||||
"sender": { "id": "user-2" },
|
||||
"recipient": { "id": "page-1" },
|
||||
"timestamp": 0,
|
||||
"message": { "mid": "mid.2", "text": "Second" }
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let msgs = parse_messenger_entry(&entry);
|
||||
assert_eq!(msgs.len(), 2);
|
||||
}
|
||||
}
|
||||
598
crates/openfang-channels/src/mumble.rs
Normal file
598
crates/openfang-channels/src/mumble.rs
Normal file
@@ -0,0 +1,598 @@
|
||||
//! Mumble text-chat channel adapter.
|
||||
//!
|
||||
//! Connects to a Mumble server via TCP and exchanges text messages using a
|
||||
//! simplified protobuf-style framing protocol. Voice channels are ignored;
|
||||
//! only `TextMessage` packets (type 11) are processed.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, watch, Mutex};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const MAX_MESSAGE_LEN: usize = 5000;
|
||||
const DEFAULT_PORT: u16 = 64738;
|
||||
|
||||
// Mumble packet types (protobuf message IDs)
|
||||
const MSG_TYPE_VERSION: u16 = 0;
|
||||
const MSG_TYPE_AUTHENTICATE: u16 = 2;
|
||||
const MSG_TYPE_PING: u16 = 3;
|
||||
const MSG_TYPE_TEXT_MESSAGE: u16 = 11;
|
||||
|
||||
/// Mumble text-chat channel adapter.
|
||||
///
|
||||
/// Connects to a Mumble server using TCP and handles text messages only
|
||||
/// (no voice). The protocol uses a 6-byte header: 2-byte big-endian message
|
||||
/// type followed by 4-byte big-endian payload length.
|
||||
pub struct MumbleAdapter {
|
||||
/// Mumble server hostname or IP.
|
||||
host: String,
|
||||
/// TCP port (default: 64738).
|
||||
port: u16,
|
||||
/// SECURITY: Server password is zeroized on drop.
|
||||
password: Zeroizing<String>,
|
||||
/// Username to authenticate with.
|
||||
username: String,
|
||||
/// Mumble channel to join (by name).
|
||||
channel_name: String,
|
||||
/// Shared TCP stream for sending (wrapped in Mutex for exclusive write access).
|
||||
stream: Arc<Mutex<Option<tokio::net::tcp::OwnedWriteHalf>>>,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl MumbleAdapter {
|
||||
/// Create a new Mumble text-chat adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `host` - Hostname or IP of the Mumble server.
|
||||
/// * `port` - TCP port (0 = use default 64738).
|
||||
/// * `password` - Server password (empty string if none).
|
||||
/// * `username` - Username for authentication.
|
||||
/// * `channel_name` - Mumble channel to join.
|
||||
pub fn new(
|
||||
host: String,
|
||||
port: u16,
|
||||
password: String,
|
||||
username: String,
|
||||
channel_name: String,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let port = if port == 0 { DEFAULT_PORT } else { port };
|
||||
Self {
|
||||
host,
|
||||
port,
|
||||
password: Zeroizing::new(password),
|
||||
username,
|
||||
channel_name,
|
||||
stream: Arc::new(Mutex::new(None)),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Encode a Mumble packet: 2-byte type (BE) + 4-byte length (BE) + payload.
|
||||
fn encode_packet(msg_type: u16, payload: &[u8]) -> Vec<u8> {
|
||||
let mut buf = Vec::with_capacity(6 + payload.len());
|
||||
buf.extend_from_slice(&msg_type.to_be_bytes());
|
||||
buf.extend_from_slice(&(payload.len() as u32).to_be_bytes());
|
||||
buf.extend_from_slice(payload);
|
||||
buf
|
||||
}
|
||||
|
||||
/// Build a minimal Version packet (type 0).
|
||||
///
|
||||
/// Simplified encoding: version fields as varint-like protobuf.
|
||||
/// Field 1 (version): 0x00010500 (1.5.0)
|
||||
/// Field 2 (release): "OpenFang"
|
||||
fn build_version_packet() -> Vec<u8> {
|
||||
let mut payload = Vec::new();
|
||||
// Field 1: fixed32 version = 0x00010500 (tag = 0x0D for wire type 5)
|
||||
payload.push(0x0D);
|
||||
payload.extend_from_slice(&0x0001_0500u32.to_le_bytes());
|
||||
// Field 2: string release (tag = 0x12)
|
||||
let release = b"OpenFang";
|
||||
payload.push(0x12);
|
||||
payload.push(release.len() as u8);
|
||||
payload.extend_from_slice(release);
|
||||
// Field 3: string os (tag = 0x1A)
|
||||
let os = std::env::consts::OS.as_bytes();
|
||||
payload.push(0x1A);
|
||||
payload.push(os.len() as u8);
|
||||
payload.extend_from_slice(os);
|
||||
payload
|
||||
}
|
||||
|
||||
/// Build an Authenticate packet (type 2).
|
||||
///
|
||||
/// Field 1 (username): string
|
||||
/// Field 2 (password): string
|
||||
fn build_authenticate_packet(username: &str, password: &str) -> Vec<u8> {
|
||||
let mut payload = Vec::new();
|
||||
// Field 1: string username (tag = 0x0A)
|
||||
let uname = username.as_bytes();
|
||||
payload.push(0x0A);
|
||||
Self::encode_varint(uname.len() as u64, &mut payload);
|
||||
payload.extend_from_slice(uname);
|
||||
// Field 2: string password (tag = 0x12)
|
||||
if !password.is_empty() {
|
||||
let pass = password.as_bytes();
|
||||
payload.push(0x12);
|
||||
Self::encode_varint(pass.len() as u64, &mut payload);
|
||||
payload.extend_from_slice(pass);
|
||||
}
|
||||
payload
|
||||
}
|
||||
|
||||
/// Build a TextMessage packet (type 11).
|
||||
///
|
||||
/// Field 1 (actor): uint32 (omitted — server assigns)
|
||||
/// Field 3 (channel_id): repeated uint32
|
||||
/// Field 5 (message): string
|
||||
fn build_text_message_packet(channel_id: u32, message: &str) -> Vec<u8> {
|
||||
let mut payload = Vec::new();
|
||||
// Field 3: uint32 channel_id (tag = 0x18, wire type 0 = varint)
|
||||
payload.push(0x18);
|
||||
Self::encode_varint(channel_id as u64, &mut payload);
|
||||
// Field 5: string message (tag = 0x2A, wire type 2 = length-delimited)
|
||||
let msg = message.as_bytes();
|
||||
payload.push(0x2A);
|
||||
Self::encode_varint(msg.len() as u64, &mut payload);
|
||||
payload.extend_from_slice(msg);
|
||||
payload
|
||||
}
|
||||
|
||||
/// Build a Ping packet (type 3). Minimal — just a timestamp field.
|
||||
fn build_ping_packet() -> Vec<u8> {
|
||||
let mut payload = Vec::new();
|
||||
// Field 1: uint64 timestamp (tag = 0x08)
|
||||
let ts = Utc::now().timestamp() as u64;
|
||||
payload.push(0x08);
|
||||
Self::encode_varint(ts, &mut payload);
|
||||
payload
|
||||
}
|
||||
|
||||
/// Encode a varint (protobuf base-128 encoding).
|
||||
fn encode_varint(mut value: u64, buf: &mut Vec<u8>) {
|
||||
loop {
|
||||
let byte = (value & 0x7F) as u8;
|
||||
value >>= 7;
|
||||
if value == 0 {
|
||||
buf.push(byte);
|
||||
break;
|
||||
} else {
|
||||
buf.push(byte | 0x80);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode a varint from bytes. Returns (value, bytes_consumed).
|
||||
fn decode_varint(data: &[u8]) -> (u64, usize) {
|
||||
let mut value: u64 = 0;
|
||||
let mut shift = 0;
|
||||
for (i, &byte) in data.iter().enumerate() {
|
||||
value |= ((byte & 0x7F) as u64) << shift;
|
||||
if byte & 0x80 == 0 {
|
||||
return (value, i + 1);
|
||||
}
|
||||
shift += 7;
|
||||
if shift >= 64 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
(value, data.len())
|
||||
}
|
||||
|
||||
/// Parse a TextMessage protobuf payload.
|
||||
/// Returns (actor, channel_ids, tree_ids, session_ids, message).
|
||||
fn parse_text_message(payload: &[u8]) -> (u32, Vec<u32>, Vec<u32>, Vec<u32>, String) {
|
||||
let mut actor: u32 = 0;
|
||||
let mut channel_ids = Vec::new();
|
||||
let mut tree_ids = Vec::new();
|
||||
let mut session_ids = Vec::new();
|
||||
let mut message = String::new();
|
||||
|
||||
let mut pos = 0;
|
||||
while pos < payload.len() {
|
||||
let tag_byte = payload[pos];
|
||||
let field_number = tag_byte >> 3;
|
||||
let wire_type = tag_byte & 0x07;
|
||||
pos += 1;
|
||||
|
||||
match (field_number, wire_type) {
|
||||
// Field 1: actor (uint32, varint)
|
||||
(1, 0) => {
|
||||
let (val, consumed) = Self::decode_varint(&payload[pos..]);
|
||||
actor = val as u32;
|
||||
pos += consumed;
|
||||
}
|
||||
// Field 2: session (repeated uint32, varint)
|
||||
(2, 0) => {
|
||||
let (val, consumed) = Self::decode_varint(&payload[pos..]);
|
||||
session_ids.push(val as u32);
|
||||
pos += consumed;
|
||||
}
|
||||
// Field 3: channel_id (repeated uint32, varint)
|
||||
(3, 0) => {
|
||||
let (val, consumed) = Self::decode_varint(&payload[pos..]);
|
||||
channel_ids.push(val as u32);
|
||||
pos += consumed;
|
||||
}
|
||||
// Field 4: tree_id (repeated uint32, varint)
|
||||
(4, 0) => {
|
||||
let (val, consumed) = Self::decode_varint(&payload[pos..]);
|
||||
tree_ids.push(val as u32);
|
||||
pos += consumed;
|
||||
}
|
||||
// Field 5: message (string, length-delimited)
|
||||
(5, 2) => {
|
||||
let (len, consumed) = Self::decode_varint(&payload[pos..]);
|
||||
pos += consumed;
|
||||
let end = pos + len as usize;
|
||||
if end <= payload.len() {
|
||||
message = String::from_utf8_lossy(&payload[pos..end]).to_string();
|
||||
}
|
||||
pos = end;
|
||||
}
|
||||
// Unknown — skip
|
||||
(_, 0) => {
|
||||
let (_, consumed) = Self::decode_varint(&payload[pos..]);
|
||||
pos += consumed;
|
||||
}
|
||||
(_, 2) => {
|
||||
let (len, consumed) = Self::decode_varint(&payload[pos..]);
|
||||
pos += consumed + len as usize;
|
||||
}
|
||||
(_, 5) => {
|
||||
pos += 4; // fixed32
|
||||
}
|
||||
(_, 1) => {
|
||||
pos += 8; // fixed64
|
||||
}
|
||||
_ => {
|
||||
break; // Unrecoverable wire type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(actor, channel_ids, tree_ids, session_ids, message)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for MumbleAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"mumble"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("mumble".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let addr = format!("{}:{}", self.host, self.port);
|
||||
info!("Mumble adapter connecting to {addr}");
|
||||
|
||||
let tcp = TcpStream::connect(&addr).await?;
|
||||
let (mut reader, writer) = tcp.into_split();
|
||||
|
||||
// Store writer for send()
|
||||
{
|
||||
let mut lock = self.stream.lock().await;
|
||||
*lock = Some(writer);
|
||||
}
|
||||
|
||||
// Send Version + Authenticate
|
||||
{
|
||||
let mut lock = self.stream.lock().await;
|
||||
if let Some(ref mut w) = *lock {
|
||||
let version_pkt =
|
||||
Self::encode_packet(MSG_TYPE_VERSION, &Self::build_version_packet());
|
||||
w.write_all(&version_pkt).await?;
|
||||
|
||||
let auth_pkt = Self::encode_packet(
|
||||
MSG_TYPE_AUTHENTICATE,
|
||||
&Self::build_authenticate_packet(&self.username, &self.password),
|
||||
);
|
||||
w.write_all(&auth_pkt).await?;
|
||||
w.flush().await?;
|
||||
}
|
||||
}
|
||||
|
||||
info!("Mumble adapter authenticated as {}", self.username);
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let channel_name = self.channel_name.clone();
|
||||
let own_username = self.username.clone();
|
||||
let stream_handle = Arc::clone(&self.stream);
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut header_buf = [0u8; 6];
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
let mut ping_interval = tokio::time::interval(Duration::from_secs(20));
|
||||
ping_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
if *shutdown_rx.borrow() {
|
||||
info!("Mumble adapter shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ = ping_interval.tick() => {
|
||||
// Send keepalive ping
|
||||
let mut lock = stream_handle.lock().await;
|
||||
if let Some(ref mut w) = *lock {
|
||||
let pkt = Self::encode_packet(MSG_TYPE_PING, &Self::build_ping_packet());
|
||||
if let Err(e) = w.write_all(&pkt).await {
|
||||
warn!("Mumble: ping write error: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
result = reader.read_exact(&mut header_buf) => {
|
||||
match result {
|
||||
Ok(_) => {
|
||||
backoff = Duration::from_secs(1);
|
||||
let msg_type = u16::from_be_bytes([header_buf[0], header_buf[1]]);
|
||||
let msg_len = u32::from_be_bytes([
|
||||
header_buf[2], header_buf[3],
|
||||
header_buf[4], header_buf[5],
|
||||
]) as usize;
|
||||
|
||||
// Sanity check — reject packets larger than 1 MB
|
||||
if msg_len > 1_048_576 {
|
||||
warn!("Mumble: oversized packet ({msg_len} bytes), skipping");
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut payload = vec![0u8; msg_len];
|
||||
if let Err(e) = reader.read_exact(&mut payload).await {
|
||||
warn!("Mumble: payload read error: {e}");
|
||||
break;
|
||||
}
|
||||
|
||||
if msg_type == MSG_TYPE_TEXT_MESSAGE {
|
||||
let (actor, _ch_ids, _tree_ids, _session_ids, message) =
|
||||
Self::parse_text_message(&payload);
|
||||
|
||||
if message.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Strip basic HTML tags that Mumble wraps text in
|
||||
let clean_msg = message
|
||||
.replace("<br>", "\n")
|
||||
.replace("<br/>", "\n")
|
||||
.replace("<br />", "\n");
|
||||
// Rough tag strip
|
||||
let clean_msg = {
|
||||
let mut out = String::with_capacity(clean_msg.len());
|
||||
let mut in_tag = false;
|
||||
for ch in clean_msg.chars() {
|
||||
if ch == '<' { in_tag = true; continue; }
|
||||
if ch == '>' { in_tag = false; continue; }
|
||||
if !in_tag { out.push(ch); }
|
||||
}
|
||||
out
|
||||
};
|
||||
|
||||
if clean_msg.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = if clean_msg.starts_with('/') {
|
||||
let parts: Vec<&str> = clean_msg.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(clean_msg)
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("mumble".to_string()),
|
||||
platform_message_id: format!(
|
||||
"mumble-{}-{}",
|
||||
actor,
|
||||
Utc::now().timestamp_millis()
|
||||
),
|
||||
sender: ChannelUser {
|
||||
platform_id: format!("session-{actor}"),
|
||||
display_name: format!("user-{actor}"),
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true,
|
||||
thread_id: None,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"channel".to_string(),
|
||||
serde_json::Value::String(channel_name.clone()),
|
||||
);
|
||||
m.insert(
|
||||
"actor".to_string(),
|
||||
serde_json::Value::Number(actor.into()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
// Other packet types (ServerSync, ChannelState, etc.) silently ignored
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Mumble: read error: {e}, backing off {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
info!("Mumble polling loop stopped");
|
||||
let _ = own_username;
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
let chunks = split_message(&text, MAX_MESSAGE_LEN);
|
||||
|
||||
let mut lock = self.stream.lock().await;
|
||||
let writer = lock
|
||||
.as_mut()
|
||||
.ok_or("Mumble: not connected — call start() first")?;
|
||||
|
||||
for chunk in chunks {
|
||||
// Send to channel 0 (root). In production the channel_id would be
|
||||
// resolved from self.channel_name via a ChannelState mapping.
|
||||
let payload = Self::build_text_message_packet(0, chunk);
|
||||
let pkt = Self::encode_packet(MSG_TYPE_TEXT_MESSAGE, &payload);
|
||||
writer.write_all(&pkt).await?;
|
||||
}
|
||||
writer.flush().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Mumble has no typing indicator in its protocol.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
// Drop the writer to close the TCP connection
|
||||
let mut lock = self.stream.lock().await;
|
||||
*lock = None;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mumble_adapter_creation() {
|
||||
let adapter = MumbleAdapter::new(
|
||||
"mumble.example.com".to_string(),
|
||||
0,
|
||||
"secret".to_string(),
|
||||
"OpenFangBot".to_string(),
|
||||
"General".to_string(),
|
||||
);
|
||||
assert_eq!(adapter.name(), "mumble");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("mumble".to_string())
|
||||
);
|
||||
assert_eq!(adapter.port, DEFAULT_PORT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mumble_custom_port() {
|
||||
let adapter = MumbleAdapter::new(
|
||||
"localhost".to_string(),
|
||||
12345,
|
||||
"".to_string(),
|
||||
"bot".to_string(),
|
||||
"Lobby".to_string(),
|
||||
);
|
||||
assert_eq!(adapter.port, 12345);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mumble_packet_encoding() {
|
||||
let packet = MumbleAdapter::encode_packet(11, &[0xAA, 0xBB]);
|
||||
assert_eq!(packet.len(), 8); // 2 type + 4 len + 2 payload
|
||||
assert_eq!(packet[0..2], [0, 11]); // type = 11 (TextMessage)
|
||||
assert_eq!(packet[2..6], [0, 0, 0, 2]); // len = 2
|
||||
assert_eq!(packet[6..8], [0xAA, 0xBB]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mumble_varint_encode_decode() {
|
||||
let mut buf = Vec::new();
|
||||
MumbleAdapter::encode_varint(300, &mut buf);
|
||||
let (value, consumed) = MumbleAdapter::decode_varint(&buf);
|
||||
assert_eq!(value, 300);
|
||||
assert_eq!(consumed, buf.len());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mumble_text_message_roundtrip() {
|
||||
let payload = MumbleAdapter::build_text_message_packet(42, "Hello Mumble!");
|
||||
let (actor, ch_ids, _tree_ids, _session_ids, message) =
|
||||
MumbleAdapter::parse_text_message(&payload);
|
||||
// actor is not set (field 1 omitted) — build only sets channel + message
|
||||
assert_eq!(actor, 0);
|
||||
assert_eq!(ch_ids, vec![42]);
|
||||
assert_eq!(message, "Hello Mumble!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mumble_version_packet() {
|
||||
let payload = MumbleAdapter::build_version_packet();
|
||||
assert!(!payload.is_empty());
|
||||
// First byte should be field 1 tag
|
||||
assert_eq!(payload[0], 0x0D);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mumble_authenticate_packet() {
|
||||
let payload = MumbleAdapter::build_authenticate_packet("bot", "pass");
|
||||
assert!(!payload.is_empty());
|
||||
assert_eq!(payload[0], 0x0A); // field 1 tag
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mumble_authenticate_packet_no_password() {
|
||||
let payload = MumbleAdapter::build_authenticate_packet("bot", "");
|
||||
// No field 2 tag (0x12) should be present
|
||||
assert!(!payload.contains(&0x12));
|
||||
}
|
||||
}
|
||||
509
crates/openfang-channels/src/nextcloud.rs
Normal file
509
crates/openfang-channels/src/nextcloud.rs
Normal file
@@ -0,0 +1,509 @@
|
||||
//! Nextcloud Talk channel adapter.
|
||||
//!
|
||||
//! Uses the Nextcloud Talk REST API (OCS v2) for sending and receiving messages.
|
||||
//! Polls the chat endpoint with `lookIntoFuture=1` for near-real-time message
|
||||
//! delivery. Authentication is performed via Bearer token with OCS-specific
|
||||
//! headers.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Maximum message length for Nextcloud Talk messages.
|
||||
const MAX_MESSAGE_LEN: usize = 32000;
|
||||
|
||||
/// Polling interval in seconds for the chat endpoint.
|
||||
const POLL_INTERVAL_SECS: u64 = 3;
|
||||
|
||||
/// Nextcloud Talk channel adapter using OCS REST API with polling.
|
||||
///
|
||||
/// Polls the Nextcloud Talk chat endpoint for new messages and sends replies
|
||||
/// via the same REST API. Supports multiple room tokens for simultaneous
|
||||
/// monitoring.
|
||||
pub struct NextcloudAdapter {
|
||||
/// Nextcloud server URL (e.g., `"https://cloud.example.com"`).
|
||||
server_url: String,
|
||||
/// SECURITY: Authentication token is zeroized on drop.
|
||||
token: Zeroizing<String>,
|
||||
/// Room tokens to poll (empty = discover from server).
|
||||
allowed_rooms: Vec<String>,
|
||||
/// HTTP client for API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Last known message ID per room for incremental polling.
|
||||
last_known_ids: Arc<RwLock<HashMap<String, i64>>>,
|
||||
}
|
||||
|
||||
impl NextcloudAdapter {
|
||||
/// Create a new Nextcloud Talk adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `server_url` - Base URL of the Nextcloud instance.
|
||||
/// * `token` - Authentication token (app password or OAuth2 token).
|
||||
/// * `allowed_rooms` - Room tokens to listen on (empty = discover joined rooms).
|
||||
pub fn new(server_url: String, token: String, allowed_rooms: Vec<String>) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let server_url = server_url.trim_end_matches('/').to_string();
|
||||
Self {
|
||||
server_url,
|
||||
token: Zeroizing::new(token),
|
||||
allowed_rooms,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
last_known_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add OCS and authorization headers to a request builder.
|
||||
fn ocs_headers(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
builder
|
||||
.header("Authorization", format!("Bearer {}", self.token.as_str()))
|
||||
.header("OCS-APIRequest", "true")
|
||||
.header("Accept", "application/json")
|
||||
}
|
||||
|
||||
/// Validate credentials by fetching the user's own status.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/ocs/v2.php/cloud/user?format=json", self.server_url);
|
||||
let resp = self.ocs_headers(self.client.get(&url)).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Nextcloud authentication failed".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let user_id = body["ocs"]["data"]["id"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
/// Fetch the list of joined rooms from the Nextcloud Talk API.
|
||||
#[allow(dead_code)]
|
||||
async fn fetch_rooms(&self) -> Result<Vec<String>, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/ocs/v2.php/apps/spreed/api/v4/room?format=json",
|
||||
self.server_url
|
||||
);
|
||||
let resp = self.ocs_headers(self.client.get(&url)).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Nextcloud: failed to fetch rooms".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let rooms = body["ocs"]["data"]
|
||||
.as_array()
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|r| r["token"].as_str().map(String::from))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Ok(rooms)
|
||||
}
|
||||
|
||||
/// Send a text message to a Nextcloud Talk room.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
room_token: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/ocs/v2.php/apps/spreed/api/v1/chat/{}",
|
||||
self.server_url, room_token
|
||||
);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"message": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.ocs_headers(self.client.post(&url))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Nextcloud Talk API error {status}: {body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a room token is in the allowed list.
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_room(&self, room_token: &str) -> bool {
|
||||
self.allowed_rooms.is_empty() || self.allowed_rooms.iter().any(|r| r == room_token)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for NextcloudAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"nextcloud"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("nextcloud".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let username = self.validate().await?;
|
||||
info!("Nextcloud Talk adapter authenticated as {username}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let server_url = self.server_url.clone();
|
||||
let token = self.token.clone();
|
||||
let own_user = username;
|
||||
let allowed_rooms = self.allowed_rooms.clone();
|
||||
let client = self.client.clone();
|
||||
let last_known_ids = Arc::clone(&self.last_known_ids);
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Determine rooms to poll
|
||||
let rooms_to_poll = if allowed_rooms.is_empty() {
|
||||
let url = format!(
|
||||
"{}/ocs/v2.php/apps/spreed/api/v4/room?format=json",
|
||||
server_url
|
||||
);
|
||||
match client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", token.as_str()))
|
||||
.header("OCS-APIRequest", "true")
|
||||
.header("Accept", "application/json")
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body["ocs"]["data"]
|
||||
.as_array()
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|r| r["token"].as_str().map(String::from))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Nextcloud: failed to list rooms: {e}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
allowed_rooms
|
||||
};
|
||||
|
||||
if rooms_to_poll.is_empty() {
|
||||
warn!("Nextcloud Talk: no rooms to poll");
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Nextcloud Talk: polling {} room(s)", rooms_to_poll.len());
|
||||
|
||||
// Initialize last known IDs to 0 (server returns newest first,
|
||||
// we use lookIntoFuture to get only new messages)
|
||||
{
|
||||
let mut ids = last_known_ids.write().await;
|
||||
for room in &rooms_to_poll {
|
||||
ids.entry(room.clone()).or_insert(0);
|
||||
}
|
||||
}
|
||||
|
||||
let poll_interval = Duration::from_secs(POLL_INTERVAL_SECS);
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Nextcloud Talk adapter shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(poll_interval) => {}
|
||||
}
|
||||
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
for room_token in &rooms_to_poll {
|
||||
let last_id = {
|
||||
let ids = last_known_ids.read().await;
|
||||
ids.get(room_token).copied().unwrap_or(0)
|
||||
};
|
||||
|
||||
// Use lookIntoFuture=1 and lastKnownMessageId for incremental polling
|
||||
let url = format!(
|
||||
"{}/ocs/v2.php/apps/spreed/api/v4/room/{}/chat?format=json&lookIntoFuture=1&limit=100&lastKnownMessageId={}",
|
||||
server_url, room_token, last_id
|
||||
);
|
||||
|
||||
let resp = match client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", token.as_str()))
|
||||
.header("OCS-APIRequest", "true")
|
||||
.header("Accept", "application/json")
|
||||
.timeout(Duration::from_secs(30))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Nextcloud: poll error for room {room_token}: {e}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// 304 Not Modified = no new messages
|
||||
if resp.status() == reqwest::StatusCode::NOT_MODIFIED {
|
||||
backoff = Duration::from_secs(1);
|
||||
continue;
|
||||
}
|
||||
|
||||
if !resp.status().is_success() {
|
||||
warn!(
|
||||
"Nextcloud: chat poll returned {} for room {room_token}",
|
||||
resp.status()
|
||||
);
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
let body: serde_json::Value = match resp.json().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("Nextcloud: failed to parse chat response: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let messages = match body["ocs"]["data"].as_array() {
|
||||
Some(arr) => arr,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let mut newest_id = last_id;
|
||||
|
||||
for msg in messages {
|
||||
// Only handle user messages (not system/command messages)
|
||||
let msg_type = msg["messageType"].as_str().unwrap_or("comment");
|
||||
if msg_type == "system" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let actor_id = msg["actorId"].as_str().unwrap_or("");
|
||||
// Skip own messages
|
||||
if actor_id == own_user {
|
||||
continue;
|
||||
}
|
||||
|
||||
let text = msg["message"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg_id = msg["id"].as_i64().unwrap_or(0);
|
||||
let actor_display = msg["actorDisplayName"].as_str().unwrap_or("unknown");
|
||||
let reference_id = msg["referenceId"].as_str().map(String::from);
|
||||
|
||||
// Track newest message ID
|
||||
if msg_id > newest_id {
|
||||
newest_id = msg_id;
|
||||
}
|
||||
|
||||
let msg_content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("nextcloud".to_string()),
|
||||
platform_message_id: msg_id.to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: room_token.clone(),
|
||||
display_name: actor_display.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true,
|
||||
thread_id: reference_id,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"actor_id".to_string(),
|
||||
serde_json::Value::String(actor_id.to_string()),
|
||||
);
|
||||
m.insert(
|
||||
"room_token".to_string(),
|
||||
serde_json::Value::String(room_token.clone()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Update last known message ID for this room
|
||||
if newest_id > last_id {
|
||||
last_known_ids
|
||||
.write()
|
||||
.await
|
||||
.insert(room_token.clone(), newest_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Nextcloud Talk polling loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Nextcloud Talk does not have a public typing indicator REST endpoint
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_nextcloud_adapter_creation() {
|
||||
let adapter = NextcloudAdapter::new(
|
||||
"https://cloud.example.com".to_string(),
|
||||
"test-token".to_string(),
|
||||
vec!["room1".to_string()],
|
||||
);
|
||||
assert_eq!(adapter.name(), "nextcloud");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("nextcloud".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nextcloud_server_url_normalization() {
|
||||
let adapter = NextcloudAdapter::new(
|
||||
"https://cloud.example.com/".to_string(),
|
||||
"tok".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.server_url, "https://cloud.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nextcloud_allowed_rooms() {
|
||||
let adapter = NextcloudAdapter::new(
|
||||
"https://cloud.example.com".to_string(),
|
||||
"tok".to_string(),
|
||||
vec!["room1".to_string(), "room2".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_room("room1"));
|
||||
assert!(adapter.is_allowed_room("room2"));
|
||||
assert!(!adapter.is_allowed_room("room3"));
|
||||
|
||||
let open = NextcloudAdapter::new(
|
||||
"https://cloud.example.com".to_string(),
|
||||
"tok".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert!(open.is_allowed_room("any-room"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nextcloud_ocs_headers() {
|
||||
let adapter = NextcloudAdapter::new(
|
||||
"https://cloud.example.com".to_string(),
|
||||
"my-token".to_string(),
|
||||
vec![],
|
||||
);
|
||||
let builder = adapter.client.get("https://example.com");
|
||||
let builder = adapter.ocs_headers(builder);
|
||||
let request = builder.build().unwrap();
|
||||
assert_eq!(request.headers().get("OCS-APIRequest").unwrap(), "true");
|
||||
assert_eq!(
|
||||
request.headers().get("Authorization").unwrap(),
|
||||
"Bearer my-token"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nextcloud_token_zeroized() {
|
||||
let adapter = NextcloudAdapter::new(
|
||||
"https://cloud.example.com".to_string(),
|
||||
"secret-token-value".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.token.as_str(), "secret-token-value");
|
||||
}
|
||||
}
|
||||
485
crates/openfang-channels/src/nostr.rs
Normal file
485
crates/openfang-channels/src/nostr.rs
Normal file
@@ -0,0 +1,485 @@
|
||||
//! Nostr NIP-01 channel adapter.
|
||||
//!
|
||||
//! Connects to Nostr relay(s) via WebSocket and subscribes to direct messages
|
||||
//! (kind 4, NIP-04) and public notes. Sends messages by creating signed events
|
||||
//! and publishing them to connected relays. Supports multiple relay connections
|
||||
//! with automatic reconnection.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Maximum message length for Nostr events.
|
||||
const MAX_MESSAGE_LEN: usize = 4096;
|
||||
|
||||
/// Nostr NIP-01 relay channel adapter using WebSocket.
|
||||
///
|
||||
/// Connects to one or more Nostr relays via WebSocket, subscribes to events
|
||||
/// matching the configured filters (kind 4 DMs by default), and sends messages
|
||||
/// by publishing signed events. The private key is used for signing events
|
||||
/// and deriving the public key for subscriptions.
|
||||
pub struct NostrAdapter {
|
||||
/// SECURITY: Private key (hex-encoded nsec or raw hex) is zeroized on drop.
|
||||
private_key: Zeroizing<String>,
|
||||
/// List of relay WebSocket URLs to connect to.
|
||||
relays: Vec<String>,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Set of already-seen event IDs to avoid duplicates across relays.
|
||||
seen_events: Arc<RwLock<std::collections::HashSet<String>>>,
|
||||
}
|
||||
|
||||
impl NostrAdapter {
|
||||
/// Create a new Nostr adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `private_key` - Hex-encoded private key for signing events.
|
||||
/// * `relays` - WebSocket URLs of Nostr relays to connect to.
|
||||
pub fn new(private_key: String, relays: Vec<String>) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
private_key: Zeroizing::new(private_key),
|
||||
relays,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
seen_events: Arc::new(RwLock::new(std::collections::HashSet::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Derive a public key hex string from the private key.
|
||||
/// In a real implementation this would use secp256k1 scalar multiplication.
|
||||
/// For now, returns a placeholder derived from the private key hash.
|
||||
fn derive_pubkey(&self) -> String {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
let mut hasher = DefaultHasher::new();
|
||||
self.private_key.as_str().hash(&mut hasher);
|
||||
format!("{:064x}", hasher.finish())
|
||||
}
|
||||
|
||||
/// Build a NIP-01 REQ message for subscribing to DMs (kind 4).
|
||||
#[allow(dead_code)]
|
||||
fn build_subscription(&self, pubkey: &str) -> String {
|
||||
let filter = serde_json::json!([
|
||||
"REQ",
|
||||
"openfang-sub",
|
||||
{
|
||||
"kinds": [4],
|
||||
"#p": [pubkey],
|
||||
"limit": 0
|
||||
}
|
||||
]);
|
||||
serde_json::to_string(&filter).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Build a NIP-01 EVENT message for sending a DM (kind 4).
|
||||
fn build_event(&self, recipient_pubkey: &str, content: &str) -> String {
|
||||
let pubkey = self.derive_pubkey();
|
||||
let created_at = Utc::now().timestamp();
|
||||
|
||||
// In a real implementation, this would:
|
||||
// 1. Serialize the event for signing
|
||||
// 2. Compute SHA256 of the serialized event
|
||||
// 3. Sign with secp256k1 schnorr
|
||||
// 4. Encrypt content with NIP-04 (shared secret ECDH + AES-256-CBC)
|
||||
let event_id = format!("{:064x}", created_at);
|
||||
let sig = format!("{:0128x}", 0u8);
|
||||
|
||||
let event = serde_json::json!([
|
||||
"EVENT",
|
||||
{
|
||||
"id": event_id,
|
||||
"pubkey": pubkey,
|
||||
"created_at": created_at,
|
||||
"kind": 4,
|
||||
"tags": [["p", recipient_pubkey]],
|
||||
"content": content,
|
||||
"sig": sig
|
||||
}
|
||||
]);
|
||||
|
||||
serde_json::to_string(&event).unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Send a text message to a recipient via all connected relays.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
recipient_pubkey: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let event_msg = self.build_event(recipient_pubkey, chunk);
|
||||
|
||||
// Send to the first available relay
|
||||
for relay_url in &self.relays {
|
||||
match tokio_tungstenite::connect_async(relay_url.as_str()).await {
|
||||
Ok((mut ws, _)) => {
|
||||
use futures::SinkExt;
|
||||
let send_result = ws
|
||||
.send(tokio_tungstenite::tungstenite::Message::Text(
|
||||
event_msg.clone(),
|
||||
))
|
||||
.await;
|
||||
|
||||
if send_result.is_ok() {
|
||||
break; // Successfully sent to at least one relay
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Nostr: failed to connect to relay {relay_url}: {e}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for NostrAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"nostr"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("nostr".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let pubkey = self.derive_pubkey();
|
||||
info!("Nostr adapter starting (pubkey: {}...)", &pubkey[..16]);
|
||||
|
||||
if self.relays.is_empty() {
|
||||
return Err("Nostr: no relay URLs configured".into());
|
||||
}
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let relays = self.relays.clone();
|
||||
let own_pubkey = pubkey.clone();
|
||||
let seen_events = Arc::clone(&self.seen_events);
|
||||
let private_key = self.private_key.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
// Spawn a task per relay for parallel connections
|
||||
for relay_url in relays {
|
||||
let tx = tx.clone();
|
||||
let own_pubkey = own_pubkey.clone();
|
||||
let seen_events = Arc::clone(&seen_events);
|
||||
let _private_key = private_key.clone();
|
||||
let mut relay_shutdown_rx = shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
if *relay_shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
let ws_stream = match tokio_tungstenite::connect_async(relay_url.as_str()).await
|
||||
{
|
||||
Ok((stream, _resp)) => stream,
|
||||
Err(e) => {
|
||||
warn!("Nostr: relay {relay_url} connection failed: {e}, retrying in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Nostr: connected to relay {relay_url}");
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
use futures::{SinkExt, StreamExt};
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
// Send REQ subscription
|
||||
// Build the subscription filter for DMs addressed to us
|
||||
let sub_msg = {
|
||||
let filter = serde_json::json!([
|
||||
"REQ",
|
||||
"openfang-sub",
|
||||
{
|
||||
"kinds": [4],
|
||||
"#p": [&own_pubkey],
|
||||
"limit": 0
|
||||
}
|
||||
]);
|
||||
serde_json::to_string(&filter).unwrap_or_default()
|
||||
};
|
||||
|
||||
if write
|
||||
.send(tokio_tungstenite::tungstenite::Message::Text(sub_msg))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
warn!("Nostr: failed to send REQ to {relay_url}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Read events
|
||||
let should_reconnect = loop {
|
||||
let msg = tokio::select! {
|
||||
_ = relay_shutdown_rx.changed() => {
|
||||
info!("Nostr: relay {relay_url} shutting down");
|
||||
// Send CLOSE
|
||||
let close_msg = serde_json::json!(["CLOSE", "openfang-sub"]);
|
||||
let _ = write.send(
|
||||
tokio_tungstenite::tungstenite::Message::Text(
|
||||
serde_json::to_string(&close_msg).unwrap_or_default()
|
||||
)
|
||||
).await;
|
||||
return;
|
||||
}
|
||||
msg = read.next() => msg,
|
||||
};
|
||||
|
||||
let msg = match msg {
|
||||
Some(Ok(m)) => m,
|
||||
Some(Err(e)) => {
|
||||
warn!("Nostr: relay {relay_url} read error: {e}");
|
||||
break true;
|
||||
}
|
||||
None => {
|
||||
info!("Nostr: relay {relay_url} stream ended");
|
||||
break true;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match msg {
|
||||
tokio_tungstenite::tungstenite::Message::Text(t) => t,
|
||||
tokio_tungstenite::tungstenite::Message::Close(_) => {
|
||||
break true;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
// Parse NIP-01 message: ["EVENT", "sub_id", {event}]
|
||||
let parsed: serde_json::Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let msg_type = parsed[0].as_str().unwrap_or("");
|
||||
if msg_type != "EVENT" {
|
||||
// Could be NOTICE, EOSE, OK, etc.
|
||||
continue;
|
||||
}
|
||||
|
||||
let event = &parsed[2];
|
||||
let event_id = event["id"].as_str().unwrap_or("").to_string();
|
||||
|
||||
// Dedup across relays
|
||||
{
|
||||
let mut seen = seen_events.write().await;
|
||||
if seen.contains(&event_id) {
|
||||
continue;
|
||||
}
|
||||
seen.insert(event_id.clone());
|
||||
// Cap the seen set size
|
||||
if seen.len() > 10000 {
|
||||
seen.clear();
|
||||
}
|
||||
}
|
||||
|
||||
let sender_pubkey = event["pubkey"].as_str().unwrap_or("").to_string();
|
||||
// Skip events from ourselves
|
||||
if sender_pubkey == own_pubkey {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = event["content"].as_str().unwrap_or("");
|
||||
if content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// In a real implementation, kind-4 content would be
|
||||
// NIP-04 encrypted and would need decryption here
|
||||
let msg_content = if content.starts_with('/') {
|
||||
let parts: Vec<&str> = content.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(content.to_string())
|
||||
};
|
||||
|
||||
let kind = event["kind"].as_i64().unwrap_or(0);
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("nostr".to_string()),
|
||||
platform_message_id: event_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: sender_pubkey.clone(),
|
||||
display_name: format!(
|
||||
"{}...",
|
||||
&sender_pubkey[..8.min(sender_pubkey.len())]
|
||||
),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: kind != 4, // DMs are 1:1, other kinds are public
|
||||
thread_id: None,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"pubkey".to_string(),
|
||||
serde_json::Value::String(sender_pubkey),
|
||||
);
|
||||
m.insert(
|
||||
"kind".to_string(),
|
||||
serde_json::Value::Number(serde_json::Number::from(kind)),
|
||||
);
|
||||
m.insert(
|
||||
"relay".to_string(),
|
||||
serde_json::Value::String(relay_url.clone()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !should_reconnect || *relay_shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
warn!("Nostr: reconnecting to {relay_url} in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
}
|
||||
|
||||
info!("Nostr: relay {relay_url} loop stopped");
|
||||
});
|
||||
}
|
||||
|
||||
// Wait for shutdown in the main task
|
||||
tokio::spawn(async move {
|
||||
let _ = shutdown_rx.changed().await;
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Nostr does not have a typing indicator protocol
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_nostr_adapter_creation() {
|
||||
let adapter = NostrAdapter::new(
|
||||
"deadbeef".repeat(8),
|
||||
vec!["wss://relay.damus.io".to_string()],
|
||||
);
|
||||
assert_eq!(adapter.name(), "nostr");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("nostr".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nostr_private_key_zeroized() {
|
||||
let key = "a".repeat(64);
|
||||
let adapter = NostrAdapter::new(key.clone(), vec!["wss://relay.example.com".to_string()]);
|
||||
assert_eq!(adapter.private_key.as_str(), key);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nostr_derive_pubkey() {
|
||||
let adapter = NostrAdapter::new("deadbeef".repeat(8), vec![]);
|
||||
let pubkey = adapter.derive_pubkey();
|
||||
assert_eq!(pubkey.len(), 64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nostr_build_subscription() {
|
||||
let adapter = NostrAdapter::new("abc123".to_string(), vec![]);
|
||||
let pubkey = adapter.derive_pubkey();
|
||||
let sub = adapter.build_subscription(&pubkey);
|
||||
assert!(sub.contains("REQ"));
|
||||
assert!(sub.contains("openfang-sub"));
|
||||
assert!(sub.contains(&pubkey));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nostr_build_event() {
|
||||
let adapter = NostrAdapter::new("abc123".to_string(), vec![]);
|
||||
let event = adapter.build_event("recipient_pubkey_hex", "Hello Nostr!");
|
||||
assert!(event.contains("EVENT"));
|
||||
assert!(event.contains("Hello Nostr!"));
|
||||
assert!(event.contains("recipient_pubkey_hex"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nostr_multiple_relays() {
|
||||
let adapter = NostrAdapter::new(
|
||||
"key".to_string(),
|
||||
vec![
|
||||
"wss://relay1.example.com".to_string(),
|
||||
"wss://relay2.example.com".to_string(),
|
||||
"wss://relay3.example.com".to_string(),
|
||||
],
|
||||
);
|
||||
assert_eq!(adapter.relays.len(), 3);
|
||||
}
|
||||
}
|
||||
438
crates/openfang-channels/src/ntfy.rs
Normal file
438
crates/openfang-channels/src/ntfy.rs
Normal file
@@ -0,0 +1,438 @@
|
||||
//! ntfy.sh channel adapter.
|
||||
//!
|
||||
//! Subscribes to a ntfy topic via Server-Sent Events (SSE) for receiving
|
||||
//! messages and publishes replies by POSTing to the same topic endpoint.
|
||||
//! Supports self-hosted ntfy instances and optional Bearer token auth.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const MAX_MESSAGE_LEN: usize = 4096;
|
||||
const DEFAULT_SERVER_URL: &str = "https://ntfy.sh";
|
||||
|
||||
/// ntfy.sh pub/sub channel adapter.
|
||||
///
|
||||
/// Subscribes to notifications via SSE and publishes replies as new
|
||||
/// notifications. Supports authentication for protected topics.
|
||||
pub struct NtfyAdapter {
|
||||
/// ntfy server URL (default: `"https://ntfy.sh"`).
|
||||
server_url: String,
|
||||
/// Topic name to subscribe and publish to.
|
||||
topic: String,
|
||||
/// SECURITY: Bearer token is zeroized on drop (empty = no auth).
|
||||
token: Zeroizing<String>,
|
||||
/// HTTP client.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl NtfyAdapter {
|
||||
/// Create a new ntfy adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `server_url` - ntfy server URL (empty = default `"https://ntfy.sh"`).
|
||||
/// * `topic` - Topic name to subscribe/publish to.
|
||||
/// * `token` - Bearer token for authentication (empty = no auth).
|
||||
pub fn new(server_url: String, topic: String, token: String) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let server_url = if server_url.is_empty() {
|
||||
DEFAULT_SERVER_URL.to_string()
|
||||
} else {
|
||||
server_url.trim_end_matches('/').to_string()
|
||||
};
|
||||
Self {
|
||||
server_url,
|
||||
topic,
|
||||
token: Zeroizing::new(token),
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an authenticated request builder.
|
||||
fn auth_request(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
if self.token.is_empty() {
|
||||
builder
|
||||
} else {
|
||||
builder.bearer_auth(self.token.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an SSE data line into a ntfy message.
|
||||
///
|
||||
/// ntfy SSE format:
|
||||
/// ```text
|
||||
/// event: message
|
||||
/// data: {"id":"abc","time":1234,"event":"message","topic":"test","message":"Hello"}
|
||||
/// ```
|
||||
fn parse_sse_data(data: &str) -> Option<(String, String, String, Option<String>)> {
|
||||
let val: serde_json::Value = serde_json::from_str(data).ok()?;
|
||||
|
||||
// Only process "message" events (skip "open", "keepalive", etc.)
|
||||
let event = val["event"].as_str().unwrap_or("");
|
||||
if event != "message" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let id = val["id"].as_str()?.to_string();
|
||||
let message = val["message"].as_str()?.to_string();
|
||||
let topic = val["topic"].as_str().unwrap_or("").to_string();
|
||||
|
||||
if message.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// ntfy messages can have a title (used as sender hint)
|
||||
let title = val["title"].as_str().map(String::from);
|
||||
|
||||
Some((id, message, topic, title))
|
||||
}
|
||||
|
||||
/// Publish a message to the topic.
|
||||
async fn publish(
|
||||
&self,
|
||||
text: &str,
|
||||
title: Option<&str>,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/{}", self.server_url, self.topic);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let mut builder = self.client.post(&url);
|
||||
builder = self.auth_request(builder);
|
||||
|
||||
// ntfy supports plain-text body publishing
|
||||
builder = builder.header("Content-Type", "text/plain");
|
||||
|
||||
if let Some(t) = title {
|
||||
builder = builder.header("Title", t);
|
||||
}
|
||||
|
||||
// Mark as UTF-8
|
||||
builder = builder.header("X-Message", chunk);
|
||||
let resp = builder.body(chunk.to_string()).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("ntfy publish error {status}: {err_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for NtfyAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"ntfy"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("ntfy".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
info!(
|
||||
"ntfy adapter subscribing to {}/{}",
|
||||
self.server_url, self.topic
|
||||
);
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let server_url = self.server_url.clone();
|
||||
let topic = self.topic.clone();
|
||||
let token = self.token.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let sse_client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(0)) // No timeout for SSE
|
||||
.build()
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
let url = format!("{}/{}/sse", server_url, topic);
|
||||
let mut builder = sse_client.get(&url);
|
||||
if !token.is_empty() {
|
||||
builder = builder.bearer_auth(token.as_str());
|
||||
}
|
||||
|
||||
let response = match builder.send().await {
|
||||
Ok(r) => {
|
||||
if !r.status().is_success() {
|
||||
warn!("ntfy: SSE returned HTTP {}", r.status());
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(120));
|
||||
continue;
|
||||
}
|
||||
backoff = Duration::from_secs(1);
|
||||
r
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("ntfy: SSE connection error: {e}, backing off {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(120));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
info!("ntfy: SSE stream connected for topic {topic}");
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
use futures::StreamExt;
|
||||
|
||||
let mut line_buffer = String::new();
|
||||
let mut current_data = String::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
if *shutdown_rx.borrow() {
|
||||
info!("ntfy adapter shutting down");
|
||||
return;
|
||||
}
|
||||
}
|
||||
chunk = stream.next() => {
|
||||
match chunk {
|
||||
Some(Ok(bytes)) => {
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
line_buffer.push_str(&text);
|
||||
|
||||
// SSE parsing: process complete lines
|
||||
while let Some(newline_pos) = line_buffer.find('\n') {
|
||||
let line = line_buffer[..newline_pos].trim_end_matches('\r').to_string();
|
||||
line_buffer = line_buffer[newline_pos + 1..].to_string();
|
||||
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
current_data = data.to_string();
|
||||
} else if line.is_empty() && !current_data.is_empty() {
|
||||
// Empty line = end of SSE event
|
||||
if let Some((id, message, _topic, title)) =
|
||||
Self::parse_sse_data(¤t_data)
|
||||
{
|
||||
let sender_name = title
|
||||
.as_deref()
|
||||
.unwrap_or("ntfy-user");
|
||||
|
||||
let content = if message.starts_with('/') {
|
||||
let parts: Vec<&str> =
|
||||
message.splitn(2, ' ').collect();
|
||||
let cmd =
|
||||
parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| {
|
||||
a.split_whitespace()
|
||||
.map(String::from)
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(message)
|
||||
};
|
||||
|
||||
let msg = ChannelMessage {
|
||||
channel: ChannelType::Custom(
|
||||
"ntfy".to_string(),
|
||||
),
|
||||
platform_message_id: id,
|
||||
sender: ChannelUser {
|
||||
platform_id: sender_name.to_string(),
|
||||
display_name: sender_name.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true,
|
||||
thread_id: None,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"topic".to_string(),
|
||||
serde_json::Value::String(
|
||||
topic.clone(),
|
||||
),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
current_data.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("ntfy: SSE read error: {e}");
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
info!("ntfy: SSE stream ended, reconnecting...");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Backoff before reconnect
|
||||
if !*shutdown_rx.borrow() {
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
}
|
||||
}
|
||||
|
||||
info!("ntfy SSE loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
self.publish(&text, Some("OpenFang")).await
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// ntfy has no typing indicator concept.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ntfy_adapter_creation() {
|
||||
let adapter = NtfyAdapter::new("".to_string(), "my-topic".to_string(), "".to_string());
|
||||
assert_eq!(adapter.name(), "ntfy");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("ntfy".to_string())
|
||||
);
|
||||
assert_eq!(adapter.server_url, DEFAULT_SERVER_URL);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntfy_custom_server_url() {
|
||||
let adapter = NtfyAdapter::new(
|
||||
"https://ntfy.internal.corp/".to_string(),
|
||||
"alerts".to_string(),
|
||||
"token-123".to_string(),
|
||||
);
|
||||
assert_eq!(adapter.server_url, "https://ntfy.internal.corp");
|
||||
assert_eq!(adapter.topic, "alerts");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntfy_auth_request_with_token() {
|
||||
let adapter = NtfyAdapter::new(
|
||||
"".to_string(),
|
||||
"test".to_string(),
|
||||
"my-bearer-token".to_string(),
|
||||
);
|
||||
let builder = adapter.client.get("https://ntfy.sh/test");
|
||||
let builder = adapter.auth_request(builder);
|
||||
let request = builder.build().unwrap();
|
||||
assert!(request.headers().contains_key("authorization"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntfy_auth_request_without_token() {
|
||||
let adapter = NtfyAdapter::new("".to_string(), "test".to_string(), "".to_string());
|
||||
let builder = adapter.client.get("https://ntfy.sh/test");
|
||||
let builder = adapter.auth_request(builder);
|
||||
let request = builder.build().unwrap();
|
||||
assert!(!request.headers().contains_key("authorization"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntfy_parse_sse_message_event() {
|
||||
let data = r#"{"id":"abc123","time":1700000000,"event":"message","topic":"test","message":"Hello from ntfy","title":"Alice"}"#;
|
||||
let result = NtfyAdapter::parse_sse_data(data);
|
||||
assert!(result.is_some());
|
||||
let (id, message, topic, title) = result.unwrap();
|
||||
assert_eq!(id, "abc123");
|
||||
assert_eq!(message, "Hello from ntfy");
|
||||
assert_eq!(topic, "test");
|
||||
assert_eq!(title.as_deref(), Some("Alice"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntfy_parse_sse_keepalive_event() {
|
||||
let data = r#"{"id":"ka1","time":1700000000,"event":"keepalive","topic":"test"}"#;
|
||||
assert!(NtfyAdapter::parse_sse_data(data).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntfy_parse_sse_open_event() {
|
||||
let data = r#"{"id":"o1","time":1700000000,"event":"open","topic":"test"}"#;
|
||||
assert!(NtfyAdapter::parse_sse_data(data).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntfy_parse_sse_empty_message() {
|
||||
let data = r#"{"id":"e1","time":1700000000,"event":"message","topic":"test","message":""}"#;
|
||||
assert!(NtfyAdapter::parse_sse_data(data).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntfy_parse_sse_no_title() {
|
||||
let data =
|
||||
r#"{"id":"nt1","time":1700000000,"event":"message","topic":"test","message":"Hi"}"#;
|
||||
let result = NtfyAdapter::parse_sse_data(data);
|
||||
assert!(result.is_some());
|
||||
let (_, _, _, title) = result.unwrap();
|
||||
assert!(title.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ntfy_parse_invalid_json() {
|
||||
assert!(NtfyAdapter::parse_sse_data("not json").is_none());
|
||||
}
|
||||
}
|
||||
486
crates/openfang-channels/src/pumble.rs
Normal file
486
crates/openfang-channels/src/pumble.rs
Normal file
@@ -0,0 +1,486 @@
|
||||
//! Pumble Bot channel adapter.
|
||||
//!
|
||||
//! Uses the Pumble Bot API with a local webhook HTTP server for receiving
|
||||
//! inbound event subscriptions and the REST API for sending messages.
|
||||
//! Authentication is performed via a Bot Bearer token. Inbound events arrive
|
||||
//! as JSON POST requests to the configured webhook port.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Pumble REST API base URL.
|
||||
const PUMBLE_API_BASE: &str = "https://api.pumble.com/v1";
|
||||
|
||||
/// Maximum message length for Pumble messages.
|
||||
const MAX_MESSAGE_LEN: usize = 4000;
|
||||
|
||||
/// Pumble Bot channel adapter using webhook for receiving and REST API for sending.
|
||||
///
|
||||
/// Listens for inbound events via a configurable HTTP webhook server and sends
|
||||
/// outbound messages via the Pumble REST API. Supports Pumble's event subscription
|
||||
/// model including URL verification challenges.
|
||||
pub struct PumbleAdapter {
|
||||
/// SECURITY: Bot token is zeroized on drop.
|
||||
bot_token: Zeroizing<String>,
|
||||
/// Port for the inbound webhook HTTP listener.
|
||||
webhook_port: u16,
|
||||
/// HTTP client for outbound API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl PumbleAdapter {
|
||||
/// Create a new Pumble adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `bot_token` - Pumble Bot access token.
|
||||
/// * `webhook_port` - Local port to bind the webhook listener on.
|
||||
pub fn new(bot_token: String, webhook_port: u16) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
bot_token: Zeroizing::new(bot_token),
|
||||
webhook_port,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate credentials by fetching bot info from the Pumble API.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/auth.test", PUMBLE_API_BASE);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.bot_token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Pumble authentication failed".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let bot_id = body["user_id"]
|
||||
.as_str()
|
||||
.or_else(|| body["bot_id"].as_str())
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
Ok(bot_id)
|
||||
}
|
||||
|
||||
/// Send a text message to a Pumble channel.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/messages", PUMBLE_API_BASE);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"channel": channel_id,
|
||||
"text": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.bot_token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Pumble API error {status}: {resp_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an inbound Pumble event JSON into a `ChannelMessage`.
|
||||
///
|
||||
/// Returns `None` for non-message events, URL verification challenges,
|
||||
/// or messages from the bot itself.
|
||||
fn parse_pumble_event(event: &serde_json::Value, own_bot_id: &str) -> Option<ChannelMessage> {
|
||||
let event_type = event["type"].as_str().unwrap_or("");
|
||||
|
||||
// Handle URL verification challenge
|
||||
if event_type == "url_verification" {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Only process message events
|
||||
if event_type != "message" && event_type != "message.new" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let text = event["text"]
|
||||
.as_str()
|
||||
.or_else(|| event["message"]["text"].as_str())
|
||||
.unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let user_id = event["user"]
|
||||
.as_str()
|
||||
.or_else(|| event["user_id"].as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Skip messages from the bot itself
|
||||
if user_id == own_bot_id {
|
||||
return None;
|
||||
}
|
||||
|
||||
let channel_id = event["channel"]
|
||||
.as_str()
|
||||
.or_else(|| event["channel_id"].as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let ts = event["ts"]
|
||||
.as_str()
|
||||
.or_else(|| event["timestamp"].as_str())
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let thread_ts = event["thread_ts"].as_str().map(String::from);
|
||||
let user_name = event["user_name"].as_str().unwrap_or("unknown");
|
||||
let channel_type = event["channel_type"].as_str().unwrap_or("channel");
|
||||
let is_group = channel_type != "im";
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"user_id".to_string(),
|
||||
serde_json::Value::String(user_id.to_string()),
|
||||
);
|
||||
if !ts.is_empty() {
|
||||
metadata.insert("ts".to_string(), serde_json::Value::String(ts.clone()));
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("pumble".to_string()),
|
||||
platform_message_id: ts,
|
||||
sender: ChannelUser {
|
||||
platform_id: channel_id,
|
||||
display_name: user_name.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: thread_ts,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for PumbleAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"pumble"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("pumble".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let bot_id = self.validate().await?;
|
||||
info!("Pumble adapter authenticated (bot_id: {bot_id})");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let own_bot_id = bot_id;
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Build the axum webhook router
|
||||
let bot_id_shared = Arc::new(own_bot_id);
|
||||
let tx_shared = Arc::new(tx);
|
||||
|
||||
let app = axum::Router::new().route(
|
||||
"/pumble/events",
|
||||
axum::routing::post({
|
||||
let bot_id = Arc::clone(&bot_id_shared);
|
||||
let tx = Arc::clone(&tx_shared);
|
||||
move |body: axum::extract::Json<serde_json::Value>| {
|
||||
let bot_id = Arc::clone(&bot_id);
|
||||
let tx = Arc::clone(&tx);
|
||||
async move {
|
||||
// Handle URL verification challenge
|
||||
if body["type"].as_str() == Some("url_verification") {
|
||||
let challenge =
|
||||
body["challenge"].as_str().unwrap_or("").to_string();
|
||||
return (
|
||||
axum::http::StatusCode::OK,
|
||||
axum::Json(serde_json::json!({ "challenge": challenge })),
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(msg) = parse_pumble_event(&body, &bot_id) {
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
|
||||
(
|
||||
axum::http::StatusCode::OK,
|
||||
axum::Json(serde_json::json!({})),
|
||||
)
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
info!("Pumble webhook server listening on {addr}");
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("Pumble webhook bind failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let server = axum::serve(listener, app);
|
||||
|
||||
tokio::select! {
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
warn!("Pumble webhook server error: {e}");
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Pumble adapter shutting down");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_in_thread(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
thread_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(text) => text,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
let url = format!("{}/messages", PUMBLE_API_BASE);
|
||||
let chunks = split_message(&text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"channel": user.platform_id,
|
||||
"text": chunk,
|
||||
"thread_ts": thread_id,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.bot_token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Pumble thread reply error {status}: {resp_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Pumble does not expose a public typing indicator API for bots
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pumble_adapter_creation() {
|
||||
let adapter = PumbleAdapter::new("test-bot-token".to_string(), 8080);
|
||||
assert_eq!(adapter.name(), "pumble");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("pumble".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pumble_token_zeroized() {
|
||||
let adapter = PumbleAdapter::new("secret-pumble-token".to_string(), 8080);
|
||||
assert_eq!(adapter.bot_token.as_str(), "secret-pumble-token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pumble_webhook_port() {
|
||||
let adapter = PumbleAdapter::new("token".to_string(), 9999);
|
||||
assert_eq!(adapter.webhook_port, 9999);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_pumble_event_message() {
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"text": "Hello from Pumble!",
|
||||
"user": "U12345",
|
||||
"channel": "C67890",
|
||||
"ts": "1234567890.123456",
|
||||
"user_name": "alice",
|
||||
"channel_type": "channel"
|
||||
});
|
||||
|
||||
let msg = parse_pumble_event(&event, "BOT001").unwrap();
|
||||
assert_eq!(msg.sender.display_name, "alice");
|
||||
assert_eq!(msg.sender.platform_id, "C67890");
|
||||
assert!(msg.is_group);
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from Pumble!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_pumble_event_command() {
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"text": "/help agents",
|
||||
"user": "U12345",
|
||||
"channel": "C67890",
|
||||
"ts": "ts1",
|
||||
"user_name": "bob"
|
||||
});
|
||||
|
||||
let msg = parse_pumble_event(&event, "BOT001").unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "help");
|
||||
assert_eq!(args, &["agents"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_pumble_event_skip_bot() {
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"text": "Bot message",
|
||||
"user": "BOT001",
|
||||
"channel": "C67890",
|
||||
"ts": "ts1"
|
||||
});
|
||||
|
||||
let msg = parse_pumble_event(&event, "BOT001");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_pumble_event_url_verification() {
|
||||
let event = serde_json::json!({
|
||||
"type": "url_verification",
|
||||
"challenge": "abc123"
|
||||
});
|
||||
|
||||
let msg = parse_pumble_event(&event, "BOT001");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_pumble_event_dm() {
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"text": "Direct message",
|
||||
"user": "U12345",
|
||||
"channel": "D11111",
|
||||
"ts": "ts2",
|
||||
"user_name": "carol",
|
||||
"channel_type": "im"
|
||||
});
|
||||
|
||||
let msg = parse_pumble_event(&event, "BOT001").unwrap();
|
||||
assert!(!msg.is_group);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_pumble_event_with_thread() {
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"text": "Thread reply",
|
||||
"user": "U12345",
|
||||
"channel": "C67890",
|
||||
"ts": "ts3",
|
||||
"thread_ts": "ts1",
|
||||
"user_name": "dave"
|
||||
});
|
||||
|
||||
let msg = parse_pumble_event(&event, "BOT001").unwrap();
|
||||
assert_eq!(msg.thread_id.as_deref(), Some("ts1"));
|
||||
}
|
||||
}
|
||||
704
crates/openfang-channels/src/reddit.rs
Normal file
704
crates/openfang-channels/src/reddit.rs
Normal file
@@ -0,0 +1,704 @@
|
||||
//! Reddit API channel adapter.
|
||||
//!
|
||||
//! Uses the Reddit OAuth2 API for both sending and receiving messages. Authentication
|
||||
//! is performed via the OAuth2 password grant (script app) at
|
||||
//! `https://www.reddit.com/api/v1/access_token`. Subreddit comments are polled
|
||||
//! periodically via `GET /r/{subreddit}/comments/new.json`. Replies are sent via
|
||||
//! `POST /api/comment`.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Reddit OAuth2 token endpoint.
|
||||
const REDDIT_TOKEN_URL: &str = "https://www.reddit.com/api/v1/access_token";
|
||||
|
||||
/// Reddit OAuth API base URL.
|
||||
const REDDIT_API_BASE: &str = "https://oauth.reddit.com";
|
||||
|
||||
/// Reddit poll interval (seconds). Reddit API rate limit is ~60 requests/minute.
|
||||
const POLL_INTERVAL_SECS: u64 = 5;
|
||||
|
||||
/// Maximum Reddit comment/message text length.
|
||||
const MAX_MESSAGE_LEN: usize = 10000;
|
||||
|
||||
/// OAuth2 token refresh buffer — refresh 5 minutes before actual expiry.
|
||||
const TOKEN_REFRESH_BUFFER_SECS: u64 = 300;
|
||||
|
||||
/// Custom User-Agent required by Reddit API guidelines.
|
||||
const USER_AGENT: &str = "openfang:v1.0.0 (by /u/openfang-bot)";
|
||||
|
||||
/// Reddit OAuth2 API adapter.
|
||||
///
|
||||
/// Inbound messages are received by polling subreddit comment streams.
|
||||
/// Outbound messages are sent as comment replies via the Reddit API.
|
||||
/// OAuth2 password grant is used for authentication (script-type app).
|
||||
pub struct RedditAdapter {
|
||||
/// Reddit OAuth2 client ID (from the app settings page).
|
||||
client_id: String,
|
||||
/// SECURITY: Reddit OAuth2 client secret, zeroized on drop.
|
||||
client_secret: Zeroizing<String>,
|
||||
/// Reddit username for OAuth2 password grant.
|
||||
username: String,
|
||||
/// SECURITY: Reddit password, zeroized on drop.
|
||||
password: Zeroizing<String>,
|
||||
/// Subreddits to monitor for new comments.
|
||||
subreddits: Vec<String>,
|
||||
/// HTTP client for API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Cached OAuth2 bearer token and its expiry instant.
|
||||
cached_token: Arc<RwLock<Option<(String, Instant)>>>,
|
||||
/// Track last seen comment IDs to avoid duplicates.
|
||||
seen_comments: Arc<RwLock<HashMap<String, bool>>>,
|
||||
}
|
||||
|
||||
impl RedditAdapter {
|
||||
/// Create a new Reddit adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `client_id` - Reddit OAuth2 app client ID.
|
||||
/// * `client_secret` - Reddit OAuth2 app client secret.
|
||||
/// * `username` - Reddit account username.
|
||||
/// * `password` - Reddit account password.
|
||||
/// * `subreddits` - Subreddits to monitor for new comments.
|
||||
pub fn new(
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
username: String,
|
||||
password: String,
|
||||
subreddits: Vec<String>,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
|
||||
// Build HTTP client with required User-Agent
|
||||
let client = reqwest::Client::builder()
|
||||
.user_agent(USER_AGENT)
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.unwrap_or_else(|_| reqwest::Client::new());
|
||||
|
||||
Self {
|
||||
client_id,
|
||||
client_secret: Zeroizing::new(client_secret),
|
||||
username,
|
||||
password: Zeroizing::new(password),
|
||||
subreddits,
|
||||
client,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
cached_token: Arc::new(RwLock::new(None)),
|
||||
seen_comments: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtain a valid OAuth2 bearer token, refreshing if expired or missing.
|
||||
async fn get_token(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
// Check cache first
|
||||
{
|
||||
let guard = self.cached_token.read().await;
|
||||
if let Some((ref token, expiry)) = *guard {
|
||||
if Instant::now() < expiry {
|
||||
return Ok(token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch a new token via password grant
|
||||
let params = [
|
||||
("grant_type", "password"),
|
||||
("username", &self.username),
|
||||
("password", self.password.as_str()),
|
||||
];
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(REDDIT_TOKEN_URL)
|
||||
.basic_auth(&self.client_id, Some(self.client_secret.as_str()))
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Reddit OAuth2 token error {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let access_token = body["access_token"]
|
||||
.as_str()
|
||||
.ok_or("Missing access_token in Reddit OAuth2 response")?
|
||||
.to_string();
|
||||
let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
|
||||
|
||||
// Cache with a safety buffer
|
||||
let expiry = Instant::now()
|
||||
+ Duration::from_secs(expires_in.saturating_sub(TOKEN_REFRESH_BUFFER_SECS));
|
||||
*self.cached_token.write().await = Some((access_token.clone(), expiry));
|
||||
|
||||
Ok(access_token)
|
||||
}
|
||||
|
||||
/// Validate credentials by calling `/api/v1/me`.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let token = self.get_token().await?;
|
||||
let url = format!("{}/api/v1/me", REDDIT_API_BASE);
|
||||
|
||||
let resp = self.client.get(&url).bearer_auth(&token).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Reddit authentication failed {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let username = body["name"].as_str().unwrap_or("unknown").to_string();
|
||||
Ok(username)
|
||||
}
|
||||
|
||||
/// Post a comment reply to a Reddit thing (comment or post).
|
||||
async fn api_comment(
|
||||
&self,
|
||||
parent_fullname: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let token = self.get_token().await?;
|
||||
let url = format!("{}/api/comment", REDDIT_API_BASE);
|
||||
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
// Reddit only allows one reply per parent, so join chunks
|
||||
let full_text = chunks.join("\n\n---\n\n");
|
||||
|
||||
let params = [
|
||||
("api_type", "json"),
|
||||
("thing_id", parent_fullname),
|
||||
("text", &full_text),
|
||||
];
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&token)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Reddit comment API error {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
let resp_body: serde_json::Value = resp.json().await?;
|
||||
if let Some(errors) = resp_body["json"]["errors"].as_array() {
|
||||
if !errors.is_empty() {
|
||||
warn!("Reddit comment errors: {:?}", errors);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a subreddit name is in the monitored list.
|
||||
#[allow(dead_code)]
|
||||
fn is_monitored_subreddit(&self, subreddit: &str) -> bool {
|
||||
self.subreddits.iter().any(|s| {
|
||||
s.eq_ignore_ascii_case(subreddit)
|
||||
|| s.trim_start_matches("r/").eq_ignore_ascii_case(subreddit)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Reddit comment JSON object into a `ChannelMessage`.
|
||||
fn parse_reddit_comment(comment: &serde_json::Value, own_username: &str) -> Option<ChannelMessage> {
|
||||
let data = comment.get("data")?;
|
||||
let kind = comment["kind"].as_str().unwrap_or("");
|
||||
|
||||
// Only process comments (t1) not posts (t3)
|
||||
if kind != "t1" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let author = data["author"].as_str().unwrap_or("");
|
||||
// Skip own comments
|
||||
if author.eq_ignore_ascii_case(own_username) {
|
||||
return None;
|
||||
}
|
||||
// Skip deleted/removed
|
||||
if author == "[deleted]" || author == "[removed]" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let body = data["body"].as_str().unwrap_or("");
|
||||
if body.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let comment_id = data["id"].as_str().unwrap_or("").to_string();
|
||||
let fullname = data["name"].as_str().unwrap_or("").to_string(); // e.g., "t1_abc123"
|
||||
let subreddit = data["subreddit"].as_str().unwrap_or("").to_string();
|
||||
let link_id = data["link_id"].as_str().unwrap_or("").to_string();
|
||||
let parent_id = data["parent_id"].as_str().unwrap_or("").to_string();
|
||||
let permalink = data["permalink"].as_str().unwrap_or("").to_string();
|
||||
|
||||
let content = if body.starts_with('/') {
|
||||
let parts: Vec<&str> = body.splitn(2, ' ').collect();
|
||||
let cmd_name = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(body.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("fullname".to_string(), serde_json::Value::String(fullname));
|
||||
metadata.insert(
|
||||
"subreddit".to_string(),
|
||||
serde_json::Value::String(subreddit.clone()),
|
||||
);
|
||||
metadata.insert("link_id".to_string(), serde_json::Value::String(link_id));
|
||||
metadata.insert(
|
||||
"parent_id".to_string(),
|
||||
serde_json::Value::String(parent_id),
|
||||
);
|
||||
if !permalink.is_empty() {
|
||||
metadata.insert(
|
||||
"permalink".to_string(),
|
||||
serde_json::Value::String(permalink),
|
||||
);
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("reddit".to_string()),
|
||||
platform_message_id: comment_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: author.to_string(),
|
||||
display_name: author.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true, // Subreddit comments are inherently public/group
|
||||
thread_id: Some(subreddit),
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for RedditAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"reddit"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("reddit".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let username = self.validate().await?;
|
||||
info!("Reddit adapter authenticated as u/{username}");
|
||||
|
||||
if self.subreddits.is_empty() {
|
||||
return Err("Reddit adapter: no subreddits configured to monitor".into());
|
||||
}
|
||||
|
||||
info!(
|
||||
"Reddit adapter monitoring {} subreddit(s): {}",
|
||||
self.subreddits.len(),
|
||||
self.subreddits.join(", ")
|
||||
);
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let subreddits = self.subreddits.clone();
|
||||
let client = self.client.clone();
|
||||
let cached_token = Arc::clone(&self.cached_token);
|
||||
let seen_comments = Arc::clone(&self.seen_comments);
|
||||
let own_username = username;
|
||||
let client_id = self.client_id.clone();
|
||||
let client_secret = self.client_secret.clone();
|
||||
let password = self.password.clone();
|
||||
let reddit_username = self.username.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let poll_interval = Duration::from_secs(POLL_INTERVAL_SECS);
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Reddit adapter shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(poll_interval) => {}
|
||||
}
|
||||
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Get current token
|
||||
let token = {
|
||||
let guard = cached_token.read().await;
|
||||
match &*guard {
|
||||
Some((token, expiry)) if Instant::now() < *expiry => token.clone(),
|
||||
_ => {
|
||||
// Token expired, need to refresh
|
||||
drop(guard);
|
||||
let params = [
|
||||
("grant_type", "password"),
|
||||
("username", reddit_username.as_str()),
|
||||
("password", password.as_str()),
|
||||
];
|
||||
match client
|
||||
.post(REDDIT_TOKEN_URL)
|
||||
.basic_auth(&client_id, Some(client_secret.as_str()))
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let body: serde_json::Value =
|
||||
resp.json().await.unwrap_or_default();
|
||||
let tok =
|
||||
body["access_token"].as_str().unwrap_or("").to_string();
|
||||
if tok.is_empty() {
|
||||
warn!("Reddit: failed to refresh token");
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
tokio::time::sleep(backoff).await;
|
||||
continue;
|
||||
}
|
||||
let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
|
||||
let expiry = Instant::now()
|
||||
+ Duration::from_secs(
|
||||
expires_in.saturating_sub(TOKEN_REFRESH_BUFFER_SECS),
|
||||
);
|
||||
*cached_token.write().await = Some((tok.clone(), expiry));
|
||||
tok
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Reddit: token refresh error: {e}");
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
tokio::time::sleep(backoff).await;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Poll each subreddit for new comments
|
||||
for subreddit in &subreddits {
|
||||
let sub = subreddit.trim_start_matches("r/");
|
||||
let url = format!("{}/r/{}/comments?limit=25&sort=new", REDDIT_API_BASE, sub);
|
||||
|
||||
let resp = match client.get(&url).bearer_auth(&token).send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Reddit: comment fetch error for r/{sub}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
warn!(
|
||||
"Reddit: comment fetch returned {} for r/{sub}",
|
||||
resp.status()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let body: serde_json::Value = match resp.json().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("Reddit: failed to parse comments for r/{sub}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let children = match body["data"]["children"].as_array() {
|
||||
Some(arr) => arr,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
for child in children {
|
||||
let comment_id = child["data"]["id"].as_str().unwrap_or("").to_string();
|
||||
|
||||
// Skip already-seen comments
|
||||
{
|
||||
let seen = seen_comments.read().await;
|
||||
if seen.contains_key(&comment_id) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(msg) = parse_reddit_comment(child, &own_username) {
|
||||
// Mark as seen
|
||||
seen_comments.write().await.insert(comment_id, true);
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Successful poll resets backoff
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
// Periodically trim seen_comments to prevent unbounded growth
|
||||
{
|
||||
let mut seen = seen_comments.write().await;
|
||||
if seen.len() > 10_000 {
|
||||
// Keep recent half (crude eviction)
|
||||
let to_remove: Vec<String> = seen.keys().take(5_000).cloned().collect();
|
||||
for key in to_remove {
|
||||
seen.remove(&key);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Reddit polling loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
// user.platform_id is the author username; we need the fullname from metadata
|
||||
// If not available, we can't reply directly
|
||||
self.api_comment(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_comment(
|
||||
&user.platform_id,
|
||||
"(Unsupported content type — Reddit only supports text replies)",
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Reddit does not support typing indicators
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_reddit_adapter_creation() {
|
||||
let adapter = RedditAdapter::new(
|
||||
"client-id".to_string(),
|
||||
"client-secret".to_string(),
|
||||
"bot-user".to_string(),
|
||||
"bot-pass".to_string(),
|
||||
vec!["rust".to_string(), "programming".to_string()],
|
||||
);
|
||||
assert_eq!(adapter.name(), "reddit");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("reddit".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reddit_subreddit_list() {
|
||||
let adapter = RedditAdapter::new(
|
||||
"cid".to_string(),
|
||||
"csec".to_string(),
|
||||
"usr".to_string(),
|
||||
"pwd".to_string(),
|
||||
vec![
|
||||
"rust".to_string(),
|
||||
"programming".to_string(),
|
||||
"r/openfang".to_string(),
|
||||
],
|
||||
);
|
||||
assert_eq!(adapter.subreddits.len(), 3);
|
||||
assert!(adapter.is_monitored_subreddit("rust"));
|
||||
assert!(adapter.is_monitored_subreddit("programming"));
|
||||
assert!(adapter.is_monitored_subreddit("openfang"));
|
||||
assert!(!adapter.is_monitored_subreddit("news"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reddit_secrets_zeroized() {
|
||||
let adapter = RedditAdapter::new(
|
||||
"cid".to_string(),
|
||||
"secret-value".to_string(),
|
||||
"usr".to_string(),
|
||||
"pass-value".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.client_secret.as_str(), "secret-value");
|
||||
assert_eq!(adapter.password.as_str(), "pass-value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_reddit_comment_basic() {
|
||||
let comment = serde_json::json!({
|
||||
"kind": "t1",
|
||||
"data": {
|
||||
"id": "abc123",
|
||||
"name": "t1_abc123",
|
||||
"author": "alice",
|
||||
"body": "Hello from Reddit!",
|
||||
"subreddit": "rust",
|
||||
"link_id": "t3_xyz789",
|
||||
"parent_id": "t3_xyz789",
|
||||
"permalink": "/r/rust/comments/xyz789/title/abc123/"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_reddit_comment(&comment, "bot-user").unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Custom("reddit".to_string()));
|
||||
assert_eq!(msg.sender.display_name, "alice");
|
||||
assert!(msg.is_group);
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from Reddit!"));
|
||||
assert_eq!(msg.thread_id, Some("rust".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_reddit_comment_skips_self() {
|
||||
let comment = serde_json::json!({
|
||||
"kind": "t1",
|
||||
"data": {
|
||||
"id": "abc123",
|
||||
"name": "t1_abc123",
|
||||
"author": "bot-user",
|
||||
"body": "My own comment",
|
||||
"subreddit": "rust",
|
||||
"link_id": "t3_xyz",
|
||||
"parent_id": "t3_xyz"
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_reddit_comment(&comment, "bot-user").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_reddit_comment_skips_deleted() {
|
||||
let comment = serde_json::json!({
|
||||
"kind": "t1",
|
||||
"data": {
|
||||
"id": "abc123",
|
||||
"name": "t1_abc123",
|
||||
"author": "[deleted]",
|
||||
"body": "[deleted]",
|
||||
"subreddit": "rust",
|
||||
"link_id": "t3_xyz",
|
||||
"parent_id": "t3_xyz"
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_reddit_comment(&comment, "bot-user").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_reddit_comment_command() {
|
||||
let comment = serde_json::json!({
|
||||
"kind": "t1",
|
||||
"data": {
|
||||
"id": "cmd1",
|
||||
"name": "t1_cmd1",
|
||||
"author": "alice",
|
||||
"body": "/ask what is rust?",
|
||||
"subreddit": "programming",
|
||||
"link_id": "t3_xyz",
|
||||
"parent_id": "t3_xyz"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_reddit_comment(&comment, "bot-user").unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "ask");
|
||||
assert_eq!(args, &["what", "is", "rust?"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_reddit_comment_skips_posts() {
|
||||
let comment = serde_json::json!({
|
||||
"kind": "t3",
|
||||
"data": {
|
||||
"id": "post1",
|
||||
"name": "t3_post1",
|
||||
"author": "alice",
|
||||
"body": "This is a post",
|
||||
"subreddit": "rust"
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_reddit_comment(&comment, "bot-user").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_reddit_comment_metadata() {
|
||||
let comment = serde_json::json!({
|
||||
"kind": "t1",
|
||||
"data": {
|
||||
"id": "meta1",
|
||||
"name": "t1_meta1",
|
||||
"author": "alice",
|
||||
"body": "Test metadata",
|
||||
"subreddit": "rust",
|
||||
"link_id": "t3_link1",
|
||||
"parent_id": "t1_parent1",
|
||||
"permalink": "/r/rust/comments/link1/title/meta1/"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_reddit_comment(&comment, "bot-user").unwrap();
|
||||
assert!(msg.metadata.contains_key("fullname"));
|
||||
assert!(msg.metadata.contains_key("subreddit"));
|
||||
assert!(msg.metadata.contains_key("link_id"));
|
||||
assert!(msg.metadata.contains_key("parent_id"));
|
||||
assert!(msg.metadata.contains_key("permalink"));
|
||||
}
|
||||
}
|
||||
704
crates/openfang-channels/src/revolt.rs
Normal file
704
crates/openfang-channels/src/revolt.rs
Normal file
@@ -0,0 +1,704 @@
|
||||
//! Revolt API channel adapter.
|
||||
//!
|
||||
//! Uses the Revolt REST API for sending messages and WebSocket (Bonfire protocol)
|
||||
//! for real-time message reception. Authentication uses the bot token via
|
||||
//! `x-bot-token` header on REST calls and `Authenticate` frame on WebSocket.
|
||||
//! Revolt is an open-source, Discord-like chat platform.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::{SinkExt, Stream, StreamExt};
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{debug, info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Default Revolt API URL.
|
||||
const DEFAULT_API_URL: &str = "https://api.revolt.chat";
|
||||
|
||||
/// Default Revolt WebSocket URL.
|
||||
const DEFAULT_WS_URL: &str = "wss://ws.revolt.chat";
|
||||
|
||||
/// Maximum Revolt message text length (characters).
|
||||
const MAX_MESSAGE_LEN: usize = 2000;
|
||||
|
||||
/// Maximum backoff duration for WebSocket reconnection.
|
||||
const MAX_BACKOFF_SECS: u64 = 60;
|
||||
|
||||
/// WebSocket heartbeat interval (seconds). Revolt expects pings every 30s.
|
||||
const HEARTBEAT_INTERVAL_SECS: u64 = 20;
|
||||
|
||||
/// Revolt API adapter using WebSocket (Bonfire) + REST.
|
||||
///
|
||||
/// Inbound messages are received via WebSocket connection to the Revolt
|
||||
/// Bonfire gateway. Outbound messages are sent via the REST API.
|
||||
/// The adapter handles automatic reconnection with exponential backoff.
|
||||
pub struct RevoltAdapter {
|
||||
/// SECURITY: Bot token is zeroized on drop to prevent memory disclosure.
|
||||
bot_token: Zeroizing<String>,
|
||||
/// Revolt API URL (default: `"https://api.revolt.chat"`).
|
||||
api_url: String,
|
||||
/// Revolt WebSocket URL (default: "wss://ws.revolt.chat").
|
||||
ws_url: String,
|
||||
/// Restrict to specific channel IDs (empty = all channels the bot is in).
|
||||
allowed_channels: Vec<String>,
|
||||
/// HTTP client for outbound REST API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Bot's own user ID (populated after authentication).
|
||||
bot_user_id: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl RevoltAdapter {
|
||||
/// Create a new Revolt adapter with default API and WebSocket URLs.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `bot_token` - Revolt bot token for authentication.
|
||||
pub fn new(bot_token: String) -> Self {
|
||||
Self::with_urls(
|
||||
bot_token,
|
||||
DEFAULT_API_URL.to_string(),
|
||||
DEFAULT_WS_URL.to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a new Revolt adapter with custom API and WebSocket URLs.
|
||||
pub fn with_urls(bot_token: String, api_url: String, ws_url: String) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let api_url = api_url.trim_end_matches('/').to_string();
|
||||
let ws_url = ws_url.trim_end_matches('/').to_string();
|
||||
Self {
|
||||
bot_token: Zeroizing::new(bot_token),
|
||||
api_url,
|
||||
ws_url,
|
||||
allowed_channels: Vec::new(),
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
bot_user_id: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Revolt adapter with channel restrictions.
|
||||
pub fn with_channels(bot_token: String, allowed_channels: Vec<String>) -> Self {
|
||||
let mut adapter = Self::new(bot_token);
|
||||
adapter.allowed_channels = allowed_channels;
|
||||
adapter
|
||||
}
|
||||
|
||||
/// Add the bot token header to a request builder.
|
||||
fn auth_header(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
builder.header("x-bot-token", self.bot_token.as_str())
|
||||
}
|
||||
|
||||
/// Validate the bot token by fetching the bot's own user info.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/users/@me", self.api_url);
|
||||
let resp = self.auth_header(self.client.get(&url)).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Revolt authentication failed {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let user_id = body["_id"].as_str().unwrap_or("").to_string();
|
||||
let username = body["username"].as_str().unwrap_or("unknown").to_string();
|
||||
|
||||
*self.bot_user_id.write().await = Some(user_id.clone());
|
||||
|
||||
Ok(format!("{username} ({user_id})"))
|
||||
}
|
||||
|
||||
/// Send a text message to a Revolt channel via REST API.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/channels/{}/messages", self.api_url, channel_id);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"content": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.auth_header(self.client.post(&url))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Revolt send message error {status}: {resp_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a reply to a specific message in a Revolt channel.
|
||||
#[allow(dead_code)]
|
||||
async fn api_reply_message(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
message_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/channels/{}/messages", self.api_url, channel_id);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for (i, chunk) in chunks.iter().enumerate() {
|
||||
let mut body = serde_json::json!({
|
||||
"content": chunk,
|
||||
});
|
||||
|
||||
// Only add reply reference to the first message
|
||||
if i == 0 {
|
||||
body["replies"] = serde_json::json!([{
|
||||
"id": message_id,
|
||||
"mention": false,
|
||||
}]);
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.auth_header(self.client.post(&url))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
warn!("Revolt reply error {status}: {resp_body}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a channel is in the allowed list (empty = allow all).
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_channel(&self, channel_id: &str) -> bool {
|
||||
self.allowed_channels.is_empty() || self.allowed_channels.iter().any(|c| c == channel_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Revolt WebSocket "Message" event into a `ChannelMessage`.
|
||||
fn parse_revolt_message(
|
||||
data: &serde_json::Value,
|
||||
bot_user_id: &str,
|
||||
allowed_channels: &[String],
|
||||
) -> Option<ChannelMessage> {
|
||||
let msg_type = data["type"].as_str().unwrap_or("");
|
||||
if msg_type != "Message" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let author = data["author"].as_str().unwrap_or("");
|
||||
// Skip own messages
|
||||
if author == bot_user_id {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Skip system messages (author = "00000000000000000000000000")
|
||||
if author.chars().all(|c| c == '0') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let channel_id = data["channel"].as_str().unwrap_or("").to_string();
|
||||
// Channel filter
|
||||
if !allowed_channels.is_empty() && !allowed_channels.iter().any(|c| c == &channel_id) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let content = data["content"].as_str().unwrap_or("");
|
||||
if content.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let msg_id = data["_id"].as_str().unwrap_or("").to_string();
|
||||
let nonce = data["nonce"].as_str().unwrap_or("").to_string();
|
||||
|
||||
let msg_content = if content.starts_with('/') {
|
||||
let parts: Vec<&str> = content.splitn(2, ' ').collect();
|
||||
let cmd_name = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(content.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"channel_id".to_string(),
|
||||
serde_json::Value::String(channel_id.clone()),
|
||||
);
|
||||
metadata.insert(
|
||||
"author_id".to_string(),
|
||||
serde_json::Value::String(author.to_string()),
|
||||
);
|
||||
if !nonce.is_empty() {
|
||||
metadata.insert("nonce".to_string(), serde_json::Value::String(nonce));
|
||||
}
|
||||
|
||||
// Check for reply references
|
||||
if let Some(replies) = data.get("replies") {
|
||||
metadata.insert("replies".to_string(), replies.clone());
|
||||
}
|
||||
|
||||
// Check for attachments
|
||||
if let Some(attachments) = data.get("attachments") {
|
||||
if let Some(arr) = attachments.as_array() {
|
||||
if !arr.is_empty() {
|
||||
metadata.insert("attachments".to_string(), attachments.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("revolt".to_string()),
|
||||
platform_message_id: msg_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: channel_id,
|
||||
display_name: author.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true, // Revolt channels are inherently group-based
|
||||
thread_id: None,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for RevoltAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"revolt"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("revolt".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let bot_info = self.validate().await?;
|
||||
info!("Revolt adapter authenticated as {bot_info}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let ws_url = self.ws_url.clone();
|
||||
let bot_token = self.bot_token.clone();
|
||||
let bot_user_id = Arc::clone(&self.bot_user_id);
|
||||
let allowed_channels = self.allowed_channels.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
let own_id = {
|
||||
let guard = bot_user_id.read().await;
|
||||
guard.clone().unwrap_or_default()
|
||||
};
|
||||
|
||||
// Connect to WebSocket
|
||||
let ws_connect_url = format!("{}/?format=json", ws_url);
|
||||
|
||||
let ws_stream = match tokio_tungstenite::connect_async(&ws_connect_url).await {
|
||||
Ok((stream, _)) => {
|
||||
info!("Revolt WebSocket connected");
|
||||
backoff = Duration::from_secs(1);
|
||||
stream
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Revolt WebSocket connection failed: {e}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(MAX_BACKOFF_SECS));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
|
||||
|
||||
// Send Authenticate frame
|
||||
let auth_msg = serde_json::json!({
|
||||
"type": "Authenticate",
|
||||
"token": bot_token.as_str(),
|
||||
});
|
||||
|
||||
if let Err(e) = ws_sink
|
||||
.send(tokio_tungstenite::tungstenite::Message::Text(
|
||||
auth_msg.to_string(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
warn!("Revolt: failed to send auth frame: {e}");
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut heartbeat_interval =
|
||||
tokio::time::interval(Duration::from_secs(HEARTBEAT_INTERVAL_SECS));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Revolt adapter shutting down");
|
||||
let _ = ws_sink.close().await;
|
||||
return;
|
||||
}
|
||||
_ = heartbeat_interval.tick() => {
|
||||
// Send Ping to keep connection alive
|
||||
let ping = serde_json::json!({
|
||||
"type": "Ping",
|
||||
"data": 0,
|
||||
});
|
||||
if let Err(e) = ws_sink
|
||||
.send(tokio_tungstenite::tungstenite::Message::Text(
|
||||
ping.to_string(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
warn!("Revolt: heartbeat send failed: {e}");
|
||||
break;
|
||||
}
|
||||
}
|
||||
msg = ws_stream_rx.next() => {
|
||||
match msg {
|
||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Text(text))) => {
|
||||
let data: serde_json::Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
let event_type = data["type"].as_str().unwrap_or("");
|
||||
|
||||
match event_type {
|
||||
"Authenticated" => {
|
||||
info!("Revolt: successfully authenticated");
|
||||
}
|
||||
"Ready" => {
|
||||
info!("Revolt: ready, receiving events");
|
||||
}
|
||||
"Pong" => {
|
||||
debug!("Revolt: pong received");
|
||||
}
|
||||
"Message" => {
|
||||
if let Some(channel_msg) = parse_revolt_message(
|
||||
&data,
|
||||
&own_id,
|
||||
&allowed_channels,
|
||||
) {
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
"Error" => {
|
||||
let error = data["error"].as_str().unwrap_or("unknown");
|
||||
warn!("Revolt WebSocket error: {error}");
|
||||
if error == "InvalidSession" || error == "NotAuthenticated" {
|
||||
break; // Reconnect
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Ignore other event types (typing, presence, etc.)
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(Ok(tokio_tungstenite::tungstenite::Message::Close(_))) => {
|
||||
info!("Revolt WebSocket closed by server");
|
||||
break;
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
warn!("Revolt WebSocket error: {e}");
|
||||
break;
|
||||
}
|
||||
None => {
|
||||
info!("Revolt WebSocket stream ended");
|
||||
break;
|
||||
}
|
||||
_ => {} // Binary, Ping, Pong frames
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Backoff before reconnection
|
||||
warn!(
|
||||
"Revolt WebSocket disconnected, reconnecting in {}s",
|
||||
backoff.as_secs()
|
||||
);
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(MAX_BACKOFF_SECS));
|
||||
}
|
||||
|
||||
info!("Revolt WebSocket loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
ChannelContent::Image { url, caption } => {
|
||||
// Revolt supports embedding images in messages via markdown
|
||||
let markdown = if let Some(cap) = caption {
|
||||
format!("", cap, url)
|
||||
} else {
|
||||
format!("", url)
|
||||
};
|
||||
self.api_send_message(&user.platform_id, &markdown).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Revolt typing indicator via REST
|
||||
let url = format!("{}/channels/{}/typing", self.api_url, user.platform_id);
|
||||
|
||||
let _ = self.auth_header(self.client.post(&url)).send().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_revolt_adapter_creation() {
|
||||
let adapter = RevoltAdapter::new("bot-token-123".to_string());
|
||||
assert_eq!(adapter.name(), "revolt");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("revolt".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_revolt_default_urls() {
|
||||
let adapter = RevoltAdapter::new("tok".to_string());
|
||||
assert_eq!(adapter.api_url, "https://api.revolt.chat");
|
||||
assert_eq!(adapter.ws_url, "wss://ws.revolt.chat");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_revolt_custom_urls() {
|
||||
let adapter = RevoltAdapter::with_urls(
|
||||
"tok".to_string(),
|
||||
"https://api.revolt.example.com/".to_string(),
|
||||
"wss://ws.revolt.example.com/".to_string(),
|
||||
);
|
||||
assert_eq!(adapter.api_url, "https://api.revolt.example.com");
|
||||
assert_eq!(adapter.ws_url, "wss://ws.revolt.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_revolt_with_channels() {
|
||||
let adapter = RevoltAdapter::with_channels(
|
||||
"tok".to_string(),
|
||||
vec!["ch1".to_string(), "ch2".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_channel("ch1"));
|
||||
assert!(adapter.is_allowed_channel("ch2"));
|
||||
assert!(!adapter.is_allowed_channel("ch3"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_revolt_empty_channels_allows_all() {
|
||||
let adapter = RevoltAdapter::new("tok".to_string());
|
||||
assert!(adapter.is_allowed_channel("any-channel"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_revolt_auth_header() {
|
||||
let adapter = RevoltAdapter::new("my-revolt-token".to_string());
|
||||
let builder = adapter.client.get("https://example.com");
|
||||
let builder = adapter.auth_header(builder);
|
||||
let request = builder.build().unwrap();
|
||||
assert_eq!(
|
||||
request.headers().get("x-bot-token").unwrap(),
|
||||
"my-revolt-token"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_revolt_message_basic() {
|
||||
let data = serde_json::json!({
|
||||
"type": "Message",
|
||||
"_id": "msg-123",
|
||||
"channel": "ch-456",
|
||||
"author": "user-789",
|
||||
"content": "Hello from Revolt!",
|
||||
"nonce": "nonce-abc"
|
||||
});
|
||||
|
||||
let msg = parse_revolt_message(&data, "bot-id", &[]).unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Custom("revolt".to_string()));
|
||||
assert_eq!(msg.platform_message_id, "msg-123");
|
||||
assert_eq!(msg.sender.platform_id, "ch-456");
|
||||
assert!(msg.is_group);
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from Revolt!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_revolt_message_skips_bot() {
|
||||
let data = serde_json::json!({
|
||||
"type": "Message",
|
||||
"_id": "msg-1",
|
||||
"channel": "ch-1",
|
||||
"author": "bot-id",
|
||||
"content": "Bot message"
|
||||
});
|
||||
|
||||
assert!(parse_revolt_message(&data, "bot-id", &[]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_revolt_message_skips_system() {
|
||||
let data = serde_json::json!({
|
||||
"type": "Message",
|
||||
"_id": "msg-1",
|
||||
"channel": "ch-1",
|
||||
"author": "00000000000000000000000000",
|
||||
"content": "System message"
|
||||
});
|
||||
|
||||
assert!(parse_revolt_message(&data, "bot-id", &[]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_revolt_message_channel_filter() {
|
||||
let data = serde_json::json!({
|
||||
"type": "Message",
|
||||
"_id": "msg-1",
|
||||
"channel": "ch-not-allowed",
|
||||
"author": "user-1",
|
||||
"content": "Filtered out"
|
||||
});
|
||||
|
||||
assert!(parse_revolt_message(&data, "bot-id", &["ch-allowed".to_string()]).is_none());
|
||||
|
||||
// Same message but with allowed channel
|
||||
let data2 = serde_json::json!({
|
||||
"type": "Message",
|
||||
"_id": "msg-2",
|
||||
"channel": "ch-allowed",
|
||||
"author": "user-1",
|
||||
"content": "Allowed"
|
||||
});
|
||||
|
||||
assert!(parse_revolt_message(&data2, "bot-id", &["ch-allowed".to_string()]).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_revolt_message_command() {
|
||||
let data = serde_json::json!({
|
||||
"type": "Message",
|
||||
"_id": "msg-cmd",
|
||||
"channel": "ch-1",
|
||||
"author": "user-1",
|
||||
"content": "/agent deploy-bot"
|
||||
});
|
||||
|
||||
let msg = parse_revolt_message(&data, "bot-id", &[]).unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "agent");
|
||||
assert_eq!(args, &["deploy-bot"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_revolt_message_non_message_type() {
|
||||
let data = serde_json::json!({
|
||||
"type": "ChannelStartTyping",
|
||||
"id": "ch-1",
|
||||
"user": "user-1"
|
||||
});
|
||||
|
||||
assert!(parse_revolt_message(&data, "bot-id", &[]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_revolt_message_empty_content() {
|
||||
let data = serde_json::json!({
|
||||
"type": "Message",
|
||||
"_id": "msg-empty",
|
||||
"channel": "ch-1",
|
||||
"author": "user-1",
|
||||
"content": ""
|
||||
});
|
||||
|
||||
assert!(parse_revolt_message(&data, "bot-id", &[]).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_revolt_message_metadata() {
|
||||
let data = serde_json::json!({
|
||||
"type": "Message",
|
||||
"_id": "msg-meta",
|
||||
"channel": "ch-1",
|
||||
"author": "user-1",
|
||||
"content": "With metadata",
|
||||
"nonce": "nonce-1",
|
||||
"replies": ["msg-replied-to"],
|
||||
"attachments": [{"_id": "att-1", "filename": "file.txt"}]
|
||||
});
|
||||
|
||||
let msg = parse_revolt_message(&data, "bot-id", &[]).unwrap();
|
||||
assert!(msg.metadata.contains_key("channel_id"));
|
||||
assert!(msg.metadata.contains_key("author_id"));
|
||||
assert!(msg.metadata.contains_key("nonce"));
|
||||
assert!(msg.metadata.contains_key("replies"));
|
||||
assert!(msg.metadata.contains_key("attachments"));
|
||||
}
|
||||
}
|
||||
450
crates/openfang-channels/src/rocketchat.rs
Normal file
450
crates/openfang-channels/src/rocketchat.rs
Normal file
@@ -0,0 +1,450 @@
|
||||
//! Rocket.Chat channel adapter.
|
||||
//!
|
||||
//! Uses the Rocket.Chat REST API for sending messages and long-polling
|
||||
//! `channels.history` for receiving new messages. Authentication is performed
|
||||
//! via personal access token with `X-Auth-Token` and `X-User-Id` headers.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const POLL_INTERVAL_SECS: u64 = 2;
|
||||
const MAX_MESSAGE_LEN: usize = 4096;
|
||||
|
||||
/// Rocket.Chat channel adapter using REST API with long-polling.
|
||||
pub struct RocketChatAdapter {
|
||||
/// Rocket.Chat server URL (e.g., `"https://chat.example.com"`).
|
||||
server_url: String,
|
||||
/// SECURITY: Auth token is zeroized on drop.
|
||||
token: Zeroizing<String>,
|
||||
/// User ID for API authentication.
|
||||
user_id: String,
|
||||
/// Channel IDs (room IDs) to poll (empty = all).
|
||||
allowed_channels: Vec<String>,
|
||||
/// HTTP client.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Last polled timestamp per channel for incremental history fetch.
|
||||
last_timestamps: Arc<RwLock<HashMap<String, String>>>,
|
||||
}
|
||||
|
||||
impl RocketChatAdapter {
|
||||
/// Create a new Rocket.Chat adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `server_url` - Base URL of the Rocket.Chat instance.
|
||||
/// * `token` - Personal access token for authentication.
|
||||
/// * `user_id` - User ID associated with the token.
|
||||
/// * `allowed_channels` - Room IDs to listen on (empty = discover from server).
|
||||
pub fn new(
|
||||
server_url: String,
|
||||
token: String,
|
||||
user_id: String,
|
||||
allowed_channels: Vec<String>,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let server_url = server_url.trim_end_matches('/').to_string();
|
||||
Self {
|
||||
server_url,
|
||||
token: Zeroizing::new(token),
|
||||
user_id,
|
||||
allowed_channels,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
last_timestamps: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add auth headers to a request builder.
|
||||
fn auth_headers(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
builder
|
||||
.header("X-Auth-Token", self.token.as_str())
|
||||
.header("X-User-Id", &self.user_id)
|
||||
}
|
||||
|
||||
/// Validate credentials by calling `/api/v1/me`.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/api/v1/me", self.server_url);
|
||||
let resp = self.auth_headers(self.client.get(&url)).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Rocket.Chat authentication failed".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let username = body["username"].as_str().unwrap_or("unknown").to_string();
|
||||
Ok(username)
|
||||
}
|
||||
|
||||
/// Send a text message to a Rocket.Chat room.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
room_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/api/v1/chat.sendMessage", self.server_url);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"message": {
|
||||
"rid": room_id,
|
||||
"msg": chunk,
|
||||
}
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.auth_headers(self.client.post(&url))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Rocket.Chat API error {status}: {body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a channel is in the allowed list.
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_channel(&self, channel_id: &str) -> bool {
|
||||
self.allowed_channels.is_empty() || self.allowed_channels.iter().any(|c| c == channel_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for RocketChatAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"rocketchat"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("rocketchat".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let username = self.validate().await?;
|
||||
info!("Rocket.Chat adapter authenticated as {username}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let server_url = self.server_url.clone();
|
||||
let token = self.token.clone();
|
||||
let user_id = self.user_id.clone();
|
||||
let own_username = username;
|
||||
let allowed_channels = self.allowed_channels.clone();
|
||||
let client = self.client.clone();
|
||||
let last_timestamps = Arc::clone(&self.last_timestamps);
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Determine channels to poll
|
||||
let channels_to_poll = if allowed_channels.is_empty() {
|
||||
// Fetch joined channels
|
||||
let url = format!("{server_url}/api/v1/channels.list.joined?count=100");
|
||||
match client
|
||||
.get(&url)
|
||||
.header("X-Auth-Token", token.as_str())
|
||||
.header("X-User-Id", &user_id)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body["channels"]
|
||||
.as_array()
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|c| c["_id"].as_str().map(String::from))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Rocket.Chat: failed to list channels: {e}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
allowed_channels
|
||||
};
|
||||
|
||||
if channels_to_poll.is_empty() {
|
||||
warn!("Rocket.Chat: no channels to poll");
|
||||
return;
|
||||
}
|
||||
|
||||
info!("Rocket.Chat: polling {} channel(s)", channels_to_poll.len());
|
||||
|
||||
// Initialize timestamps to "now" so we only get new messages
|
||||
{
|
||||
let now = Utc::now().to_rfc3339();
|
||||
let mut ts = last_timestamps.write().await;
|
||||
for ch in &channels_to_poll {
|
||||
ts.entry(ch.clone()).or_insert_with(|| now.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let poll_interval = Duration::from_secs(POLL_INTERVAL_SECS);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Rocket.Chat adapter shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(poll_interval) => {}
|
||||
}
|
||||
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
for channel_id in &channels_to_poll {
|
||||
let oldest = {
|
||||
let ts = last_timestamps.read().await;
|
||||
ts.get(channel_id).cloned().unwrap_or_default()
|
||||
};
|
||||
|
||||
let url = format!(
|
||||
"{}/api/v1/channels.history?roomId={}&oldest={}&count=50",
|
||||
server_url, channel_id, oldest
|
||||
);
|
||||
|
||||
let resp = match client
|
||||
.get(&url)
|
||||
.header("X-Auth-Token", token.as_str())
|
||||
.header("X-User-Id", &user_id)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Rocket.Chat: history fetch error for {channel_id}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
warn!(
|
||||
"Rocket.Chat: history fetch returned {} for {channel_id}",
|
||||
resp.status()
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let body: serde_json::Value = match resp.json().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("Rocket.Chat: failed to parse history: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let messages = match body["messages"].as_array() {
|
||||
Some(arr) => arr,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let mut newest_ts = oldest.clone();
|
||||
|
||||
for msg in messages {
|
||||
let sender_username = msg["u"]["username"].as_str().unwrap_or("");
|
||||
// Skip own messages
|
||||
if sender_username == own_username {
|
||||
continue;
|
||||
}
|
||||
|
||||
let text = msg["msg"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg_id = msg["_id"].as_str().unwrap_or("").to_string();
|
||||
let msg_ts = msg["ts"].as_str().unwrap_or("").to_string();
|
||||
let sender_id = msg["u"]["_id"].as_str().unwrap_or("").to_string();
|
||||
let thread_id = msg["tmid"].as_str().map(String::from);
|
||||
|
||||
// Track newest timestamp
|
||||
if msg_ts > newest_ts {
|
||||
newest_ts = msg_ts;
|
||||
}
|
||||
|
||||
let msg_content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("rocketchat".to_string()),
|
||||
platform_message_id: msg_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: channel_id.clone(),
|
||||
display_name: sender_username.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true,
|
||||
thread_id,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"sender_id".to_string(),
|
||||
serde_json::Value::String(sender_id),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Update the last timestamp for this channel
|
||||
if newest_ts != oldest {
|
||||
last_timestamps
|
||||
.write()
|
||||
.await
|
||||
.insert(channel_id.clone(), newest_ts);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Rocket.Chat polling loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Rocket.Chat supports typing notifications via REST
|
||||
let url = format!("{}/api/v1/chat.sendMessage", self.server_url);
|
||||
// There's no dedicated typing endpoint in REST; this is a no-op.
|
||||
// Real typing would need the realtime API (WebSocket/DDP).
|
||||
let _ = url;
|
||||
let _ = user;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_rocketchat_adapter_creation() {
|
||||
let adapter = RocketChatAdapter::new(
|
||||
"https://chat.example.com".to_string(),
|
||||
"test-token".to_string(),
|
||||
"user123".to_string(),
|
||||
vec!["room1".to_string()],
|
||||
);
|
||||
assert_eq!(adapter.name(), "rocketchat");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("rocketchat".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rocketchat_server_url_normalization() {
|
||||
let adapter = RocketChatAdapter::new(
|
||||
"https://chat.example.com/".to_string(),
|
||||
"tok".to_string(),
|
||||
"uid".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.server_url, "https://chat.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rocketchat_allowed_channels() {
|
||||
let adapter = RocketChatAdapter::new(
|
||||
"https://chat.example.com".to_string(),
|
||||
"tok".to_string(),
|
||||
"uid".to_string(),
|
||||
vec!["room1".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_channel("room1"));
|
||||
assert!(!adapter.is_allowed_channel("room2"));
|
||||
|
||||
let open = RocketChatAdapter::new(
|
||||
"https://chat.example.com".to_string(),
|
||||
"tok".to_string(),
|
||||
"uid".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert!(open.is_allowed_channel("any-room"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rocketchat_auth_headers() {
|
||||
let adapter = RocketChatAdapter::new(
|
||||
"https://chat.example.com".to_string(),
|
||||
"my-token".to_string(),
|
||||
"user-42".to_string(),
|
||||
vec![],
|
||||
);
|
||||
// Verify the builder can be constructed (headers are added internally)
|
||||
let builder = adapter.client.get("https://example.com");
|
||||
let builder = adapter.auth_headers(builder);
|
||||
let request = builder.build().unwrap();
|
||||
assert_eq!(request.headers().get("X-Auth-Token").unwrap(), "my-token");
|
||||
assert_eq!(request.headers().get("X-User-Id").unwrap(), "user-42");
|
||||
}
|
||||
}
|
||||
576
crates/openfang-channels/src/router.rs
Normal file
576
crates/openfang-channels/src/router.rs
Normal file
@@ -0,0 +1,576 @@
|
||||
//! Agent router — routes incoming channel messages to the correct agent.
|
||||
|
||||
use crate::types::ChannelType;
|
||||
use dashmap::DashMap;
|
||||
use openfang_types::agent::AgentId;
|
||||
use openfang_types::config::{AgentBinding, BroadcastConfig, BroadcastStrategy};
|
||||
use std::sync::Mutex;
|
||||
use tracing::warn;
|
||||
|
||||
/// Context for evaluating binding match rules against incoming messages.
|
||||
#[derive(Debug, Default)]
|
||||
pub struct BindingContext {
|
||||
/// Channel type string (e.g., "telegram", "discord").
|
||||
pub channel: String,
|
||||
/// Account/bot ID within the channel.
|
||||
pub account_id: Option<String>,
|
||||
/// Peer/user ID (platform_user_id).
|
||||
pub peer_id: String,
|
||||
/// Guild/server ID.
|
||||
pub guild_id: Option<String>,
|
||||
/// User's roles.
|
||||
pub roles: Vec<String>,
|
||||
}
|
||||
|
||||
/// Routes incoming messages to the correct agent.
|
||||
///
|
||||
/// Routing priority: bindings (most specific first) > direct routes > user defaults > system default.
|
||||
pub struct AgentRouter {
|
||||
/// Default agent per user (keyed by openfang_user or platform_id).
|
||||
user_defaults: DashMap<String, AgentId>,
|
||||
/// Direct routes: (channel_type_key, platform_user_id) -> AgentId.
|
||||
direct_routes: DashMap<(String, String), AgentId>,
|
||||
/// System-wide default agent.
|
||||
default_agent: Option<AgentId>,
|
||||
/// Sorted bindings (most specific first). Uses Mutex for runtime updates via Arc.
|
||||
bindings: Mutex<Vec<(AgentBinding, String)>>,
|
||||
/// Broadcast configuration. Uses Mutex for runtime updates via Arc.
|
||||
broadcast: Mutex<BroadcastConfig>,
|
||||
/// Agent name -> AgentId cache for binding resolution.
|
||||
agent_name_cache: DashMap<String, AgentId>,
|
||||
}
|
||||
|
||||
impl AgentRouter {
|
||||
/// Create a new router.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
user_defaults: DashMap::new(),
|
||||
direct_routes: DashMap::new(),
|
||||
default_agent: None,
|
||||
bindings: Mutex::new(Vec::new()),
|
||||
broadcast: Mutex::new(BroadcastConfig::default()),
|
||||
agent_name_cache: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the system-wide default agent.
|
||||
pub fn set_default(&mut self, agent_id: AgentId) {
|
||||
self.default_agent = Some(agent_id);
|
||||
}
|
||||
|
||||
/// Set a user's default agent.
|
||||
pub fn set_user_default(&self, user_key: String, agent_id: AgentId) {
|
||||
self.user_defaults.insert(user_key, agent_id);
|
||||
}
|
||||
|
||||
/// Set a direct route for a specific (channel, user) pair.
|
||||
pub fn set_direct_route(
|
||||
&self,
|
||||
channel_key: String,
|
||||
platform_user_id: String,
|
||||
agent_id: AgentId,
|
||||
) {
|
||||
self.direct_routes
|
||||
.insert((channel_key, platform_user_id), agent_id);
|
||||
}
|
||||
|
||||
/// Load agent bindings from configuration. Sorts by specificity (most specific first).
|
||||
pub fn load_bindings(&self, bindings: &[AgentBinding]) {
|
||||
let mut sorted: Vec<(AgentBinding, String)> = bindings
|
||||
.iter()
|
||||
.map(|b| (b.clone(), b.agent.clone()))
|
||||
.collect();
|
||||
// Sort by specificity descending (most specific first)
|
||||
sorted.sort_by(|a, b| {
|
||||
b.0.match_rule
|
||||
.specificity()
|
||||
.cmp(&a.0.match_rule.specificity())
|
||||
});
|
||||
*self.bindings.lock().unwrap_or_else(|e| e.into_inner()) = sorted;
|
||||
}
|
||||
|
||||
/// Load broadcast configuration.
|
||||
pub fn load_broadcast(&self, broadcast: BroadcastConfig) {
|
||||
*self.broadcast.lock().unwrap_or_else(|e| e.into_inner()) = broadcast;
|
||||
}
|
||||
|
||||
/// Register an agent name -> ID mapping for binding resolution.
|
||||
pub fn register_agent(&self, name: String, id: AgentId) {
|
||||
self.agent_name_cache.insert(name, id);
|
||||
}
|
||||
|
||||
/// Resolve which agent should handle a message.
|
||||
///
|
||||
/// Priority: bindings > direct route > user default > system default.
|
||||
pub fn resolve(
|
||||
&self,
|
||||
channel_type: &ChannelType,
|
||||
platform_user_id: &str,
|
||||
user_key: Option<&str>,
|
||||
) -> Option<AgentId> {
|
||||
let channel_key = format!("{channel_type:?}");
|
||||
|
||||
// 0. Check bindings (most specific first)
|
||||
let ctx = BindingContext {
|
||||
channel: channel_type_to_str(channel_type).to_string(),
|
||||
account_id: None,
|
||||
peer_id: platform_user_id.to_string(),
|
||||
guild_id: None,
|
||||
roles: Vec::new(),
|
||||
};
|
||||
if let Some(agent_id) = self.resolve_binding(&ctx) {
|
||||
return Some(agent_id);
|
||||
}
|
||||
|
||||
// 1. Check direct routes
|
||||
if let Some(agent) = self
|
||||
.direct_routes
|
||||
.get(&(channel_key, platform_user_id.to_string()))
|
||||
{
|
||||
return Some(*agent);
|
||||
}
|
||||
|
||||
// 2. Check user defaults
|
||||
if let Some(key) = user_key {
|
||||
if let Some(agent) = self.user_defaults.get(key) {
|
||||
return Some(*agent);
|
||||
}
|
||||
}
|
||||
// Also check by platform_user_id
|
||||
if let Some(agent) = self.user_defaults.get(platform_user_id) {
|
||||
return Some(*agent);
|
||||
}
|
||||
|
||||
// 3. System default
|
||||
self.default_agent
|
||||
}
|
||||
|
||||
/// Resolve with full binding context (supports guild_id, roles, account_id).
|
||||
pub fn resolve_with_context(
|
||||
&self,
|
||||
channel_type: &ChannelType,
|
||||
platform_user_id: &str,
|
||||
user_key: Option<&str>,
|
||||
ctx: &BindingContext,
|
||||
) -> Option<AgentId> {
|
||||
// 0. Check bindings first
|
||||
if let Some(agent_id) = self.resolve_binding(ctx) {
|
||||
return Some(agent_id);
|
||||
}
|
||||
// Fall back to standard resolution
|
||||
let channel_key = format!("{channel_type:?}");
|
||||
if let Some(agent) = self
|
||||
.direct_routes
|
||||
.get(&(channel_key, platform_user_id.to_string()))
|
||||
{
|
||||
return Some(*agent);
|
||||
}
|
||||
if let Some(key) = user_key {
|
||||
if let Some(agent) = self.user_defaults.get(key) {
|
||||
return Some(*agent);
|
||||
}
|
||||
}
|
||||
if let Some(agent) = self.user_defaults.get(platform_user_id) {
|
||||
return Some(*agent);
|
||||
}
|
||||
self.default_agent
|
||||
}
|
||||
|
||||
/// Resolve broadcast: returns all agents that should receive a message for the given peer.
|
||||
pub fn resolve_broadcast(&self, peer_id: &str) -> Vec<(String, Option<AgentId>)> {
|
||||
let bc = self.broadcast.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if let Some(agent_names) = bc.routes.get(peer_id) {
|
||||
agent_names
|
||||
.iter()
|
||||
.map(|name| {
|
||||
let id = self.agent_name_cache.get(name).map(|r| *r);
|
||||
(name.clone(), id)
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get broadcast strategy.
|
||||
pub fn broadcast_strategy(&self) -> BroadcastStrategy {
|
||||
self.broadcast
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.strategy
|
||||
}
|
||||
|
||||
/// Check if a peer has broadcast routing configured.
|
||||
pub fn has_broadcast(&self, peer_id: &str) -> bool {
|
||||
self.broadcast
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.routes
|
||||
.contains_key(peer_id)
|
||||
}
|
||||
|
||||
/// Get current bindings (read-only).
|
||||
pub fn bindings(&self) -> Vec<AgentBinding> {
|
||||
self.bindings
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.iter()
|
||||
.map(|(b, _)| b.clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Add a single binding at runtime.
|
||||
pub fn add_binding(&self, binding: AgentBinding) {
|
||||
let name = binding.agent.clone();
|
||||
let mut bindings = self.bindings.lock().unwrap_or_else(|e| e.into_inner());
|
||||
bindings.push((binding, name));
|
||||
// Re-sort by specificity
|
||||
bindings.sort_by(|a, b| {
|
||||
b.0.match_rule
|
||||
.specificity()
|
||||
.cmp(&a.0.match_rule.specificity())
|
||||
});
|
||||
}
|
||||
|
||||
/// Remove a binding by index (original insertion order after sort).
|
||||
pub fn remove_binding(&self, index: usize) -> Option<AgentBinding> {
|
||||
let mut bindings = self.bindings.lock().unwrap_or_else(|e| e.into_inner());
|
||||
if index < bindings.len() {
|
||||
Some(bindings.remove(index).0)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Evaluate bindings against a context, returning the first matching agent ID.
|
||||
fn resolve_binding(&self, ctx: &BindingContext) -> Option<AgentId> {
|
||||
let bindings = self.bindings.lock().unwrap_or_else(|e| e.into_inner());
|
||||
for (binding, _agent_name) in bindings.iter() {
|
||||
if self.binding_matches(binding, ctx) {
|
||||
// Look up agent by name in cache
|
||||
if let Some(id) = self.agent_name_cache.get(&binding.agent) {
|
||||
return Some(*id);
|
||||
}
|
||||
warn!(
|
||||
agent = %binding.agent,
|
||||
"Binding matched but agent not found in cache"
|
||||
);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a single binding's match_rule matches the context.
|
||||
fn binding_matches(&self, binding: &AgentBinding, ctx: &BindingContext) -> bool {
|
||||
let rule = &binding.match_rule;
|
||||
|
||||
// All specified fields must match
|
||||
if let Some(ref ch) = rule.channel {
|
||||
if ch != &ctx.channel {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(ref acc) = rule.account_id {
|
||||
if ctx.account_id.as_ref() != Some(acc) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(ref pid) = rule.peer_id {
|
||||
if pid != &ctx.peer_id {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(ref gid) = rule.guild_id {
|
||||
if ctx.guild_id.as_ref() != Some(gid) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if !rule.roles.is_empty() {
|
||||
// User must have at least one of the specified roles
|
||||
let has_role = rule.roles.iter().any(|r| ctx.roles.contains(r));
|
||||
if !has_role {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert ChannelType to lowercase string for binding matching.
|
||||
fn channel_type_to_str(ct: &ChannelType) -> &str {
|
||||
match ct {
|
||||
ChannelType::Telegram => "telegram",
|
||||
ChannelType::Discord => "discord",
|
||||
ChannelType::Slack => "slack",
|
||||
ChannelType::WhatsApp => "whatsapp",
|
||||
ChannelType::Signal => "signal",
|
||||
ChannelType::Matrix => "matrix",
|
||||
ChannelType::Email => "email",
|
||||
ChannelType::Teams => "teams",
|
||||
ChannelType::Mattermost => "mattermost",
|
||||
ChannelType::WebChat => "webchat",
|
||||
ChannelType::CLI => "cli",
|
||||
ChannelType::Custom(s) => s.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AgentRouter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_routing_priority() {
|
||||
let mut router = AgentRouter::new();
|
||||
let default_agent = AgentId::new();
|
||||
let user_agent = AgentId::new();
|
||||
let direct_agent = AgentId::new();
|
||||
|
||||
router.set_default(default_agent);
|
||||
router.set_user_default("alice".to_string(), user_agent);
|
||||
router.set_direct_route("Telegram".to_string(), "tg_123".to_string(), direct_agent);
|
||||
|
||||
// Direct route wins
|
||||
let resolved = router.resolve(&ChannelType::Telegram, "tg_123", Some("alice"));
|
||||
assert_eq!(resolved, Some(direct_agent));
|
||||
|
||||
// User default for non-direct-routed user
|
||||
let resolved = router.resolve(&ChannelType::WhatsApp, "wa_456", Some("alice"));
|
||||
assert_eq!(resolved, Some(user_agent));
|
||||
|
||||
// System default for unknown user
|
||||
let resolved = router.resolve(&ChannelType::Discord, "dc_789", None);
|
||||
assert_eq!(resolved, Some(default_agent));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_route() {
|
||||
let router = AgentRouter::new();
|
||||
let resolved = router.resolve(&ChannelType::CLI, "local", None);
|
||||
assert_eq!(resolved, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binding_channel_match() {
|
||||
let router = AgentRouter::new();
|
||||
let agent_id = AgentId::new();
|
||||
router.register_agent("coder".to_string(), agent_id);
|
||||
router.load_bindings(&[AgentBinding {
|
||||
agent: "coder".to_string(),
|
||||
match_rule: openfang_types::config::BindingMatchRule {
|
||||
channel: Some("telegram".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
}]);
|
||||
|
||||
// Should match telegram
|
||||
let resolved = router.resolve(&ChannelType::Telegram, "user1", None);
|
||||
assert_eq!(resolved, Some(agent_id));
|
||||
|
||||
// Should NOT match discord
|
||||
let resolved = router.resolve(&ChannelType::Discord, "user1", None);
|
||||
assert_eq!(resolved, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binding_peer_id_match() {
|
||||
let router = AgentRouter::new();
|
||||
let agent_id = AgentId::new();
|
||||
router.register_agent("support".to_string(), agent_id);
|
||||
router.load_bindings(&[AgentBinding {
|
||||
agent: "support".to_string(),
|
||||
match_rule: openfang_types::config::BindingMatchRule {
|
||||
peer_id: Some("vip_user".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
}]);
|
||||
|
||||
let resolved = router.resolve(&ChannelType::Discord, "vip_user", None);
|
||||
assert_eq!(resolved, Some(agent_id));
|
||||
|
||||
let resolved = router.resolve(&ChannelType::Discord, "other_user", None);
|
||||
assert_eq!(resolved, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binding_guild_and_role_match() {
|
||||
let router = AgentRouter::new();
|
||||
let agent_id = AgentId::new();
|
||||
router.register_agent("admin-bot".to_string(), agent_id);
|
||||
router.load_bindings(&[AgentBinding {
|
||||
agent: "admin-bot".to_string(),
|
||||
match_rule: openfang_types::config::BindingMatchRule {
|
||||
guild_id: Some("guild_123".to_string()),
|
||||
roles: vec!["admin".to_string()],
|
||||
..Default::default()
|
||||
},
|
||||
}]);
|
||||
|
||||
let ctx = BindingContext {
|
||||
channel: "discord".to_string(),
|
||||
peer_id: "user1".to_string(),
|
||||
guild_id: Some("guild_123".to_string()),
|
||||
roles: vec!["admin".to_string(), "user".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = router.resolve_with_context(&ChannelType::Discord, "user1", None, &ctx);
|
||||
assert_eq!(resolved, Some(agent_id));
|
||||
|
||||
// Wrong guild
|
||||
let ctx2 = BindingContext {
|
||||
channel: "discord".to_string(),
|
||||
peer_id: "user1".to_string(),
|
||||
guild_id: Some("guild_999".to_string()),
|
||||
roles: vec!["admin".to_string()],
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = router.resolve_with_context(&ChannelType::Discord, "user1", None, &ctx2);
|
||||
assert_eq!(resolved, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binding_specificity_ordering() {
|
||||
let router = AgentRouter::new();
|
||||
let general_id = AgentId::new();
|
||||
let specific_id = AgentId::new();
|
||||
router.register_agent("general".to_string(), general_id);
|
||||
router.register_agent("specific".to_string(), specific_id);
|
||||
|
||||
// Load in wrong order — less specific first
|
||||
router.load_bindings(&[
|
||||
AgentBinding {
|
||||
agent: "general".to_string(),
|
||||
match_rule: openfang_types::config::BindingMatchRule {
|
||||
channel: Some("discord".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
AgentBinding {
|
||||
agent: "specific".to_string(),
|
||||
match_rule: openfang_types::config::BindingMatchRule {
|
||||
channel: Some("discord".to_string()),
|
||||
peer_id: Some("user1".to_string()),
|
||||
guild_id: Some("guild_1".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
// More specific binding should win despite being loaded second
|
||||
let ctx = BindingContext {
|
||||
channel: "discord".to_string(),
|
||||
peer_id: "user1".to_string(),
|
||||
guild_id: Some("guild_1".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = router.resolve_with_context(&ChannelType::Discord, "user1", None, &ctx);
|
||||
assert_eq!(resolved, Some(specific_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_broadcast_routing() {
|
||||
let router = AgentRouter::new();
|
||||
let id1 = AgentId::new();
|
||||
let id2 = AgentId::new();
|
||||
router.register_agent("agent-a".to_string(), id1);
|
||||
router.register_agent("agent-b".to_string(), id2);
|
||||
|
||||
let mut routes = std::collections::HashMap::new();
|
||||
routes.insert(
|
||||
"vip_user".to_string(),
|
||||
vec!["agent-a".to_string(), "agent-b".to_string()],
|
||||
);
|
||||
router.load_broadcast(BroadcastConfig {
|
||||
strategy: BroadcastStrategy::Parallel,
|
||||
routes,
|
||||
});
|
||||
|
||||
assert!(router.has_broadcast("vip_user"));
|
||||
assert!(!router.has_broadcast("normal_user"));
|
||||
|
||||
let targets = router.resolve_broadcast("vip_user");
|
||||
assert_eq!(targets.len(), 2);
|
||||
assert_eq!(targets[0].0, "agent-a");
|
||||
assert_eq!(targets[0].1, Some(id1));
|
||||
assert_eq!(targets[1].0, "agent-b");
|
||||
assert_eq!(targets[1].1, Some(id2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_bindings_legacy_behavior() {
|
||||
let mut router = AgentRouter::new();
|
||||
let default_id = AgentId::new();
|
||||
router.set_default(default_id);
|
||||
router.load_bindings(&[]);
|
||||
|
||||
// Should fall through to system default
|
||||
let resolved = router.resolve(&ChannelType::Telegram, "user1", None);
|
||||
assert_eq!(resolved, Some(default_id));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binding_nonexistent_agent_warning() {
|
||||
let router = AgentRouter::new();
|
||||
// Don't register the agent — binding should match but resolve_binding returns None
|
||||
router.load_bindings(&[AgentBinding {
|
||||
agent: "ghost-agent".to_string(),
|
||||
match_rule: openfang_types::config::BindingMatchRule {
|
||||
channel: Some("telegram".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
}]);
|
||||
|
||||
let resolved = router.resolve(&ChannelType::Telegram, "user1", None);
|
||||
assert_eq!(resolved, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_remove_binding() {
|
||||
let router = AgentRouter::new();
|
||||
let id = AgentId::new();
|
||||
router.register_agent("test".to_string(), id);
|
||||
|
||||
assert!(router.bindings().is_empty());
|
||||
|
||||
router.add_binding(AgentBinding {
|
||||
agent: "test".to_string(),
|
||||
match_rule: openfang_types::config::BindingMatchRule {
|
||||
channel: Some("slack".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
});
|
||||
assert_eq!(router.bindings().len(), 1);
|
||||
|
||||
let removed = router.remove_binding(0);
|
||||
assert!(removed.is_some());
|
||||
assert!(router.bindings().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binding_specificity_scores() {
|
||||
use openfang_types::config::BindingMatchRule;
|
||||
|
||||
let empty = BindingMatchRule::default();
|
||||
assert_eq!(empty.specificity(), 0);
|
||||
|
||||
let channel_only = BindingMatchRule {
|
||||
channel: Some("discord".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(channel_only.specificity(), 1);
|
||||
|
||||
let full = BindingMatchRule {
|
||||
channel: Some("discord".to_string()),
|
||||
peer_id: Some("user".to_string()),
|
||||
guild_id: Some("guild".to_string()),
|
||||
roles: vec!["admin".to_string()],
|
||||
account_id: Some("bot".to_string()),
|
||||
};
|
||||
assert_eq!(full.specificity(), 17); // 8+4+2+2+1
|
||||
}
|
||||
}
|
||||
266
crates/openfang-channels/src/signal.rs
Normal file
266
crates/openfang-channels/src/signal.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
//! Signal channel adapter.
|
||||
//!
|
||||
//! Uses signal-cli's JSON-RPC daemon mode for sending/receiving messages.
|
||||
//! Requires signal-cli to be installed and registered with a phone number.
|
||||
|
||||
use crate::types::{ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{debug, info};
|
||||
|
||||
const POLL_INTERVAL: Duration = Duration::from_secs(2);
|
||||
|
||||
/// Signal adapter via signal-cli REST API.
|
||||
pub struct SignalAdapter {
|
||||
/// URL of signal-cli REST API (e.g., "http://localhost:8080").
|
||||
api_url: String,
|
||||
/// Registered phone number.
|
||||
phone_number: String,
|
||||
/// HTTP client.
|
||||
client: reqwest::Client,
|
||||
/// Allowed phone numbers (empty = allow all).
|
||||
allowed_users: Vec<String>,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl SignalAdapter {
|
||||
/// Create a new Signal adapter.
|
||||
pub fn new(api_url: String, phone_number: String, allowed_users: Vec<String>) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
api_url,
|
||||
phone_number,
|
||||
client: reqwest::Client::new(),
|
||||
allowed_users,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a message via signal-cli REST API.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
recipient: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/v2/send", self.api_url);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"message": text,
|
||||
"number": self.phone_number,
|
||||
"recipients": [recipient],
|
||||
});
|
||||
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Signal API error {status}: {body}").into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Receive messages from signal-cli REST API.
|
||||
#[allow(dead_code)]
|
||||
async fn receive_messages(&self) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/v1/receive/{}", self.api_url, self.phone_number);
|
||||
|
||||
let resp = self.client.get(&url).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let messages: Vec<serde_json::Value> = resp.json().await.unwrap_or_default();
|
||||
Ok(messages)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed(&self, phone: &str) -> bool {
|
||||
self.allowed_users.is_empty() || self.allowed_users.iter().any(|u| u == phone)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for SignalAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"signal"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Signal
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let api_url = self.api_url.clone();
|
||||
let phone_number = self.phone_number.clone();
|
||||
let allowed_users = self.allowed_users.clone();
|
||||
let client = self.client.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
info!(
|
||||
"Starting Signal adapter (polling {} every {:?})",
|
||||
api_url, POLL_INTERVAL
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Signal adapter shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(POLL_INTERVAL) => {}
|
||||
}
|
||||
|
||||
// Poll for new messages
|
||||
let url = format!("{}/v1/receive/{}", api_url, phone_number);
|
||||
let resp = match client.get(&url).send().await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
debug!("Signal poll error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let messages: Vec<serde_json::Value> = match resp.json().await {
|
||||
Ok(m) => m,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
for msg in messages {
|
||||
let envelope = msg.get("envelope").unwrap_or(&msg);
|
||||
|
||||
let source = envelope["source"].as_str().unwrap_or("").to_string();
|
||||
|
||||
if source.is_empty() || source == phone_number {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !allowed_users.is_empty() && !allowed_users.iter().any(|u| u == &source) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract text from dataMessage
|
||||
let text = envelope["dataMessage"]["message"].as_str().unwrap_or("");
|
||||
|
||||
if text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let source_name = envelope["sourceName"]
|
||||
.as_str()
|
||||
.unwrap_or(&source)
|
||||
.to_string();
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Signal,
|
||||
platform_message_id: envelope["timestamp"]
|
||||
.as_u64()
|
||||
.unwrap_or(0)
|
||||
.to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: source.clone(),
|
||||
display_name: source_name,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: false,
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_signal_adapter_creation() {
|
||||
let adapter = SignalAdapter::new(
|
||||
"http://localhost:8080".to_string(),
|
||||
"+1234567890".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.name(), "signal");
|
||||
assert_eq!(adapter.channel_type(), ChannelType::Signal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_signal_allowed_check() {
|
||||
let adapter = SignalAdapter::new(
|
||||
"http://localhost:8080".to_string(),
|
||||
"+1234567890".to_string(),
|
||||
vec!["+9876543210".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed("+9876543210"));
|
||||
assert!(!adapter.is_allowed("+1111111111"));
|
||||
}
|
||||
}
|
||||
575
crates/openfang-channels/src/slack.rs
Normal file
575
crates/openfang-channels/src/slack.rs
Normal file
@@ -0,0 +1,575 @@
|
||||
//! Slack Socket Mode adapter for the OpenFang channel bridge.
|
||||
//!
|
||||
//! Uses Slack Socket Mode WebSocket (app token) for receiving events and the
|
||||
//! Web API (bot token) for sending responses. No external Slack crate.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::{SinkExt, Stream, StreamExt};
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const SLACK_API_BASE: &str = "https://slack.com/api";
|
||||
const MAX_BACKOFF: Duration = Duration::from_secs(60);
|
||||
const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
|
||||
const SLACK_MSG_LIMIT: usize = 3000;
|
||||
|
||||
/// Slack Socket Mode adapter.
|
||||
pub struct SlackAdapter {
|
||||
/// SECURITY: Tokens are zeroized on drop to prevent memory disclosure.
|
||||
app_token: Zeroizing<String>,
|
||||
bot_token: Zeroizing<String>,
|
||||
client: reqwest::Client,
|
||||
allowed_channels: Vec<String>,
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Bot's own user ID (populated after auth.test).
|
||||
bot_user_id: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl SlackAdapter {
|
||||
pub fn new(app_token: String, bot_token: String, allowed_channels: Vec<String>) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
app_token: Zeroizing::new(app_token),
|
||||
bot_token: Zeroizing::new(bot_token),
|
||||
client: reqwest::Client::new(),
|
||||
allowed_channels,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
bot_user_id: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate the bot token by calling auth.test.
|
||||
async fn validate_bot_token(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let resp: serde_json::Value = self
|
||||
.client
|
||||
.post(format!("{SLACK_API_BASE}/auth.test"))
|
||||
.header(
|
||||
"Authorization",
|
||||
format!("Bearer {}", self.bot_token.as_str()),
|
||||
)
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if resp["ok"].as_bool() != Some(true) {
|
||||
let err = resp["error"].as_str().unwrap_or("unknown error");
|
||||
return Err(format!("Slack auth.test failed: {err}").into());
|
||||
}
|
||||
|
||||
let user_id = resp["user_id"].as_str().unwrap_or("unknown").to_string();
|
||||
Ok(user_id)
|
||||
}
|
||||
|
||||
/// Send a message to a Slack channel via chat.postMessage.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let chunks = split_message(text, SLACK_MSG_LIMIT);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"channel": channel_id,
|
||||
"text": chunk,
|
||||
});
|
||||
|
||||
let resp: serde_json::Value = self
|
||||
.client
|
||||
.post(format!("{SLACK_API_BASE}/chat.postMessage"))
|
||||
.header(
|
||||
"Authorization",
|
||||
format!("Bearer {}", self.bot_token.as_str()),
|
||||
)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if resp["ok"].as_bool() != Some(true) {
|
||||
let err = resp["error"].as_str().unwrap_or("unknown");
|
||||
warn!("Slack chat.postMessage failed: {err}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for SlackAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"slack"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Slack
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate bot token first
|
||||
let bot_user_id_val = self.validate_bot_token().await?;
|
||||
*self.bot_user_id.write().await = Some(bot_user_id_val.clone());
|
||||
info!("Slack bot authenticated (user_id: {bot_user_id_val})");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
|
||||
let app_token = self.app_token.clone();
|
||||
let bot_user_id = self.bot_user_id.clone();
|
||||
let allowed_channels = self.allowed_channels.clone();
|
||||
let client = self.client.clone();
|
||||
let mut shutdown = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = INITIAL_BACKOFF;
|
||||
|
||||
loop {
|
||||
if *shutdown.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Get a fresh WebSocket URL
|
||||
let ws_url_result = get_socket_mode_url(&client, &app_token)
|
||||
.await
|
||||
.map_err(|e| e.to_string());
|
||||
let ws_url = match ws_url_result {
|
||||
Ok(url) => url,
|
||||
Err(err_msg) => {
|
||||
warn!("Slack: failed to get WebSocket URL: {err_msg}, retrying in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Connecting to Slack Socket Mode...");
|
||||
|
||||
let ws_result = tokio_tungstenite::connect_async(&ws_url).await;
|
||||
let ws_stream = match ws_result {
|
||||
Ok((stream, _)) => stream,
|
||||
Err(e) => {
|
||||
warn!("Slack WebSocket connection failed: {e}, retrying in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
backoff = INITIAL_BACKOFF;
|
||||
info!("Slack Socket Mode connected");
|
||||
|
||||
let (mut ws_tx, mut ws_rx) = ws_stream.split();
|
||||
|
||||
let should_reconnect = 'inner: loop {
|
||||
let msg = tokio::select! {
|
||||
msg = ws_rx.next() => msg,
|
||||
_ = shutdown.changed() => {
|
||||
if *shutdown.borrow() {
|
||||
let _ = ws_tx.close().await;
|
||||
return;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let msg = match msg {
|
||||
Some(Ok(m)) => m,
|
||||
Some(Err(e)) => {
|
||||
warn!("Slack WebSocket error: {e}");
|
||||
break 'inner true;
|
||||
}
|
||||
None => {
|
||||
info!("Slack WebSocket closed");
|
||||
break 'inner true;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match msg {
|
||||
tokio_tungstenite::tungstenite::Message::Text(t) => t,
|
||||
tokio_tungstenite::tungstenite::Message::Close(_) => {
|
||||
info!("Slack Socket Mode closed by server");
|
||||
break 'inner true;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let payload: serde_json::Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
warn!("Slack: failed to parse message: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let envelope_type = payload["type"].as_str().unwrap_or("");
|
||||
|
||||
match envelope_type {
|
||||
"hello" => {
|
||||
debug!("Slack Socket Mode hello received");
|
||||
}
|
||||
|
||||
"events_api" => {
|
||||
// Acknowledge the envelope
|
||||
let envelope_id = payload["envelope_id"].as_str().unwrap_or("");
|
||||
if !envelope_id.is_empty() {
|
||||
let ack = serde_json::json!({ "envelope_id": envelope_id });
|
||||
if let Err(e) = ws_tx
|
||||
.send(tokio_tungstenite::tungstenite::Message::Text(
|
||||
serde_json::to_string(&ack).unwrap(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
error!("Slack: failed to send ack: {e}");
|
||||
break 'inner true;
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the event
|
||||
let event = &payload["payload"]["event"];
|
||||
if let Some(msg) =
|
||||
parse_slack_event(event, &bot_user_id, &allowed_channels).await
|
||||
{
|
||||
debug!(
|
||||
"Slack message from {}: {:?}",
|
||||
msg.sender.display_name, msg.content
|
||||
);
|
||||
if tx.send(msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
"disconnect" => {
|
||||
let reason = payload["reason"].as_str().unwrap_or("unknown");
|
||||
info!("Slack disconnect request: {reason}");
|
||||
break 'inner true;
|
||||
}
|
||||
|
||||
_ => {
|
||||
debug!("Slack envelope type: {envelope_type}");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !should_reconnect || *shutdown.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
warn!("Slack: reconnecting in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
}
|
||||
|
||||
info!("Slack Socket Mode loop stopped");
|
||||
});
|
||||
|
||||
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let channel_id = &user.platform_id;
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(channel_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(channel_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to get Socket Mode WebSocket URL.
|
||||
async fn get_socket_mode_url(
|
||||
client: &reqwest::Client,
|
||||
app_token: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let resp: serde_json::Value = client
|
||||
.post(format!("{SLACK_API_BASE}/apps.connections.open"))
|
||||
.header("Authorization", format!("Bearer {app_token}"))
|
||||
.header("Content-Type", "application/x-www-form-urlencoded")
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if resp["ok"].as_bool() != Some(true) {
|
||||
let err = resp["error"].as_str().unwrap_or("unknown error");
|
||||
return Err(format!("Slack apps.connections.open failed: {err}").into());
|
||||
}
|
||||
|
||||
resp["url"]
|
||||
.as_str()
|
||||
.map(String::from)
|
||||
.ok_or_else(|| "Missing 'url' in connections.open response".into())
|
||||
}
|
||||
|
||||
/// Parse a Slack event into a `ChannelMessage`.
|
||||
async fn parse_slack_event(
|
||||
event: &serde_json::Value,
|
||||
bot_user_id: &Arc<RwLock<Option<String>>>,
|
||||
allowed_channels: &[String],
|
||||
) -> Option<ChannelMessage> {
|
||||
let event_type = event["type"].as_str()?;
|
||||
if event_type != "message" {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Handle message_changed subtype: extract inner message
|
||||
let subtype = event["subtype"].as_str();
|
||||
let (msg_data, is_edit) = match subtype {
|
||||
Some("message_changed") => {
|
||||
// Edited messages have the new content in event.message
|
||||
match event.get("message") {
|
||||
Some(inner) => (inner, true),
|
||||
None => return None,
|
||||
}
|
||||
}
|
||||
Some(_) => return None, // Skip other subtypes (joins, leaves, etc.)
|
||||
None => (event, false),
|
||||
};
|
||||
|
||||
// Filter out bot's own messages
|
||||
if msg_data.get("bot_id").is_some() {
|
||||
return None;
|
||||
}
|
||||
let user_id = msg_data["user"]
|
||||
.as_str()
|
||||
.or_else(|| event["user"].as_str())?;
|
||||
if let Some(ref bid) = *bot_user_id.read().await {
|
||||
if user_id == bid {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let channel = event["channel"].as_str()?;
|
||||
|
||||
// Filter by allowed channels
|
||||
if !allowed_channels.is_empty() && !allowed_channels.contains(&channel.to_string()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let text = msg_data["text"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ts = if is_edit {
|
||||
msg_data["ts"]
|
||||
.as_str()
|
||||
.unwrap_or(event["ts"].as_str().unwrap_or("0"))
|
||||
} else {
|
||||
event["ts"].as_str().unwrap_or("0")
|
||||
};
|
||||
|
||||
// Parse timestamp (Slack uses epoch.microseconds format)
|
||||
let timestamp = ts
|
||||
.split('.')
|
||||
.next()
|
||||
.and_then(|s| s.parse::<i64>().ok())
|
||||
.and_then(|epoch| chrono::DateTime::from_timestamp(epoch, 0))
|
||||
.unwrap_or_else(chrono::Utc::now);
|
||||
|
||||
// Parse commands (messages starting with /)
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd_name = &parts[0][1..];
|
||||
let args = if parts.len() > 1 {
|
||||
parts[1].split_whitespace().map(String::from).collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Slack,
|
||||
platform_message_id: ts.to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: channel.to_string(),
|
||||
display_name: user_id.to_string(), // Slack user IDs as display name
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp,
|
||||
is_group: true,
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_slack_event_basic() {
|
||||
let bot_id = Arc::new(RwLock::new(Some("B123".to_string())));
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"user": "U456",
|
||||
"channel": "C789",
|
||||
"text": "Hello agent!",
|
||||
"ts": "1700000000.000100"
|
||||
});
|
||||
|
||||
let msg = parse_slack_event(&event, &bot_id, &[]).await.unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Slack);
|
||||
assert_eq!(msg.sender.platform_id, "C789");
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello agent!"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_slack_event_filters_bot() {
|
||||
let bot_id = Arc::new(RwLock::new(Some("B123".to_string())));
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"user": "U456",
|
||||
"channel": "C789",
|
||||
"text": "Bot message",
|
||||
"ts": "1700000000.000100",
|
||||
"bot_id": "B999"
|
||||
});
|
||||
|
||||
let msg = parse_slack_event(&event, &bot_id, &[]).await;
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_slack_event_filters_own_user() {
|
||||
let bot_id = Arc::new(RwLock::new(Some("U456".to_string())));
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"user": "U456",
|
||||
"channel": "C789",
|
||||
"text": "My message",
|
||||
"ts": "1700000000.000100"
|
||||
});
|
||||
|
||||
let msg = parse_slack_event(&event, &bot_id, &[]).await;
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_slack_event_channel_filter() {
|
||||
let bot_id = Arc::new(RwLock::new(None));
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"user": "U456",
|
||||
"channel": "C789",
|
||||
"text": "Hello",
|
||||
"ts": "1700000000.000100"
|
||||
});
|
||||
|
||||
// Not in allowed channels
|
||||
let msg =
|
||||
parse_slack_event(&event, &bot_id, &["C111".to_string(), "C222".to_string()]).await;
|
||||
assert!(msg.is_none());
|
||||
|
||||
// In allowed channels
|
||||
let msg = parse_slack_event(&event, &bot_id, &["C789".to_string()]).await;
|
||||
assert!(msg.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_slack_event_skips_other_subtypes() {
|
||||
let bot_id = Arc::new(RwLock::new(None));
|
||||
// Non-message_changed subtypes should still be filtered
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"subtype": "channel_join",
|
||||
"user": "U456",
|
||||
"channel": "C789",
|
||||
"text": "joined",
|
||||
"ts": "1700000000.000100"
|
||||
});
|
||||
|
||||
let msg = parse_slack_event(&event, &bot_id, &[]).await;
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_slack_command() {
|
||||
let bot_id = Arc::new(RwLock::new(None));
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"user": "U456",
|
||||
"channel": "C789",
|
||||
"text": "/agent hello-world",
|
||||
"ts": "1700000000.000100"
|
||||
});
|
||||
|
||||
let msg = parse_slack_event(&event, &bot_id, &[]).await.unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "agent");
|
||||
assert_eq!(args, &["hello-world"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_slack_event_message_changed() {
|
||||
let bot_id = Arc::new(RwLock::new(Some("B123".to_string())));
|
||||
let event = serde_json::json!({
|
||||
"type": "message",
|
||||
"subtype": "message_changed",
|
||||
"channel": "C789",
|
||||
"message": {
|
||||
"user": "U456",
|
||||
"text": "Edited message text",
|
||||
"ts": "1700000000.000100"
|
||||
},
|
||||
"ts": "1700000001.000200"
|
||||
});
|
||||
|
||||
let msg = parse_slack_event(&event, &bot_id, &[]).await.unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Slack);
|
||||
assert_eq!(msg.sender.platform_id, "C789");
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Edited message text"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slack_adapter_creation() {
|
||||
let adapter = SlackAdapter::new(
|
||||
"xapp-test".to_string(),
|
||||
"xoxb-test".to_string(),
|
||||
vec!["C123".to_string()],
|
||||
);
|
||||
assert_eq!(adapter.name(), "slack");
|
||||
assert_eq!(adapter.channel_type(), ChannelType::Slack);
|
||||
}
|
||||
}
|
||||
590
crates/openfang-channels/src/teams.rs
Normal file
590
crates/openfang-channels/src/teams.rs
Normal file
@@ -0,0 +1,590 @@
|
||||
//! Microsoft Teams channel adapter for the OpenFang channel bridge.
|
||||
//!
|
||||
//! Uses Bot Framework v3 REST API for sending messages and a lightweight axum
|
||||
//! HTTP webhook server for receiving inbound activities. OAuth2 client credentials
|
||||
//! flow is used to obtain and cache access tokens for outbound API calls.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// OAuth2 token endpoint for Bot Framework.
|
||||
const OAUTH_TOKEN_URL: &str =
|
||||
"https://login.microsoftonline.com/botframework.com/oauth2/v2.0/token";
|
||||
|
||||
/// Maximum Teams message length (characters).
|
||||
const MAX_MESSAGE_LEN: usize = 4096;
|
||||
|
||||
/// OAuth2 token refresh buffer — refresh 5 minutes before actual expiry.
|
||||
const TOKEN_REFRESH_BUFFER_SECS: u64 = 300;
|
||||
|
||||
/// Microsoft Teams Bot Framework v3 adapter.
|
||||
///
|
||||
/// Inbound messages arrive via an axum HTTP webhook on `POST /api/messages`.
|
||||
/// Outbound messages are sent via the Bot Framework v3 REST API using a
|
||||
/// cached OAuth2 bearer token (client credentials flow).
|
||||
pub struct TeamsAdapter {
|
||||
/// Bot Framework App ID (also called "Microsoft App ID").
|
||||
app_id: String,
|
||||
/// SECURITY: App password is zeroized on drop to prevent memory disclosure.
|
||||
app_password: Zeroizing<String>,
|
||||
/// Port on which the inbound webhook HTTP server listens.
|
||||
webhook_port: u16,
|
||||
/// Restrict inbound activities to specific Azure AD tenant IDs (empty = allow all).
|
||||
allowed_tenants: Vec<String>,
|
||||
/// HTTP client for outbound API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Cached OAuth2 bearer token and its expiry instant.
|
||||
cached_token: Arc<RwLock<Option<(String, Instant)>>>,
|
||||
}
|
||||
|
||||
impl TeamsAdapter {
|
||||
/// Create a new Teams adapter.
|
||||
///
|
||||
/// * `app_id` — Bot Framework application ID.
|
||||
/// * `app_password` — Bot Framework application password (client secret).
|
||||
/// * `webhook_port` — Local port for the inbound webhook HTTP server.
|
||||
/// * `allowed_tenants` — Azure AD tenant IDs to accept (empty = accept all).
|
||||
pub fn new(
|
||||
app_id: String,
|
||||
app_password: String,
|
||||
webhook_port: u16,
|
||||
allowed_tenants: Vec<String>,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
app_id,
|
||||
app_password: Zeroizing::new(app_password),
|
||||
webhook_port,
|
||||
allowed_tenants,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
cached_token: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtain a valid OAuth2 bearer token, refreshing if expired or missing.
|
||||
async fn get_token(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
// Check cache first
|
||||
{
|
||||
let guard = self.cached_token.read().await;
|
||||
if let Some((ref token, expiry)) = *guard {
|
||||
if Instant::now() < expiry {
|
||||
return Ok(token.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch a new token via client credentials flow
|
||||
let params = [
|
||||
("grant_type", "client_credentials"),
|
||||
("client_id", &self.app_id),
|
||||
("client_secret", self.app_password.as_str()),
|
||||
("scope", "https://api.botframework.com/.default"),
|
||||
];
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(OAUTH_TOKEN_URL)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Teams OAuth2 token error {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let access_token = body["access_token"]
|
||||
.as_str()
|
||||
.ok_or("Missing access_token in OAuth2 response")?
|
||||
.to_string();
|
||||
let expires_in = body["expires_in"].as_u64().unwrap_or(3600);
|
||||
|
||||
// Cache with a safety buffer
|
||||
let expiry = Instant::now()
|
||||
+ Duration::from_secs(expires_in.saturating_sub(TOKEN_REFRESH_BUFFER_SECS));
|
||||
*self.cached_token.write().await = Some((access_token.clone(), expiry));
|
||||
|
||||
Ok(access_token)
|
||||
}
|
||||
|
||||
/// Send a text reply to a Teams conversation via Bot Framework v3.
|
||||
///
|
||||
/// * `service_url` — The per-conversation service URL provided in inbound activities.
|
||||
/// * `conversation_id` — The Teams conversation ID.
|
||||
/// * `text` — The message text to send.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
service_url: &str,
|
||||
conversation_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let token = self.get_token().await?;
|
||||
let url = format!(
|
||||
"{}/v3/conversations/{}/activities",
|
||||
service_url.trim_end_matches('/'),
|
||||
conversation_id
|
||||
);
|
||||
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"type": "message",
|
||||
"text": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
warn!("Teams API error {status}: {resp_body}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check whether a tenant ID is allowed (empty list = allow all).
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_tenant(&self, tenant_id: &str) -> bool {
|
||||
self.allowed_tenants.is_empty() || self.allowed_tenants.iter().any(|t| t == tenant_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an inbound Bot Framework activity JSON into a `ChannelMessage`.
|
||||
///
|
||||
/// Returns `None` for activities that should be ignored (non-message types,
|
||||
/// activities from the bot itself, activities from disallowed tenants, etc.).
|
||||
fn parse_teams_activity(
|
||||
activity: &serde_json::Value,
|
||||
app_id: &str,
|
||||
allowed_tenants: &[String],
|
||||
) -> Option<ChannelMessage> {
|
||||
let activity_type = activity["type"].as_str().unwrap_or("");
|
||||
if activity_type != "message" {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Extract sender info
|
||||
let from = activity.get("from")?;
|
||||
let from_id = from["id"].as_str().unwrap_or("");
|
||||
let from_name = from["name"].as_str().unwrap_or("Unknown");
|
||||
|
||||
// Skip messages from the bot itself
|
||||
if from_id == app_id {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Tenant filtering
|
||||
if !allowed_tenants.is_empty() {
|
||||
let tenant_id = activity["channelData"]["tenant"]["id"]
|
||||
.as_str()
|
||||
.unwrap_or("");
|
||||
if !allowed_tenants.iter().any(|t| t == tenant_id) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
let text = activity["text"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let conversation_id = activity["conversation"]["id"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
let activity_id = activity["id"].as_str().unwrap_or("").to_string();
|
||||
let service_url = activity["serviceUrl"].as_str().unwrap_or("").to_string();
|
||||
|
||||
// Determine if this is a group conversation
|
||||
let is_group = activity["conversation"]["isGroup"]
|
||||
.as_bool()
|
||||
.unwrap_or(false);
|
||||
|
||||
// Parse commands (messages starting with /)
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd_name = &parts[0][1..];
|
||||
let args = if parts.len() > 1 {
|
||||
parts[1].split_whitespace().map(String::from).collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
// Store serviceUrl in metadata so outbound replies can use it
|
||||
if !service_url.is_empty() {
|
||||
metadata.insert(
|
||||
"serviceUrl".to_string(),
|
||||
serde_json::Value::String(service_url),
|
||||
);
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Teams,
|
||||
platform_message_id: activity_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: conversation_id,
|
||||
display_name: from_name.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: None,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for TeamsAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"teams"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Teams
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials by obtaining an initial token
|
||||
let _ = self.get_token().await?;
|
||||
info!("Teams adapter authenticated (app_id: {})", self.app_id);
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let app_id = self.app_id.clone();
|
||||
let allowed_tenants = self.allowed_tenants.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Build the axum webhook router
|
||||
let app_id_shared = Arc::new(app_id);
|
||||
let tenants_shared = Arc::new(allowed_tenants);
|
||||
let tx_shared = Arc::new(tx);
|
||||
|
||||
let app = axum::Router::new().route(
|
||||
"/api/messages",
|
||||
axum::routing::post({
|
||||
let app_id = Arc::clone(&app_id_shared);
|
||||
let tenants = Arc::clone(&tenants_shared);
|
||||
let tx = Arc::clone(&tx_shared);
|
||||
move |body: axum::extract::Json<serde_json::Value>| {
|
||||
let app_id = Arc::clone(&app_id);
|
||||
let tenants = Arc::clone(&tenants);
|
||||
let tx = Arc::clone(&tx);
|
||||
async move {
|
||||
if let Some(msg) = parse_teams_activity(&body, &app_id, &tenants) {
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
axum::http::StatusCode::OK
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
info!("Teams webhook server listening on {addr}");
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("Teams webhook bind failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let server = axum::serve(listener, app);
|
||||
|
||||
tokio::select! {
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
warn!("Teams webhook server error: {e}");
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Teams adapter shutting down");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// We need the serviceUrl from metadata; fall back to the default Bot Framework URL
|
||||
let default_service_url = "https://smba.trafficmanager.net/teams/".to_string();
|
||||
let conversation_id = &user.platform_id;
|
||||
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&default_service_url, conversation_id, &text)
|
||||
.await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(
|
||||
&default_service_url,
|
||||
conversation_id,
|
||||
"(Unsupported content type)",
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let token = self.get_token().await?;
|
||||
let default_service_url = "https://smba.trafficmanager.net/teams/";
|
||||
let url = format!(
|
||||
"{}/v3/conversations/{}/activities",
|
||||
default_service_url.trim_end_matches('/'),
|
||||
user.platform_id
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"type": "typing",
|
||||
});
|
||||
|
||||
let _ = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_teams_adapter_creation() {
|
||||
let adapter = TeamsAdapter::new(
|
||||
"app-id-123".to_string(),
|
||||
"app-password".to_string(),
|
||||
3978,
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.name(), "teams");
|
||||
assert_eq!(adapter.channel_type(), ChannelType::Teams);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_teams_allowed_tenants() {
|
||||
let adapter = TeamsAdapter::new(
|
||||
"app-id".to_string(),
|
||||
"password".to_string(),
|
||||
3978,
|
||||
vec!["tenant-abc".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_tenant("tenant-abc"));
|
||||
assert!(!adapter.is_allowed_tenant("tenant-xyz"));
|
||||
|
||||
let open = TeamsAdapter::new("app-id".to_string(), "password".to_string(), 3978, vec![]);
|
||||
assert!(open.is_allowed_tenant("any-tenant"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_teams_activity_basic() {
|
||||
let activity = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": "activity-1",
|
||||
"text": "Hello from Teams!",
|
||||
"from": {
|
||||
"id": "user-456",
|
||||
"name": "Alice"
|
||||
},
|
||||
"conversation": {
|
||||
"id": "conv-789",
|
||||
"isGroup": false
|
||||
},
|
||||
"serviceUrl": "https://smba.trafficmanager.net/teams/",
|
||||
"channelData": {
|
||||
"tenant": {
|
||||
"id": "tenant-abc"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_teams_activity(&activity, "app-id-123", &[]).unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Teams);
|
||||
assert_eq!(msg.sender.display_name, "Alice");
|
||||
assert_eq!(msg.sender.platform_id, "conv-789");
|
||||
assert!(!msg.is_group);
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from Teams!"));
|
||||
assert!(msg.metadata.contains_key("serviceUrl"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_teams_activity_skips_bot_self() {
|
||||
let activity = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": "activity-1",
|
||||
"text": "Bot reply",
|
||||
"from": {
|
||||
"id": "app-id-123",
|
||||
"name": "OpenFang Bot"
|
||||
},
|
||||
"conversation": {
|
||||
"id": "conv-789"
|
||||
},
|
||||
"serviceUrl": "https://smba.trafficmanager.net/teams/"
|
||||
});
|
||||
|
||||
let msg = parse_teams_activity(&activity, "app-id-123", &[]);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_teams_activity_tenant_filter() {
|
||||
let activity = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": "activity-1",
|
||||
"text": "Hello",
|
||||
"from": {
|
||||
"id": "user-1",
|
||||
"name": "Bob"
|
||||
},
|
||||
"conversation": {
|
||||
"id": "conv-1"
|
||||
},
|
||||
"serviceUrl": "https://smba.trafficmanager.net/teams/",
|
||||
"channelData": {
|
||||
"tenant": {
|
||||
"id": "tenant-xyz"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Not in allowed tenants
|
||||
let msg = parse_teams_activity(&activity, "app-id", &["tenant-abc".to_string()]);
|
||||
assert!(msg.is_none());
|
||||
|
||||
// In allowed tenants
|
||||
let msg = parse_teams_activity(&activity, "app-id", &["tenant-xyz".to_string()]);
|
||||
assert!(msg.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_teams_activity_command() {
|
||||
let activity = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": "activity-1",
|
||||
"text": "/agent hello-world",
|
||||
"from": {
|
||||
"id": "user-1",
|
||||
"name": "Alice"
|
||||
},
|
||||
"conversation": {
|
||||
"id": "conv-1"
|
||||
},
|
||||
"serviceUrl": "https://smba.trafficmanager.net/teams/"
|
||||
});
|
||||
|
||||
let msg = parse_teams_activity(&activity, "app-id", &[]).unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "agent");
|
||||
assert_eq!(args, &["hello-world"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_teams_activity_non_message() {
|
||||
let activity = serde_json::json!({
|
||||
"type": "conversationUpdate",
|
||||
"id": "activity-1",
|
||||
"from": { "id": "user-1", "name": "Alice" },
|
||||
"conversation": { "id": "conv-1" },
|
||||
"serviceUrl": "https://smba.trafficmanager.net/teams/"
|
||||
});
|
||||
|
||||
let msg = parse_teams_activity(&activity, "app-id", &[]);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_teams_activity_empty_text() {
|
||||
let activity = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": "activity-1",
|
||||
"text": "",
|
||||
"from": { "id": "user-1", "name": "Alice" },
|
||||
"conversation": { "id": "conv-1" },
|
||||
"serviceUrl": "https://smba.trafficmanager.net/teams/"
|
||||
});
|
||||
|
||||
let msg = parse_teams_activity(&activity, "app-id", &[]);
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_teams_activity_group() {
|
||||
let activity = serde_json::json!({
|
||||
"type": "message",
|
||||
"id": "activity-1",
|
||||
"text": "Group hello",
|
||||
"from": { "id": "user-1", "name": "Alice" },
|
||||
"conversation": {
|
||||
"id": "conv-1",
|
||||
"isGroup": true
|
||||
},
|
||||
"serviceUrl": "https://smba.trafficmanager.net/teams/"
|
||||
});
|
||||
|
||||
let msg = parse_teams_activity(&activity, "app-id", &[]).unwrap();
|
||||
assert!(msg.is_group);
|
||||
}
|
||||
}
|
||||
558
crates/openfang-channels/src/telegram.rs
Normal file
558
crates/openfang-channels/src/telegram.rs
Normal file
@@ -0,0 +1,558 @@
|
||||
//! Telegram Bot API adapter for the OpenFang channel bridge.
|
||||
//!
|
||||
//! Uses long-polling via `getUpdates` with exponential backoff on failures.
|
||||
//! No external Telegram crate — just `reqwest` for full control over error handling.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Maximum backoff duration on API failures.
|
||||
const MAX_BACKOFF: Duration = Duration::from_secs(60);
|
||||
/// Initial backoff duration on API failures.
|
||||
const INITIAL_BACKOFF: Duration = Duration::from_secs(1);
|
||||
/// Telegram long-polling timeout (seconds) — sent as the `timeout` parameter to getUpdates.
|
||||
const LONG_POLL_TIMEOUT: u64 = 30;
|
||||
|
||||
/// Telegram Bot API adapter using long-polling.
|
||||
pub struct TelegramAdapter {
|
||||
/// SECURITY: Bot token is zeroized on drop to prevent memory disclosure.
|
||||
token: Zeroizing<String>,
|
||||
client: reqwest::Client,
|
||||
allowed_users: Vec<i64>,
|
||||
poll_interval: Duration,
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl TelegramAdapter {
|
||||
/// Create a new Telegram adapter.
|
||||
///
|
||||
/// `token` is the raw bot token (read from env by the caller).
|
||||
/// `allowed_users` is the list of Telegram user IDs allowed to interact (empty = allow all).
|
||||
pub fn new(token: String, allowed_users: Vec<i64>, poll_interval: Duration) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
token: Zeroizing::new(token),
|
||||
client: reqwest::Client::new(),
|
||||
allowed_users,
|
||||
poll_interval,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate the bot token by calling `getMe`.
|
||||
pub async fn validate_token(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("https://api.telegram.org/bot{}/getMe", self.token.as_str());
|
||||
let resp: serde_json::Value = self.client.get(&url).send().await?.json().await?;
|
||||
|
||||
if resp["ok"].as_bool() != Some(true) {
|
||||
let desc = resp["description"].as_str().unwrap_or("unknown error");
|
||||
return Err(format!("Telegram getMe failed: {desc}").into());
|
||||
}
|
||||
|
||||
let bot_name = resp["result"]["username"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
Ok(bot_name)
|
||||
}
|
||||
|
||||
/// Call `sendMessage` on the Telegram API.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
chat_id: i64,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"https://api.telegram.org/bot{}/sendMessage",
|
||||
self.token.as_str()
|
||||
);
|
||||
|
||||
// Telegram has a 4096 character limit per message — split if needed
|
||||
let chunks = split_message(text, 4096);
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": chunk,
|
||||
});
|
||||
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
let status = resp.status();
|
||||
if !status.is_success() {
|
||||
let body_text = resp.text().await.unwrap_or_default();
|
||||
warn!("Telegram sendMessage failed ({status}): {body_text}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Call `sendChatAction` to show "typing..." indicator.
|
||||
async fn api_send_typing(&self, chat_id: i64) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"https://api.telegram.org/bot{}/sendChatAction",
|
||||
self.token.as_str()
|
||||
);
|
||||
let body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"action": "typing",
|
||||
});
|
||||
let _ = self.client.post(&url).json(&body).send().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for TelegramAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"telegram"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Telegram
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate token first (fail fast)
|
||||
let bot_name = self.validate_token().await?;
|
||||
info!("Telegram bot @{bot_name} connected");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
|
||||
let token = self.token.clone();
|
||||
let client = self.client.clone();
|
||||
let allowed_users = self.allowed_users.clone();
|
||||
let poll_interval = self.poll_interval;
|
||||
let mut shutdown = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut offset: Option<i64> = None;
|
||||
let mut backoff = INITIAL_BACKOFF;
|
||||
|
||||
loop {
|
||||
// Check shutdown
|
||||
if *shutdown.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Build getUpdates request
|
||||
let url = format!("https://api.telegram.org/bot{}/getUpdates", token.as_str());
|
||||
let mut params = serde_json::json!({
|
||||
"timeout": LONG_POLL_TIMEOUT,
|
||||
"allowed_updates": ["message", "edited_message"],
|
||||
});
|
||||
if let Some(off) = offset {
|
||||
params["offset"] = serde_json::json!(off);
|
||||
}
|
||||
|
||||
// Make the request with a timeout slightly longer than the long-poll timeout
|
||||
let request_timeout = Duration::from_secs(LONG_POLL_TIMEOUT + 10);
|
||||
let result = tokio::select! {
|
||||
res = async {
|
||||
client
|
||||
.get(&url)
|
||||
.json(¶ms)
|
||||
.timeout(request_timeout)
|
||||
.send()
|
||||
.await
|
||||
} => res,
|
||||
_ = shutdown.changed() => {
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
let resp = match result {
|
||||
Ok(resp) => resp,
|
||||
Err(e) => {
|
||||
warn!("Telegram getUpdates network error: {e}, retrying in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let status = resp.status();
|
||||
|
||||
// Handle rate limiting
|
||||
if status.as_u16() == 429 {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
let retry_after = body["parameters"]["retry_after"].as_u64().unwrap_or(5);
|
||||
warn!("Telegram rate limited, retry after {retry_after}s");
|
||||
tokio::time::sleep(Duration::from_secs(retry_after)).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle conflict (another bot instance polling)
|
||||
if status.as_u16() == 409 {
|
||||
error!("Telegram 409 Conflict — another bot instance is running. Stopping.");
|
||||
break;
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
let body_text = resp.text().await.unwrap_or_default();
|
||||
warn!("Telegram getUpdates failed ({status}): {body_text}, retrying in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse response
|
||||
let body: serde_json::Value = match resp.json().await {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
warn!("Telegram getUpdates parse error: {e}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(MAX_BACKOFF);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Reset backoff on success
|
||||
backoff = INITIAL_BACKOFF;
|
||||
|
||||
if body["ok"].as_bool() != Some(true) {
|
||||
warn!("Telegram getUpdates returned ok=false");
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
let updates = match body["result"].as_array() {
|
||||
Some(arr) => arr,
|
||||
None => {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for update in updates {
|
||||
// Track offset for dedup
|
||||
if let Some(update_id) = update["update_id"].as_i64() {
|
||||
offset = Some(update_id + 1);
|
||||
}
|
||||
|
||||
// Parse the message
|
||||
let msg = match parse_telegram_update(update, &allowed_users) {
|
||||
Some(m) => m,
|
||||
None => continue, // filtered out or unparseable
|
||||
};
|
||||
|
||||
debug!(
|
||||
"Telegram message from {}: {:?}",
|
||||
msg.sender.display_name, msg.content
|
||||
);
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
// Receiver dropped — bridge is shutting down
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Small delay between polls even on success to avoid tight loops
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
}
|
||||
|
||||
info!("Telegram polling loop stopped");
|
||||
});
|
||||
|
||||
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let chat_id: i64 = user
|
||||
.platform_id
|
||||
.parse()
|
||||
.map_err(|_| format!("Invalid Telegram chat_id: {}", user.platform_id))?;
|
||||
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(chat_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(chat_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let chat_id: i64 = user
|
||||
.platform_id
|
||||
.parse()
|
||||
.map_err(|_| format!("Invalid Telegram chat_id: {}", user.platform_id))?;
|
||||
self.api_send_typing(chat_id).await
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Telegram update JSON into a `ChannelMessage`, or `None` if filtered/unparseable.
|
||||
/// Handles both `message` and `edited_message` update types.
|
||||
fn parse_telegram_update(
|
||||
update: &serde_json::Value,
|
||||
allowed_users: &[i64],
|
||||
) -> Option<ChannelMessage> {
|
||||
let message = update
|
||||
.get("message")
|
||||
.or_else(|| update.get("edited_message"))?;
|
||||
let from = message.get("from")?;
|
||||
let user_id = from["id"].as_i64()?;
|
||||
|
||||
// Security: check allowed_users
|
||||
if !allowed_users.is_empty() && !allowed_users.contains(&user_id) {
|
||||
debug!("Telegram: ignoring message from unlisted user {user_id}");
|
||||
return None;
|
||||
}
|
||||
|
||||
let chat_id = message["chat"]["id"].as_i64()?;
|
||||
let first_name = from["first_name"].as_str().unwrap_or("Unknown");
|
||||
let last_name = from["last_name"].as_str().unwrap_or("");
|
||||
let display_name = if last_name.is_empty() {
|
||||
first_name.to_string()
|
||||
} else {
|
||||
format!("{first_name} {last_name}")
|
||||
};
|
||||
|
||||
let chat_type = message["chat"]["type"].as_str().unwrap_or("private");
|
||||
let is_group = chat_type == "group" || chat_type == "supergroup";
|
||||
|
||||
let text = message["text"].as_str()?;
|
||||
let message_id = message["message_id"].as_i64().unwrap_or(0);
|
||||
let timestamp = message["date"]
|
||||
.as_i64()
|
||||
.and_then(|ts| chrono::DateTime::from_timestamp(ts, 0))
|
||||
.unwrap_or_else(chrono::Utc::now);
|
||||
|
||||
// Parse bot commands (Telegram sends entities for /commands)
|
||||
let content = if let Some(entities) = message["entities"].as_array() {
|
||||
let is_bot_command = entities
|
||||
.iter()
|
||||
.any(|e| e["type"].as_str() == Some("bot_command") && e["offset"].as_i64() == Some(0));
|
||||
if is_bot_command {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd_name = parts[0].trim_start_matches('/');
|
||||
// Strip @botname from command (e.g. /agents@mybot -> agents)
|
||||
let cmd_name = cmd_name.split('@').next().unwrap_or(cmd_name);
|
||||
let args = if parts.len() > 1 {
|
||||
parts[1].split_whitespace().map(String::from).collect()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
// Use chat_id as the platform_id (so responses go to the right chat)
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Telegram,
|
||||
platform_message_id: message_id.to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: chat_id.to_string(),
|
||||
display_name,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp,
|
||||
is_group,
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Calculate exponential backoff capped at MAX_BACKOFF.
|
||||
pub fn calculate_backoff(current: Duration) -> Duration {
|
||||
(current * 2).min(MAX_BACKOFF)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_telegram_update() {
|
||||
let update = serde_json::json!({
|
||||
"update_id": 123456,
|
||||
"message": {
|
||||
"message_id": 42,
|
||||
"from": {
|
||||
"id": 111222333,
|
||||
"first_name": "Alice",
|
||||
"last_name": "Smith"
|
||||
},
|
||||
"chat": {
|
||||
"id": 111222333,
|
||||
"type": "private"
|
||||
},
|
||||
"date": 1700000000,
|
||||
"text": "Hello, agent!"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_telegram_update(&update, &[]).unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Telegram);
|
||||
assert_eq!(msg.sender.display_name, "Alice Smith");
|
||||
assert_eq!(msg.sender.platform_id, "111222333");
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello, agent!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_telegram_command() {
|
||||
let update = serde_json::json!({
|
||||
"update_id": 123457,
|
||||
"message": {
|
||||
"message_id": 43,
|
||||
"from": {
|
||||
"id": 111222333,
|
||||
"first_name": "Alice"
|
||||
},
|
||||
"chat": {
|
||||
"id": 111222333,
|
||||
"type": "private"
|
||||
},
|
||||
"date": 1700000001,
|
||||
"text": "/agent hello-world",
|
||||
"entities": [{
|
||||
"type": "bot_command",
|
||||
"offset": 0,
|
||||
"length": 6
|
||||
}]
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_telegram_update(&update, &[]).unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "agent");
|
||||
assert_eq!(args, &["hello-world"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allowed_users_filter() {
|
||||
let update = serde_json::json!({
|
||||
"update_id": 123458,
|
||||
"message": {
|
||||
"message_id": 44,
|
||||
"from": {
|
||||
"id": 999,
|
||||
"first_name": "Bob"
|
||||
},
|
||||
"chat": {
|
||||
"id": 999,
|
||||
"type": "private"
|
||||
},
|
||||
"date": 1700000002,
|
||||
"text": "blocked"
|
||||
}
|
||||
});
|
||||
|
||||
// Empty allowed_users = allow all
|
||||
let msg = parse_telegram_update(&update, &[]);
|
||||
assert!(msg.is_some());
|
||||
|
||||
// Non-matching allowed_users = filter out
|
||||
let msg = parse_telegram_update(&update, &[111, 222]);
|
||||
assert!(msg.is_none());
|
||||
|
||||
// Matching allowed_users = allow
|
||||
let msg = parse_telegram_update(&update, &[999]);
|
||||
assert!(msg.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_telegram_edited_message() {
|
||||
let update = serde_json::json!({
|
||||
"update_id": 123459,
|
||||
"edited_message": {
|
||||
"message_id": 42,
|
||||
"from": {
|
||||
"id": 111222333,
|
||||
"first_name": "Alice",
|
||||
"last_name": "Smith"
|
||||
},
|
||||
"chat": {
|
||||
"id": 111222333,
|
||||
"type": "private"
|
||||
},
|
||||
"date": 1700000000,
|
||||
"edit_date": 1700000060,
|
||||
"text": "Edited message!"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_telegram_update(&update, &[]).unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Telegram);
|
||||
assert_eq!(msg.sender.display_name, "Alice Smith");
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Edited message!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backoff_calculation() {
|
||||
let b1 = calculate_backoff(Duration::from_secs(1));
|
||||
assert_eq!(b1, Duration::from_secs(2));
|
||||
|
||||
let b2 = calculate_backoff(Duration::from_secs(2));
|
||||
assert_eq!(b2, Duration::from_secs(4));
|
||||
|
||||
let b3 = calculate_backoff(Duration::from_secs(32));
|
||||
assert_eq!(b3, Duration::from_secs(60)); // capped
|
||||
|
||||
let b4 = calculate_backoff(Duration::from_secs(60));
|
||||
assert_eq!(b4, Duration::from_secs(60)); // stays at cap
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_command_with_botname() {
|
||||
let update = serde_json::json!({
|
||||
"update_id": 100,
|
||||
"message": {
|
||||
"message_id": 1,
|
||||
"from": { "id": 123, "first_name": "X" },
|
||||
"chat": { "id": 123, "type": "private" },
|
||||
"date": 1700000000,
|
||||
"text": "/agents@myopenfangbot",
|
||||
"entities": [{ "type": "bot_command", "offset": 0, "length": 17 }]
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_telegram_update(&update, &[]).unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "agents");
|
||||
assert!(args.is_empty());
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
430
crates/openfang-channels/src/threema.rs
Normal file
430
crates/openfang-channels/src/threema.rs
Normal file
@@ -0,0 +1,430 @@
|
||||
//! Threema Gateway channel adapter.
|
||||
//!
|
||||
//! Uses the Threema Gateway HTTP API for sending messages and a local webhook
|
||||
//! HTTP server for receiving inbound messages. Authentication is performed via
|
||||
//! the Threema Gateway API secret. Inbound messages arrive as POST requests
|
||||
//! to the configured webhook port.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Threema Gateway API base URL for sending messages.
|
||||
const THREEMA_API_URL: &str = "https://msgapi.threema.ch";
|
||||
|
||||
/// Maximum message length for Threema messages.
|
||||
const MAX_MESSAGE_LEN: usize = 3500;
|
||||
|
||||
/// Threema Gateway channel adapter using webhook for receiving and REST API for sending.
|
||||
///
|
||||
/// Listens for inbound messages via a configurable HTTP webhook server and sends
|
||||
/// outbound messages via the Threema Gateway `send_simple` endpoint.
|
||||
pub struct ThreemaAdapter {
|
||||
/// Threema Gateway ID (8-character alphanumeric, starts with '*').
|
||||
threema_id: String,
|
||||
/// SECURITY: API secret is zeroized on drop.
|
||||
secret: Zeroizing<String>,
|
||||
/// Port for the inbound webhook HTTP listener.
|
||||
webhook_port: u16,
|
||||
/// HTTP client for outbound API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl ThreemaAdapter {
|
||||
/// Create a new Threema Gateway adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `threema_id` - Threema Gateway ID (e.g., "*MYGATEW").
|
||||
/// * `secret` - API secret for the Gateway ID.
|
||||
/// * `webhook_port` - Local port to bind the inbound webhook listener on.
|
||||
pub fn new(threema_id: String, secret: String, webhook_port: u16) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
threema_id,
|
||||
secret: Zeroizing::new(secret),
|
||||
webhook_port,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate credentials by checking the remaining credits.
|
||||
async fn validate(&self) -> Result<u64, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/credits?from={}&secret={}",
|
||||
THREEMA_API_URL,
|
||||
self.threema_id,
|
||||
self.secret.as_str()
|
||||
);
|
||||
let resp = self.client.get(&url).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Threema Gateway authentication failed".into());
|
||||
}
|
||||
|
||||
let credits: u64 = resp.text().await?.trim().parse().unwrap_or(0);
|
||||
Ok(credits)
|
||||
}
|
||||
|
||||
/// Send a simple text message to a Threema ID.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
to: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/send_simple", THREEMA_API_URL);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let params = [
|
||||
("from", self.threema_id.as_str()),
|
||||
("to", to),
|
||||
("secret", self.secret.as_str()),
|
||||
("text", chunk),
|
||||
];
|
||||
|
||||
let resp = self.client.post(&url).form(¶ms).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Threema API error {status}: {body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an inbound Threema webhook payload into a `ChannelMessage`.
|
||||
///
|
||||
/// The Threema Gateway delivers inbound messages as form-encoded POST requests
|
||||
/// with fields: `from`, `to`, `messageId`, `date`, `text`, `nonce`, `box`, `mac`.
|
||||
/// For the `send_simple` mode, the `text` field contains the plaintext message.
|
||||
fn parse_threema_webhook(
|
||||
payload: &HashMap<String, String>,
|
||||
own_id: &str,
|
||||
) -> Option<ChannelMessage> {
|
||||
let from = payload.get("from")?;
|
||||
let text = payload.get("text").or_else(|| payload.get("body"))?;
|
||||
let message_id = payload.get("messageId").cloned().unwrap_or_default();
|
||||
|
||||
// Skip messages from ourselves
|
||||
if from == own_id {
|
||||
return None;
|
||||
}
|
||||
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
if let Some(nonce) = payload.get("nonce") {
|
||||
metadata.insert(
|
||||
"nonce".to_string(),
|
||||
serde_json::Value::String(nonce.clone()),
|
||||
);
|
||||
}
|
||||
if let Some(mac) = payload.get("mac") {
|
||||
metadata.insert("mac".to_string(), serde_json::Value::String(mac.clone()));
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("threema".to_string()),
|
||||
platform_message_id: message_id,
|
||||
sender: ChannelUser {
|
||||
platform_id: from.clone(),
|
||||
display_name: from.clone(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: false, // Threema Gateway simple mode is 1:1
|
||||
thread_id: None,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for ThreemaAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"threema"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("threema".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let credits = self.validate().await?;
|
||||
info!(
|
||||
"Threema Gateway adapter authenticated (ID: {}, credits: {credits})",
|
||||
self.threema_id
|
||||
);
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let own_id = self.threema_id.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Bind a webhook HTTP listener for inbound messages
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("Threema: failed to bind webhook on port {port}: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Threema webhook listener bound on {addr}");
|
||||
|
||||
loop {
|
||||
let (stream, _peer) = tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Threema adapter shutting down");
|
||||
break;
|
||||
}
|
||||
result = listener.accept() => {
|
||||
match result {
|
||||
Ok(conn) => conn,
|
||||
Err(e) => {
|
||||
warn!("Threema: accept error: {e}");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let tx = tx.clone();
|
||||
let own_id = own_id.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
let mut reader = tokio::io::BufReader::new(stream);
|
||||
|
||||
// Read HTTP request line
|
||||
let mut request_line = String::new();
|
||||
if reader.read_line(&mut request_line).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only accept POST requests
|
||||
if !request_line.starts_with("POST") {
|
||||
let resp = b"HTTP/1.1 405 Method Not Allowed\r\nContent-Length: 0\r\n\r\n";
|
||||
let _ = reader.get_mut().write_all(resp).await;
|
||||
return;
|
||||
}
|
||||
|
||||
// Read headers
|
||||
let mut content_length: usize = 0;
|
||||
let mut content_type = String::new();
|
||||
loop {
|
||||
let mut header = String::new();
|
||||
if reader.read_line(&mut header).await.is_err() {
|
||||
return;
|
||||
}
|
||||
let trimmed = header.trim();
|
||||
if trimmed.is_empty() {
|
||||
break;
|
||||
}
|
||||
let lower = trimmed.to_lowercase();
|
||||
if let Some(val) = lower.strip_prefix("content-length:") {
|
||||
if let Ok(len) = val.trim().parse::<usize>() {
|
||||
content_length = len;
|
||||
}
|
||||
}
|
||||
if let Some(val) = lower.strip_prefix("content-type:") {
|
||||
content_type = val.trim().to_string();
|
||||
}
|
||||
}
|
||||
|
||||
// Read body (cap at 64KB)
|
||||
let read_len = content_length.min(65536);
|
||||
let mut body_buf = vec![0u8; read_len];
|
||||
if read_len > 0 && reader.read_exact(&mut body_buf[..read_len]).await.is_err() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Send 200 OK
|
||||
let resp = b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n";
|
||||
let _ = reader.get_mut().write_all(resp).await;
|
||||
|
||||
// Parse the body based on content type
|
||||
let body_str = String::from_utf8_lossy(&body_buf[..read_len]);
|
||||
let payload: HashMap<String, String> =
|
||||
if content_type.contains("application/json") {
|
||||
// JSON payload
|
||||
serde_json::from_str(&body_str).unwrap_or_default()
|
||||
} else {
|
||||
// Form-encoded payload
|
||||
url::form_urlencoded::parse(body_str.as_bytes())
|
||||
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||
.collect()
|
||||
};
|
||||
|
||||
if let Some(msg) = parse_threema_webhook(&payload, &own_id) {
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
info!("Threema webhook loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Threema Gateway does not support typing indicators
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_threema_adapter_creation() {
|
||||
let adapter = ThreemaAdapter::new("*MYGATEW".to_string(), "test-secret".to_string(), 8443);
|
||||
assert_eq!(adapter.name(), "threema");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("threema".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_threema_secret_zeroized() {
|
||||
let adapter =
|
||||
ThreemaAdapter::new("*MYID123".to_string(), "super-secret-key".to_string(), 8443);
|
||||
assert_eq!(adapter.secret.as_str(), "super-secret-key");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_threema_webhook_port() {
|
||||
let adapter = ThreemaAdapter::new("*TEST".to_string(), "secret".to_string(), 9090);
|
||||
assert_eq!(adapter.webhook_port, 9090);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_threema_webhook_basic() {
|
||||
let mut payload = HashMap::new();
|
||||
payload.insert("from".to_string(), "ABCDEFGH".to_string());
|
||||
payload.insert("text".to_string(), "Hello from Threema!".to_string());
|
||||
payload.insert("messageId".to_string(), "msg-001".to_string());
|
||||
|
||||
let msg = parse_threema_webhook(&payload, "*MYGATEW").unwrap();
|
||||
assert_eq!(msg.sender.platform_id, "ABCDEFGH");
|
||||
assert_eq!(msg.sender.display_name, "ABCDEFGH");
|
||||
assert!(!msg.is_group);
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from Threema!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_threema_webhook_command() {
|
||||
let mut payload = HashMap::new();
|
||||
payload.insert("from".to_string(), "SENDER01".to_string());
|
||||
payload.insert("text".to_string(), "/help me".to_string());
|
||||
|
||||
let msg = parse_threema_webhook(&payload, "*MYGATEW").unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "help");
|
||||
assert_eq!(args, &["me"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_threema_webhook_skip_self() {
|
||||
let mut payload = HashMap::new();
|
||||
payload.insert("from".to_string(), "*MYGATEW".to_string());
|
||||
payload.insert("text".to_string(), "Self message".to_string());
|
||||
|
||||
let msg = parse_threema_webhook(&payload, "*MYGATEW");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_threema_webhook_empty_text() {
|
||||
let mut payload = HashMap::new();
|
||||
payload.insert("from".to_string(), "SENDER01".to_string());
|
||||
payload.insert("text".to_string(), String::new());
|
||||
|
||||
let msg = parse_threema_webhook(&payload, "*MYGATEW");
|
||||
assert!(msg.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_threema_webhook_with_nonce_and_mac() {
|
||||
let mut payload = HashMap::new();
|
||||
payload.insert("from".to_string(), "SENDER01".to_string());
|
||||
payload.insert("text".to_string(), "Secure msg".to_string());
|
||||
payload.insert("nonce".to_string(), "abc123".to_string());
|
||||
payload.insert("mac".to_string(), "def456".to_string());
|
||||
|
||||
let msg = parse_threema_webhook(&payload, "*MYGATEW").unwrap();
|
||||
assert!(msg.metadata.contains_key("nonce"));
|
||||
assert!(msg.metadata.contains_key("mac"));
|
||||
}
|
||||
}
|
||||
603
crates/openfang-channels/src/twist.rs
Normal file
603
crates/openfang-channels/src/twist.rs
Normal file
@@ -0,0 +1,603 @@
|
||||
//! Twist API v3 channel adapter.
|
||||
//!
|
||||
//! Uses the Twist REST API v3 for sending and receiving messages. Polls the
|
||||
//! comments endpoint for new messages and posts replies via the comments/add
|
||||
//! endpoint. Authentication is performed via OAuth2 Bearer token.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Twist API v3 base URL.
|
||||
const TWIST_API_BASE: &str = "https://api.twist.com/api/v3";
|
||||
|
||||
/// Maximum message length for Twist comments.
|
||||
const MAX_MESSAGE_LEN: usize = 10000;
|
||||
|
||||
/// Polling interval in seconds for new comments.
|
||||
const POLL_INTERVAL_SECS: u64 = 5;
|
||||
|
||||
/// Twist API v3 channel adapter using REST polling.
|
||||
///
|
||||
/// Polls the Twist comments endpoint for new messages in configured channels
|
||||
/// (threads) and sends replies via the comments/add endpoint. Supports
|
||||
/// workspace-level and channel-level filtering.
|
||||
pub struct TwistAdapter {
|
||||
/// SECURITY: OAuth2 token is zeroized on drop.
|
||||
token: Zeroizing<String>,
|
||||
/// Twist workspace ID.
|
||||
workspace_id: String,
|
||||
/// Channel IDs to poll (empty = all channels in workspace).
|
||||
allowed_channels: Vec<String>,
|
||||
/// HTTP client for API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Last seen comment ID per channel for incremental polling.
|
||||
last_comment_ids: Arc<RwLock<HashMap<String, i64>>>,
|
||||
}
|
||||
|
||||
impl TwistAdapter {
|
||||
/// Create a new Twist adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token` - OAuth2 Bearer token for API authentication.
|
||||
/// * `workspace_id` - Twist workspace ID to operate in.
|
||||
/// * `allowed_channels` - Channel IDs to poll (empty = discover all).
|
||||
pub fn new(token: String, workspace_id: String, allowed_channels: Vec<String>) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
token: Zeroizing::new(token),
|
||||
workspace_id,
|
||||
allowed_channels,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
last_comment_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate credentials by fetching the authenticated user's info.
|
||||
async fn validate(&self) -> Result<(String, String), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/users/get_session_user", TWIST_API_BASE);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Twist authentication failed".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let user_id = body["id"]
|
||||
.as_i64()
|
||||
.map(|id| id.to_string())
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
let name = body["name"].as_str().unwrap_or("unknown").to_string();
|
||||
|
||||
Ok((user_id, name))
|
||||
}
|
||||
|
||||
/// Fetch channels (threads) in the workspace.
|
||||
#[allow(dead_code)]
|
||||
async fn fetch_channels(&self) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/channels/get?workspace_id={}",
|
||||
TWIST_API_BASE, self.workspace_id
|
||||
);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Twist: failed to fetch channels".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let channels = match body.as_array() {
|
||||
Some(arr) => arr.clone(),
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
Ok(channels)
|
||||
}
|
||||
|
||||
/// Fetch threads in a channel.
|
||||
#[allow(dead_code)]
|
||||
async fn fetch_threads(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/threads/get?channel_id={}", TWIST_API_BASE, channel_id);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("Twist: failed to fetch threads for channel {channel_id}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let threads = match body.as_array() {
|
||||
Some(arr) => arr.clone(),
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
Ok(threads)
|
||||
}
|
||||
|
||||
/// Fetch comments (messages) in a thread.
|
||||
#[allow(dead_code)]
|
||||
async fn fetch_comments(
|
||||
&self,
|
||||
thread_id: &str,
|
||||
) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"{}/comments/get?thread_id={}&limit=50",
|
||||
TWIST_API_BASE, thread_id
|
||||
);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(format!("Twist: failed to fetch comments for thread {thread_id}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let comments = match body.as_array() {
|
||||
Some(arr) => arr.clone(),
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
Ok(comments)
|
||||
}
|
||||
|
||||
/// Send a comment (message) to a Twist thread.
|
||||
async fn api_send_comment(
|
||||
&self,
|
||||
thread_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/comments/add", TWIST_API_BASE);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"thread_id": thread_id.parse::<i64>().unwrap_or(0),
|
||||
"content": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Twist API error {status}: {resp_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a new thread in a channel and post the initial message.
|
||||
#[allow(dead_code)]
|
||||
async fn api_create_thread(
|
||||
&self,
|
||||
channel_id: &str,
|
||||
title: &str,
|
||||
content: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/threads/add", TWIST_API_BASE);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"channel_id": channel_id.parse::<i64>().unwrap_or(0),
|
||||
"title": title,
|
||||
"content": content,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Twist thread create error {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
let result: serde_json::Value = resp.json().await?;
|
||||
let thread_id = result["id"]
|
||||
.as_i64()
|
||||
.map(|id| id.to_string())
|
||||
.unwrap_or_default();
|
||||
Ok(thread_id)
|
||||
}
|
||||
|
||||
/// Check if a channel ID is in the allowed list.
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_channel(&self, channel_id: &str) -> bool {
|
||||
self.allowed_channels.is_empty() || self.allowed_channels.iter().any(|c| c == channel_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for TwistAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"twist"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("twist".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let (user_id, user_name) = self.validate().await?;
|
||||
info!("Twist adapter authenticated as {user_name} (id: {user_id})");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let token = self.token.clone();
|
||||
let workspace_id = self.workspace_id.clone();
|
||||
let own_user_id = user_id;
|
||||
let allowed_channels = self.allowed_channels.clone();
|
||||
let client = self.client.clone();
|
||||
let last_comment_ids = Arc::clone(&self.last_comment_ids);
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Discover channels if not configured
|
||||
let channels_to_poll = if allowed_channels.is_empty() {
|
||||
let url = format!(
|
||||
"{}/channels/get?workspace_id={}",
|
||||
TWIST_API_BASE, workspace_id
|
||||
);
|
||||
match client.get(&url).bearer_auth(token.as_str()).send().await {
|
||||
Ok(resp) => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body.as_array()
|
||||
.map(|arr| {
|
||||
arr.iter()
|
||||
.filter_map(|c| c["id"].as_i64().map(|id| id.to_string()))
|
||||
.collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Twist: failed to list channels: {e}");
|
||||
return;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
allowed_channels
|
||||
};
|
||||
|
||||
if channels_to_poll.is_empty() {
|
||||
warn!("Twist: no channels to poll");
|
||||
return;
|
||||
}
|
||||
|
||||
info!(
|
||||
"Twist: polling {} channel(s) in workspace {workspace_id}",
|
||||
channels_to_poll.len()
|
||||
);
|
||||
|
||||
let poll_interval = Duration::from_secs(POLL_INTERVAL_SECS);
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Twist adapter shutting down");
|
||||
break;
|
||||
}
|
||||
_ = tokio::time::sleep(poll_interval) => {}
|
||||
}
|
||||
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
for channel_id in &channels_to_poll {
|
||||
// Get threads in channel
|
||||
let threads_url =
|
||||
format!("{}/threads/get?channel_id={}", TWIST_API_BASE, channel_id);
|
||||
|
||||
let threads = match client
|
||||
.get(&threads_url)
|
||||
.bearer_auth(token.as_str())
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body.as_array().cloned().unwrap_or_default()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Twist: thread fetch error for channel {channel_id}: {e}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
for thread in &threads {
|
||||
let thread_id = thread["id"]
|
||||
.as_i64()
|
||||
.map(|id| id.to_string())
|
||||
.unwrap_or_default();
|
||||
if thread_id.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let thread_title =
|
||||
thread["title"].as_str().unwrap_or("Untitled").to_string();
|
||||
|
||||
let comments_url = format!(
|
||||
"{}/comments/get?thread_id={}&limit=20",
|
||||
TWIST_API_BASE, thread_id
|
||||
);
|
||||
|
||||
let comments = match client
|
||||
.get(&comments_url)
|
||||
.bearer_auth(token.as_str())
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
body.as_array().cloned().unwrap_or_default()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Twist: comment fetch error for thread {thread_id}: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let comment_key = format!("{}:{}", channel_id, thread_id);
|
||||
let last_id = {
|
||||
let ids = last_comment_ids.read().await;
|
||||
ids.get(&comment_key).copied().unwrap_or(0)
|
||||
};
|
||||
|
||||
let mut newest_id = last_id;
|
||||
|
||||
for comment in &comments {
|
||||
let comment_id = comment["id"].as_i64().unwrap_or(0);
|
||||
|
||||
// Skip already-seen comments
|
||||
if comment_id <= last_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let creator = comment["creator"]
|
||||
.as_i64()
|
||||
.map(|id| id.to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
// Skip own comments
|
||||
if creator == own_user_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = comment["content"].as_str().unwrap_or("");
|
||||
if content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if comment_id > newest_id {
|
||||
newest_id = comment_id;
|
||||
}
|
||||
|
||||
let creator_name =
|
||||
comment["creator_name"].as_str().unwrap_or("unknown");
|
||||
|
||||
let msg_content = if content.starts_with('/') {
|
||||
let parts: Vec<&str> = content.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(content.to_string())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("twist".to_string()),
|
||||
platform_message_id: comment_id.to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: thread_id.clone(),
|
||||
display_name: creator_name.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true,
|
||||
thread_id: Some(thread_title.clone()),
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"channel_id".to_string(),
|
||||
serde_json::Value::String(channel_id.clone()),
|
||||
);
|
||||
m.insert(
|
||||
"thread_id".to_string(),
|
||||
serde_json::Value::String(thread_id.clone()),
|
||||
);
|
||||
m.insert(
|
||||
"creator_id".to_string(),
|
||||
serde_json::Value::String(creator),
|
||||
);
|
||||
m.insert(
|
||||
"workspace_id".to_string(),
|
||||
serde_json::Value::String(workspace_id.clone()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Update last seen comment ID
|
||||
if newest_id > last_id {
|
||||
last_comment_ids
|
||||
.write()
|
||||
.await
|
||||
.insert(comment_key, newest_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Twist polling loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(text) => text,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
// platform_id is the thread_id
|
||||
self.api_send_comment(&user.platform_id, &text).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_in_thread(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
thread_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(text) => text,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
self.api_send_comment(thread_id, &text).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Twist does not expose a typing indicator API
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_twist_adapter_creation() {
|
||||
let adapter = TwistAdapter::new(
|
||||
"test-token".to_string(),
|
||||
"12345".to_string(),
|
||||
vec!["ch1".to_string()],
|
||||
);
|
||||
assert_eq!(adapter.name(), "twist");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("twist".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_twist_token_zeroized() {
|
||||
let adapter =
|
||||
TwistAdapter::new("secret-twist-token".to_string(), "ws1".to_string(), vec![]);
|
||||
assert_eq!(adapter.token.as_str(), "secret-twist-token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_twist_workspace_id() {
|
||||
let adapter = TwistAdapter::new("tok".to_string(), "workspace-99".to_string(), vec![]);
|
||||
assert_eq!(adapter.workspace_id, "workspace-99");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_twist_allowed_channels() {
|
||||
let adapter = TwistAdapter::new(
|
||||
"tok".to_string(),
|
||||
"ws1".to_string(),
|
||||
vec!["ch-1".to_string(), "ch-2".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_channel("ch-1"));
|
||||
assert!(adapter.is_allowed_channel("ch-2"));
|
||||
assert!(!adapter.is_allowed_channel("ch-3"));
|
||||
|
||||
let open = TwistAdapter::new("tok".to_string(), "ws1".to_string(), vec![]);
|
||||
assert!(open.is_allowed_channel("any-channel"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_twist_constants() {
|
||||
assert_eq!(MAX_MESSAGE_LEN, 10000);
|
||||
assert_eq!(POLL_INTERVAL_SECS, 5);
|
||||
assert!(TWIST_API_BASE.starts_with("https://"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_twist_poll_interval() {
|
||||
assert_eq!(POLL_INTERVAL_SECS, 5);
|
||||
}
|
||||
}
|
||||
385
crates/openfang-channels/src/twitch.rs
Normal file
385
crates/openfang-channels/src/twitch.rs
Normal file
@@ -0,0 +1,385 @@
|
||||
//! Twitch IRC channel adapter.
|
||||
//!
|
||||
//! Connects to Twitch's IRC gateway (`irc.chat.twitch.tv`) over plain TCP and
|
||||
//! implements the IRC protocol for sending and receiving chat messages. Handles
|
||||
//! PING/PONG keepalive, channel joins, and PRIVMSG parsing.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const TWITCH_IRC_HOST: &str = "irc.chat.twitch.tv";
|
||||
const TWITCH_IRC_PORT: u16 = 6667;
|
||||
const MAX_MESSAGE_LEN: usize = 500;
|
||||
|
||||
/// Twitch IRC channel adapter.
|
||||
///
|
||||
/// Connects to Twitch chat via the IRC protocol and bridges messages to the
|
||||
/// OpenFang channel system. Supports multiple channels simultaneously.
|
||||
pub struct TwitchAdapter {
|
||||
/// SECURITY: OAuth token is zeroized on drop.
|
||||
oauth_token: Zeroizing<String>,
|
||||
/// Twitch channels to join (without the '#' prefix).
|
||||
channels: Vec<String>,
|
||||
/// Bot's IRC nickname.
|
||||
nick: String,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl TwitchAdapter {
|
||||
/// Create a new Twitch adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `oauth_token` - Twitch OAuth token (without the "oauth:" prefix; it will be added).
|
||||
/// * `channels` - Channel names to join (without '#' prefix).
|
||||
/// * `nick` - Bot's IRC nickname (must match the token owner's Twitch username).
|
||||
pub fn new(oauth_token: String, channels: Vec<String>, nick: String) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
oauth_token: Zeroizing::new(oauth_token),
|
||||
channels,
|
||||
nick,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Format the OAuth token for the IRC PASS command.
|
||||
fn pass_string(&self) -> String {
|
||||
let token = self.oauth_token.as_str();
|
||||
if token.starts_with("oauth:") {
|
||||
format!("PASS {token}\r\n")
|
||||
} else {
|
||||
format!("PASS oauth:{token}\r\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse an IRC PRIVMSG line into its components.
|
||||
///
|
||||
/// Expected format: `:nick!user@host PRIVMSG #channel :message text`
|
||||
/// Returns `(nick, channel, message)` on success.
|
||||
fn parse_privmsg(line: &str) -> Option<(String, String, String)> {
|
||||
// Must start with ':'
|
||||
if !line.starts_with(':') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let without_prefix = &line[1..];
|
||||
let parts: Vec<&str> = without_prefix.splitn(2, ' ').collect();
|
||||
if parts.len() < 2 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let nick = parts[0].split('!').next()?.to_string();
|
||||
let rest = parts[1];
|
||||
|
||||
// Expect "PRIVMSG #channel :message"
|
||||
if !rest.starts_with("PRIVMSG ") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let after_cmd = &rest[8..]; // skip "PRIVMSG "
|
||||
let channel_end = after_cmd.find(' ')?;
|
||||
let channel = after_cmd[..channel_end].to_string();
|
||||
let msg_start = after_cmd[channel_end..].find(':')?;
|
||||
let message = after_cmd[channel_end + msg_start + 1..].to_string();
|
||||
|
||||
Some((nick, channel, message))
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for TwitchAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"twitch"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("twitch".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
info!("Twitch adapter connecting to {TWITCH_IRC_HOST}:{TWITCH_IRC_PORT}");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let pass = self.pass_string();
|
||||
let nick_cmd = format!("NICK {}\r\n", self.nick);
|
||||
let join_cmds: Vec<String> = self
|
||||
.channels
|
||||
.iter()
|
||||
.map(|ch| {
|
||||
let ch = ch.trim_start_matches('#');
|
||||
format!("JOIN #{ch}\r\n")
|
||||
})
|
||||
.collect();
|
||||
let bot_nick = self.nick.to_lowercase();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Connect to Twitch IRC
|
||||
let stream = match TcpStream::connect((TWITCH_IRC_HOST, TWITCH_IRC_PORT)).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!("Twitch: connection failed: {e}, retrying in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let (read_half, mut write_half) = stream.into_split();
|
||||
let mut reader = BufReader::new(read_half);
|
||||
|
||||
// Authenticate
|
||||
if write_half.write_all(pass.as_bytes()).await.is_err() {
|
||||
warn!("Twitch: failed to send PASS");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
if write_half.write_all(nick_cmd.as_bytes()).await.is_err() {
|
||||
warn!("Twitch: failed to send NICK");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Join channels
|
||||
for join in &join_cmds {
|
||||
if write_half.write_all(join.as_bytes()).await.is_err() {
|
||||
warn!("Twitch: failed to send JOIN");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
info!("Twitch IRC connected and joined channels");
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
// Read loop
|
||||
let should_reconnect = loop {
|
||||
let mut line = String::new();
|
||||
let read_result = tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Twitch adapter shutting down");
|
||||
let _ = write_half.write_all(b"QUIT :Shutting down\r\n").await;
|
||||
return;
|
||||
}
|
||||
result = reader.read_line(&mut line) => result,
|
||||
};
|
||||
|
||||
match read_result {
|
||||
Ok(0) => {
|
||||
info!("Twitch IRC connection closed");
|
||||
break true;
|
||||
}
|
||||
Ok(_) => {}
|
||||
Err(e) => {
|
||||
warn!("Twitch IRC read error: {e}");
|
||||
break true;
|
||||
}
|
||||
}
|
||||
|
||||
let line = line.trim_end_matches('\n').trim_end_matches('\r');
|
||||
|
||||
// Handle PING
|
||||
if line.starts_with("PING") {
|
||||
let pong = line.replacen("PING", "PONG", 1);
|
||||
let _ = write_half.write_all(format!("{pong}\r\n").as_bytes()).await;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse PRIVMSG
|
||||
if let Some((sender_nick, channel, message)) = parse_privmsg(line) {
|
||||
// Skip own messages
|
||||
if sender_nick.to_lowercase() == bot_nick {
|
||||
continue;
|
||||
}
|
||||
|
||||
if message.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let msg_content = if message.starts_with('/') || message.starts_with('!') {
|
||||
let trimmed = message.trim_start_matches('/').trim_start_matches('!');
|
||||
let parts: Vec<&str> = trimmed.splitn(2, ' ').collect();
|
||||
let cmd = parts[0];
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(message.clone())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("twitch".to_string()),
|
||||
platform_message_id: uuid::Uuid::new_v4().to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: channel.clone(),
|
||||
display_name: sender_nick,
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: true, // Twitch channels are always group
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !should_reconnect || *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
warn!("Twitch: reconnecting in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
}
|
||||
|
||||
info!("Twitch IRC loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let channel = &user.platform_id;
|
||||
let text = match content {
|
||||
ChannelContent::Text(text) => text,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
// Connect briefly to send the message
|
||||
// In production, a persistent write connection would be maintained.
|
||||
let stream = TcpStream::connect((TWITCH_IRC_HOST, TWITCH_IRC_PORT)).await?;
|
||||
let (_reader, mut writer) = stream.into_split();
|
||||
|
||||
writer.write_all(self.pass_string().as_bytes()).await?;
|
||||
writer
|
||||
.write_all(format!("NICK {}\r\n", self.nick).as_bytes())
|
||||
.await?;
|
||||
|
||||
// Wait briefly for auth to complete
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
let chunks = split_message(&text, MAX_MESSAGE_LEN);
|
||||
for chunk in chunks {
|
||||
let msg = format!("PRIVMSG {channel} :{chunk}\r\n");
|
||||
writer.write_all(msg.as_bytes()).await?;
|
||||
}
|
||||
|
||||
writer.write_all(b"QUIT\r\n").await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_twitch_adapter_creation() {
|
||||
let adapter = TwitchAdapter::new(
|
||||
"test-oauth-token".to_string(),
|
||||
vec!["testchannel".to_string()],
|
||||
"openfang_bot".to_string(),
|
||||
);
|
||||
assert_eq!(adapter.name(), "twitch");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("twitch".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_twitch_pass_string_with_prefix() {
|
||||
let adapter = TwitchAdapter::new("oauth:abc123".to_string(), vec![], "bot".to_string());
|
||||
assert_eq!(adapter.pass_string(), "PASS oauth:abc123\r\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_twitch_pass_string_without_prefix() {
|
||||
let adapter = TwitchAdapter::new("abc123".to_string(), vec![], "bot".to_string());
|
||||
assert_eq!(adapter.pass_string(), "PASS oauth:abc123\r\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_valid() {
|
||||
let line = ":nick123!user@host PRIVMSG #channel :Hello world!";
|
||||
let (nick, channel, message) = parse_privmsg(line).unwrap();
|
||||
assert_eq!(nick, "nick123");
|
||||
assert_eq!(channel, "#channel");
|
||||
assert_eq!(message, "Hello world!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_no_message() {
|
||||
// Missing colon before message
|
||||
let line = ":nick!user@host PRIVMSG #channel";
|
||||
assert!(parse_privmsg(line).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_not_privmsg() {
|
||||
let line = ":server 001 bot :Welcome";
|
||||
assert!(parse_privmsg(line).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_command() {
|
||||
let line = ":user!u@h PRIVMSG #ch :!help me";
|
||||
let (nick, channel, message) = parse_privmsg(line).unwrap();
|
||||
assert_eq!(nick, "user");
|
||||
assert_eq!(channel, "#ch");
|
||||
assert_eq!(message, "!help me");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_privmsg_empty_prefix() {
|
||||
let line = "PING :tmi.twitch.tv";
|
||||
assert!(parse_privmsg(line).is_none());
|
||||
}
|
||||
}
|
||||
461
crates/openfang-channels/src/types.rs
Normal file
461
crates/openfang-channels/src/types.rs
Normal file
@@ -0,0 +1,461 @@
|
||||
//! Core channel bridge types.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use openfang_types::agent::AgentId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
|
||||
/// The type of messaging channel.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum ChannelType {
|
||||
Telegram,
|
||||
WhatsApp,
|
||||
Slack,
|
||||
Discord,
|
||||
Signal,
|
||||
Matrix,
|
||||
Email,
|
||||
Teams,
|
||||
Mattermost,
|
||||
WebChat,
|
||||
CLI,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
/// A user on a messaging platform.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChannelUser {
|
||||
/// Platform-specific user ID.
|
||||
pub platform_id: String,
|
||||
/// Human-readable display name.
|
||||
pub display_name: String,
|
||||
/// Optional mapping to an OpenFang user identity.
|
||||
pub openfang_user: Option<String>,
|
||||
}
|
||||
|
||||
/// Content types that can be received from a channel.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ChannelContent {
|
||||
Text(String),
|
||||
Image {
|
||||
url: String,
|
||||
caption: Option<String>,
|
||||
},
|
||||
File {
|
||||
url: String,
|
||||
filename: String,
|
||||
},
|
||||
Voice {
|
||||
url: String,
|
||||
duration_seconds: u32,
|
||||
},
|
||||
Location {
|
||||
lat: f64,
|
||||
lon: f64,
|
||||
},
|
||||
Command {
|
||||
name: String,
|
||||
args: Vec<String>,
|
||||
},
|
||||
}
|
||||
|
||||
/// A unified message from any channel.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChannelMessage {
|
||||
/// Which channel this came from.
|
||||
pub channel: ChannelType,
|
||||
/// Platform-specific message identifier.
|
||||
pub platform_message_id: String,
|
||||
/// Who sent this message.
|
||||
pub sender: ChannelUser,
|
||||
/// The message content.
|
||||
pub content: ChannelContent,
|
||||
/// Optional target agent (if routed directly).
|
||||
pub target_agent: Option<AgentId>,
|
||||
/// When the message was sent.
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// Whether this message is from a group chat (vs DM).
|
||||
#[serde(default)]
|
||||
pub is_group: bool,
|
||||
/// Thread ID for threaded conversations (platform-specific).
|
||||
#[serde(default)]
|
||||
pub thread_id: Option<String>,
|
||||
/// Arbitrary platform metadata.
|
||||
pub metadata: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Agent lifecycle phase for UX indicators.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AgentPhase {
|
||||
/// Message is queued, waiting for agent.
|
||||
Queued,
|
||||
/// Agent is calling the LLM.
|
||||
Thinking,
|
||||
/// Agent is executing a tool.
|
||||
ToolUse {
|
||||
/// Tool being executed (max 64 chars, sanitized).
|
||||
tool_name: String,
|
||||
},
|
||||
/// Agent is streaming tokens.
|
||||
Streaming,
|
||||
/// Agent finished successfully.
|
||||
Done,
|
||||
/// Agent encountered an error.
|
||||
Error,
|
||||
}
|
||||
|
||||
impl AgentPhase {
|
||||
/// Sanitize a tool name for display (truncate to 64 chars, strip control chars).
|
||||
pub fn tool_use(name: &str) -> Self {
|
||||
let sanitized: String = name.chars().filter(|c| !c.is_control()).take(64).collect();
|
||||
Self::ToolUse {
|
||||
tool_name: sanitized,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Reaction to show in a channel (emoji-based).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LifecycleReaction {
|
||||
/// The agent phase this reaction represents.
|
||||
pub phase: AgentPhase,
|
||||
/// Channel-appropriate emoji.
|
||||
pub emoji: String,
|
||||
/// Whether to remove the previous phase reaction.
|
||||
pub remove_previous: bool,
|
||||
}
|
||||
|
||||
/// Hardcoded emoji allowlist for lifecycle reactions.
|
||||
pub const ALLOWED_REACTION_EMOJI: &[&str] = &[
|
||||
"\u{1F914}", // 🤔 thinking
|
||||
"\u{2699}\u{FE0F}", // ⚙️ tool_use
|
||||
"\u{270D}\u{FE0F}", // ✍️ streaming
|
||||
"\u{2705}", // ✅ done
|
||||
"\u{274C}", // ❌ error
|
||||
"\u{23F3}", // ⏳ queued
|
||||
"\u{1F504}", // 🔄 processing
|
||||
"\u{1F440}", // 👀 looking
|
||||
];
|
||||
|
||||
/// Get the default emoji for a given agent phase.
|
||||
pub fn default_phase_emoji(phase: &AgentPhase) -> &'static str {
|
||||
match phase {
|
||||
AgentPhase::Queued => "\u{23F3}", // ⏳
|
||||
AgentPhase::Thinking => "\u{1F914}", // 🤔
|
||||
AgentPhase::ToolUse { .. } => "\u{2699}\u{FE0F}", // ⚙️
|
||||
AgentPhase::Streaming => "\u{270D}\u{FE0F}", // ✍️
|
||||
AgentPhase::Done => "\u{2705}", // ✅
|
||||
AgentPhase::Error => "\u{274C}", // ❌
|
||||
}
|
||||
}
|
||||
|
||||
/// Delivery status for outbound messages.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum DeliveryStatus {
|
||||
/// Message was sent to the channel API.
|
||||
Sent,
|
||||
/// Message was confirmed delivered to recipient.
|
||||
Delivered,
|
||||
/// Message delivery failed.
|
||||
Failed,
|
||||
/// Best-effort delivery (no confirmation available).
|
||||
BestEffort,
|
||||
}
|
||||
|
||||
/// Receipt tracking outbound message delivery.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeliveryReceipt {
|
||||
/// Platform message ID (if available).
|
||||
pub message_id: String,
|
||||
/// Channel type this was sent through.
|
||||
pub channel: String,
|
||||
/// Sanitized recipient identifier (no PII).
|
||||
pub recipient: String,
|
||||
/// Delivery status.
|
||||
pub status: DeliveryStatus,
|
||||
/// When the delivery attempt occurred.
|
||||
pub timestamp: DateTime<Utc>,
|
||||
/// Error message (if failed — sanitized, no credentials).
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Health status for a channel adapter.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ChannelStatus {
|
||||
/// Whether the adapter is currently connected/running.
|
||||
pub connected: bool,
|
||||
/// When the adapter was started (ISO 8601).
|
||||
pub started_at: Option<DateTime<Utc>>,
|
||||
/// When the last message was received.
|
||||
pub last_message_at: Option<DateTime<Utc>>,
|
||||
/// Total messages received since start.
|
||||
pub messages_received: u64,
|
||||
/// Total messages sent since start.
|
||||
pub messages_sent: u64,
|
||||
/// Last error message (if any).
|
||||
pub last_error: Option<String>,
|
||||
}
|
||||
|
||||
// Re-export policy/format types from openfang-types for convenience.
|
||||
pub use openfang_types::config::{DmPolicy, GroupPolicy, OutputFormat};
|
||||
|
||||
/// Trait that every channel adapter must implement.
|
||||
///
|
||||
/// A channel adapter bridges a messaging platform to the OpenFang kernel by converting
|
||||
/// platform-specific messages into `ChannelMessage` events and sending responses back.
|
||||
#[async_trait]
|
||||
pub trait ChannelAdapter: Send + Sync {
|
||||
/// Human-readable name of this adapter.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// The channel type this adapter handles.
|
||||
fn channel_type(&self) -> ChannelType;
|
||||
|
||||
/// Start receiving messages. Returns a stream of incoming messages.
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>;
|
||||
|
||||
/// Send a response back to a user on this channel.
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>>;
|
||||
|
||||
/// Send a typing indicator (optional — default no-op).
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a lifecycle reaction to a message (optional — default no-op).
|
||||
async fn send_reaction(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
_message_id: &str,
|
||||
_reaction: &LifecycleReaction,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Stop the adapter and clean up resources.
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>>;
|
||||
|
||||
/// Get the current health status of this adapter (optional — default returns disconnected).
|
||||
fn status(&self) -> ChannelStatus {
|
||||
ChannelStatus::default()
|
||||
}
|
||||
|
||||
/// Send a response as a thread reply (optional — default falls back to `send()`).
|
||||
async fn send_in_thread(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
_thread_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
self.send(user, content).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Split a message into chunks of at most `max_len` characters,
|
||||
/// preferring to split at newline boundaries.
|
||||
///
|
||||
/// Shared utility used by Telegram, Discord, and Slack adapters.
|
||||
pub fn split_message(text: &str, max_len: usize) -> Vec<&str> {
|
||||
if text.len() <= max_len {
|
||||
return vec![text];
|
||||
}
|
||||
let mut chunks = Vec::new();
|
||||
let mut remaining = text;
|
||||
while !remaining.is_empty() {
|
||||
if remaining.len() <= max_len {
|
||||
chunks.push(remaining);
|
||||
break;
|
||||
}
|
||||
// Try to split at a newline near the boundary (UTF-8 safe)
|
||||
let safe_end = openfang_types::truncate_str(remaining, max_len).len();
|
||||
let split_at = remaining[..safe_end].rfind('\n').unwrap_or(safe_end);
|
||||
let (chunk, rest) = remaining.split_at(split_at);
|
||||
chunks.push(chunk);
|
||||
// Skip the newline (and optional \r) we split on
|
||||
remaining = rest
|
||||
.strip_prefix("\r\n")
|
||||
.or_else(|| rest.strip_prefix('\n'))
|
||||
.unwrap_or(rest);
|
||||
}
|
||||
chunks
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_channel_message_serialization() {
|
||||
let msg = ChannelMessage {
|
||||
channel: ChannelType::Telegram,
|
||||
platform_message_id: "123".to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: "user1".to_string(),
|
||||
display_name: "Alice".to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: ChannelContent::Text("Hello!".to_string()),
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: false,
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&msg).unwrap();
|
||||
let deserialized: ChannelMessage = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.channel, ChannelType::Telegram);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_message_short() {
|
||||
assert_eq!(split_message("hello", 100), vec!["hello"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_message_at_newlines() {
|
||||
let text = "line1\nline2\nline3";
|
||||
let chunks = split_message(text, 10);
|
||||
assert_eq!(chunks, vec!["line1", "line2", "line3"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_type_matrix_serde() {
|
||||
let ct = ChannelType::Matrix;
|
||||
let json = serde_json::to_string(&ct).unwrap();
|
||||
let back: ChannelType = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back, ChannelType::Matrix);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_type_email_serde() {
|
||||
let ct = ChannelType::Email;
|
||||
let json = serde_json::to_string(&ct).unwrap();
|
||||
let back: ChannelType = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back, ChannelType::Email);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_channel_content_variants() {
|
||||
let text = ChannelContent::Text("hello".to_string());
|
||||
let cmd = ChannelContent::Command {
|
||||
name: "status".to_string(),
|
||||
args: vec![],
|
||||
};
|
||||
let loc = ChannelContent::Location {
|
||||
lat: 40.7128,
|
||||
lon: -74.0060,
|
||||
};
|
||||
|
||||
// Just verify they serialize without panic
|
||||
serde_json::to_string(&text).unwrap();
|
||||
serde_json::to_string(&cmd).unwrap();
|
||||
serde_json::to_string(&loc).unwrap();
|
||||
}
|
||||
|
||||
// ----- AgentPhase tests -----
|
||||
|
||||
#[test]
|
||||
fn test_agent_phase_serde_roundtrip() {
|
||||
let phases = vec![
|
||||
AgentPhase::Queued,
|
||||
AgentPhase::Thinking,
|
||||
AgentPhase::tool_use("web_fetch"),
|
||||
AgentPhase::Streaming,
|
||||
AgentPhase::Done,
|
||||
AgentPhase::Error,
|
||||
];
|
||||
for phase in &phases {
|
||||
let json = serde_json::to_string(phase).unwrap();
|
||||
let back: AgentPhase = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(*phase, back);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_phase_tool_use_sanitizes() {
|
||||
let phase = AgentPhase::tool_use("hello\x00world\x01test");
|
||||
if let AgentPhase::ToolUse { tool_name } = phase {
|
||||
assert!(!tool_name.contains('\x00'));
|
||||
assert!(!tool_name.contains('\x01'));
|
||||
assert!(tool_name.contains("hello"));
|
||||
} else {
|
||||
panic!("Expected ToolUse variant");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_phase_tool_use_truncates_long_name() {
|
||||
let long_name = "a".repeat(200);
|
||||
let phase = AgentPhase::tool_use(&long_name);
|
||||
if let AgentPhase::ToolUse { tool_name } = phase {
|
||||
assert!(tool_name.len() <= 64);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_phase_emoji() {
|
||||
assert_eq!(default_phase_emoji(&AgentPhase::Thinking), "\u{1F914}");
|
||||
assert_eq!(default_phase_emoji(&AgentPhase::Done), "\u{2705}");
|
||||
assert_eq!(default_phase_emoji(&AgentPhase::Error), "\u{274C}");
|
||||
}
|
||||
|
||||
// ----- DeliveryReceipt tests -----
|
||||
|
||||
#[test]
|
||||
fn test_delivery_status_serde() {
|
||||
let statuses = vec![
|
||||
DeliveryStatus::Sent,
|
||||
DeliveryStatus::Delivered,
|
||||
DeliveryStatus::Failed,
|
||||
DeliveryStatus::BestEffort,
|
||||
];
|
||||
for status in &statuses {
|
||||
let json = serde_json::to_string(status).unwrap();
|
||||
let back: DeliveryStatus = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(*status, back);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delivery_receipt_serde() {
|
||||
let receipt = DeliveryReceipt {
|
||||
message_id: "msg-123".to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
recipient: "user-456".to_string(),
|
||||
status: DeliveryStatus::Sent,
|
||||
timestamp: Utc::now(),
|
||||
error: None,
|
||||
};
|
||||
let json = serde_json::to_string(&receipt).unwrap();
|
||||
let back: DeliveryReceipt = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back.message_id, "msg-123");
|
||||
assert_eq!(back.status, DeliveryStatus::Sent);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delivery_receipt_with_error() {
|
||||
let receipt = DeliveryReceipt {
|
||||
message_id: "msg-789".to_string(),
|
||||
channel: "slack".to_string(),
|
||||
recipient: "channel-abc".to_string(),
|
||||
status: DeliveryStatus::Failed,
|
||||
timestamp: Utc::now(),
|
||||
error: Some("Connection refused".to_string()),
|
||||
};
|
||||
let json = serde_json::to_string(&receipt).unwrap();
|
||||
assert!(json.contains("Connection refused"));
|
||||
}
|
||||
}
|
||||
587
crates/openfang-channels/src/viber.rs
Normal file
587
crates/openfang-channels/src/viber.rs
Normal file
@@ -0,0 +1,587 @@
|
||||
//! Viber Bot API channel adapter.
|
||||
//!
|
||||
//! Uses the Viber REST API for sending messages and a webhook HTTP server for
|
||||
//! receiving inbound events. Authentication is performed via the `X-Viber-Auth-Token`
|
||||
//! header on all outbound API calls. The webhook is registered on startup via
|
||||
//! `POST https://chatapi.viber.com/pa/set_webhook`.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Viber set webhook endpoint.
|
||||
const VIBER_SET_WEBHOOK_URL: &str = "https://chatapi.viber.com/pa/set_webhook";
|
||||
|
||||
/// Viber send message endpoint.
|
||||
const VIBER_SEND_MESSAGE_URL: &str = "https://chatapi.viber.com/pa/send_message";
|
||||
|
||||
/// Viber get account info endpoint (used for validation).
|
||||
const VIBER_ACCOUNT_INFO_URL: &str = "https://chatapi.viber.com/pa/get_account_info";
|
||||
|
||||
/// Maximum Viber message text length (characters).
|
||||
const MAX_MESSAGE_LEN: usize = 7000;
|
||||
|
||||
/// Sender name shown in Viber messages from the bot.
|
||||
const DEFAULT_SENDER_NAME: &str = "OpenFang";
|
||||
|
||||
/// Viber Bot API adapter.
|
||||
///
|
||||
/// Inbound messages arrive via a webhook HTTP server that Viber pushes events to.
|
||||
/// Outbound messages are sent via the Viber send_message REST API with the
|
||||
/// `X-Viber-Auth-Token` header for authentication.
|
||||
pub struct ViberAdapter {
|
||||
/// SECURITY: Auth token is zeroized on drop to prevent memory disclosure.
|
||||
auth_token: Zeroizing<String>,
|
||||
/// Public webhook URL that Viber will POST events to.
|
||||
webhook_url: String,
|
||||
/// Port on which the inbound webhook HTTP server listens.
|
||||
webhook_port: u16,
|
||||
/// Sender name displayed in outbound messages.
|
||||
sender_name: String,
|
||||
/// Optional sender avatar URL for outbound messages.
|
||||
sender_avatar: Option<String>,
|
||||
/// HTTP client for outbound API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl ViberAdapter {
|
||||
/// Create a new Viber adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `auth_token` - Viber bot authentication token.
|
||||
/// * `webhook_url` - Public URL where Viber will send webhook events.
|
||||
/// * `webhook_port` - Local port for the inbound webhook HTTP server.
|
||||
pub fn new(auth_token: String, webhook_url: String, webhook_port: u16) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let webhook_url = webhook_url.trim_end_matches('/').to_string();
|
||||
Self {
|
||||
auth_token: Zeroizing::new(auth_token),
|
||||
webhook_url,
|
||||
webhook_port,
|
||||
sender_name: DEFAULT_SENDER_NAME.to_string(),
|
||||
sender_avatar: None,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new Viber adapter with a custom sender name and avatar.
|
||||
pub fn with_sender(
|
||||
auth_token: String,
|
||||
webhook_url: String,
|
||||
webhook_port: u16,
|
||||
sender_name: String,
|
||||
sender_avatar: Option<String>,
|
||||
) -> Self {
|
||||
let mut adapter = Self::new(auth_token, webhook_url, webhook_port);
|
||||
adapter.sender_name = sender_name;
|
||||
adapter.sender_avatar = sender_avatar;
|
||||
adapter
|
||||
}
|
||||
|
||||
/// Add the Viber auth token header to a request builder.
|
||||
fn auth_header(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
builder.header("X-Viber-Auth-Token", self.auth_token.as_str())
|
||||
}
|
||||
|
||||
/// Validate the auth token by calling the get_account_info endpoint.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let resp = self
|
||||
.auth_header(self.client.post(VIBER_ACCOUNT_INFO_URL))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Viber authentication failed {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let status = body["status"].as_u64().unwrap_or(1);
|
||||
if status != 0 {
|
||||
let msg = body["status_message"].as_str().unwrap_or("unknown error");
|
||||
return Err(format!("Viber API error: {msg}").into());
|
||||
}
|
||||
|
||||
let name = body["name"].as_str().unwrap_or("Viber Bot").to_string();
|
||||
Ok(name)
|
||||
}
|
||||
|
||||
/// Register the webhook URL with Viber.
|
||||
async fn register_webhook(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let body = serde_json::json!({
|
||||
"url": self.webhook_url,
|
||||
"event_types": [
|
||||
"delivered",
|
||||
"seen",
|
||||
"failed",
|
||||
"subscribed",
|
||||
"unsubscribed",
|
||||
"conversation_started",
|
||||
"message"
|
||||
],
|
||||
"send_name": true,
|
||||
"send_photo": true,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.auth_header(self.client.post(VIBER_SET_WEBHOOK_URL))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Viber set_webhook failed {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
let resp_body: serde_json::Value = resp.json().await?;
|
||||
let status = resp_body["status"].as_u64().unwrap_or(1);
|
||||
if status != 0 {
|
||||
let msg = resp_body["status_message"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown error");
|
||||
return Err(format!("Viber set_webhook error: {msg}").into());
|
||||
}
|
||||
|
||||
info!("Viber webhook registered at {}", self.webhook_url);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a text message to a Viber user.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
receiver: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let mut sender = serde_json::json!({
|
||||
"name": self.sender_name,
|
||||
});
|
||||
if let Some(ref avatar) = self.sender_avatar {
|
||||
sender["avatar"] = serde_json::Value::String(avatar.clone());
|
||||
}
|
||||
|
||||
let body = serde_json::json!({
|
||||
"receiver": receiver,
|
||||
"min_api_version": 1,
|
||||
"sender": sender,
|
||||
"tracking_data": "openfang",
|
||||
"type": "text",
|
||||
"text": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.auth_header(self.client.post(VIBER_SEND_MESSAGE_URL))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Viber send_message error {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
let resp_body: serde_json::Value = resp.json().await?;
|
||||
let api_status = resp_body["status"].as_u64().unwrap_or(1);
|
||||
if api_status != 0 {
|
||||
let msg = resp_body["status_message"]
|
||||
.as_str()
|
||||
.unwrap_or("unknown error");
|
||||
warn!("Viber send_message API error: {msg}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a Viber webhook event into a `ChannelMessage`.
|
||||
///
|
||||
/// Handles `message` events with text type. Returns `None` for non-message
|
||||
/// events (delivered, seen, subscribed, conversation_started, etc.).
|
||||
fn parse_viber_event(event: &serde_json::Value) -> Option<ChannelMessage> {
|
||||
let event_type = event["event"].as_str().unwrap_or("");
|
||||
if event_type != "message" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let message = event.get("message")?;
|
||||
let msg_type = message["type"].as_str().unwrap_or("");
|
||||
|
||||
// Only handle text messages
|
||||
if msg_type != "text" {
|
||||
return None;
|
||||
}
|
||||
|
||||
let text = message["text"].as_str().unwrap_or("");
|
||||
if text.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let sender = event.get("sender")?;
|
||||
let sender_id = sender["id"].as_str().unwrap_or("").to_string();
|
||||
let sender_name = sender["name"].as_str().unwrap_or("Unknown").to_string();
|
||||
let sender_avatar = sender["avatar"].as_str().unwrap_or("").to_string();
|
||||
|
||||
let message_token = event["message_token"]
|
||||
.as_u64()
|
||||
.map(|t| t.to_string())
|
||||
.unwrap_or_default();
|
||||
|
||||
let content = if text.starts_with('/') {
|
||||
let parts: Vec<&str> = text.splitn(2, ' ').collect();
|
||||
let cmd_name = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd_name.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(text.to_string())
|
||||
};
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"sender_id".to_string(),
|
||||
serde_json::Value::String(sender_id.clone()),
|
||||
);
|
||||
if !sender_avatar.is_empty() {
|
||||
metadata.insert(
|
||||
"sender_avatar".to_string(),
|
||||
serde_json::Value::String(sender_avatar),
|
||||
);
|
||||
}
|
||||
if let Some(tracking) = message["tracking_data"].as_str() {
|
||||
metadata.insert(
|
||||
"tracking_data".to_string(),
|
||||
serde_json::Value::String(tracking.to_string()),
|
||||
);
|
||||
}
|
||||
|
||||
Some(ChannelMessage {
|
||||
channel: ChannelType::Custom("viber".to_string()),
|
||||
platform_message_id: message_token,
|
||||
sender: ChannelUser {
|
||||
platform_id: sender_id,
|
||||
display_name: sender_name,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group: false, // Viber bot API messages are always 1:1
|
||||
thread_id: None,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for ViberAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"viber"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("viber".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let bot_name = self.validate().await?;
|
||||
info!("Viber adapter authenticated as {bot_name}");
|
||||
|
||||
// Register webhook
|
||||
self.register_webhook().await?;
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let tx = Arc::new(tx);
|
||||
|
||||
let app = axum::Router::new().route(
|
||||
"/viber/webhook",
|
||||
axum::routing::post({
|
||||
let tx = Arc::clone(&tx);
|
||||
move |body: axum::extract::Json<serde_json::Value>| {
|
||||
let tx = Arc::clone(&tx);
|
||||
async move {
|
||||
if let Some(msg) = parse_viber_event(&body.0) {
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
axum::http::StatusCode::OK
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
info!("Viber webhook server listening on {addr}");
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("Viber webhook bind failed: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let server = axum::serve(listener, app);
|
||||
|
||||
tokio::select! {
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
warn!("Viber webhook server error: {e}");
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Viber adapter shutting down");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
ChannelContent::Image { url, caption } => {
|
||||
let mut sender = serde_json::json!({
|
||||
"name": self.sender_name,
|
||||
});
|
||||
if let Some(ref avatar) = self.sender_avatar {
|
||||
sender["avatar"] = serde_json::Value::String(avatar.clone());
|
||||
}
|
||||
|
||||
let body = serde_json::json!({
|
||||
"receiver": user.platform_id,
|
||||
"min_api_version": 1,
|
||||
"sender": sender,
|
||||
"type": "picture",
|
||||
"text": caption.unwrap_or_default(),
|
||||
"media": url,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.auth_header(self.client.post(VIBER_SEND_MESSAGE_URL))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
warn!("Viber image send error {status}: {resp_body}");
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Viber does not support typing indicators via REST API
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_viber_adapter_creation() {
|
||||
let adapter = ViberAdapter::new(
|
||||
"auth-token-123".to_string(),
|
||||
"https://example.com/viber/webhook".to_string(),
|
||||
8443,
|
||||
);
|
||||
assert_eq!(adapter.name(), "viber");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("viber".to_string())
|
||||
);
|
||||
assert_eq!(adapter.webhook_port, 8443);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_viber_url_normalization() {
|
||||
let adapter = ViberAdapter::new(
|
||||
"tok".to_string(),
|
||||
"https://example.com/viber/webhook/".to_string(),
|
||||
8443,
|
||||
);
|
||||
assert_eq!(adapter.webhook_url, "https://example.com/viber/webhook");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_viber_with_sender() {
|
||||
let adapter = ViberAdapter::with_sender(
|
||||
"tok".to_string(),
|
||||
"https://example.com".to_string(),
|
||||
8443,
|
||||
"MyBot".to_string(),
|
||||
Some("https://example.com/avatar.png".to_string()),
|
||||
);
|
||||
assert_eq!(adapter.sender_name, "MyBot");
|
||||
assert_eq!(
|
||||
adapter.sender_avatar,
|
||||
Some("https://example.com/avatar.png".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_viber_auth_header() {
|
||||
let adapter = ViberAdapter::new(
|
||||
"my-viber-token".to_string(),
|
||||
"https://example.com".to_string(),
|
||||
8443,
|
||||
);
|
||||
let builder = adapter.client.post("https://example.com");
|
||||
let builder = adapter.auth_header(builder);
|
||||
let request = builder.build().unwrap();
|
||||
assert_eq!(
|
||||
request.headers().get("X-Viber-Auth-Token").unwrap(),
|
||||
"my-viber-token"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_viber_event_text_message() {
|
||||
let event = serde_json::json!({
|
||||
"event": "message",
|
||||
"timestamp": 1457764197627_u64,
|
||||
"message_token": 4912661846655238145_u64,
|
||||
"sender": {
|
||||
"id": "01234567890A=",
|
||||
"name": "Alice",
|
||||
"avatar": "https://example.com/avatar.jpg"
|
||||
},
|
||||
"message": {
|
||||
"type": "text",
|
||||
"text": "Hello from Viber!"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_viber_event(&event).unwrap();
|
||||
assert_eq!(msg.channel, ChannelType::Custom("viber".to_string()));
|
||||
assert_eq!(msg.sender.display_name, "Alice");
|
||||
assert_eq!(msg.sender.platform_id, "01234567890A=");
|
||||
assert!(!msg.is_group);
|
||||
assert!(matches!(msg.content, ChannelContent::Text(ref t) if t == "Hello from Viber!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_viber_event_command() {
|
||||
let event = serde_json::json!({
|
||||
"event": "message",
|
||||
"message_token": 123_u64,
|
||||
"sender": {
|
||||
"id": "sender-1",
|
||||
"name": "Bob"
|
||||
},
|
||||
"message": {
|
||||
"type": "text",
|
||||
"text": "/help agents"
|
||||
}
|
||||
});
|
||||
|
||||
let msg = parse_viber_event(&event).unwrap();
|
||||
match &msg.content {
|
||||
ChannelContent::Command { name, args } => {
|
||||
assert_eq!(name, "help");
|
||||
assert_eq!(args, &["agents"]);
|
||||
}
|
||||
other => panic!("Expected Command, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_viber_event_non_message() {
|
||||
let event = serde_json::json!({
|
||||
"event": "delivered",
|
||||
"timestamp": 1457764197627_u64,
|
||||
"message_token": 123_u64,
|
||||
"user_id": "user-1"
|
||||
});
|
||||
|
||||
assert!(parse_viber_event(&event).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_viber_event_non_text() {
|
||||
let event = serde_json::json!({
|
||||
"event": "message",
|
||||
"message_token": 123_u64,
|
||||
"sender": {
|
||||
"id": "sender-1",
|
||||
"name": "Bob"
|
||||
},
|
||||
"message": {
|
||||
"type": "picture",
|
||||
"media": "https://example.com/image.jpg"
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_viber_event(&event).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_viber_event_empty_text() {
|
||||
let event = serde_json::json!({
|
||||
"event": "message",
|
||||
"message_token": 123_u64,
|
||||
"sender": {
|
||||
"id": "sender-1",
|
||||
"name": "Bob"
|
||||
},
|
||||
"message": {
|
||||
"type": "text",
|
||||
"text": ""
|
||||
}
|
||||
});
|
||||
|
||||
assert!(parse_viber_event(&event).is_none());
|
||||
}
|
||||
}
|
||||
522
crates/openfang-channels/src/webex.rs
Normal file
522
crates/openfang-channels/src/webex.rs
Normal file
@@ -0,0 +1,522 @@
|
||||
//! Webex Bot channel adapter.
|
||||
//!
|
||||
//! Connects to the Webex platform via the Mercury WebSocket for receiving
|
||||
//! real-time message events and uses the Webex REST API for sending messages.
|
||||
//! Authentication is performed via a Bot Bearer token. Supports room filtering
|
||||
//! and automatic WebSocket reconnection.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// Webex REST API base URL.
|
||||
const WEBEX_API_BASE: &str = "https://webexapis.com/v1";
|
||||
|
||||
/// Webex Mercury WebSocket URL for device connections.
|
||||
const WEBEX_WS_URL: &str = "wss://mercury-connection-a.wbx2.com/v1/apps/wx2/registrations";
|
||||
|
||||
/// Maximum message length for Webex (official limit is 7439 characters).
|
||||
const MAX_MESSAGE_LEN: usize = 7439;
|
||||
|
||||
/// Webex Bot channel adapter using WebSocket for events and REST for sending.
|
||||
///
|
||||
/// Connects to the Webex Mercury WebSocket gateway for real-time message
|
||||
/// notifications and fetches full message content via the REST API. Outbound
|
||||
/// messages are sent directly via the REST API.
|
||||
pub struct WebexAdapter {
|
||||
/// SECURITY: Bot token is zeroized on drop.
|
||||
bot_token: Zeroizing<String>,
|
||||
/// Room IDs to listen on (empty = all rooms the bot is in).
|
||||
allowed_rooms: Vec<String>,
|
||||
/// HTTP client for REST API calls.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Cached bot identity (ID and display name).
|
||||
bot_info: Arc<RwLock<Option<(String, String)>>>,
|
||||
}
|
||||
|
||||
impl WebexAdapter {
|
||||
/// Create a new Webex adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `bot_token` - Webex Bot access token.
|
||||
/// * `allowed_rooms` - Room IDs to filter events for (empty = all).
|
||||
pub fn new(bot_token: String, allowed_rooms: Vec<String>) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
bot_token: Zeroizing::new(bot_token),
|
||||
allowed_rooms,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
bot_info: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate credentials and retrieve bot identity.
|
||||
async fn validate(&self) -> Result<(String, String), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/people/me", WEBEX_API_BASE);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.bot_token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Webex authentication failed".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let bot_id = body["id"].as_str().unwrap_or("unknown").to_string();
|
||||
let display_name = body["displayName"]
|
||||
.as_str()
|
||||
.unwrap_or("OpenFang Bot")
|
||||
.to_string();
|
||||
|
||||
*self.bot_info.write().await = Some((bot_id.clone(), display_name.clone()));
|
||||
|
||||
Ok((bot_id, display_name))
|
||||
}
|
||||
|
||||
/// Fetch the full message content by ID (Mercury events only include activity data).
|
||||
#[allow(dead_code)]
|
||||
async fn get_message(
|
||||
&self,
|
||||
message_id: &str,
|
||||
) -> Result<serde_json::Value, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/messages/{}", WEBEX_API_BASE, message_id);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.bearer_auth(self.bot_token.as_str())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
return Err(format!("Webex: failed to get message {message_id}: {status}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
/// Register a webhook for receiving message events (alternative to WebSocket).
|
||||
#[allow(dead_code)]
|
||||
async fn register_webhook(
|
||||
&self,
|
||||
target_url: &str,
|
||||
) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/webhooks", WEBEX_API_BASE);
|
||||
let body = serde_json::json!({
|
||||
"name": "OpenFang Bot Webhook",
|
||||
"targetUrl": target_url,
|
||||
"resource": "messages",
|
||||
"event": "created",
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.bot_token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Webex webhook registration failed {status}: {resp_body}").into());
|
||||
}
|
||||
|
||||
let result: serde_json::Value = resp.json().await?;
|
||||
let webhook_id = result["id"].as_str().unwrap_or("unknown").to_string();
|
||||
Ok(webhook_id)
|
||||
}
|
||||
|
||||
/// Send a text message to a Webex room.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
room_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/messages", WEBEX_API_BASE);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"roomId": room_id,
|
||||
"text": chunk,
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.bot_token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Webex API error {status}: {resp_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a direct message to a person by email or person ID.
|
||||
#[allow(dead_code)]
|
||||
async fn api_send_direct(
|
||||
&self,
|
||||
person_id: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/messages", WEBEX_API_BASE);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let body = if person_id.contains('@') {
|
||||
serde_json::json!({
|
||||
"toPersonEmail": person_id,
|
||||
"text": chunk,
|
||||
})
|
||||
} else {
|
||||
serde_json::json!({
|
||||
"toPersonId": person_id,
|
||||
"text": chunk,
|
||||
})
|
||||
};
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(self.bot_token.as_str())
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let resp_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Webex direct message error {status}: {resp_body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a room ID is in the allowed list.
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_room(&self, room_id: &str) -> bool {
|
||||
self.allowed_rooms.is_empty() || self.allowed_rooms.iter().any(|r| r == room_id)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for WebexAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"webex"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("webex".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials and get bot identity
|
||||
let (bot_id, bot_name) = self.validate().await?;
|
||||
info!("Webex adapter authenticated as {bot_name} ({bot_id})");
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let bot_token = self.bot_token.clone();
|
||||
let allowed_rooms = self.allowed_rooms.clone();
|
||||
let client = self.client.clone();
|
||||
let own_bot_id = bot_id;
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
if *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
// Attempt WebSocket connection to Mercury
|
||||
let mut request =
|
||||
match tokio_tungstenite::tungstenite::client::IntoClientRequest::into_client_request(WEBEX_WS_URL) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Webex: failed to build WS request: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
request.headers_mut().insert(
|
||||
"Authorization",
|
||||
format!("Bearer {}", bot_token.as_str()).parse().unwrap(),
|
||||
);
|
||||
|
||||
let ws_stream = match tokio_tungstenite::connect_async(request).await {
|
||||
Ok((stream, _resp)) => stream,
|
||||
Err(e) => {
|
||||
warn!("Webex: WebSocket connection failed: {e}, retrying in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Webex Mercury WebSocket connected");
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
use futures::StreamExt;
|
||||
let (_write, mut read) = ws_stream.split();
|
||||
|
||||
let should_reconnect = loop {
|
||||
let msg = tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Webex adapter shutting down");
|
||||
return;
|
||||
}
|
||||
msg = read.next() => msg,
|
||||
};
|
||||
|
||||
let msg = match msg {
|
||||
Some(Ok(m)) => m,
|
||||
Some(Err(e)) => {
|
||||
warn!("Webex WS read error: {e}");
|
||||
break true;
|
||||
}
|
||||
None => {
|
||||
info!("Webex WS stream ended");
|
||||
break true;
|
||||
}
|
||||
};
|
||||
|
||||
let text = match msg {
|
||||
tokio_tungstenite::tungstenite::Message::Text(t) => t,
|
||||
tokio_tungstenite::tungstenite::Message::Close(_) => {
|
||||
break true;
|
||||
}
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let event: serde_json::Value = match serde_json::from_str(&text) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
// Mercury events have a data.activity structure
|
||||
let activity = &event["data"]["activity"];
|
||||
let verb = activity["verb"].as_str().unwrap_or("");
|
||||
|
||||
// Only process "post" activities (new messages)
|
||||
if verb != "post" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let actor_id = activity["actor"]["id"].as_str().unwrap_or("");
|
||||
// Skip messages from the bot itself
|
||||
if actor_id == own_bot_id {
|
||||
continue;
|
||||
}
|
||||
|
||||
let message_id = activity["object"]["id"].as_str().unwrap_or("");
|
||||
if message_id.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let room_id = activity["target"]["id"].as_str().unwrap_or("").to_string();
|
||||
|
||||
// Filter by room if configured
|
||||
if !allowed_rooms.is_empty() && !allowed_rooms.iter().any(|r| r == &room_id) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fetch full message content via REST API
|
||||
let msg_url = format!("{}/messages/{}", WEBEX_API_BASE, message_id);
|
||||
let full_msg = match client
|
||||
.get(&msg_url)
|
||||
.bearer_auth(bot_token.as_str())
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
if !resp.status().is_success() {
|
||||
warn!("Webex: failed to fetch message {message_id}");
|
||||
continue;
|
||||
}
|
||||
resp.json::<serde_json::Value>().await.unwrap_or_default()
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Webex: message fetch error: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let msg_text = full_msg["text"].as_str().unwrap_or("");
|
||||
if msg_text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let sender_email = full_msg["personEmail"].as_str().unwrap_or("unknown");
|
||||
let sender_id = full_msg["personId"].as_str().unwrap_or("").to_string();
|
||||
let full_room_id = full_msg["roomId"].as_str().unwrap_or(&room_id).to_string();
|
||||
let room_type = full_msg["roomType"].as_str().unwrap_or("group");
|
||||
let is_group = room_type == "group";
|
||||
|
||||
let msg_content = if msg_text.starts_with('/') {
|
||||
let parts: Vec<&str> = msg_text.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(msg_text.to_string())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("webex".to_string()),
|
||||
platform_message_id: message_id.to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: full_room_id,
|
||||
display_name: sender_email.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: None,
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"sender_id".to_string(),
|
||||
serde_json::Value::String(sender_id),
|
||||
);
|
||||
m.insert(
|
||||
"sender_email".to_string(),
|
||||
serde_json::Value::String(sender_email.to_string()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !should_reconnect || *shutdown_rx.borrow() {
|
||||
break;
|
||||
}
|
||||
|
||||
warn!("Webex: reconnecting in {backoff:?}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
}
|
||||
|
||||
info!("Webex WebSocket loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Webex does not expose a public typing indicator API for bots
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_webex_adapter_creation() {
|
||||
let adapter = WebexAdapter::new("test-bot-token".to_string(), vec!["room1".to_string()]);
|
||||
assert_eq!(adapter.name(), "webex");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("webex".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webex_allowed_rooms() {
|
||||
let adapter = WebexAdapter::new(
|
||||
"tok".to_string(),
|
||||
vec!["room-a".to_string(), "room-b".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_room("room-a"));
|
||||
assert!(adapter.is_allowed_room("room-b"));
|
||||
assert!(!adapter.is_allowed_room("room-c"));
|
||||
|
||||
let open = WebexAdapter::new("tok".to_string(), vec![]);
|
||||
assert!(open.is_allowed_room("any-room"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webex_token_zeroized() {
|
||||
let adapter = WebexAdapter::new("my-secret-bot-token".to_string(), vec![]);
|
||||
assert_eq!(adapter.bot_token.as_str(), "my-secret-bot-token");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webex_message_length_limit() {
|
||||
assert_eq!(MAX_MESSAGE_LEN, 7439);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webex_constants() {
|
||||
assert!(WEBEX_API_BASE.starts_with("https://"));
|
||||
assert!(WEBEX_WS_URL.starts_with("wss://"));
|
||||
}
|
||||
}
|
||||
478
crates/openfang-channels/src/webhook.rs
Normal file
478
crates/openfang-channels/src/webhook.rs
Normal file
@@ -0,0 +1,478 @@
|
||||
//! Generic HTTP webhook channel adapter.
|
||||
//!
|
||||
//! Provides a bidirectional webhook integration point. Incoming messages are
|
||||
//! received via an HTTP server that verifies `X-Webhook-Signature` (HMAC-SHA256
|
||||
//! of the request body). Outbound messages are POSTed to a configurable
|
||||
//! callback URL with the same signature scheme.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const MAX_MESSAGE_LEN: usize = 65535;
|
||||
|
||||
/// Generic HTTP webhook channel adapter.
|
||||
///
|
||||
/// The most flexible adapter in the OpenFang channel suite. Any system that
|
||||
/// can send/receive HTTP requests with HMAC-SHA256 signatures can integrate
|
||||
/// through this adapter.
|
||||
///
|
||||
/// ## Inbound (receiving)
|
||||
///
|
||||
/// Listens on `listen_port` for `POST /webhook` (or `POST /`) requests.
|
||||
/// Each request must include an `X-Webhook-Signature` header containing
|
||||
/// `sha256=<hex-digest>` where the digest is `HMAC-SHA256(secret, body)`.
|
||||
///
|
||||
/// Expected JSON body:
|
||||
/// ```json
|
||||
/// {
|
||||
/// "sender_id": "user-123",
|
||||
/// "sender_name": "Alice",
|
||||
/// "message": "Hello!",
|
||||
/// "thread_id": "optional-thread",
|
||||
/// "is_group": false,
|
||||
/// "metadata": {}
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// ## Outbound (sending)
|
||||
///
|
||||
/// If `callback_url` is set, messages are POSTed there with the same signature
|
||||
/// scheme.
|
||||
pub struct WebhookAdapter {
|
||||
/// SECURITY: Shared secret for HMAC-SHA256 signatures (zeroized on drop).
|
||||
secret: Zeroizing<String>,
|
||||
/// Port to listen on for incoming webhooks.
|
||||
listen_port: u16,
|
||||
/// Optional callback URL for sending messages.
|
||||
callback_url: Option<String>,
|
||||
/// HTTP client for outbound requests.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl WebhookAdapter {
|
||||
/// Create a new generic webhook adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `secret` - Shared secret for HMAC-SHA256 signature verification.
|
||||
/// * `listen_port` - Port to listen for incoming webhook POST requests.
|
||||
/// * `callback_url` - Optional URL to POST outbound messages to.
|
||||
pub fn new(secret: String, listen_port: u16, callback_url: Option<String>) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
secret: Zeroizing::new(secret),
|
||||
listen_port,
|
||||
callback_url,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute HMAC-SHA256 signature of data with the shared secret.
|
||||
///
|
||||
/// Returns the hex-encoded digest prefixed with "sha256=".
|
||||
fn compute_signature(secret: &str, data: &[u8]) -> String {
|
||||
use hmac::{Hmac, Mac};
|
||||
use sha2::Sha256;
|
||||
|
||||
let mut mac =
|
||||
Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC accepts any key size");
|
||||
mac.update(data);
|
||||
let result = mac.finalize();
|
||||
let hex = hex::encode(result.into_bytes());
|
||||
format!("sha256={hex}")
|
||||
}
|
||||
|
||||
/// Verify an incoming webhook signature (constant-time comparison).
|
||||
fn verify_signature(secret: &str, body: &[u8], signature: &str) -> bool {
|
||||
let expected = Self::compute_signature(secret, body);
|
||||
if expected.len() != signature.len() {
|
||||
return false;
|
||||
}
|
||||
// Constant-time comparison to prevent timing attacks
|
||||
let mut diff = 0u8;
|
||||
for (a, b) in expected.bytes().zip(signature.bytes()) {
|
||||
diff |= a ^ b;
|
||||
}
|
||||
diff == 0
|
||||
}
|
||||
|
||||
/// Parse an incoming webhook JSON body.
|
||||
#[allow(clippy::type_complexity)]
|
||||
fn parse_webhook_body(
|
||||
body: &serde_json::Value,
|
||||
) -> Option<(
|
||||
String,
|
||||
String,
|
||||
String,
|
||||
Option<String>,
|
||||
bool,
|
||||
HashMap<String, serde_json::Value>,
|
||||
)> {
|
||||
let message = body["message"].as_str()?.to_string();
|
||||
if message.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let sender_id = body["sender_id"]
|
||||
.as_str()
|
||||
.unwrap_or("webhook-user")
|
||||
.to_string();
|
||||
let sender_name = body["sender_name"]
|
||||
.as_str()
|
||||
.unwrap_or("Webhook User")
|
||||
.to_string();
|
||||
let thread_id = body["thread_id"].as_str().map(String::from);
|
||||
let is_group = body["is_group"].as_bool().unwrap_or(false);
|
||||
|
||||
let metadata = body["metadata"]
|
||||
.as_object()
|
||||
.map(|obj| {
|
||||
obj.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect::<HashMap<_, _>>()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Some((
|
||||
message,
|
||||
sender_id,
|
||||
sender_name,
|
||||
thread_id,
|
||||
is_group,
|
||||
metadata,
|
||||
))
|
||||
}
|
||||
|
||||
/// Check if a callback URL is configured.
|
||||
pub fn has_callback(&self) -> bool {
|
||||
self.callback_url.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for WebhookAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"webhook"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("webhook".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.listen_port;
|
||||
let secret = self.secret.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
info!("Webhook adapter starting HTTP server on port {port}");
|
||||
|
||||
tokio::spawn(async move {
|
||||
let tx_shared = Arc::new(tx);
|
||||
let secret_shared = Arc::new(secret);
|
||||
|
||||
let app = axum::Router::new().route(
|
||||
"/webhook",
|
||||
axum::routing::post({
|
||||
let tx = Arc::clone(&tx_shared);
|
||||
let secret = Arc::clone(&secret_shared);
|
||||
move |headers: axum::http::HeaderMap, body: axum::body::Bytes| {
|
||||
let tx = Arc::clone(&tx);
|
||||
let secret = Arc::clone(&secret);
|
||||
async move {
|
||||
// Extract and verify signature
|
||||
let signature = headers
|
||||
.get("X-Webhook-Signature")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if !WebhookAdapter::verify_signature(&secret, &body, signature) {
|
||||
warn!("Webhook: invalid signature");
|
||||
return (
|
||||
axum::http::StatusCode::FORBIDDEN,
|
||||
"Forbidden: invalid signature",
|
||||
);
|
||||
}
|
||||
|
||||
let json_body: serde_json::Value = match serde_json::from_slice(&body) {
|
||||
Ok(v) => v,
|
||||
Err(_) => {
|
||||
return (axum::http::StatusCode::BAD_REQUEST, "Invalid JSON");
|
||||
}
|
||||
};
|
||||
|
||||
if let Some((
|
||||
message,
|
||||
sender_id,
|
||||
sender_name,
|
||||
thread_id,
|
||||
is_group,
|
||||
metadata,
|
||||
)) = WebhookAdapter::parse_webhook_body(&json_body)
|
||||
{
|
||||
let content = if message.starts_with('/') {
|
||||
let parts: Vec<&str> = message.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(message)
|
||||
};
|
||||
|
||||
let msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("webhook".to_string()),
|
||||
platform_message_id: format!(
|
||||
"wh-{}",
|
||||
Utc::now().timestamp_millis()
|
||||
),
|
||||
sender: ChannelUser {
|
||||
platform_id: sender_id,
|
||||
display_name: sender_name,
|
||||
openfang_user: None,
|
||||
},
|
||||
content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id,
|
||||
metadata,
|
||||
};
|
||||
|
||||
let _ = tx.send(msg).await;
|
||||
}
|
||||
|
||||
(axum::http::StatusCode::OK, "ok")
|
||||
}
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
info!("Webhook HTTP server listening on {addr}");
|
||||
|
||||
let listener = match tokio::net::TcpListener::bind(addr).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
warn!("Webhook: failed to bind port {port}: {e}");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let server = axum::serve(listener, app);
|
||||
|
||||
tokio::select! {
|
||||
result = server => {
|
||||
if let Err(e) = result {
|
||||
warn!("Webhook server error: {e}");
|
||||
}
|
||||
}
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Webhook adapter shutting down");
|
||||
}
|
||||
}
|
||||
|
||||
info!("Webhook HTTP server stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let callback_url = self
|
||||
.callback_url
|
||||
.as_ref()
|
||||
.ok_or("Webhook: no callback_url configured for outbound messages")?;
|
||||
|
||||
let text = match content {
|
||||
ChannelContent::Text(t) => t,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
let chunks = split_message(&text, MAX_MESSAGE_LEN);
|
||||
let num_chunks = chunks.len();
|
||||
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"sender_id": "openfang",
|
||||
"sender_name": "OpenFang",
|
||||
"recipient_id": user.platform_id,
|
||||
"recipient_name": user.display_name,
|
||||
"message": chunk,
|
||||
"timestamp": Utc::now().to_rfc3339(),
|
||||
});
|
||||
|
||||
let body_bytes = serde_json::to_vec(&body)?;
|
||||
let signature = Self::compute_signature(&self.secret, &body_bytes);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(callback_url)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("X-Webhook-Signature", &signature)
|
||||
.body(body_bytes)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let err_body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Webhook callback error {status}: {err_body}").into());
|
||||
}
|
||||
|
||||
// Small delay between chunks for large messages
|
||||
if num_chunks > 1 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_typing(&self, _user: &ChannelUser) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Generic webhooks have no typing indicator concept.
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_webhook_adapter_creation() {
|
||||
let adapter = WebhookAdapter::new(
|
||||
"my-secret".to_string(),
|
||||
9000,
|
||||
Some("https://example.com/callback".to_string()),
|
||||
);
|
||||
assert_eq!(adapter.name(), "webhook");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("webhook".to_string())
|
||||
);
|
||||
assert!(adapter.has_callback());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webhook_no_callback() {
|
||||
let adapter = WebhookAdapter::new("secret".to_string(), 9000, None);
|
||||
assert!(!adapter.has_callback());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webhook_signature_computation() {
|
||||
let sig = WebhookAdapter::compute_signature("secret", b"hello world");
|
||||
assert!(sig.starts_with("sha256="));
|
||||
// Verify deterministic
|
||||
let sig2 = WebhookAdapter::compute_signature("secret", b"hello world");
|
||||
assert_eq!(sig, sig2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webhook_signature_verification() {
|
||||
let secret = "test-secret";
|
||||
let body = b"test body content";
|
||||
let sig = WebhookAdapter::compute_signature(secret, body);
|
||||
assert!(WebhookAdapter::verify_signature(secret, body, &sig));
|
||||
assert!(!WebhookAdapter::verify_signature(
|
||||
secret,
|
||||
body,
|
||||
"sha256=bad"
|
||||
));
|
||||
assert!(!WebhookAdapter::verify_signature("wrong", body, &sig));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webhook_signature_different_data() {
|
||||
let secret = "same-secret";
|
||||
let sig1 = WebhookAdapter::compute_signature(secret, b"data1");
|
||||
let sig2 = WebhookAdapter::compute_signature(secret, b"data2");
|
||||
assert_ne!(sig1, sig2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webhook_parse_body_full() {
|
||||
let body = serde_json::json!({
|
||||
"sender_id": "user-123",
|
||||
"sender_name": "Alice",
|
||||
"message": "Hello webhook!",
|
||||
"thread_id": "thread-1",
|
||||
"is_group": true,
|
||||
"metadata": {
|
||||
"custom": "value"
|
||||
}
|
||||
});
|
||||
let result = WebhookAdapter::parse_webhook_body(&body);
|
||||
assert!(result.is_some());
|
||||
let (message, sender_id, sender_name, thread_id, is_group, metadata) = result.unwrap();
|
||||
assert_eq!(message, "Hello webhook!");
|
||||
assert_eq!(sender_id, "user-123");
|
||||
assert_eq!(sender_name, "Alice");
|
||||
assert_eq!(thread_id, Some("thread-1".to_string()));
|
||||
assert!(is_group);
|
||||
assert_eq!(
|
||||
metadata.get("custom"),
|
||||
Some(&serde_json::Value::String("value".to_string()))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webhook_parse_body_minimal() {
|
||||
let body = serde_json::json!({
|
||||
"message": "Just a message"
|
||||
});
|
||||
let result = WebhookAdapter::parse_webhook_body(&body);
|
||||
assert!(result.is_some());
|
||||
let (message, sender_id, sender_name, thread_id, is_group, _metadata) = result.unwrap();
|
||||
assert_eq!(message, "Just a message");
|
||||
assert_eq!(sender_id, "webhook-user");
|
||||
assert_eq!(sender_name, "Webhook User");
|
||||
assert!(thread_id.is_none());
|
||||
assert!(!is_group);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webhook_parse_body_empty_message() {
|
||||
let body = serde_json::json!({ "message": "" });
|
||||
assert!(WebhookAdapter::parse_webhook_body(&body).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_webhook_parse_body_no_message() {
|
||||
let body = serde_json::json!({ "sender_id": "user" });
|
||||
assert!(WebhookAdapter::parse_webhook_body(&body).is_none());
|
||||
}
|
||||
}
|
||||
364
crates/openfang-channels/src/whatsapp.rs
Normal file
364
crates/openfang-channels/src/whatsapp.rs
Normal file
@@ -0,0 +1,364 @@
|
||||
//! WhatsApp Cloud API channel adapter.
|
||||
//!
|
||||
//! Uses the official WhatsApp Business Cloud API to send and receive messages.
|
||||
//! Requires a webhook endpoint for incoming messages and the Cloud API for outgoing.
|
||||
|
||||
use crate::types::{ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser};
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, watch};
|
||||
use tracing::{error, info};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const MAX_MESSAGE_LEN: usize = 4096;
|
||||
|
||||
/// WhatsApp Cloud API adapter.
|
||||
///
|
||||
/// Supports two modes:
|
||||
/// - **Cloud API mode**: Uses the official WhatsApp Business Cloud API (requires Meta dev account).
|
||||
/// - **Web/QR mode**: Routes outgoing messages through a local Baileys-based gateway process.
|
||||
///
|
||||
/// Mode is selected automatically: if `gateway_url` is set (from `WHATSAPP_WEB_GATEWAY_URL`),
|
||||
/// the adapter uses Web mode. Otherwise it falls back to Cloud API mode.
|
||||
pub struct WhatsAppAdapter {
|
||||
/// WhatsApp Business phone number ID (Cloud API mode).
|
||||
phone_number_id: String,
|
||||
/// SECURITY: Access token is zeroized on drop.
|
||||
access_token: Zeroizing<String>,
|
||||
/// SECURITY: Verify token is zeroized on drop.
|
||||
verify_token: Zeroizing<String>,
|
||||
/// Port to listen for webhook callbacks (Cloud API mode).
|
||||
webhook_port: u16,
|
||||
/// HTTP client.
|
||||
client: reqwest::Client,
|
||||
/// Allowed phone numbers (empty = allow all).
|
||||
allowed_users: Vec<String>,
|
||||
/// Optional WhatsApp Web gateway URL for QR/Web mode (e.g. "http://127.0.0.1:3009").
|
||||
gateway_url: Option<String>,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl WhatsAppAdapter {
|
||||
/// Create a new WhatsApp Cloud API adapter.
|
||||
pub fn new(
|
||||
phone_number_id: String,
|
||||
access_token: String,
|
||||
verify_token: String,
|
||||
webhook_port: u16,
|
||||
allowed_users: Vec<String>,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
phone_number_id,
|
||||
access_token: Zeroizing::new(access_token),
|
||||
verify_token: Zeroizing::new(verify_token),
|
||||
webhook_port,
|
||||
client: reqwest::Client::new(),
|
||||
allowed_users,
|
||||
gateway_url: None,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new WhatsApp adapter with gateway URL for Web/QR mode.
|
||||
///
|
||||
/// When `gateway_url` is `Some`, outgoing messages are sent via `POST {gateway_url}/message/send`
|
||||
/// instead of the Cloud API. Incoming messages are handled by the gateway itself.
|
||||
pub fn with_gateway(mut self, gateway_url: Option<String>) -> Self {
|
||||
self.gateway_url = gateway_url.filter(|u| !u.is_empty());
|
||||
self
|
||||
}
|
||||
|
||||
/// Send a text message via the WhatsApp Cloud API.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
to: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"https://graph.facebook.com/v21.0/{}/messages",
|
||||
self.phone_number_id
|
||||
);
|
||||
|
||||
// Split long messages
|
||||
let chunks = crate::types::split_message(text, MAX_MESSAGE_LEN);
|
||||
for chunk in chunks {
|
||||
let body = serde_json::json!({
|
||||
"messaging_product": "whatsapp",
|
||||
"to": to,
|
||||
"type": "text",
|
||||
"text": { "body": chunk }
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&*self.access_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
error!("WhatsApp API error {status}: {body}");
|
||||
return Err(format!("WhatsApp API error {status}: {body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Mark a message as read.
|
||||
#[allow(dead_code)]
|
||||
async fn api_mark_read(&self, message_id: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!(
|
||||
"https://graph.facebook.com/v21.0/{}/messages",
|
||||
self.phone_number_id
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"messaging_product": "whatsapp",
|
||||
"status": "read",
|
||||
"message_id": message_id
|
||||
});
|
||||
|
||||
let _ = self
|
||||
.client
|
||||
.post(&url)
|
||||
.bearer_auth(&*self.access_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a text message via the WhatsApp Web gateway.
|
||||
async fn gateway_send_message(
|
||||
&self,
|
||||
gateway_url: &str,
|
||||
to: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/message/send", gateway_url.trim_end_matches('/'));
|
||||
let body = serde_json::json!({ "to": to, "text": text });
|
||||
|
||||
let resp = self.client.post(&url).json(&body).send().await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
error!("WhatsApp gateway error {status}: {body}");
|
||||
return Err(format!("WhatsApp gateway error {status}: {body}").into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a phone number is allowed.
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed(&self, phone: &str) -> bool {
|
||||
self.allowed_users.is_empty() || self.allowed_users.iter().any(|u| u == phone)
|
||||
}
|
||||
|
||||
/// Returns true if this adapter is configured for Web/QR gateway mode.
|
||||
#[allow(dead_code)]
|
||||
pub fn is_gateway_mode(&self) -> bool {
|
||||
self.gateway_url.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for WhatsAppAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"whatsapp"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::WhatsApp
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let (_tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let port = self.webhook_port;
|
||||
let _verify_token = self.verify_token.clone();
|
||||
let _allowed_users = self.allowed_users.clone();
|
||||
let _access_token = self.access_token.clone();
|
||||
let _phone_number_id = self.phone_number_id.clone();
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
info!("Starting WhatsApp webhook listener on port {port}");
|
||||
|
||||
tokio::spawn(async move {
|
||||
// Simple webhook polling simulation
|
||||
// In production, this would be an axum HTTP server handling webhook POSTs
|
||||
// For now, log that the webhook is ready
|
||||
info!("WhatsApp webhook ready on port {port} (verify_token configured)");
|
||||
info!("Configure your webhook URL: https://your-domain:{port}/webhook");
|
||||
|
||||
// Wait for shutdown
|
||||
let _ = shutdown_rx.changed().await;
|
||||
info!("WhatsApp adapter stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
// Web/QR gateway mode: route all messages through the gateway
|
||||
if let Some(ref gw) = self.gateway_url {
|
||||
let text = match &content {
|
||||
ChannelContent::Text(t) => t.clone(),
|
||||
ChannelContent::Image { caption, .. } => {
|
||||
caption
|
||||
.clone()
|
||||
.unwrap_or_else(|| "(Image — not supported in Web mode)".to_string())
|
||||
}
|
||||
ChannelContent::File { filename, .. } => {
|
||||
format!("(File: {filename} — not supported in Web mode)")
|
||||
}
|
||||
_ => "(Unsupported content type in Web mode)".to_string(),
|
||||
};
|
||||
// Split long messages the same way as Cloud API mode
|
||||
let chunks = crate::types::split_message(&text, MAX_MESSAGE_LEN);
|
||||
for chunk in chunks {
|
||||
self.gateway_send_message(gw, &user.platform_id, chunk)
|
||||
.await?;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Cloud API mode (default)
|
||||
match content {
|
||||
ChannelContent::Text(text) => {
|
||||
self.api_send_message(&user.platform_id, &text).await?;
|
||||
}
|
||||
ChannelContent::Image { url, caption } => {
|
||||
let body = serde_json::json!({
|
||||
"messaging_product": "whatsapp",
|
||||
"to": user.platform_id,
|
||||
"type": "image",
|
||||
"image": {
|
||||
"link": url,
|
||||
"caption": caption.unwrap_or_default()
|
||||
}
|
||||
});
|
||||
let api_url = format!(
|
||||
"https://graph.facebook.com/v21.0/{}/messages",
|
||||
self.phone_number_id
|
||||
);
|
||||
self.client
|
||||
.post(&api_url)
|
||||
.bearer_auth(&*self.access_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
}
|
||||
ChannelContent::File { url, filename } => {
|
||||
let body = serde_json::json!({
|
||||
"messaging_product": "whatsapp",
|
||||
"to": user.platform_id,
|
||||
"type": "document",
|
||||
"document": {
|
||||
"link": url,
|
||||
"filename": filename
|
||||
}
|
||||
});
|
||||
let api_url = format!(
|
||||
"https://graph.facebook.com/v21.0/{}/messages",
|
||||
self.phone_number_id
|
||||
);
|
||||
self.client
|
||||
.post(&api_url)
|
||||
.bearer_auth(&*self.access_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
}
|
||||
ChannelContent::Location { lat, lon } => {
|
||||
let body = serde_json::json!({
|
||||
"messaging_product": "whatsapp",
|
||||
"to": user.platform_id,
|
||||
"type": "location",
|
||||
"location": {
|
||||
"latitude": lat,
|
||||
"longitude": lon
|
||||
}
|
||||
});
|
||||
let api_url = format!(
|
||||
"https://graph.facebook.com/v21.0/{}/messages",
|
||||
self.phone_number_id
|
||||
);
|
||||
self.client
|
||||
.post(&api_url)
|
||||
.bearer_auth(&*self.access_token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
}
|
||||
_ => {
|
||||
self.api_send_message(&user.platform_id, "(Unsupported content type)")
|
||||
.await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_whatsapp_adapter_creation() {
|
||||
let adapter = WhatsAppAdapter::new(
|
||||
"12345".to_string(),
|
||||
"access_token".to_string(),
|
||||
"verify_token".to_string(),
|
||||
8443,
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.name(), "whatsapp");
|
||||
assert_eq!(adapter.channel_type(), ChannelType::WhatsApp);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_allowed_users_check() {
|
||||
let adapter = WhatsAppAdapter::new(
|
||||
"12345".to_string(),
|
||||
"token".to_string(),
|
||||
"verify".to_string(),
|
||||
8443,
|
||||
vec!["+1234567890".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed("+1234567890"));
|
||||
assert!(!adapter.is_allowed("+9999999999"));
|
||||
|
||||
let open = WhatsAppAdapter::new(
|
||||
"12345".to_string(),
|
||||
"token".to_string(),
|
||||
"verify".to_string(),
|
||||
8443,
|
||||
vec![],
|
||||
);
|
||||
assert!(open.is_allowed("+anything"));
|
||||
}
|
||||
}
|
||||
266
crates/openfang-channels/src/xmpp.rs
Normal file
266
crates/openfang-channels/src/xmpp.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
//! XMPP channel adapter (stub).
|
||||
//!
|
||||
//! This is a stub adapter for XMPP/Jabber messaging. A full XMPP implementation
|
||||
//! requires the `tokio-xmpp` crate (or equivalent) for proper SASL authentication,
|
||||
//! TLS negotiation, XML stream parsing, and MUC (Multi-User Chat) support.
|
||||
//!
|
||||
//! The adapter struct is fully defined so it can be constructed and configured, but
|
||||
//! `start()` returns an error explaining that the `tokio-xmpp` dependency is needed.
|
||||
//! This allows the adapter to be wired into the channel system without adding
|
||||
//! heavyweight dependencies to the workspace.
|
||||
|
||||
use crate::types::{ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser};
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::watch;
|
||||
use tracing::warn;
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
/// XMPP/Jabber channel adapter (stub implementation).
|
||||
///
|
||||
/// Holds all configuration needed for a full XMPP client but defers actual
|
||||
/// connection to when the `tokio-xmpp` dependency is added.
|
||||
pub struct XmppAdapter {
|
||||
/// JID (Jabber ID) of the bot (e.g., "bot@example.com").
|
||||
jid: String,
|
||||
/// SECURITY: Password is zeroized on drop.
|
||||
#[allow(dead_code)]
|
||||
password: Zeroizing<String>,
|
||||
/// XMPP server hostname.
|
||||
server: String,
|
||||
/// XMPP server port (default 5222 for STARTTLS, 5223 for direct TLS).
|
||||
port: u16,
|
||||
/// MUC rooms to join (e.g., "room@conference.example.com").
|
||||
rooms: Vec<String>,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
#[allow(dead_code)]
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
}
|
||||
|
||||
impl XmppAdapter {
|
||||
/// Create a new XMPP adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `jid` - Full JID of the bot (user@domain).
|
||||
/// * `password` - XMPP account password.
|
||||
/// * `server` - Server hostname (may differ from JID domain).
|
||||
/// * `port` - Server port (typically 5222).
|
||||
/// * `rooms` - MUC room JIDs to auto-join.
|
||||
pub fn new(
|
||||
jid: String,
|
||||
password: String,
|
||||
server: String,
|
||||
port: u16,
|
||||
rooms: Vec<String>,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
Self {
|
||||
jid,
|
||||
password: Zeroizing::new(password),
|
||||
server,
|
||||
port,
|
||||
rooms,
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the bare JID (without resource).
|
||||
#[allow(dead_code)]
|
||||
pub fn bare_jid(&self) -> &str {
|
||||
self.jid.split('/').next().unwrap_or(&self.jid)
|
||||
}
|
||||
|
||||
/// Get the configured server endpoint.
|
||||
#[allow(dead_code)]
|
||||
pub fn endpoint(&self) -> String {
|
||||
format!("{}:{}", self.server, self.port)
|
||||
}
|
||||
|
||||
/// Get the list of configured rooms.
|
||||
#[allow(dead_code)]
|
||||
pub fn rooms(&self) -> &[String] {
|
||||
&self.rooms
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for XmppAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"xmpp"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("xmpp".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
warn!(
|
||||
"XMPP adapter for {}@{}:{} cannot start: \
|
||||
full XMPP support requires the tokio-xmpp dependency which is not \
|
||||
currently included in the workspace. Add tokio-xmpp to Cargo.toml \
|
||||
and implement the SASL/TLS/XML stream handling to enable this adapter.",
|
||||
self.jid, self.server, self.port
|
||||
);
|
||||
|
||||
Err(format!(
|
||||
"XMPP adapter requires tokio-xmpp dependency (not yet added to workspace). \
|
||||
Configured for JID '{}' on {}:{} with {} room(s).",
|
||||
self.jid,
|
||||
self.server,
|
||||
self.port,
|
||||
self.rooms.len()
|
||||
)
|
||||
.into())
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
_user: &ChannelUser,
|
||||
_content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
Err("XMPP adapter not started: tokio-xmpp dependency required".into())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_xmpp_adapter_creation() {
|
||||
let adapter = XmppAdapter::new(
|
||||
"bot@example.com".to_string(),
|
||||
"secret-password".to_string(),
|
||||
"xmpp.example.com".to_string(),
|
||||
5222,
|
||||
vec!["room@conference.example.com".to_string()],
|
||||
);
|
||||
assert_eq!(adapter.name(), "xmpp");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("xmpp".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xmpp_bare_jid() {
|
||||
let adapter = XmppAdapter::new(
|
||||
"bot@example.com/resource".to_string(),
|
||||
"pass".to_string(),
|
||||
"xmpp.example.com".to_string(),
|
||||
5222,
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.bare_jid(), "bot@example.com");
|
||||
|
||||
let adapter_no_resource = XmppAdapter::new(
|
||||
"bot@example.com".to_string(),
|
||||
"pass".to_string(),
|
||||
"xmpp.example.com".to_string(),
|
||||
5222,
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter_no_resource.bare_jid(), "bot@example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xmpp_endpoint() {
|
||||
let adapter = XmppAdapter::new(
|
||||
"bot@example.com".to_string(),
|
||||
"pass".to_string(),
|
||||
"xmpp.example.com".to_string(),
|
||||
5222,
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.endpoint(), "xmpp.example.com:5222");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xmpp_rooms() {
|
||||
let rooms = vec![
|
||||
"room1@conference.example.com".to_string(),
|
||||
"room2@conference.example.com".to_string(),
|
||||
];
|
||||
let adapter = XmppAdapter::new(
|
||||
"bot@example.com".to_string(),
|
||||
"pass".to_string(),
|
||||
"xmpp.example.com".to_string(),
|
||||
5222,
|
||||
rooms.clone(),
|
||||
);
|
||||
assert_eq!(adapter.rooms(), &rooms);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_xmpp_start_returns_error() {
|
||||
let adapter = XmppAdapter::new(
|
||||
"bot@example.com".to_string(),
|
||||
"pass".to_string(),
|
||||
"xmpp.example.com".to_string(),
|
||||
5222,
|
||||
vec!["room@conference.example.com".to_string()],
|
||||
);
|
||||
let result = adapter.start().await;
|
||||
assert!(result.is_err());
|
||||
let err = result.err().unwrap().to_string();
|
||||
assert!(err.contains("tokio-xmpp"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_xmpp_send_returns_error() {
|
||||
let adapter = XmppAdapter::new(
|
||||
"bot@example.com".to_string(),
|
||||
"pass".to_string(),
|
||||
"xmpp.example.com".to_string(),
|
||||
5222,
|
||||
vec![],
|
||||
);
|
||||
let user = ChannelUser {
|
||||
platform_id: "user@example.com".to_string(),
|
||||
display_name: "Test User".to_string(),
|
||||
openfang_user: None,
|
||||
};
|
||||
let result = adapter
|
||||
.send(&user, ChannelContent::Text("hello".to_string()))
|
||||
.await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xmpp_password_zeroized() {
|
||||
let adapter = XmppAdapter::new(
|
||||
"bot@example.com".to_string(),
|
||||
"my-secret-pass".to_string(),
|
||||
"xmpp.example.com".to_string(),
|
||||
5222,
|
||||
vec![],
|
||||
);
|
||||
// Verify accessible before drop (zeroized on drop)
|
||||
assert_eq!(adapter.password.as_str(), "my-secret-pass");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xmpp_custom_port() {
|
||||
let adapter = XmppAdapter::new(
|
||||
"bot@example.com".to_string(),
|
||||
"pass".to_string(),
|
||||
"xmpp.example.com".to_string(),
|
||||
5223,
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.port, 5223);
|
||||
assert_eq!(adapter.endpoint(), "xmpp.example.com:5223");
|
||||
}
|
||||
}
|
||||
548
crates/openfang-channels/src/zulip.rs
Normal file
548
crates/openfang-channels/src/zulip.rs
Normal file
@@ -0,0 +1,548 @@
|
||||
//! Zulip channel adapter.
|
||||
//!
|
||||
//! Uses the Zulip REST API with HTTP Basic authentication (bot email + API key).
|
||||
//! Receives messages via Zulip's event queue system (register + long-poll) and
|
||||
//! sends messages via the `/api/v1/messages` endpoint.
|
||||
|
||||
use crate::types::{
|
||||
split_message, ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Utc;
|
||||
use futures::Stream;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{mpsc, watch, RwLock};
|
||||
use tracing::{info, warn};
|
||||
use zeroize::Zeroizing;
|
||||
|
||||
const MAX_MESSAGE_LEN: usize = 10000;
|
||||
const POLL_TIMEOUT_SECS: u64 = 60;
|
||||
|
||||
/// Zulip channel adapter using REST API with event queue long-polling.
|
||||
pub struct ZulipAdapter {
|
||||
/// Zulip server URL (e.g., `"https://myorg.zulipchat.com"`).
|
||||
server_url: String,
|
||||
/// Bot email address for HTTP Basic auth.
|
||||
bot_email: String,
|
||||
/// SECURITY: API key is zeroized on drop.
|
||||
api_key: Zeroizing<String>,
|
||||
/// Stream names to listen on (empty = all).
|
||||
streams: Vec<String>,
|
||||
/// HTTP client.
|
||||
client: reqwest::Client,
|
||||
/// Shutdown signal.
|
||||
shutdown_tx: Arc<watch::Sender<bool>>,
|
||||
shutdown_rx: watch::Receiver<bool>,
|
||||
/// Current event queue ID for resuming polls.
|
||||
queue_id: Arc<RwLock<Option<String>>>,
|
||||
}
|
||||
|
||||
impl ZulipAdapter {
|
||||
/// Create a new Zulip adapter.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `server_url` - Base URL of the Zulip server.
|
||||
/// * `bot_email` - Email address of the Zulip bot.
|
||||
/// * `api_key` - API key for the bot.
|
||||
/// * `streams` - Stream names to subscribe to (empty = all public streams).
|
||||
pub fn new(
|
||||
server_url: String,
|
||||
bot_email: String,
|
||||
api_key: String,
|
||||
streams: Vec<String>,
|
||||
) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = watch::channel(false);
|
||||
let server_url = server_url.trim_end_matches('/').to_string();
|
||||
Self {
|
||||
server_url,
|
||||
bot_email,
|
||||
api_key: Zeroizing::new(api_key),
|
||||
streams,
|
||||
client: reqwest::Client::new(),
|
||||
shutdown_tx: Arc::new(shutdown_tx),
|
||||
shutdown_rx,
|
||||
queue_id: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register an event queue with the Zulip server.
|
||||
async fn register_queue(&self) -> Result<(String, i64), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/api/v1/register", self.server_url);
|
||||
|
||||
let mut params = vec![("event_types", r#"["message"]"#.to_string())];
|
||||
|
||||
// If specific streams are configured, narrow to those
|
||||
if !self.streams.is_empty() {
|
||||
let narrow: Vec<serde_json::Value> = self
|
||||
.streams
|
||||
.iter()
|
||||
.map(|s| serde_json::json!(["stream", s]))
|
||||
.collect();
|
||||
params.push(("narrow", serde_json::to_string(&narrow)?));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.basic_auth(&self.bot_email, Some(self.api_key.as_str()))
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Zulip register failed {status}: {body}").into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
|
||||
let queue_id = body["queue_id"]
|
||||
.as_str()
|
||||
.ok_or("Missing queue_id in register response")?
|
||||
.to_string();
|
||||
let last_event_id = body["last_event_id"]
|
||||
.as_i64()
|
||||
.ok_or("Missing last_event_id in register response")?;
|
||||
|
||||
Ok((queue_id, last_event_id))
|
||||
}
|
||||
|
||||
/// Validate credentials by fetching the bot's own profile.
|
||||
async fn validate(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/api/v1/users/me", self.server_url);
|
||||
let resp = self
|
||||
.client
|
||||
.get(&url)
|
||||
.basic_auth(&self.bot_email, Some(self.api_key.as_str()))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err("Zulip authentication failed".into());
|
||||
}
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
let full_name = body["full_name"].as_str().unwrap_or("unknown").to_string();
|
||||
Ok(full_name)
|
||||
}
|
||||
|
||||
/// Send a message to a Zulip stream or direct message.
|
||||
async fn api_send_message(
|
||||
&self,
|
||||
msg_type: &str,
|
||||
to: &str,
|
||||
topic: &str,
|
||||
text: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let url = format!("{}/api/v1/messages", self.server_url);
|
||||
let chunks = split_message(text, MAX_MESSAGE_LEN);
|
||||
|
||||
for chunk in chunks {
|
||||
let mut params = vec![
|
||||
("type", msg_type.to_string()),
|
||||
("to", to.to_string()),
|
||||
("content", chunk.to_string()),
|
||||
];
|
||||
|
||||
if msg_type == "stream" {
|
||||
params.push(("topic", topic.to_string()));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.basic_auth(&self.bot_email, Some(self.api_key.as_str()))
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("Zulip send error {status}: {body}").into());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a stream name is in the allowed list.
|
||||
#[allow(dead_code)]
|
||||
fn is_allowed_stream(&self, stream: &str) -> bool {
|
||||
self.streams.is_empty() || self.streams.iter().any(|s| s == stream)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for ZulipAdapter {
|
||||
fn name(&self) -> &str {
|
||||
"zulip"
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
ChannelType::Custom("zulip".to_string())
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
// Validate credentials
|
||||
let bot_name = self.validate().await?;
|
||||
info!("Zulip adapter authenticated as {bot_name}");
|
||||
|
||||
// Register event queue
|
||||
let (initial_queue_id, initial_last_id) = self.register_queue().await?;
|
||||
info!("Zulip event queue registered: {initial_queue_id}");
|
||||
*self.queue_id.write().await = Some(initial_queue_id.clone());
|
||||
|
||||
let (tx, rx) = mpsc::channel::<ChannelMessage>(256);
|
||||
let server_url = self.server_url.clone();
|
||||
let bot_email = self.bot_email.clone();
|
||||
let api_key = self.api_key.clone();
|
||||
let streams = self.streams.clone();
|
||||
let client = self.client.clone();
|
||||
let queue_id_lock = Arc::clone(&self.queue_id);
|
||||
let mut shutdown_rx = self.shutdown_rx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut current_queue_id = initial_queue_id;
|
||||
let mut last_event_id = initial_last_id;
|
||||
let mut backoff = Duration::from_secs(1);
|
||||
|
||||
loop {
|
||||
let url = format!(
|
||||
"{}/api/v1/events?queue_id={}&last_event_id={}&dont_block=false",
|
||||
server_url, current_queue_id, last_event_id
|
||||
);
|
||||
|
||||
let resp = tokio::select! {
|
||||
_ = shutdown_rx.changed() => {
|
||||
info!("Zulip adapter shutting down");
|
||||
break;
|
||||
}
|
||||
result = client
|
||||
.get(&url)
|
||||
.basic_auth(&bot_email, Some(api_key.as_str()))
|
||||
.timeout(Duration::from_secs(POLL_TIMEOUT_SECS + 10))
|
||||
.send() => {
|
||||
match result {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!("Zulip poll error: {e}");
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
warn!("Zulip poll returned {status}");
|
||||
|
||||
// If the queue is expired (BAD_EVENT_QUEUE_ID), re-register
|
||||
if status == reqwest::StatusCode::BAD_REQUEST {
|
||||
let body: serde_json::Value = resp.json().await.unwrap_or_default();
|
||||
if body["code"].as_str() == Some("BAD_EVENT_QUEUE_ID") {
|
||||
info!("Zulip: event queue expired, re-registering");
|
||||
let register_url = format!("{}/api/v1/register", server_url);
|
||||
|
||||
let mut params = vec![("event_types", r#"["message"]"#.to_string())];
|
||||
if !streams.is_empty() {
|
||||
let narrow: Vec<serde_json::Value> = streams
|
||||
.iter()
|
||||
.map(|s| serde_json::json!(["stream", s]))
|
||||
.collect();
|
||||
if let Ok(narrow_str) = serde_json::to_string(&narrow) {
|
||||
params.push(("narrow", narrow_str));
|
||||
}
|
||||
}
|
||||
|
||||
match client
|
||||
.post(®ister_url)
|
||||
.basic_auth(&bot_email, Some(api_key.as_str()))
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(reg_resp) => {
|
||||
let reg_body: serde_json::Value =
|
||||
reg_resp.json().await.unwrap_or_default();
|
||||
if let (Some(qid), Some(lid)) = (
|
||||
reg_body["queue_id"].as_str(),
|
||||
reg_body["last_event_id"].as_i64(),
|
||||
) {
|
||||
current_queue_id = qid.to_string();
|
||||
last_event_id = lid;
|
||||
*queue_id_lock.write().await =
|
||||
Some(current_queue_id.clone());
|
||||
info!("Zulip: re-registered queue {current_queue_id}");
|
||||
backoff = Duration::from_secs(1);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Zulip: re-register failed: {e}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(backoff).await;
|
||||
backoff = (backoff * 2).min(Duration::from_secs(60));
|
||||
continue;
|
||||
}
|
||||
|
||||
backoff = Duration::from_secs(1);
|
||||
|
||||
let body: serde_json::Value = match resp.json().await {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("Zulip: failed to parse events: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let events = match body["events"].as_array() {
|
||||
Some(arr) => arr,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
for event in events {
|
||||
// Update last_event_id
|
||||
if let Some(eid) = event["id"].as_i64() {
|
||||
if eid > last_event_id {
|
||||
last_event_id = eid;
|
||||
}
|
||||
}
|
||||
|
||||
let event_type = event["type"].as_str().unwrap_or("");
|
||||
if event_type != "message" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let message = &event["message"];
|
||||
let msg_type = message["type"].as_str().unwrap_or("");
|
||||
|
||||
// Filter by stream if configured
|
||||
let stream_name = message["display_recipient"].as_str().unwrap_or("");
|
||||
if msg_type == "stream"
|
||||
&& !streams.is_empty()
|
||||
&& !streams.iter().any(|s| s == stream_name)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip messages from the bot itself
|
||||
let sender_email = message["sender_email"].as_str().unwrap_or("");
|
||||
if sender_email == bot_email {
|
||||
continue;
|
||||
}
|
||||
|
||||
let content = message["content"].as_str().unwrap_or("");
|
||||
if content.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let sender_name = message["sender_full_name"].as_str().unwrap_or("unknown");
|
||||
let sender_id = message["sender_id"]
|
||||
.as_i64()
|
||||
.map(|id| id.to_string())
|
||||
.unwrap_or_default();
|
||||
let msg_id = message["id"]
|
||||
.as_i64()
|
||||
.map(|id| id.to_string())
|
||||
.unwrap_or_default();
|
||||
let topic = message["subject"].as_str().unwrap_or("").to_string();
|
||||
let is_group = msg_type == "stream";
|
||||
|
||||
// Determine platform_id: stream name for stream messages,
|
||||
// sender email for DMs
|
||||
let platform_id = if is_group {
|
||||
stream_name.to_string()
|
||||
} else {
|
||||
sender_email.to_string()
|
||||
};
|
||||
|
||||
let msg_content = if content.starts_with('/') {
|
||||
let parts: Vec<&str> = content.splitn(2, ' ').collect();
|
||||
let cmd = parts[0].trim_start_matches('/');
|
||||
let args: Vec<String> = parts
|
||||
.get(1)
|
||||
.map(|a| a.split_whitespace().map(String::from).collect())
|
||||
.unwrap_or_default();
|
||||
ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args,
|
||||
}
|
||||
} else {
|
||||
ChannelContent::Text(content.to_string())
|
||||
};
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
channel: ChannelType::Custom("zulip".to_string()),
|
||||
platform_message_id: msg_id,
|
||||
sender: ChannelUser {
|
||||
platform_id,
|
||||
display_name: sender_name.to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: msg_content,
|
||||
target_agent: None,
|
||||
timestamp: Utc::now(),
|
||||
is_group,
|
||||
thread_id: if !topic.is_empty() { Some(topic) } else { None },
|
||||
metadata: {
|
||||
let mut m = HashMap::new();
|
||||
m.insert(
|
||||
"sender_id".to_string(),
|
||||
serde_json::Value::String(sender_id),
|
||||
);
|
||||
m.insert(
|
||||
"sender_email".to_string(),
|
||||
serde_json::Value::String(sender_email.to_string()),
|
||||
);
|
||||
m
|
||||
},
|
||||
};
|
||||
|
||||
if tx.send(channel_msg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Zulip event loop stopped");
|
||||
});
|
||||
|
||||
Ok(Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx)))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(text) => text,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
// Determine message type based on platform_id format
|
||||
// If it looks like an email, send as direct; otherwise as stream message
|
||||
if user.platform_id.contains('@') {
|
||||
self.api_send_message("direct", &user.platform_id, "", &text)
|
||||
.await?;
|
||||
} else {
|
||||
// Use the thread_id (topic) if available, otherwise default topic
|
||||
let topic = "OpenFang";
|
||||
self.api_send_message("stream", &user.platform_id, topic, &text)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_in_thread(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
thread_id: &str,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let text = match content {
|
||||
ChannelContent::Text(text) => text,
|
||||
_ => "(Unsupported content type)".to_string(),
|
||||
};
|
||||
|
||||
// thread_id maps to Zulip "topic"
|
||||
self.api_send_message("stream", &user.platform_id, thread_id, &text)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_zulip_adapter_creation() {
|
||||
let adapter = ZulipAdapter::new(
|
||||
"https://myorg.zulipchat.com".to_string(),
|
||||
"bot@myorg.zulipchat.com".to_string(),
|
||||
"test-api-key".to_string(),
|
||||
vec!["general".to_string()],
|
||||
);
|
||||
assert_eq!(adapter.name(), "zulip");
|
||||
assert_eq!(
|
||||
adapter.channel_type(),
|
||||
ChannelType::Custom("zulip".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zulip_server_url_normalization() {
|
||||
let adapter = ZulipAdapter::new(
|
||||
"https://myorg.zulipchat.com/".to_string(),
|
||||
"bot@example.com".to_string(),
|
||||
"key".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.server_url, "https://myorg.zulipchat.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zulip_allowed_streams() {
|
||||
let adapter = ZulipAdapter::new(
|
||||
"https://zulip.example.com".to_string(),
|
||||
"bot@example.com".to_string(),
|
||||
"key".to_string(),
|
||||
vec!["general".to_string(), "dev".to_string()],
|
||||
);
|
||||
assert!(adapter.is_allowed_stream("general"));
|
||||
assert!(adapter.is_allowed_stream("dev"));
|
||||
assert!(!adapter.is_allowed_stream("random"));
|
||||
|
||||
let open = ZulipAdapter::new(
|
||||
"https://zulip.example.com".to_string(),
|
||||
"bot@example.com".to_string(),
|
||||
"key".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert!(open.is_allowed_stream("any-stream"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zulip_bot_email_stored() {
|
||||
let adapter = ZulipAdapter::new(
|
||||
"https://zulip.example.com".to_string(),
|
||||
"mybot@zulip.example.com".to_string(),
|
||||
"secret-key".to_string(),
|
||||
vec![],
|
||||
);
|
||||
assert_eq!(adapter.bot_email, "mybot@zulip.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_zulip_api_key_zeroized() {
|
||||
let adapter = ZulipAdapter::new(
|
||||
"https://zulip.example.com".to_string(),
|
||||
"bot@example.com".to_string(),
|
||||
"my-secret-api-key".to_string(),
|
||||
vec![],
|
||||
);
|
||||
// Verify the key is accessible (it will be zeroized on drop)
|
||||
assert_eq!(adapter.api_key.as_str(), "my-secret-api-key");
|
||||
}
|
||||
}
|
||||
545
crates/openfang-channels/tests/bridge_integration_test.rs
Normal file
545
crates/openfang-channels/tests/bridge_integration_test.rs
Normal file
@@ -0,0 +1,545 @@
|
||||
//! Integration tests for the BridgeManager dispatch pipeline.
|
||||
//!
|
||||
//! These tests create a mock channel adapter (with injectable messages)
|
||||
//! and a mock kernel handle, wire them through the real BridgeManager,
|
||||
//! and verify the full dispatch pipeline works end-to-end.
|
||||
//!
|
||||
//! No external services are contacted — all communication is in-process
|
||||
//! via real tokio channels and tasks.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use openfang_channels::bridge::{BridgeManager, ChannelBridgeHandle};
|
||||
use openfang_channels::router::AgentRouter;
|
||||
use openfang_channels::types::{
|
||||
ChannelAdapter, ChannelContent, ChannelMessage, ChannelType, ChannelUser,
|
||||
};
|
||||
use openfang_types::agent::AgentId;
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::{mpsc, watch};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock Adapter — injects test messages, captures sent responses
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct MockAdapter {
|
||||
name: String,
|
||||
channel_type: ChannelType,
|
||||
/// Receiver consumed by start() — wrapped as a Stream.
|
||||
rx: Mutex<Option<mpsc::Receiver<ChannelMessage>>>,
|
||||
/// Captures all messages sent via send().
|
||||
sent: Arc<Mutex<Vec<(String, String)>>>,
|
||||
shutdown_tx: watch::Sender<bool>,
|
||||
}
|
||||
|
||||
impl MockAdapter {
|
||||
/// Create a new mock adapter. Returns (adapter, sender) — use the sender
|
||||
/// to inject test messages into the adapter's stream.
|
||||
fn new(name: &str, channel_type: ChannelType) -> (Arc<Self>, mpsc::Sender<ChannelMessage>) {
|
||||
let (tx, rx) = mpsc::channel(256);
|
||||
let (shutdown_tx, _shutdown_rx) = watch::channel(false);
|
||||
|
||||
let adapter = Arc::new(Self {
|
||||
name: name.to_string(),
|
||||
channel_type,
|
||||
rx: Mutex::new(Some(rx)),
|
||||
sent: Arc::new(Mutex::new(Vec::new())),
|
||||
shutdown_tx,
|
||||
});
|
||||
(adapter, tx)
|
||||
}
|
||||
|
||||
/// Get a copy of all sent responses as (platform_id, text) pairs.
|
||||
fn get_sent(&self) -> Vec<(String, String)> {
|
||||
self.sent.lock().unwrap().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelAdapter for MockAdapter {
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn channel_type(&self) -> ChannelType {
|
||||
self.channel_type.clone()
|
||||
}
|
||||
|
||||
async fn start(
|
||||
&self,
|
||||
) -> Result<Pin<Box<dyn Stream<Item = ChannelMessage> + Send>>, Box<dyn std::error::Error>>
|
||||
{
|
||||
let rx = self
|
||||
.rx
|
||||
.lock()
|
||||
.unwrap()
|
||||
.take()
|
||||
.expect("start() called more than once");
|
||||
let stream = tokio_stream::wrappers::ReceiverStream::new(rx);
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn send(
|
||||
&self,
|
||||
user: &ChannelUser,
|
||||
content: ChannelContent,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if let ChannelContent::Text(text) = content {
|
||||
self.sent
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push((user.platform_id.clone(), text));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn stop(&self) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _ = self.shutdown_tx.send(true);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mock Kernel Handle — echoes messages, serves agent lists
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct MockHandle {
|
||||
agents: Mutex<Vec<(AgentId, String)>>,
|
||||
/// Records all messages sent to agents: (agent_id, message).
|
||||
received: Arc<Mutex<Vec<(AgentId, String)>>>,
|
||||
}
|
||||
|
||||
impl MockHandle {
|
||||
fn new(agents: Vec<(AgentId, String)>) -> Self {
|
||||
Self {
|
||||
agents: Mutex::new(agents),
|
||||
received: Arc::new(Mutex::new(Vec::new())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ChannelBridgeHandle for MockHandle {
|
||||
async fn send_message(&self, agent_id: AgentId, message: &str) -> Result<String, String> {
|
||||
self.received
|
||||
.lock()
|
||||
.unwrap()
|
||||
.push((agent_id, message.to_string()));
|
||||
Ok(format!("Echo: {message}"))
|
||||
}
|
||||
|
||||
async fn find_agent_by_name(&self, name: &str) -> Result<Option<AgentId>, String> {
|
||||
let agents = self.agents.lock().unwrap();
|
||||
Ok(agents.iter().find(|(_, n)| n == name).map(|(id, _)| *id))
|
||||
}
|
||||
|
||||
async fn list_agents(&self) -> Result<Vec<(AgentId, String)>, String> {
|
||||
Ok(self.agents.lock().unwrap().clone())
|
||||
}
|
||||
|
||||
async fn spawn_agent_by_name(&self, _manifest_name: &str) -> Result<AgentId, String> {
|
||||
Err("mock: spawn not implemented".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper to create a ChannelMessage
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn make_text_msg(channel: ChannelType, user_id: &str, text: &str) -> ChannelMessage {
|
||||
ChannelMessage {
|
||||
channel,
|
||||
platform_message_id: "msg1".to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: user_id.to_string(),
|
||||
display_name: "TestUser".to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: ChannelContent::Text(text.to_string()),
|
||||
target_agent: None,
|
||||
timestamp: chrono::Utc::now(),
|
||||
is_group: false,
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn make_command_msg(
|
||||
channel: ChannelType,
|
||||
user_id: &str,
|
||||
cmd: &str,
|
||||
args: Vec<&str>,
|
||||
) -> ChannelMessage {
|
||||
ChannelMessage {
|
||||
channel,
|
||||
platform_message_id: "msg1".to_string(),
|
||||
sender: ChannelUser {
|
||||
platform_id: user_id.to_string(),
|
||||
display_name: "TestUser".to_string(),
|
||||
openfang_user: None,
|
||||
},
|
||||
content: ChannelContent::Command {
|
||||
name: cmd.to_string(),
|
||||
args: args.into_iter().map(String::from).collect(),
|
||||
},
|
||||
target_agent: None,
|
||||
timestamp: chrono::Utc::now(),
|
||||
is_group: false,
|
||||
thread_id: None,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Test that text messages are dispatched to the correct agent and responses
|
||||
/// are sent back through the adapter.
|
||||
#[tokio::test]
|
||||
async fn test_bridge_dispatch_text_message() {
|
||||
let agent_id = AgentId::new();
|
||||
let handle = Arc::new(MockHandle::new(vec![(agent_id, "coder".to_string())]));
|
||||
let router = Arc::new(AgentRouter::new());
|
||||
|
||||
// Pre-route the user to the agent
|
||||
router.set_user_default("user1".to_string(), agent_id);
|
||||
|
||||
let (adapter, tx) = MockAdapter::new("test-adapter", ChannelType::Telegram);
|
||||
let adapter_ref = adapter.clone();
|
||||
|
||||
let mut manager = BridgeManager::new(handle.clone(), router);
|
||||
manager.start_adapter(adapter.clone()).await.unwrap();
|
||||
|
||||
// Inject a text message
|
||||
tx.send(make_text_msg(
|
||||
ChannelType::Telegram,
|
||||
"user1",
|
||||
"Hello agent!",
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Give the async dispatch loop time to process
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
// Verify: adapter received the echo response
|
||||
let sent = adapter_ref.get_sent();
|
||||
assert_eq!(sent.len(), 1, "Expected 1 response, got {}", sent.len());
|
||||
assert_eq!(sent[0].0, "user1");
|
||||
assert_eq!(sent[0].1, "Echo: Hello agent!");
|
||||
|
||||
// Verify: handle received the message
|
||||
{
|
||||
let received = handle.received.lock().unwrap();
|
||||
assert_eq!(received.len(), 1);
|
||||
assert_eq!(received[0].0, agent_id);
|
||||
assert_eq!(received[0].1, "Hello agent!");
|
||||
}
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
|
||||
/// Test that /agents command returns the list of running agents.
|
||||
#[tokio::test]
|
||||
async fn test_bridge_dispatch_agents_command() {
|
||||
let agent_id = AgentId::new();
|
||||
let handle = Arc::new(MockHandle::new(vec![
|
||||
(agent_id, "coder".to_string()),
|
||||
(AgentId::new(), "researcher".to_string()),
|
||||
]));
|
||||
let router = Arc::new(AgentRouter::new());
|
||||
|
||||
let (adapter, tx) = MockAdapter::new("test-adapter", ChannelType::Discord);
|
||||
let adapter_ref = adapter.clone();
|
||||
|
||||
let mut manager = BridgeManager::new(handle.clone(), router);
|
||||
manager.start_adapter(adapter.clone()).await.unwrap();
|
||||
|
||||
// Send /agents command as ChannelContent::Command
|
||||
tx.send(make_command_msg(
|
||||
ChannelType::Discord,
|
||||
"user1",
|
||||
"agents",
|
||||
vec![],
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
let sent = adapter_ref.get_sent();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert!(
|
||||
sent[0].1.contains("coder"),
|
||||
"Response should list 'coder', got: {}",
|
||||
sent[0].1
|
||||
);
|
||||
assert!(
|
||||
sent[0].1.contains("researcher"),
|
||||
"Response should list 'researcher', got: {}",
|
||||
sent[0].1
|
||||
);
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
|
||||
/// Test the /help command returns help text.
|
||||
#[tokio::test]
|
||||
async fn test_bridge_dispatch_help_command() {
|
||||
let handle = Arc::new(MockHandle::new(vec![]));
|
||||
let router = Arc::new(AgentRouter::new());
|
||||
|
||||
let (adapter, tx) = MockAdapter::new("test-adapter", ChannelType::Slack);
|
||||
let adapter_ref = adapter.clone();
|
||||
|
||||
let mut manager = BridgeManager::new(handle, router);
|
||||
manager.start_adapter(adapter.clone()).await.unwrap();
|
||||
|
||||
tx.send(make_command_msg(
|
||||
ChannelType::Slack,
|
||||
"user1",
|
||||
"help",
|
||||
vec![],
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
let sent = adapter_ref.get_sent();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert!(sent[0].1.contains("/agents"), "Help should mention /agents");
|
||||
assert!(sent[0].1.contains("/agent"), "Help should mention /agent");
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
|
||||
/// Test /agent <name> command selects the agent and updates the router.
|
||||
#[tokio::test]
|
||||
async fn test_bridge_dispatch_agent_select_command() {
|
||||
let agent_id = AgentId::new();
|
||||
let handle = Arc::new(MockHandle::new(vec![(agent_id, "coder".to_string())]));
|
||||
let router = Arc::new(AgentRouter::new());
|
||||
|
||||
let (adapter, tx) = MockAdapter::new("test-adapter", ChannelType::Telegram);
|
||||
let adapter_ref = adapter.clone();
|
||||
|
||||
let mut manager = BridgeManager::new(handle, router.clone());
|
||||
manager.start_adapter(adapter.clone()).await.unwrap();
|
||||
|
||||
// User selects "coder" agent
|
||||
tx.send(make_command_msg(
|
||||
ChannelType::Telegram,
|
||||
"user42",
|
||||
"agent",
|
||||
vec!["coder"],
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
let sent = adapter_ref.get_sent();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert!(
|
||||
sent[0].1.contains("Now talking to agent: coder"),
|
||||
"Expected selection confirmation, got: {}",
|
||||
sent[0].1
|
||||
);
|
||||
|
||||
// Verify router was updated — user42 should now route to agent_id
|
||||
let resolved = router.resolve(&ChannelType::Telegram, "user42", None);
|
||||
assert_eq!(resolved, Some(agent_id));
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
|
||||
/// Test that unrouted messages (no agent assigned) get a helpful error.
|
||||
#[tokio::test]
|
||||
async fn test_bridge_dispatch_no_agent_assigned() {
|
||||
let handle = Arc::new(MockHandle::new(vec![]));
|
||||
let router = Arc::new(AgentRouter::new());
|
||||
|
||||
let (adapter, tx) = MockAdapter::new("test-adapter", ChannelType::Telegram);
|
||||
let adapter_ref = adapter.clone();
|
||||
|
||||
let mut manager = BridgeManager::new(handle, router);
|
||||
manager.start_adapter(adapter.clone()).await.unwrap();
|
||||
|
||||
// Send message with no agent routed
|
||||
tx.send(make_text_msg(ChannelType::Telegram, "user1", "hello"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
let sent = adapter_ref.get_sent();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert!(
|
||||
sent[0].1.contains("No agent assigned"),
|
||||
"Expected 'No agent assigned' message, got: {}",
|
||||
sent[0].1
|
||||
);
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
|
||||
/// Test that slash commands embedded in text (/agents, /help) are handled as commands.
|
||||
#[tokio::test]
|
||||
async fn test_bridge_dispatch_slash_command_in_text() {
|
||||
let agent_id = AgentId::new();
|
||||
let handle = Arc::new(MockHandle::new(vec![(agent_id, "writer".to_string())]));
|
||||
let router = Arc::new(AgentRouter::new());
|
||||
|
||||
let (adapter, tx) = MockAdapter::new("test-adapter", ChannelType::Telegram);
|
||||
let adapter_ref = adapter.clone();
|
||||
|
||||
let mut manager = BridgeManager::new(handle, router);
|
||||
manager.start_adapter(adapter.clone()).await.unwrap();
|
||||
|
||||
// Send "/agents" as plain text (not as a Command variant)
|
||||
tx.send(make_text_msg(ChannelType::Telegram, "user1", "/agents"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
let sent = adapter_ref.get_sent();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert!(
|
||||
sent[0].1.contains("writer"),
|
||||
"Should list the 'writer' agent, got: {}",
|
||||
sent[0].1
|
||||
);
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
|
||||
/// Test /status command returns uptime info.
|
||||
#[tokio::test]
|
||||
async fn test_bridge_dispatch_status_command() {
|
||||
let handle = Arc::new(MockHandle::new(vec![
|
||||
(AgentId::new(), "a".to_string()),
|
||||
(AgentId::new(), "b".to_string()),
|
||||
]));
|
||||
let router = Arc::new(AgentRouter::new());
|
||||
|
||||
let (adapter, tx) = MockAdapter::new("test-adapter", ChannelType::Telegram);
|
||||
let adapter_ref = adapter.clone();
|
||||
|
||||
let mut manager = BridgeManager::new(handle, router);
|
||||
manager.start_adapter(adapter.clone()).await.unwrap();
|
||||
|
||||
tx.send(make_command_msg(
|
||||
ChannelType::Telegram,
|
||||
"user1",
|
||||
"status",
|
||||
vec![],
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
let sent = adapter_ref.get_sent();
|
||||
assert_eq!(sent.len(), 1);
|
||||
assert!(
|
||||
sent[0].1.contains("2 agent(s) running"),
|
||||
"Expected uptime info, got: {}",
|
||||
sent[0].1
|
||||
);
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
|
||||
/// Test the full lifecycle: start adapter, send messages, stop adapter.
|
||||
#[tokio::test]
|
||||
async fn test_bridge_manager_lifecycle() {
|
||||
let agent_id = AgentId::new();
|
||||
let handle = Arc::new(MockHandle::new(vec![(agent_id, "bot".to_string())]));
|
||||
let router = Arc::new(AgentRouter::new());
|
||||
router.set_user_default("user1".to_string(), agent_id);
|
||||
|
||||
let (adapter, tx) = MockAdapter::new("lifecycle-adapter", ChannelType::WebChat);
|
||||
let adapter_ref = adapter.clone();
|
||||
|
||||
let mut manager = BridgeManager::new(handle, router);
|
||||
manager.start_adapter(adapter.clone()).await.unwrap();
|
||||
|
||||
// Send multiple messages
|
||||
for i in 0..5 {
|
||||
tx.send(make_text_msg(
|
||||
ChannelType::WebChat,
|
||||
"user1",
|
||||
&format!("message {i}"),
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
||||
|
||||
let sent = adapter_ref.get_sent();
|
||||
assert_eq!(sent.len(), 5, "Expected 5 responses, got {}", sent.len());
|
||||
|
||||
for (i, (_, text)) in sent.iter().enumerate() {
|
||||
assert_eq!(*text, format!("Echo: message {i}"));
|
||||
}
|
||||
|
||||
// Stop — should complete without hanging
|
||||
manager.stop().await;
|
||||
}
|
||||
|
||||
/// Test multiple adapters running simultaneously in the same BridgeManager.
|
||||
#[tokio::test]
|
||||
async fn test_bridge_multiple_adapters() {
|
||||
let agent_id = AgentId::new();
|
||||
let handle = Arc::new(MockHandle::new(vec![(agent_id, "multi".to_string())]));
|
||||
let router = Arc::new(AgentRouter::new());
|
||||
router.set_user_default("tg_user".to_string(), agent_id);
|
||||
router.set_user_default("dc_user".to_string(), agent_id);
|
||||
|
||||
let (tg_adapter, tg_tx) = MockAdapter::new("telegram", ChannelType::Telegram);
|
||||
let (dc_adapter, dc_tx) = MockAdapter::new("discord", ChannelType::Discord);
|
||||
let tg_ref = tg_adapter.clone();
|
||||
let dc_ref = dc_adapter.clone();
|
||||
|
||||
let mut manager = BridgeManager::new(handle, router);
|
||||
manager.start_adapter(tg_adapter).await.unwrap();
|
||||
manager.start_adapter(dc_adapter).await.unwrap();
|
||||
|
||||
// Send to Telegram adapter
|
||||
tg_tx
|
||||
.send(make_text_msg(
|
||||
ChannelType::Telegram,
|
||||
"tg_user",
|
||||
"from telegram",
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Send to Discord adapter
|
||||
dc_tx
|
||||
.send(make_text_msg(
|
||||
ChannelType::Discord,
|
||||
"dc_user",
|
||||
"from discord",
|
||||
))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
|
||||
|
||||
let tg_sent = tg_ref.get_sent();
|
||||
assert_eq!(tg_sent.len(), 1);
|
||||
assert_eq!(tg_sent[0].1, "Echo: from telegram");
|
||||
|
||||
let dc_sent = dc_ref.get_sent();
|
||||
assert_eq!(dc_sent.len(), 1);
|
||||
assert_eq!(dc_sent[0].1, "Echo: from discord");
|
||||
|
||||
manager.stop().await;
|
||||
}
|
||||
33
crates/openfang-cli/Cargo.toml
Normal file
33
crates/openfang-cli/Cargo.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[package]
|
||||
name = "openfang-cli"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "CLI tool for the OpenFang Agent OS"
|
||||
|
||||
[[bin]]
|
||||
name = "openfang"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
openfang-types = { path = "../openfang-types" }
|
||||
openfang-kernel = { path = "../openfang-kernel" }
|
||||
openfang-api = { path = "../openfang-api" }
|
||||
openfang-migrate = { path = "../openfang-migrate" }
|
||||
openfang-skills = { path = "../openfang-skills" }
|
||||
openfang-extensions = { path = "../openfang-extensions" }
|
||||
zeroize = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
clap_complete = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["blocking"] }
|
||||
openfang-runtime = { path = "../openfang-runtime" }
|
||||
uuid = { workspace = true }
|
||||
ratatui = { workspace = true }
|
||||
colored = { workspace = true }
|
||||
56
crates/openfang-cli/src/bundled_agents.rs
Normal file
56
crates/openfang-cli/src/bundled_agents.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
//! Compile-time embedded agent templates.
|
||||
//!
|
||||
//! All 30 bundled agent templates are embedded into the binary via `include_str!`.
|
||||
//! This ensures `openfang agent new` works immediately after install — no filesystem
|
||||
//! discovery needed.
|
||||
|
||||
/// Returns all bundled agent templates as `(name, toml_content)` pairs.
|
||||
pub fn bundled_agents() -> Vec<(&'static str, &'static str)> {
|
||||
vec![
|
||||
("analyst", include_str!("../../../agents/analyst/agent.toml")),
|
||||
("architect", include_str!("../../../agents/architect/agent.toml")),
|
||||
("assistant", include_str!("../../../agents/assistant/agent.toml")),
|
||||
("coder", include_str!("../../../agents/coder/agent.toml")),
|
||||
("code-reviewer", include_str!("../../../agents/code-reviewer/agent.toml")),
|
||||
("customer-support", include_str!("../../../agents/customer-support/agent.toml")),
|
||||
("data-scientist", include_str!("../../../agents/data-scientist/agent.toml")),
|
||||
("debugger", include_str!("../../../agents/debugger/agent.toml")),
|
||||
("devops-lead", include_str!("../../../agents/devops-lead/agent.toml")),
|
||||
("doc-writer", include_str!("../../../agents/doc-writer/agent.toml")),
|
||||
("email-assistant", include_str!("../../../agents/email-assistant/agent.toml")),
|
||||
("health-tracker", include_str!("../../../agents/health-tracker/agent.toml")),
|
||||
("hello-world", include_str!("../../../agents/hello-world/agent.toml")),
|
||||
("home-automation", include_str!("../../../agents/home-automation/agent.toml")),
|
||||
("legal-assistant", include_str!("../../../agents/legal-assistant/agent.toml")),
|
||||
("meeting-assistant", include_str!("../../../agents/meeting-assistant/agent.toml")),
|
||||
("ops", include_str!("../../../agents/ops/agent.toml")),
|
||||
("orchestrator", include_str!("../../../agents/orchestrator/agent.toml")),
|
||||
("personal-finance", include_str!("../../../agents/personal-finance/agent.toml")),
|
||||
("planner", include_str!("../../../agents/planner/agent.toml")),
|
||||
("recruiter", include_str!("../../../agents/recruiter/agent.toml")),
|
||||
("researcher", include_str!("../../../agents/researcher/agent.toml")),
|
||||
("sales-assistant", include_str!("../../../agents/sales-assistant/agent.toml")),
|
||||
("security-auditor", include_str!("../../../agents/security-auditor/agent.toml")),
|
||||
("social-media", include_str!("../../../agents/social-media/agent.toml")),
|
||||
("test-engineer", include_str!("../../../agents/test-engineer/agent.toml")),
|
||||
("translator", include_str!("../../../agents/translator/agent.toml")),
|
||||
("travel-planner", include_str!("../../../agents/travel-planner/agent.toml")),
|
||||
("tutor", include_str!("../../../agents/tutor/agent.toml")),
|
||||
("writer", include_str!("../../../agents/writer/agent.toml")),
|
||||
]
|
||||
}
|
||||
|
||||
/// Install bundled agent templates to `~/.openfang/agents/`.
|
||||
/// Skips any template that already exists on disk (user customization preserved).
|
||||
pub fn install_bundled_agents(agents_dir: &std::path::Path) {
|
||||
for (name, content) in bundled_agents() {
|
||||
let dest_dir = agents_dir.join(name);
|
||||
let dest_file = dest_dir.join("agent.toml");
|
||||
if dest_file.exists() {
|
||||
continue; // Preserve user customization
|
||||
}
|
||||
if std::fs::create_dir_all(&dest_dir).is_ok() {
|
||||
let _ = std::fs::write(&dest_file, content);
|
||||
}
|
||||
}
|
||||
}
|
||||
241
crates/openfang-cli/src/dotenv.rs
Normal file
241
crates/openfang-cli/src/dotenv.rs
Normal file
@@ -0,0 +1,241 @@
|
||||
//! Minimal `.env` file loader/saver for `~/.openfang/.env`.
|
||||
//!
|
||||
//! No external crate needed — hand-rolled for simplicity.
|
||||
//! Format: `KEY=VALUE` lines, `#` comments, optional quotes.
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Return the path to `~/.openfang/.env`.
|
||||
pub fn env_file_path() -> Option<PathBuf> {
|
||||
dirs::home_dir().map(|h| h.join(".openfang").join(".env"))
|
||||
}
|
||||
|
||||
/// Load `~/.openfang/.env` and `~/.openfang/secrets.env` into `std::env`.
|
||||
///
|
||||
/// System env vars take priority — existing vars are NOT overridden.
|
||||
/// `secrets.env` is loaded second so `.env` values take priority over secrets
|
||||
/// (but both yield to system env vars).
|
||||
/// Silently does nothing if the files don't exist.
|
||||
pub fn load_dotenv() {
|
||||
load_env_file(env_file_path());
|
||||
// Also load secrets.env (written by dashboard "Set API Key" button)
|
||||
load_env_file(secrets_env_path());
|
||||
}
|
||||
|
||||
/// Return the path to `~/.openfang/secrets.env`.
|
||||
pub fn secrets_env_path() -> Option<PathBuf> {
|
||||
dirs::home_dir().map(|h| h.join(".openfang").join("secrets.env"))
|
||||
}
|
||||
|
||||
fn load_env_file(path: Option<PathBuf>) {
|
||||
let path = match path {
|
||||
Some(p) => p,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let content = match std::fs::read_to_string(&path) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return,
|
||||
};
|
||||
|
||||
for line in content.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() || trimmed.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((key, value)) = parse_env_line(trimmed) {
|
||||
if std::env::var(&key).is_err() {
|
||||
std::env::set_var(&key, &value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Upsert a key in `~/.openfang/.env`.
|
||||
///
|
||||
/// Creates the file if missing. Sets 0600 permissions on Unix.
|
||||
/// Also sets the key in the current process environment.
|
||||
pub fn save_env_key(key: &str, value: &str) -> Result<(), String> {
|
||||
let path = env_file_path().ok_or("Could not determine home directory")?;
|
||||
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = path.parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| format!("Failed to create directory: {e}"))?;
|
||||
}
|
||||
|
||||
let mut entries = read_env_file(&path);
|
||||
entries.insert(key.to_string(), value.to_string());
|
||||
write_env_file(&path, &entries)?;
|
||||
|
||||
// Also set in current process
|
||||
std::env::set_var(key, value);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Remove a key from `~/.openfang/.env`.
|
||||
///
|
||||
/// Also removes it from the current process environment.
|
||||
pub fn remove_env_key(key: &str) -> Result<(), String> {
|
||||
let path = env_file_path().ok_or("Could not determine home directory")?;
|
||||
|
||||
let mut entries = read_env_file(&path);
|
||||
entries.remove(key);
|
||||
write_env_file(&path, &entries)?;
|
||||
|
||||
std::env::remove_var(key);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List key names (without values) from `~/.openfang/.env`.
|
||||
#[allow(dead_code)]
|
||||
pub fn list_env_keys() -> Vec<String> {
|
||||
let path = match env_file_path() {
|
||||
Some(p) => p,
|
||||
None => return Vec::new(),
|
||||
};
|
||||
|
||||
read_env_file(&path).into_keys().collect()
|
||||
}
|
||||
|
||||
/// Check if the `.env` file exists.
|
||||
#[allow(dead_code)]
|
||||
pub fn env_file_exists() -> bool {
|
||||
env_file_path().map(|p| p.exists()).unwrap_or(false)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Internal helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Parse a single `KEY=VALUE` line. Handles optional quotes.
|
||||
fn parse_env_line(line: &str) -> Option<(String, String)> {
|
||||
let eq_pos = line.find('=')?;
|
||||
let key = line[..eq_pos].trim().to_string();
|
||||
let mut value = line[eq_pos + 1..].trim().to_string();
|
||||
|
||||
if key.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Strip matching quotes
|
||||
if ((value.starts_with('"') && value.ends_with('"'))
|
||||
|| (value.starts_with('\'') && value.ends_with('\'')))
|
||||
&& value.len() >= 2
|
||||
{
|
||||
value = value[1..value.len() - 1].to_string();
|
||||
}
|
||||
|
||||
Some((key, value))
|
||||
}
|
||||
|
||||
/// Read all key-value pairs from the .env file.
|
||||
fn read_env_file(path: &PathBuf) -> BTreeMap<String, String> {
|
||||
let mut map = BTreeMap::new();
|
||||
|
||||
let content = match std::fs::read_to_string(path) {
|
||||
Ok(c) => c,
|
||||
Err(_) => return map,
|
||||
};
|
||||
|
||||
for line in content.lines() {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() || trimmed.starts_with('#') {
|
||||
continue;
|
||||
}
|
||||
if let Some((key, value)) = parse_env_line(trimmed) {
|
||||
map.insert(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
map
|
||||
}
|
||||
|
||||
/// Write key-value pairs back to the .env file with a header comment.
|
||||
fn write_env_file(path: &PathBuf, entries: &BTreeMap<String, String>) -> Result<(), String> {
|
||||
let mut content =
|
||||
String::from("# OpenFang environment — managed by `openfang config set-key`\n");
|
||||
content.push_str("# Do not edit while the daemon is running.\n\n");
|
||||
|
||||
for (key, value) in entries {
|
||||
// Quote values that contain spaces or special characters
|
||||
if value.contains(' ') || value.contains('#') || value.contains('"') {
|
||||
content.push_str(&format!("{key}=\"{}\"\n", value.replace('"', "\\\"")));
|
||||
} else {
|
||||
content.push_str(&format!("{key}={value}\n"));
|
||||
}
|
||||
}
|
||||
|
||||
std::fs::write(path, &content).map_err(|e| format!("Failed to write .env file: {e}"))?;
|
||||
|
||||
// Set 0600 permissions on Unix
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_env_line_simple() {
|
||||
let (k, v) = parse_env_line("FOO=bar").unwrap();
|
||||
assert_eq!(k, "FOO");
|
||||
assert_eq!(v, "bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_env_line_quoted() {
|
||||
let (k, v) = parse_env_line("KEY=\"hello world\"").unwrap();
|
||||
assert_eq!(k, "KEY");
|
||||
assert_eq!(v, "hello world");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_env_line_single_quoted() {
|
||||
let (k, v) = parse_env_line("KEY='value'").unwrap();
|
||||
assert_eq!(k, "KEY");
|
||||
assert_eq!(v, "value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_env_line_spaces() {
|
||||
let (k, v) = parse_env_line(" KEY = value ").unwrap();
|
||||
assert_eq!(k, "KEY");
|
||||
assert_eq!(v, "value");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_env_line_no_value() {
|
||||
let (k, v) = parse_env_line("KEY=").unwrap();
|
||||
assert_eq!(k, "KEY");
|
||||
assert_eq!(v, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_env_line_comment() {
|
||||
assert!(
|
||||
parse_env_line("# comment").is_none()
|
||||
|| parse_env_line("# comment").unwrap().0.starts_with('#')
|
||||
);
|
||||
// Comments are filtered before reaching parse_env_line in production code
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_env_line_no_equals() {
|
||||
assert!(parse_env_line("NOEQUALS").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_env_line_empty_key() {
|
||||
assert!(parse_env_line("=value").is_none());
|
||||
}
|
||||
}
|
||||
600
crates/openfang-cli/src/launcher.rs
Normal file
600
crates/openfang-cli/src/launcher.rs
Normal file
@@ -0,0 +1,600 @@
|
||||
//! Interactive launcher — lightweight Ratatui one-shot menu.
|
||||
//!
|
||||
//! Shown when `openfang` is run with no subcommand in a TTY.
|
||||
//! Full-width left-aligned layout, adapts for first-time vs returning users.
|
||||
|
||||
use ratatui::crossterm::event::{self, Event as CtEvent, KeyCode, KeyEventKind};
|
||||
use ratatui::layout::{Constraint, Layout, Rect};
|
||||
use ratatui::style::{Modifier, Style};
|
||||
use ratatui::text::{Line, Span};
|
||||
use ratatui::widgets::{List, ListItem, ListState, Paragraph};
|
||||
|
||||
use crate::tui::theme;
|
||||
use crate::ui;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
// ── Provider detection ──────────────────────────────────────────────────────
|
||||
|
||||
const PROVIDER_ENV_VARS: &[(&str, &str)] = &[
|
||||
("ANTHROPIC_API_KEY", "Anthropic"),
|
||||
("OPENAI_API_KEY", "OpenAI"),
|
||||
("DEEPSEEK_API_KEY", "DeepSeek"),
|
||||
("GEMINI_API_KEY", "Gemini"),
|
||||
("GOOGLE_API_KEY", "Gemini"),
|
||||
("GROQ_API_KEY", "Groq"),
|
||||
("OPENROUTER_API_KEY", "OpenRouter"),
|
||||
("TOGETHER_API_KEY", "Together"),
|
||||
("MISTRAL_API_KEY", "Mistral"),
|
||||
("FIREWORKS_API_KEY", "Fireworks"),
|
||||
];
|
||||
|
||||
fn detect_provider() -> Option<(&'static str, &'static str)> {
|
||||
for &(var, name) in PROVIDER_ENV_VARS {
|
||||
if std::env::var(var).is_ok() {
|
||||
return Some((name, var));
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn is_first_run() -> bool {
|
||||
let home = match dirs::home_dir() {
|
||||
Some(h) => h,
|
||||
None => return true,
|
||||
};
|
||||
!home.join(".openfang").join("config.toml").exists()
|
||||
}
|
||||
|
||||
fn has_openclaw() -> bool {
|
||||
// Quick check: does ~/.openclaw exist?
|
||||
dirs::home_dir()
|
||||
.map(|h| h.join(".openclaw").exists())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
// ── Types ───────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LauncherChoice {
|
||||
GetStarted,
|
||||
Chat,
|
||||
Dashboard,
|
||||
DesktopApp,
|
||||
TerminalUI,
|
||||
ShowHelp,
|
||||
Quit,
|
||||
}
|
||||
|
||||
struct MenuItem {
|
||||
label: &'static str,
|
||||
hint: &'static str,
|
||||
choice: LauncherChoice,
|
||||
}
|
||||
|
||||
// Menu for first-time users: "Get started" is first and prominent
|
||||
const MENU_FIRST_RUN: &[MenuItem] = &[
|
||||
MenuItem {
|
||||
label: "Get started",
|
||||
hint: "Providers, API keys, models, migration",
|
||||
choice: LauncherChoice::GetStarted,
|
||||
},
|
||||
MenuItem {
|
||||
label: "Chat with an agent",
|
||||
hint: "Quick chat in the terminal",
|
||||
choice: LauncherChoice::Chat,
|
||||
},
|
||||
MenuItem {
|
||||
label: "Open dashboard",
|
||||
hint: "Launch the web UI in your browser",
|
||||
choice: LauncherChoice::Dashboard,
|
||||
},
|
||||
MenuItem {
|
||||
label: "Open desktop app",
|
||||
hint: "Launch the native desktop app",
|
||||
choice: LauncherChoice::DesktopApp,
|
||||
},
|
||||
MenuItem {
|
||||
label: "Launch terminal UI",
|
||||
hint: "Full interactive TUI dashboard",
|
||||
choice: LauncherChoice::TerminalUI,
|
||||
},
|
||||
MenuItem {
|
||||
label: "Show all commands",
|
||||
hint: "Print full --help output",
|
||||
choice: LauncherChoice::ShowHelp,
|
||||
},
|
||||
];
|
||||
|
||||
// Menu for returning users: action-first, setup at the bottom
|
||||
const MENU_RETURNING: &[MenuItem] = &[
|
||||
MenuItem {
|
||||
label: "Chat with an agent",
|
||||
hint: "Quick chat in the terminal",
|
||||
choice: LauncherChoice::Chat,
|
||||
},
|
||||
MenuItem {
|
||||
label: "Open dashboard",
|
||||
hint: "Launch the web UI in your browser",
|
||||
choice: LauncherChoice::Dashboard,
|
||||
},
|
||||
MenuItem {
|
||||
label: "Launch terminal UI",
|
||||
hint: "Full interactive TUI dashboard",
|
||||
choice: LauncherChoice::TerminalUI,
|
||||
},
|
||||
MenuItem {
|
||||
label: "Open desktop app",
|
||||
hint: "Launch the native desktop app",
|
||||
choice: LauncherChoice::DesktopApp,
|
||||
},
|
||||
MenuItem {
|
||||
label: "Settings",
|
||||
hint: "Providers, API keys, models, routing",
|
||||
choice: LauncherChoice::GetStarted,
|
||||
},
|
||||
MenuItem {
|
||||
label: "Show all commands",
|
||||
hint: "Print full --help output",
|
||||
choice: LauncherChoice::ShowHelp,
|
||||
},
|
||||
];
|
||||
|
||||
// ── Launcher state ──────────────────────────────────────────────────────────
|
||||
|
||||
struct LauncherState {
|
||||
list: ListState,
|
||||
daemon_url: Option<String>,
|
||||
daemon_agents: u64,
|
||||
detecting: bool,
|
||||
tick: usize,
|
||||
first_run: bool,
|
||||
openclaw_detected: bool,
|
||||
}
|
||||
|
||||
impl LauncherState {
|
||||
fn new() -> Self {
|
||||
let first_run = is_first_run();
|
||||
let openclaw_detected = first_run && has_openclaw();
|
||||
let mut list = ListState::default();
|
||||
list.select(Some(0));
|
||||
Self {
|
||||
list,
|
||||
daemon_url: None,
|
||||
daemon_agents: 0,
|
||||
detecting: true,
|
||||
tick: 0,
|
||||
first_run,
|
||||
openclaw_detected,
|
||||
}
|
||||
}
|
||||
|
||||
fn menu(&self) -> &'static [MenuItem] {
|
||||
if self.first_run {
|
||||
MENU_FIRST_RUN
|
||||
} else {
|
||||
MENU_RETURNING
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Entry point ─────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn run(_config: Option<PathBuf>) -> LauncherChoice {
|
||||
let mut terminal = ratatui::init();
|
||||
|
||||
// Panic hook: restore terminal on panic (set AFTER init succeeds)
|
||||
let original_hook = std::panic::take_hook();
|
||||
std::panic::set_hook(Box::new(move |info| {
|
||||
let _ = ratatui::try_restore();
|
||||
original_hook(info);
|
||||
}));
|
||||
|
||||
let mut state = LauncherState::new();
|
||||
|
||||
// Spawn background daemon detection (catch_unwind protects against thread panics)
|
||||
let (daemon_tx, daemon_rx) = std::sync::mpsc::channel();
|
||||
std::thread::spawn(move || {
|
||||
let _ = std::panic::catch_unwind(|| {
|
||||
let result = crate::find_daemon();
|
||||
let agent_count = result.as_ref().map_or(0, |base| {
|
||||
let client = reqwest::blocking::Client::builder()
|
||||
.timeout(Duration::from_secs(2))
|
||||
.build()
|
||||
.ok();
|
||||
client
|
||||
.and_then(|c| c.get(format!("{base}/api/agents")).send().ok())
|
||||
.and_then(|r| r.json::<serde_json::Value>().ok())
|
||||
.and_then(|v| v.as_array().map(|a| a.len() as u64))
|
||||
.unwrap_or(0)
|
||||
});
|
||||
let _ = daemon_tx.send((result, agent_count));
|
||||
});
|
||||
});
|
||||
|
||||
let choice;
|
||||
|
||||
loop {
|
||||
// Check for daemon detection result
|
||||
if state.detecting {
|
||||
if let Ok((url, agents)) = daemon_rx.try_recv() {
|
||||
state.daemon_url = url;
|
||||
state.daemon_agents = agents;
|
||||
state.detecting = false;
|
||||
}
|
||||
}
|
||||
|
||||
state.tick = state.tick.wrapping_add(1);
|
||||
|
||||
// Draw (gracefully handle render failures)
|
||||
if terminal.draw(|frame| draw(frame, &mut state)).is_err() {
|
||||
choice = LauncherChoice::Quit;
|
||||
break;
|
||||
}
|
||||
|
||||
// Poll for input (50ms = 20fps spinner)
|
||||
if event::poll(Duration::from_millis(50)).unwrap_or(false) {
|
||||
if let Ok(CtEvent::Key(key)) = event::read() {
|
||||
if key.kind != KeyEventKind::Press {
|
||||
continue;
|
||||
}
|
||||
let menu = state.menu();
|
||||
if menu.is_empty() {
|
||||
choice = LauncherChoice::Quit;
|
||||
break;
|
||||
}
|
||||
match key.code {
|
||||
KeyCode::Char('q') | KeyCode::Esc => {
|
||||
choice = LauncherChoice::Quit;
|
||||
break;
|
||||
}
|
||||
KeyCode::Up | KeyCode::Char('k') => {
|
||||
let i = state.list.selected().unwrap_or(0);
|
||||
let next = if i == 0 { menu.len() - 1 } else { i - 1 };
|
||||
state.list.select(Some(next));
|
||||
}
|
||||
KeyCode::Down | KeyCode::Char('j') => {
|
||||
let i = state.list.selected().unwrap_or(0);
|
||||
let next = (i + 1) % menu.len();
|
||||
state.list.select(Some(next));
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
if let Some(i) = state.list.selected() {
|
||||
if i < menu.len() {
|
||||
choice = menu[i].choice;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = ratatui::try_restore();
|
||||
choice
|
||||
}
|
||||
|
||||
// ── Drawing ─────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Left margin for content alignment.
|
||||
const MARGIN_LEFT: u16 = 3;
|
||||
|
||||
/// Constrain content to a readable area within the terminal.
|
||||
fn content_area(area: Rect) -> Rect {
|
||||
if area.width < 10 || area.height < 5 {
|
||||
// Terminal too small — use full area with no margin
|
||||
return area;
|
||||
}
|
||||
let margin = MARGIN_LEFT.min(area.width.saturating_sub(10));
|
||||
let w = 80u16.min(area.width.saturating_sub(margin));
|
||||
Rect {
|
||||
x: area.x.saturating_add(margin),
|
||||
y: area.y,
|
||||
width: w,
|
||||
height: area.height,
|
||||
}
|
||||
}
|
||||
|
||||
fn draw(frame: &mut ratatui::Frame, state: &mut LauncherState) {
|
||||
let area = frame.area();
|
||||
|
||||
// Fill background
|
||||
frame.render_widget(
|
||||
ratatui::widgets::Block::default().style(Style::default().bg(theme::BG_PRIMARY)),
|
||||
area,
|
||||
);
|
||||
|
||||
let content = content_area(area);
|
||||
let version = env!("CARGO_PKG_VERSION");
|
||||
let has_provider = detect_provider().is_some();
|
||||
let menu = state.menu();
|
||||
|
||||
// Compute dynamic heights
|
||||
let header_h: u16 = if state.first_run { 3 } else { 1 }; // welcome text or just title
|
||||
let status_h: u16 = if state.detecting {
|
||||
1
|
||||
} else if has_provider {
|
||||
2
|
||||
} else {
|
||||
3
|
||||
};
|
||||
let migration_hint_h: u16 = if state.first_run && state.openclaw_detected {
|
||||
2
|
||||
} else {
|
||||
0
|
||||
};
|
||||
let menu_h = menu.len() as u16;
|
||||
|
||||
let total_needed = 1 + header_h + 1 + status_h + 1 + menu_h + migration_hint_h + 1;
|
||||
|
||||
// Vertical centering: place content block in the upper-third area
|
||||
let top_pad = if area.height > total_needed + 2 {
|
||||
((area.height - total_needed) / 3).max(1)
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
let chunks = Layout::vertical([
|
||||
Constraint::Length(top_pad), // top space
|
||||
Constraint::Length(header_h), // header / welcome
|
||||
Constraint::Length(1), // separator
|
||||
Constraint::Length(status_h), // status indicators
|
||||
Constraint::Length(1), // separator
|
||||
Constraint::Length(menu_h), // menu items
|
||||
Constraint::Length(migration_hint_h), // openclaw migration hint (if any)
|
||||
Constraint::Length(1), // keybind hints
|
||||
Constraint::Min(0), // remaining space
|
||||
])
|
||||
.split(content);
|
||||
|
||||
// ── Header ──────────────────────────────────────────────────────────────
|
||||
if state.first_run {
|
||||
let header_lines = vec![
|
||||
Line::from(vec![
|
||||
Span::styled(
|
||||
"OpenFang",
|
||||
Style::default()
|
||||
.fg(theme::ACCENT)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
),
|
||||
Span::styled(
|
||||
format!(" v{version}"),
|
||||
Style::default().fg(theme::TEXT_TERTIARY),
|
||||
),
|
||||
]),
|
||||
Line::from(""),
|
||||
Line::from(vec![Span::styled(
|
||||
"Welcome! Let's get you set up.",
|
||||
Style::default().fg(theme::TEXT_PRIMARY),
|
||||
)]),
|
||||
];
|
||||
frame.render_widget(Paragraph::new(header_lines), chunks[1]);
|
||||
} else {
|
||||
let header = Line::from(vec![
|
||||
Span::styled(
|
||||
"OpenFang",
|
||||
Style::default()
|
||||
.fg(theme::ACCENT)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
),
|
||||
Span::styled(
|
||||
format!(" v{version}"),
|
||||
Style::default().fg(theme::TEXT_TERTIARY),
|
||||
),
|
||||
]);
|
||||
frame.render_widget(Paragraph::new(header), chunks[1]);
|
||||
}
|
||||
|
||||
// ── Separator ───────────────────────────────────────────────────────────
|
||||
render_separator(frame, chunks[2]);
|
||||
|
||||
// ── Status block ────────────────────────────────────────────────────────
|
||||
if state.detecting {
|
||||
let spinner = theme::SPINNER_FRAMES[state.tick % theme::SPINNER_FRAMES.len()];
|
||||
let line = Line::from(vec![
|
||||
Span::styled(format!("{spinner} "), Style::default().fg(theme::YELLOW)),
|
||||
Span::styled("Checking for daemon\u{2026}", theme::dim_style()),
|
||||
]);
|
||||
frame.render_widget(Paragraph::new(line), chunks[3]);
|
||||
} else {
|
||||
let mut lines: Vec<Line> = Vec::new();
|
||||
|
||||
// Daemon status
|
||||
if let Some(ref url) = state.daemon_url {
|
||||
let agent_suffix = if state.daemon_agents > 0 {
|
||||
format!(
|
||||
" ({} agent{})",
|
||||
state.daemon_agents,
|
||||
if state.daemon_agents == 1 { "" } else { "s" }
|
||||
)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(
|
||||
"\u{25cf} ",
|
||||
Style::default()
|
||||
.fg(theme::GREEN)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
),
|
||||
Span::styled(
|
||||
format!("Daemon running at {url}"),
|
||||
Style::default().fg(theme::TEXT_PRIMARY),
|
||||
),
|
||||
Span::styled(agent_suffix, Style::default().fg(theme::GREEN)),
|
||||
]));
|
||||
} else {
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled("\u{25cb} ", theme::dim_style()),
|
||||
Span::styled("No daemon running", theme::dim_style()),
|
||||
]));
|
||||
}
|
||||
|
||||
// Provider status
|
||||
if let Some((provider, env_var)) = detect_provider() {
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(
|
||||
"\u{2714} ",
|
||||
Style::default()
|
||||
.fg(theme::GREEN)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
),
|
||||
Span::styled(
|
||||
format!("Provider: {provider}"),
|
||||
Style::default().fg(theme::TEXT_PRIMARY),
|
||||
),
|
||||
Span::styled(format!(" ({env_var})"), theme::dim_style()),
|
||||
]));
|
||||
} else {
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled("\u{25cb} ", Style::default().fg(theme::YELLOW)),
|
||||
Span::styled("No API keys detected", Style::default().fg(theme::YELLOW)),
|
||||
]));
|
||||
if !state.first_run {
|
||||
lines.push(Line::from(vec![Span::styled(
|
||||
" Run 'Re-run setup' to configure a provider",
|
||||
theme::hint_style(),
|
||||
)]));
|
||||
} else {
|
||||
lines.push(Line::from(vec![Span::styled(
|
||||
" Select 'Get started' to configure",
|
||||
theme::hint_style(),
|
||||
)]));
|
||||
}
|
||||
}
|
||||
|
||||
frame.render_widget(Paragraph::new(lines), chunks[3]);
|
||||
}
|
||||
|
||||
// ── Separator 2 ─────────────────────────────────────────────────────────
|
||||
render_separator(frame, chunks[4]);
|
||||
|
||||
// ── Menu ────────────────────────────────────────────────────────────────
|
||||
let items: Vec<ListItem> = menu
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, item)| {
|
||||
// Highlight "Get started" for first-run users
|
||||
let is_primary = state.first_run && i == 0;
|
||||
let label_style = if is_primary {
|
||||
Style::default()
|
||||
.fg(theme::ACCENT)
|
||||
.add_modifier(Modifier::BOLD)
|
||||
} else {
|
||||
Style::default().fg(theme::TEXT_PRIMARY)
|
||||
};
|
||||
|
||||
ListItem::new(Line::from(vec![
|
||||
Span::styled(format!("{:<26}", item.label), label_style),
|
||||
Span::styled(item.hint, theme::dim_style()),
|
||||
]))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let list = List::new(items)
|
||||
.highlight_style(
|
||||
Style::default()
|
||||
.fg(theme::ACCENT)
|
||||
.bg(theme::BG_HOVER)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
)
|
||||
.highlight_symbol("\u{25b8} ");
|
||||
|
||||
frame.render_stateful_widget(list, chunks[5], &mut state.list);
|
||||
|
||||
// ── OpenClaw migration hint ─────────────────────────────────────────────
|
||||
if state.first_run && state.openclaw_detected {
|
||||
let hint_lines = vec![
|
||||
Line::from(""),
|
||||
Line::from(vec![
|
||||
Span::styled("\u{2192} ", Style::default().fg(theme::BLUE)),
|
||||
Span::styled("Coming from OpenClaw? ", Style::default().fg(theme::BLUE)),
|
||||
Span::styled(
|
||||
"'Get started' includes automatic migration.",
|
||||
theme::hint_style(),
|
||||
),
|
||||
]),
|
||||
];
|
||||
frame.render_widget(Paragraph::new(hint_lines), chunks[6]);
|
||||
}
|
||||
|
||||
// ── Keybind hints ───────────────────────────────────────────────────────
|
||||
let hints = Line::from(vec![Span::styled(
|
||||
"\u{2191}\u{2193} navigate enter select q quit",
|
||||
theme::hint_style(),
|
||||
)]);
|
||||
frame.render_widget(Paragraph::new(hints), chunks[7]);
|
||||
}
|
||||
|
||||
fn render_separator(frame: &mut ratatui::Frame, area: Rect) {
|
||||
let w = (area.width as usize).min(60);
|
||||
let line = Line::from(vec![Span::styled(
|
||||
"\u{2500}".repeat(w),
|
||||
Style::default().fg(theme::BORDER),
|
||||
)]);
|
||||
frame.render_widget(Paragraph::new(line), area);
|
||||
}
|
||||
|
||||
// ── Desktop app launcher ────────────────────────────────────────────────────
|
||||
|
||||
pub fn launch_desktop_app() {
|
||||
let desktop_bin = {
|
||||
let exe = std::env::current_exe().ok();
|
||||
let dir = exe.as_ref().and_then(|e| e.parent());
|
||||
|
||||
#[cfg(windows)]
|
||||
let name = "openfang-desktop.exe";
|
||||
#[cfg(not(windows))]
|
||||
let name = "openfang-desktop";
|
||||
|
||||
// Check sibling of current exe first
|
||||
let sibling = dir.map(|d| d.join(name));
|
||||
|
||||
match sibling {
|
||||
Some(ref path) if path.exists() => sibling,
|
||||
_ => which_lookup(name),
|
||||
}
|
||||
};
|
||||
|
||||
match desktop_bin {
|
||||
Some(ref path) if path.exists() => {
|
||||
match std::process::Command::new(path)
|
||||
.stdin(std::process::Stdio::null())
|
||||
.stdout(std::process::Stdio::null())
|
||||
.stderr(std::process::Stdio::null())
|
||||
.spawn()
|
||||
{
|
||||
Ok(_) => {
|
||||
ui::success("Desktop app launched.");
|
||||
}
|
||||
Err(e) => {
|
||||
ui::error_with_fix(
|
||||
&format!("Failed to launch desktop app: {e}"),
|
||||
"Build it: cargo build -p openfang-desktop",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
ui::error_with_fix(
|
||||
"Desktop app not found",
|
||||
"Build it: cargo build -p openfang-desktop",
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple PATH lookup for a binary name.
|
||||
fn which_lookup(name: &str) -> Option<PathBuf> {
|
||||
let path_var = std::env::var("PATH").ok()?;
|
||||
let separator = if cfg!(windows) { ';' } else { ':' };
|
||||
for dir in path_var.split(separator) {
|
||||
let candidate = PathBuf::from(dir).join(name);
|
||||
if candidate.exists() {
|
||||
return Some(candidate);
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
5663
crates/openfang-cli/src/main.rs
Normal file
5663
crates/openfang-cli/src/main.rs
Normal file
File diff suppressed because it is too large
Load Diff
426
crates/openfang-cli/src/mcp.rs
Normal file
426
crates/openfang-cli/src/mcp.rs
Normal file
@@ -0,0 +1,426 @@
|
||||
//! MCP (Model Context Protocol) server for OpenFang.
|
||||
//!
|
||||
//! Exposes running agents as MCP tools over JSON-RPC 2.0 stdio.
|
||||
//! Each agent becomes a callable tool named `openfang_agent_{name}`.
|
||||
//!
|
||||
//! Protocol: Content-Length framing over stdin/stdout.
|
||||
//! Connects to running daemon via HTTP, falls back to in-process kernel.
|
||||
|
||||
use openfang_kernel::OpenFangKernel;
|
||||
use serde_json::{json, Value};
|
||||
use std::io::{self, BufRead, Write};
|
||||
|
||||
/// Backend for MCP: either a running daemon or an in-process kernel.
|
||||
enum McpBackend {
|
||||
Daemon {
|
||||
base_url: String,
|
||||
client: reqwest::blocking::Client,
|
||||
},
|
||||
InProcess {
|
||||
kernel: Box<OpenFangKernel>,
|
||||
rt: tokio::runtime::Runtime,
|
||||
},
|
||||
}
|
||||
|
||||
impl McpBackend {
|
||||
fn list_agents(&self) -> Vec<(String, String, String)> {
|
||||
// Returns (id, name, description) triples
|
||||
match self {
|
||||
McpBackend::Daemon { base_url, client } => {
|
||||
let resp = client
|
||||
.get(format!("{base_url}/api/agents"))
|
||||
.send()
|
||||
.ok()
|
||||
.and_then(|r| r.json::<Value>().ok());
|
||||
match resp.and_then(|v| v.as_array().cloned()) {
|
||||
Some(agents) => agents
|
||||
.iter()
|
||||
.map(|a| {
|
||||
(
|
||||
a["id"].as_str().unwrap_or("").to_string(),
|
||||
a["name"].as_str().unwrap_or("").to_string(),
|
||||
a["description"].as_str().unwrap_or("").to_string(),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
McpBackend::InProcess { kernel, .. } => kernel
|
||||
.registry
|
||||
.list()
|
||||
.iter()
|
||||
.map(|e| {
|
||||
(
|
||||
e.id.to_string(),
|
||||
e.name.clone(),
|
||||
e.manifest.description.clone(),
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn send_message(&self, agent_id: &str, message: &str) -> Result<String, String> {
|
||||
match self {
|
||||
McpBackend::Daemon { base_url, client } => {
|
||||
let resp = client
|
||||
.post(format!("{base_url}/api/agents/{agent_id}/message"))
|
||||
.json(&json!({"message": message}))
|
||||
.send()
|
||||
.map_err(|e| format!("HTTP error: {e}"))?;
|
||||
let body: Value = resp.json().map_err(|e| format!("Parse error: {e}"))?;
|
||||
if let Some(response) = body["response"].as_str() {
|
||||
Ok(response.to_string())
|
||||
} else {
|
||||
Err(body["error"]
|
||||
.as_str()
|
||||
.unwrap_or("Unknown error")
|
||||
.to_string())
|
||||
}
|
||||
}
|
||||
McpBackend::InProcess { kernel, rt } => {
|
||||
let aid: openfang_types::agent::AgentId =
|
||||
agent_id.parse().map_err(|_| "Invalid agent ID")?;
|
||||
let result = rt
|
||||
.block_on(kernel.send_message(aid, message))
|
||||
.map_err(|e| format!("{e}"))?;
|
||||
Ok(result.response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Find agent ID by tool name (strip `openfang_agent_` prefix, match by name).
|
||||
fn resolve_tool_agent(&self, tool_name: &str) -> Option<String> {
|
||||
let agent_name = tool_name.strip_prefix("openfang_agent_")?.replace('_', "-");
|
||||
let agents = self.list_agents();
|
||||
// Try exact match first (with underscores replaced by hyphens)
|
||||
for (id, name, _) in &agents {
|
||||
if name.replace(' ', "-").to_lowercase() == agent_name.to_lowercase() {
|
||||
return Some(id.clone());
|
||||
}
|
||||
}
|
||||
// Try with underscores
|
||||
let agent_name_underscore = tool_name.strip_prefix("openfang_agent_")?;
|
||||
for (id, name, _) in &agents {
|
||||
if name.replace('-', "_").to_lowercase() == agent_name_underscore.to_lowercase() {
|
||||
return Some(id.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the MCP server over stdio.
|
||||
pub fn run_mcp_server(config: Option<std::path::PathBuf>) {
|
||||
let backend = create_backend(config);
|
||||
|
||||
let stdin = io::stdin();
|
||||
let stdout = io::stdout();
|
||||
let mut reader = stdin.lock();
|
||||
let mut writer = stdout.lock();
|
||||
|
||||
loop {
|
||||
match read_message(&mut reader) {
|
||||
Ok(Some(msg)) => {
|
||||
let response = handle_message(&backend, &msg);
|
||||
if let Some(resp) = response {
|
||||
write_message(&mut writer, &resp);
|
||||
}
|
||||
}
|
||||
Ok(None) => break,
|
||||
Err(_) => break,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn create_backend(config: Option<std::path::PathBuf>) -> McpBackend {
|
||||
// Try daemon first
|
||||
if let Some(base_url) = super::find_daemon() {
|
||||
let client = reqwest::blocking::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.build()
|
||||
.expect("Failed to build HTTP client");
|
||||
return McpBackend::Daemon { base_url, client };
|
||||
}
|
||||
|
||||
// Fall back to in-process kernel
|
||||
let kernel = match OpenFangKernel::boot(config.as_deref()) {
|
||||
Ok(k) => k,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to boot kernel for MCP: {e}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
let rt = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
|
||||
McpBackend::InProcess {
|
||||
kernel: Box::new(kernel),
|
||||
rt,
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a Content-Length framed JSON-RPC message from the reader.
|
||||
fn read_message(reader: &mut impl BufRead) -> io::Result<Option<Value>> {
|
||||
// Read headers until empty line
|
||||
let mut content_length: usize = 0;
|
||||
loop {
|
||||
let mut header = String::new();
|
||||
let bytes_read = reader.read_line(&mut header)?;
|
||||
if bytes_read == 0 {
|
||||
return Ok(None); // EOF
|
||||
}
|
||||
|
||||
let trimmed = header.trim();
|
||||
if trimmed.is_empty() {
|
||||
break; // End of headers
|
||||
}
|
||||
|
||||
if let Some(len_str) = trimmed.strip_prefix("Content-Length: ") {
|
||||
content_length = len_str.parse().unwrap_or(0);
|
||||
}
|
||||
}
|
||||
|
||||
if content_length == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// SECURITY: Reject oversized messages to prevent OOM.
|
||||
const MAX_MCP_MESSAGE_SIZE: usize = 10 * 1024 * 1024; // 10MB
|
||||
if content_length > MAX_MCP_MESSAGE_SIZE {
|
||||
// Drain the oversized body to avoid stream desync
|
||||
let mut discard = [0u8; 4096];
|
||||
let mut remaining = content_length;
|
||||
while remaining > 0 {
|
||||
let to_read = remaining.min(4096);
|
||||
if reader.read_exact(&mut discard[..to_read]).is_err() {
|
||||
break;
|
||||
}
|
||||
remaining -= to_read;
|
||||
}
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("MCP message too large: {content_length} bytes (max {MAX_MCP_MESSAGE_SIZE})"),
|
||||
));
|
||||
}
|
||||
|
||||
// Read the body
|
||||
let mut body = vec![0u8; content_length];
|
||||
reader.read_exact(&mut body)?;
|
||||
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(_) => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a Content-Length framed JSON-RPC response to the writer.
|
||||
fn write_message(writer: &mut impl Write, msg: &Value) {
|
||||
let body = serde_json::to_string(msg).unwrap_or_default();
|
||||
let _ = write!(writer, "Content-Length: {}\r\n\r\n{}", body.len(), body);
|
||||
let _ = writer.flush();
|
||||
}
|
||||
|
||||
/// Handle a JSON-RPC message and return an optional response.
|
||||
fn handle_message(backend: &McpBackend, msg: &Value) -> Option<Value> {
|
||||
let method = msg["method"].as_str().unwrap_or("");
|
||||
let id = msg.get("id").cloned();
|
||||
|
||||
match method {
|
||||
"initialize" => {
|
||||
let result = json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"tools": {}
|
||||
},
|
||||
"serverInfo": {
|
||||
"name": "openfang",
|
||||
"version": "0.1.0"
|
||||
}
|
||||
});
|
||||
Some(jsonrpc_response(id?, result))
|
||||
}
|
||||
|
||||
"notifications/initialized" => None, // Notification, no response
|
||||
|
||||
"tools/list" => {
|
||||
let agents = backend.list_agents();
|
||||
let tools: Vec<Value> = agents
|
||||
.iter()
|
||||
.map(|(_, name, description)| {
|
||||
let tool_name = format!("openfang_agent_{}", name.replace('-', "_"));
|
||||
let desc = if description.is_empty() {
|
||||
format!("Send a message to OpenFang agent '{name}'")
|
||||
} else {
|
||||
description.clone()
|
||||
};
|
||||
json!({
|
||||
"name": tool_name,
|
||||
"description": desc,
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Message to send to the agent"
|
||||
}
|
||||
},
|
||||
"required": ["message"]
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
Some(jsonrpc_response(id?, json!({ "tools": tools })))
|
||||
}
|
||||
|
||||
"tools/call" => {
|
||||
let params = &msg["params"];
|
||||
let tool_name = params["name"].as_str().unwrap_or("");
|
||||
let message = params["arguments"]["message"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
if message.is_empty() {
|
||||
return Some(jsonrpc_error(id?, -32602, "Missing 'message' argument"));
|
||||
}
|
||||
|
||||
let agent_id = match backend.resolve_tool_agent(tool_name) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return Some(jsonrpc_error(
|
||||
id?,
|
||||
-32602,
|
||||
&format!("Unknown tool: {tool_name}"),
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
match backend.send_message(&agent_id, &message) {
|
||||
Ok(response) => Some(jsonrpc_response(
|
||||
id?,
|
||||
json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": response
|
||||
}]
|
||||
}),
|
||||
)),
|
||||
Err(e) => Some(jsonrpc_response(
|
||||
id?,
|
||||
json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": format!("Error: {e}")
|
||||
}],
|
||||
"isError": true
|
||||
}),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
_ => {
|
||||
// Unknown method
|
||||
id.map(|id| jsonrpc_error(id, -32601, &format!("Method not found: {method}")))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn jsonrpc_response(id: Value, result: Value) -> Value {
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": result
|
||||
})
|
||||
}
|
||||
|
||||
fn jsonrpc_error(id: Value, code: i32, message: &str) -> Value {
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"error": {
|
||||
"code": code,
|
||||
"message": message
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_handle_initialize() {
|
||||
let msg = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {}
|
||||
});
|
||||
// We can't easily create a backend in tests without a kernel,
|
||||
// but we can test the protocol handling
|
||||
let backend = McpBackend::Daemon {
|
||||
base_url: "http://localhost:9999".to_string(),
|
||||
client: reqwest::blocking::Client::new(),
|
||||
};
|
||||
let resp = handle_message(&backend, &msg).unwrap();
|
||||
assert_eq!(resp["id"], 1);
|
||||
assert_eq!(resp["result"]["protocolVersion"], "2024-11-05");
|
||||
assert_eq!(resp["result"]["serverInfo"]["name"], "openfang");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_notifications_initialized() {
|
||||
let msg = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized"
|
||||
});
|
||||
let backend = McpBackend::Daemon {
|
||||
base_url: "http://localhost:9999".to_string(),
|
||||
client: reqwest::blocking::Client::new(),
|
||||
};
|
||||
let resp = handle_message(&backend, &msg);
|
||||
assert!(resp.is_none()); // No response for notifications
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_unknown_method() {
|
||||
let msg = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "unknown/method"
|
||||
});
|
||||
let backend = McpBackend::Daemon {
|
||||
base_url: "http://localhost:9999".to_string(),
|
||||
client: reqwest::blocking::Client::new(),
|
||||
};
|
||||
let resp = handle_message(&backend, &msg).unwrap();
|
||||
assert_eq!(resp["error"]["code"], -32601);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_response() {
|
||||
let resp = jsonrpc_response(json!(1), json!({"status": "ok"}));
|
||||
assert_eq!(resp["jsonrpc"], "2.0");
|
||||
assert_eq!(resp["id"], 1);
|
||||
assert_eq!(resp["result"]["status"], "ok");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_jsonrpc_error() {
|
||||
let resp = jsonrpc_error(json!(2), -32601, "Not found");
|
||||
assert_eq!(resp["jsonrpc"], "2.0");
|
||||
assert_eq!(resp["id"], 2);
|
||||
assert_eq!(resp["error"]["code"], -32601);
|
||||
assert_eq!(resp["error"]["message"], "Not found");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_message() {
|
||||
let body = r#"{"jsonrpc":"2.0","id":1,"method":"initialize"}"#;
|
||||
let input = format!("Content-Length: {}\r\n\r\n{}", body.len(), body);
|
||||
let mut reader = io::BufReader::new(input.as_bytes());
|
||||
let msg = read_message(&mut reader).unwrap().unwrap();
|
||||
assert_eq!(msg["method"], "initialize");
|
||||
assert_eq!(msg["id"], 1);
|
||||
}
|
||||
}
|
||||
322
crates/openfang-cli/src/progress.rs
Normal file
322
crates/openfang-cli/src/progress.rs
Normal file
@@ -0,0 +1,322 @@
|
||||
//! Progress bars and spinners for CLI output.
|
||||
//!
|
||||
//! Uses raw ANSI escape sequences (no external dependency). Supports:
|
||||
//! - Percentage progress bar with visual block characters
|
||||
//! - Spinner with label
|
||||
//! - OSC 9;4 terminal progress protocol (ConEmu/Windows Terminal/iTerm2)
|
||||
//! - Delay suppression for fast operations
|
||||
|
||||
use std::io::{self, Write};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Default progress bar width (in characters).
|
||||
const DEFAULT_BAR_WIDTH: usize = 30;
|
||||
|
||||
/// Minimum elapsed time before showing progress output. Operations that
|
||||
/// complete faster than this threshold produce no visual noise.
|
||||
const DELAY_SUPPRESS_MS: u64 = 200;
|
||||
|
||||
/// Block characters for the progress bar.
|
||||
const FILLED: char = '\u{2588}'; // █
|
||||
const EMPTY: char = '\u{2591}'; // ░
|
||||
|
||||
/// Spinner animation frames.
|
||||
const SPINNER_FRAMES: &[char] = &[
|
||||
'\u{280b}', '\u{2819}', '\u{2839}', '\u{2838}', '\u{283c}', '\u{2834}', '\u{2826}', '\u{2827}',
|
||||
'\u{2807}', '\u{280f}',
|
||||
];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// OSC 9;4 progress protocol
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Emit an OSC 9;4 progress sequence (supported by Windows Terminal, ConEmu,
|
||||
/// iTerm2). `state`: 1 = set progress, 2 = error, 3 = indeterminate, 0 = clear.
|
||||
fn osc_progress(state: u8, percent: u8) {
|
||||
// ESC ] 9 ; 4 ; state ; percent ST
|
||||
// ST = ESC \ (string terminator)
|
||||
let _ = write!(io::stderr(), "\x1b]9;4;{state};{percent}\x1b\\");
|
||||
let _ = io::stderr().flush();
|
||||
}
|
||||
|
||||
/// Clear the OSC 9;4 progress indicator.
|
||||
fn osc_progress_clear() {
|
||||
osc_progress(0, 0);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ProgressBar
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// A simple percentage-based progress bar.
|
||||
///
|
||||
/// ```text
|
||||
/// Downloading [████████████░░░░░░░░░░░░░░░░░░] 40% (4/10)
|
||||
/// ```
|
||||
pub struct ProgressBar {
|
||||
label: String,
|
||||
total: u64,
|
||||
current: u64,
|
||||
width: usize,
|
||||
start: Instant,
|
||||
suppress_until: Duration,
|
||||
visible: bool,
|
||||
use_osc: bool,
|
||||
}
|
||||
|
||||
impl ProgressBar {
|
||||
/// Create a new progress bar.
|
||||
///
|
||||
/// `label`: text shown before the bar.
|
||||
/// `total`: the 100% value.
|
||||
pub fn new(label: &str, total: u64) -> Self {
|
||||
Self {
|
||||
label: label.to_string(),
|
||||
total: total.max(1),
|
||||
current: 0,
|
||||
width: DEFAULT_BAR_WIDTH,
|
||||
start: Instant::now(),
|
||||
suppress_until: Duration::from_millis(DELAY_SUPPRESS_MS),
|
||||
visible: false,
|
||||
use_osc: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the bar width in characters.
|
||||
pub fn width(mut self, w: usize) -> Self {
|
||||
self.width = w.max(5);
|
||||
self
|
||||
}
|
||||
|
||||
/// Disable delay suppression (always show immediately).
|
||||
pub fn no_delay(mut self) -> Self {
|
||||
self.suppress_until = Duration::ZERO;
|
||||
self
|
||||
}
|
||||
|
||||
/// Disable OSC 9;4 terminal progress protocol.
|
||||
pub fn no_osc(mut self) -> Self {
|
||||
self.use_osc = false;
|
||||
self
|
||||
}
|
||||
|
||||
/// Update progress to `n`.
|
||||
pub fn set(&mut self, n: u64) {
|
||||
self.current = n.min(self.total);
|
||||
self.draw();
|
||||
}
|
||||
|
||||
/// Increment progress by `delta`.
|
||||
pub fn inc(&mut self, delta: u64) {
|
||||
self.current = (self.current + delta).min(self.total);
|
||||
self.draw();
|
||||
}
|
||||
|
||||
/// Mark as finished and clear the line.
|
||||
pub fn finish(&mut self) {
|
||||
self.current = self.total;
|
||||
self.draw();
|
||||
if self.visible {
|
||||
// Move to next line
|
||||
eprintln!();
|
||||
}
|
||||
if self.use_osc {
|
||||
osc_progress_clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark as finished with a message replacing the bar.
|
||||
pub fn finish_with_message(&mut self, msg: &str) {
|
||||
self.current = self.total;
|
||||
if self.visible {
|
||||
eprint!("\r\x1b[2K{msg}");
|
||||
eprintln!();
|
||||
} else if self.start.elapsed() >= self.suppress_until {
|
||||
eprintln!("{msg}");
|
||||
}
|
||||
if self.use_osc {
|
||||
osc_progress_clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn draw(&mut self) {
|
||||
// Delay suppression: don't render if op is still fast
|
||||
if self.start.elapsed() < self.suppress_until && self.current < self.total {
|
||||
return;
|
||||
}
|
||||
|
||||
self.visible = true;
|
||||
|
||||
let pct = (self.current as f64 / self.total as f64 * 100.0) as u8;
|
||||
let filled = (self.current as f64 / self.total as f64 * self.width as f64) as usize;
|
||||
let empty = self.width.saturating_sub(filled);
|
||||
|
||||
let bar: String = std::iter::repeat_n(FILLED, filled)
|
||||
.chain(std::iter::repeat_n(EMPTY, empty))
|
||||
.collect();
|
||||
|
||||
eprint!(
|
||||
"\r\x1b[2K{:<14} [{}] {:>3}% ({}/{})",
|
||||
self.label, bar, pct, self.current, self.total
|
||||
);
|
||||
let _ = io::stderr().flush();
|
||||
|
||||
if self.use_osc {
|
||||
osc_progress(1, pct);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for ProgressBar {
|
||||
fn drop(&mut self) {
|
||||
if self.use_osc && self.visible {
|
||||
osc_progress_clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Spinner
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// An indeterminate spinner for operations without known total.
|
||||
///
|
||||
/// ```text
|
||||
/// ⠋ Loading models...
|
||||
/// ```
|
||||
pub struct Spinner {
|
||||
label: String,
|
||||
frame: usize,
|
||||
start: Instant,
|
||||
suppress_until: Duration,
|
||||
visible: bool,
|
||||
use_osc: bool,
|
||||
}
|
||||
|
||||
impl Spinner {
|
||||
/// Create a spinner with the given label.
|
||||
pub fn new(label: &str) -> Self {
|
||||
Self {
|
||||
label: label.to_string(),
|
||||
frame: 0,
|
||||
start: Instant::now(),
|
||||
suppress_until: Duration::from_millis(DELAY_SUPPRESS_MS),
|
||||
visible: false,
|
||||
use_osc: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Disable delay suppression.
|
||||
pub fn no_delay(mut self) -> Self {
|
||||
self.suppress_until = Duration::ZERO;
|
||||
self
|
||||
}
|
||||
|
||||
/// Disable OSC 9;4 protocol.
|
||||
pub fn no_osc(mut self) -> Self {
|
||||
self.use_osc = false;
|
||||
self
|
||||
}
|
||||
|
||||
/// Advance the spinner by one frame and redraw.
|
||||
pub fn tick(&mut self) {
|
||||
if self.start.elapsed() < self.suppress_until {
|
||||
return;
|
||||
}
|
||||
|
||||
self.visible = true;
|
||||
let ch = SPINNER_FRAMES[self.frame % SPINNER_FRAMES.len()];
|
||||
self.frame += 1;
|
||||
|
||||
eprint!("\r\x1b[2K{ch} {}", self.label);
|
||||
let _ = io::stderr().flush();
|
||||
|
||||
if self.use_osc {
|
||||
osc_progress(3, 0);
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the label text.
|
||||
pub fn set_label(&mut self, label: &str) {
|
||||
self.label = label.to_string();
|
||||
}
|
||||
|
||||
/// Stop the spinner and clear the line.
|
||||
pub fn finish(&self) {
|
||||
if self.visible {
|
||||
eprint!("\r\x1b[2K");
|
||||
let _ = io::stderr().flush();
|
||||
}
|
||||
if self.use_osc {
|
||||
osc_progress_clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Stop the spinner and print a final message.
|
||||
pub fn finish_with_message(&self, msg: &str) {
|
||||
if self.visible {
|
||||
eprint!("\r\x1b[2K");
|
||||
}
|
||||
eprintln!("{msg}");
|
||||
if self.use_osc {
|
||||
osc_progress_clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for Spinner {
|
||||
fn drop(&mut self) {
|
||||
if self.use_osc && self.visible {
|
||||
osc_progress_clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn progress_bar_percentage() {
|
||||
let mut pb = ProgressBar::new("Test", 10).no_delay().no_osc();
|
||||
pb.set(5);
|
||||
assert_eq!(pb.current, 5);
|
||||
pb.inc(3);
|
||||
assert_eq!(pb.current, 8);
|
||||
// Cannot exceed total
|
||||
pb.inc(100);
|
||||
assert_eq!(pb.current, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn progress_bar_zero_total_no_panic() {
|
||||
// total of 0 should be clamped to 1 to avoid division by zero
|
||||
let mut pb = ProgressBar::new("Empty", 0).no_delay().no_osc();
|
||||
pb.set(0);
|
||||
pb.finish();
|
||||
assert_eq!(pb.total, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spinner_frame_advance() {
|
||||
let mut sp = Spinner::new("Loading").no_delay().no_osc();
|
||||
sp.tick();
|
||||
assert_eq!(sp.frame, 1);
|
||||
sp.tick();
|
||||
assert_eq!(sp.frame, 2);
|
||||
sp.finish();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn delay_suppression() {
|
||||
// With default suppress_until, a freshly-created bar should NOT
|
||||
// become visible on the first draw (elapsed < 200ms).
|
||||
let mut pb = ProgressBar::new("Quick", 10).no_osc();
|
||||
pb.set(1);
|
||||
assert!(!pb.visible);
|
||||
}
|
||||
}
|
||||
248
crates/openfang-cli/src/table.rs
Normal file
248
crates/openfang-cli/src/table.rs
Normal file
@@ -0,0 +1,248 @@
|
||||
//! ASCII table renderer with Unicode box-drawing borders for CLI output.
|
||||
//!
|
||||
//! Supports column alignment, auto-width, header styling, and optional colored
|
||||
//! output via the `colored` crate.
|
||||
|
||||
use colored::Colorize;
|
||||
|
||||
/// Column alignment.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Align {
|
||||
Left,
|
||||
Right,
|
||||
Center,
|
||||
}
|
||||
|
||||
/// A table builder that collects headers and rows, then renders to a
|
||||
/// Unicode box-drawing string.
|
||||
pub struct Table {
|
||||
headers: Vec<String>,
|
||||
alignments: Vec<Align>,
|
||||
rows: Vec<Vec<String>>,
|
||||
}
|
||||
|
||||
impl Table {
|
||||
/// Create a new table with the given column headers.
|
||||
/// All columns default to left-alignment.
|
||||
pub fn new(headers: &[&str]) -> Self {
|
||||
let headers: Vec<String> = headers.iter().map(|h| h.to_string()).collect();
|
||||
let alignments = vec![Align::Left; headers.len()];
|
||||
Self {
|
||||
headers,
|
||||
alignments,
|
||||
rows: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Override the alignment for a specific column (0-indexed).
|
||||
/// Out-of-range indices are silently ignored.
|
||||
pub fn align(mut self, col: usize, alignment: Align) -> Self {
|
||||
if col < self.alignments.len() {
|
||||
self.alignments[col] = alignment;
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a row. Extra cells are truncated; missing cells are filled with "".
|
||||
pub fn add_row(&mut self, cells: &[&str]) {
|
||||
let row: Vec<String> = (0..self.headers.len())
|
||||
.map(|i| cells.get(i).unwrap_or(&"").to_string())
|
||||
.collect();
|
||||
self.rows.push(row);
|
||||
}
|
||||
|
||||
/// Compute the display width of each column (max of header and all cells).
|
||||
fn column_widths(&self) -> Vec<usize> {
|
||||
let mut widths: Vec<usize> = self.headers.iter().map(|h| h.len()).collect();
|
||||
for row in &self.rows {
|
||||
for (i, cell) in row.iter().enumerate() {
|
||||
if i < widths.len() {
|
||||
widths[i] = widths[i].max(cell.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
widths
|
||||
}
|
||||
|
||||
/// Pad a string to the given width according to alignment.
|
||||
fn pad(text: &str, width: usize, alignment: Align) -> String {
|
||||
let len = text.len();
|
||||
if len >= width {
|
||||
return text.to_string();
|
||||
}
|
||||
let diff = width - len;
|
||||
match alignment {
|
||||
Align::Left => format!("{text}{}", " ".repeat(diff)),
|
||||
Align::Right => format!("{}{text}", " ".repeat(diff)),
|
||||
Align::Center => {
|
||||
let left = diff / 2;
|
||||
let right = diff - left;
|
||||
format!("{}{text}{}", " ".repeat(left), " ".repeat(right))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a horizontal border line.
|
||||
/// `left`, `mid`, `right` are the corner/junction characters.
|
||||
fn border(widths: &[usize], left: &str, mid: &str, right: &str) -> String {
|
||||
let segments: Vec<String> = widths.iter().map(|w| "\u{2500}".repeat(w + 2)).collect();
|
||||
format!("{left}{}{right}", segments.join(mid))
|
||||
}
|
||||
|
||||
/// Render the table to a string with Unicode box-drawing borders.
|
||||
///
|
||||
/// Layout:
|
||||
/// ```text
|
||||
/// ┌──────┬───────┐
|
||||
/// │ Name │ Value │
|
||||
/// ├──────┼───────┤
|
||||
/// │ foo │ bar │
|
||||
/// └──────┴───────┘
|
||||
/// ```
|
||||
pub fn render(&self) -> String {
|
||||
let widths = self.column_widths();
|
||||
|
||||
let top = Self::border(&widths, "\u{250c}", "\u{252c}", "\u{2510}");
|
||||
let sep = Self::border(&widths, "\u{251c}", "\u{253c}", "\u{2524}");
|
||||
let bot = Self::border(&widths, "\u{2514}", "\u{2534}", "\u{2518}");
|
||||
|
||||
let mut lines = Vec::new();
|
||||
|
||||
// Top border
|
||||
lines.push(top);
|
||||
|
||||
// Header row (bold)
|
||||
let header_cells: Vec<String> = self
|
||||
.headers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, h)| format!(" {} ", Self::pad(h, widths[i], self.alignments[i]).bold()))
|
||||
.collect();
|
||||
lines.push(format!("\u{2502}{}\u{2502}", header_cells.join("\u{2502}")));
|
||||
|
||||
// Separator
|
||||
lines.push(sep);
|
||||
|
||||
// Data rows
|
||||
for row in &self.rows {
|
||||
let cells: Vec<String> = row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, cell)| format!(" {} ", Self::pad(cell, widths[i], self.alignments[i])))
|
||||
.collect();
|
||||
lines.push(format!("\u{2502}{}\u{2502}", cells.join("\u{2502}")));
|
||||
}
|
||||
|
||||
// Bottom border
|
||||
lines.push(bot);
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
/// Render the table and print it to stdout.
|
||||
pub fn print(&self) {
|
||||
println!("{}", self.render());
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn basic_table() {
|
||||
let mut t = Table::new(&["Name", "Age", "City"]);
|
||||
t.add_row(&["Alice", "30", "London"]);
|
||||
t.add_row(&["Bob", "25", "Paris"]);
|
||||
|
||||
let rendered = t.render();
|
||||
let lines: Vec<&str> = rendered.lines().collect();
|
||||
|
||||
// 5 lines: top, header, sep, 2 rows, bottom = 6
|
||||
assert_eq!(lines.len(), 6);
|
||||
|
||||
// Top border uses box-drawing
|
||||
assert!(lines[0].starts_with('\u{250c}'));
|
||||
assert!(lines[0].ends_with('\u{2510}'));
|
||||
|
||||
// Bottom border
|
||||
assert!(lines[5].starts_with('\u{2514}'));
|
||||
assert!(lines[5].ends_with('\u{2518}'));
|
||||
|
||||
// Header line contains column names (ignore ANSI codes for bold)
|
||||
assert!(lines[1].contains("Name"));
|
||||
assert!(lines[1].contains("Age"));
|
||||
assert!(lines[1].contains("City"));
|
||||
|
||||
// Data rows contain cell values
|
||||
assert!(lines[3].contains("Alice"));
|
||||
assert!(lines[3].contains("30"));
|
||||
assert!(lines[3].contains("London"));
|
||||
assert!(lines[4].contains("Bob"));
|
||||
assert!(lines[4].contains("25"));
|
||||
assert!(lines[4].contains("Paris"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn right_alignment() {
|
||||
let mut t = Table::new(&["Item", "Count"]);
|
||||
t = t.align(1, Align::Right);
|
||||
t.add_row(&["apples", "5"]);
|
||||
t.add_row(&["oranges", "123"]);
|
||||
|
||||
let rendered = t.render();
|
||||
// The "5" should be right-padded on the left within its column
|
||||
// Find the data line with "5"
|
||||
let line = rendered.lines().find(|l| l.contains("apples")).unwrap();
|
||||
// After the second box char, the number should be right-aligned
|
||||
assert!(line.contains(" 5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn center_alignment() {
|
||||
let pad = Table::pad("hi", 6, Align::Center);
|
||||
assert_eq!(pad, " hi ");
|
||||
|
||||
let pad_odd = Table::pad("hi", 7, Align::Center);
|
||||
assert_eq!(pad_odd, " hi ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_table() {
|
||||
let t = Table::new(&["A", "B"]);
|
||||
let rendered = t.render();
|
||||
let lines: Vec<&str> = rendered.lines().collect();
|
||||
// top, header, sep, bottom = 4 lines (no data rows)
|
||||
assert_eq!(lines.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn missing_cells_filled() {
|
||||
let mut t = Table::new(&["X", "Y", "Z"]);
|
||||
t.add_row(&["only-one"]);
|
||||
|
||||
let rendered = t.render();
|
||||
// Row should still have 3 columns; missing ones are empty
|
||||
let data_line = rendered.lines().nth(3).unwrap();
|
||||
// Count box-drawing vertical bars in data line
|
||||
let bars = data_line.matches('\u{2502}').count();
|
||||
assert_eq!(bars, 4); // left + 2 inner + right
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wide_cells_auto_width() {
|
||||
let mut t = Table::new(&["ID", "Description"]);
|
||||
t.add_row(&["1", "A very long description string"]);
|
||||
|
||||
let rendered = t.render();
|
||||
assert!(rendered.contains("A very long description string"));
|
||||
// The top border should be wide enough to contain the description
|
||||
let top = rendered.lines().next().unwrap();
|
||||
// At minimum: 2 padding + description length for second column
|
||||
assert!(top.len() > 30);
|
||||
}
|
||||
}
|
||||
130
crates/openfang-cli/src/templates.rs
Normal file
130
crates/openfang-cli/src/templates.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
//! Discover and load agent templates from the agents directory.
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// A discovered agent template.
|
||||
pub struct AgentTemplate {
|
||||
/// Template name (directory name).
|
||||
pub name: String,
|
||||
/// Description from the manifest.
|
||||
pub description: String,
|
||||
/// Raw TOML content.
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Discover template directories. Checks:
|
||||
/// 1. The repo `agents/` dir (for dev builds)
|
||||
/// 2. `~/.openfang/agents/` (installed templates)
|
||||
/// 3. `OPENFANG_AGENTS_DIR` env var
|
||||
pub fn discover_template_dirs() -> Vec<PathBuf> {
|
||||
let mut dirs = Vec::new();
|
||||
|
||||
// Dev: repo agents/ directory (relative to the binary)
|
||||
if let Ok(exe) = std::env::current_exe() {
|
||||
// Walk up from the binary to find the workspace root
|
||||
let mut dir = exe.as_path();
|
||||
for _ in 0..5 {
|
||||
if let Some(parent) = dir.parent() {
|
||||
let agents = parent.join("agents");
|
||||
if agents.is_dir() {
|
||||
dirs.push(agents);
|
||||
break;
|
||||
}
|
||||
dir = parent;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Installed templates
|
||||
if let Some(home) = dirs::home_dir() {
|
||||
let agents = home.join(".openfang").join("agents");
|
||||
if agents.is_dir() && !dirs.contains(&agents) {
|
||||
dirs.push(agents);
|
||||
}
|
||||
}
|
||||
|
||||
// Environment override
|
||||
if let Ok(env_dir) = std::env::var("OPENFANG_AGENTS_DIR") {
|
||||
let p = PathBuf::from(env_dir);
|
||||
if p.is_dir() && !dirs.contains(&p) {
|
||||
dirs.push(p);
|
||||
}
|
||||
}
|
||||
|
||||
dirs
|
||||
}
|
||||
|
||||
/// Load all templates from discovered directories, falling back to bundled templates.
|
||||
pub fn load_all_templates() -> Vec<AgentTemplate> {
|
||||
let mut templates = Vec::new();
|
||||
let mut seen_names = std::collections::HashSet::new();
|
||||
|
||||
// First: load from filesystem (user-installed or dev repo)
|
||||
for dir in discover_template_dirs() {
|
||||
if let Ok(entries) = std::fs::read_dir(&dir) {
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if !path.is_dir() {
|
||||
continue;
|
||||
}
|
||||
let manifest = path.join("agent.toml");
|
||||
if !manifest.exists() {
|
||||
continue;
|
||||
}
|
||||
let name = entry.file_name().to_string_lossy().to_string();
|
||||
if name == "custom" || !seen_names.insert(name.clone()) {
|
||||
continue;
|
||||
}
|
||||
if let Ok(content) = std::fs::read_to_string(&manifest) {
|
||||
let description = extract_description(&content);
|
||||
templates.push(AgentTemplate {
|
||||
name,
|
||||
description,
|
||||
content,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: load bundled templates for any not found on disk
|
||||
for (name, content) in crate::bundled_agents::bundled_agents() {
|
||||
if seen_names.insert(name.to_string()) {
|
||||
let description = extract_description(content);
|
||||
templates.push(AgentTemplate {
|
||||
name: name.to_string(),
|
||||
description,
|
||||
content: content.to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
templates.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
templates
|
||||
}
|
||||
|
||||
/// Extract the `description` field from raw TOML without full parsing.
|
||||
fn extract_description(toml_str: &str) -> String {
|
||||
for line in toml_str.lines() {
|
||||
let trimmed = line.trim();
|
||||
if let Some(rest) = trimmed.strip_prefix("description") {
|
||||
if let Some(rest) = rest.trim_start().strip_prefix('=') {
|
||||
let val = rest.trim().trim_matches('"');
|
||||
return val.to_string();
|
||||
}
|
||||
}
|
||||
}
|
||||
String::new()
|
||||
}
|
||||
|
||||
/// Format a template description as a hint for cliclack select items.
|
||||
pub fn template_display_hint(t: &AgentTemplate) -> String {
|
||||
if t.description.is_empty() {
|
||||
String::new()
|
||||
} else if t.description.chars().count() > 60 {
|
||||
let truncated: String = t.description.chars().take(57).collect();
|
||||
format!("{truncated}...")
|
||||
} else {
|
||||
t.description.clone()
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user