diff --git a/Cargo.lock b/Cargo.lock index c1f5fbe..95854aa 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -49,7 +49,7 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", - "rand", + "rand 0.9.2", "sha1", "smallvec", "tokio", @@ -201,7 +201,7 @@ dependencies = [ "mime", "percent-encoding", "pin-project-lite", - "rand", + "rand 0.9.2", "rustls 0.20.9", "serde", "serde_json", @@ -232,6 +232,20 @@ dependencies = [ "fs_extra", ] +[[package]] +name = "backoff" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b62ddb9cb1ec0a098ad4bbf9344d0713fa193ae1a80af55febcff2627b6a00c1" +dependencies = [ + "futures-core", + "getrandom 0.2.16", + "instant", + "pin-project-lite", + "rand 0.8.5", + "tokio", +] + [[package]] name = "base64" version = "0.22.1" @@ -768,6 +782,15 @@ dependencies = [ "hashbrown", ] +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if", +] + [[package]] name = "itertools" version = "0.13.0" @@ -1026,14 +1049,35 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + [[package]] name = "rand" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha", - "rand_core", + "rand_chacha 0.9.0", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", ] [[package]] @@ -1043,7 +1087,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.16", ] [[package]] @@ -1842,6 +1895,7 @@ dependencies = [ "anyhow", "async-stomp", "awc", + "backoff", "bytes", "futures-util", "pretty-readme", diff --git a/Cargo.toml b/Cargo.toml index 797b82e..a582707 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ actix-http = "3" anyhow = "1" async-stomp = "0.6" awc = "3" +backoff = { version = "0.4", features = ["tokio"] } bytes = "1" futures-util = "0.3" pretty-readme = "0.1" diff --git a/README.md b/README.md index 080fb77..0e2721d 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,11 @@ This crate provides a simple client to connect to a STOMP-enabled WebSocket serv * Connects to STOMP servers over WebSocket using [`awc`](https://crates.io/crates/awc). * Handles all STOMP protocol encoding and decoding via [`async-stomp`](https://crates.io/crates/async-stomp). * Manages WebSocket ping/pong heartbeats automatically in a background task. -* Provides a simple `tokio::mpsc` channel-based API (`WStompClient`) for sending and receiving STOMP frames. +* Provides a simple `tokio::mpsc` channel-based API ([`WStompClient`]) for sending and receiving STOMP frames. * Connection helpers for various authentication methods: - * `connect`: Anonymous connection. - * `connect_with_pass`: Login and passcode authentication. - * `connect_with_token`: Authentication using an authorization token header. + * [`connect`]: Anonymous connection. + * [`connect_with_pass`]: Login and passcode authentication. + * [`connect_with_token`]: Authentication using an authorization token header. * Optional `rustls` feature for SSL connections, with helpers that force HTTP/1.1 for compatibility with servers like SockJS. ## Installation @@ -24,8 +24,6 @@ Add this to your `Cargo.toml`: [dependencies] wstomp = "0.1.0" # Replace with the actual version actix-rt = "2.0" -tokio = { version = "1", features = ["macros", "rt-multi-thread"] } -futures-util = "0.3" ``` For SSL support, enable the `rustls` feature: @@ -109,7 +107,6 @@ async fn main() -> Result<(), Box> { other => println!("Received other frame: {:?}", other), } } - WStompEvent::WebsocketClosed(reason) => break, // Handle errors WStompEvent::Error(err) => { match err { @@ -176,11 +173,41 @@ async fn main() { } ``` +### Auto-reconnect + +Use [`WStompConfig::build_and_connect_with_reconnection_cb`] method to automatically perform a full reconnect upon errors. + +```rust,no_run +use wstomp::{WStompClient, WStompConfig, WStompConnectError}; + +#[actix_rt::main] +async fn main() { + let url = "wss://secure-server.com/ws"; + let session_token = "session_token"; + + let cb = { + move |wstomp_client_res: Result| { + async move { + // Unwrap wstomp client here or react to an error. + // Upon an error you can return from the callback to make wstomp library a re-connection attempt + } + } + }; + + let res = WStompConfig::new(url) + .ssl() + .auth_token(session_token) + .build_and_connect_with_reconnection_cb(cb); + + // ... do different stuff here, but don't exit immediately as this will terminate wstomp loop. +} +``` + ## Error Handling -The connection functions (`connect`, `connect_ssl`, etc.) return a `Result`. +The connection functions ([`connect`], [`connect_ssl`], etc.) return a `Result`. -Once connected, the `WStompClient::rx` channel produces `WStompEvent` items, it may be a message, websocket closing, or `WStompError`. +Once connected, the `WStompClient::rx` channel produces [`WStompEvent`] items, it may be a message or [`WStompError`]. * **`WStompConnectError`**: An error that occurs during the initial WebSocket and STOMP `CONNECT` handshake. @@ -188,6 +215,9 @@ Once connected, the `WStompClient::rx` channel produces `WStompEvent` items, it * `WsReceive` / `WsSend`: A WebSocket protocol error. * `StompDecoding` / `StompEncoding`: A STOMP frame decoding/encoding error. * `IncompleteStompFrame`: A warning indicating that data was received but was not enough to form a complete STOMP frame. The client has dropped this data. This is often safe to ignore or log as a warning. + * `WebsocketClosed`: WebSocket was closed, possibly a reason from `awc` library is inside. + * `PingFailed`: Couldn't send ping through the WebSocket protocol. + * `PingTimeout`: There was no pong for last ping. ## License diff --git a/src/client.rs b/src/client.rs index 8dc2131..5a09ffe 100644 --- a/src/client.rs +++ b/src/client.rs @@ -31,6 +31,9 @@ impl WStompClient { /// You can use this struct directly by passing the `Framed` object you get from `awc` into this constructor. /// This will create a background worker in actix system (on current thread), which will encode and decode STOMP messages for you. /// It also manages websocket ping-pong heartbeat. + /// + /// NOTE: This method does not perform automatic reconnection. + /// Use [WStompConfig::build_and_connect_with_reconnection_cb] to auto-reconnect. pub fn from_framed(ws_framed: Framed) -> Self { // Channel for you to send STOMP frames to the handler task let (app_tx, app_rx) = mpsc::channel::>(100); diff --git a/src/config.rs b/src/config.rs index 87e9eb9..0cf575b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,7 @@ pub struct WStompConfig { opts: WStompConfigOpts, } -#[derive(Default)] +#[derive(Clone)] pub struct WStompConfigOpts { #[cfg(feature = "rustls")] pub ssl: bool, @@ -12,6 +12,30 @@ pub struct WStompConfigOpts { pub passcode: Option, pub additional_headers: Vec<(String, String)>, pub client: Option, + + // Reconnection opts in seconds + pub retry_initial_interval: u64, + pub retry_max_interval: u64, + pub retry_multiplier: f64, + pub retry_max_elapsed_time: Option, +} + +impl Default for WStompConfigOpts { + fn default() -> Self { + Self { + ssl: Default::default(), + auth_token: Default::default(), + login: Default::default(), + passcode: Default::default(), + additional_headers: Default::default(), + client: Default::default(), + + retry_initial_interval: 3, + retry_max_interval: 60, + retry_multiplier: 1.2, + retry_max_elapsed_time: None, + } + } } impl WStompConfig { @@ -22,47 +46,102 @@ impl WStompConfig { } } + /// Get url to which this config is assigned to use. pub fn get_url(&self) -> &U { &self.url } + /// Get options for this config. pub fn get_opts(&self) -> &WStompConfigOpts { &self.opts } - pub fn into_inner(self) -> (U, WStompConfigOpts) { + /// De-couple url and options in this config. + pub fn into_parts(self) -> (U, WStompConfigOpts) { (self.url, self.opts) } // Setters + /// Enables TLS/SSL encryption for the connection. + /// + /// When set, the client will attempt to perform a secure handshake + /// (typically for `wss://` schemes). pub fn ssl(mut self) -> Self { self.opts.ssl = true; self } + /// Sets the authentication token for the connection. pub fn auth_token(mut self, auth_token: impl Into) -> Self { self.opts.auth_token = Some(auth_token.into()); self } + /// Sets the `login` header for STOMP authentication. pub fn login(mut self, login: impl Into) -> Self { self.opts.login = Some(login.into()); self } + /// Sets the `passcode` header for STOMP authentication. pub fn passcode(mut self, passcode: impl Into) -> Self { self.opts.passcode = Some(passcode.into()); self } + /// Appends a list of custom headers to the connection configuration. + /// + /// These headers will be included in the STOMP `CONNECT` frame. + /// This method does not replace existing headers; it extends the list. pub fn add_headers(mut self, additional_headers: Vec<(String, String)>) -> Self { self.opts.additional_headers.extend(additional_headers); self } + /// Sets a custom `awc::Client` instance. + /// + /// Use this if you need to provide a pre-configured HTTP client (e.g., + /// with custom timeouts, proxy settings, or connector configurations) + /// instead of letting the library create a default one. pub fn client(mut self, client: awc::Client) -> Self { self.opts.client = Some(client); self } + + /// If [Self::build_and_connect_with_reconnection_cb] method is used, + /// sets the initial retry interval in seconds. + /// + /// Example: Start retrying after 3 seconds. + pub fn retry_initial_interval(mut self, seconds: u64) -> Self { + self.opts.retry_initial_interval = seconds; + self + } + + /// If [Self::build_and_connect_with_reconnection_cb] method is used, + /// sets the maximum retry interval in seconds. + /// + /// Example: Cap the wait time at 30 seconds. + pub fn retry_max_interval(mut self, seconds: u64) -> Self { + self.opts.retry_max_interval = seconds; + self + } + + /// If [Self::build_and_connect_with_reconnection_cb] method is used, + /// sets the multiplier for the backoff. + /// + /// Example: 2.0 doubles the wait time after every failure. + pub fn retry_multiplier(mut self, multiplier: f64) -> Self { + self.opts.retry_multiplier = multiplier; + self + } + + /// If [Self::build_and_connect_with_reconnection_cb] method is used, + /// sets a maximum total time to try reconnecting before giving up. + /// + /// Defaults to no limit if method not invoked. + pub fn retry_max_elapsed_time(mut self, seconds: u64) -> Self { + self.opts.retry_max_elapsed_time = Some(seconds); + self + } } diff --git a/src/connect.rs b/src/connect.rs index ccbaadb..30e6269 100644 --- a/src/connect.rs +++ b/src/connect.rs @@ -4,10 +4,15 @@ use awc::{ error::{HttpError, WsClientError}, ws::WebsocketsRequest, }; +use backoff::{ExponentialBackoffBuilder, backoff::Backoff}; +use std::time::Duration; +use tokio::time::sleep; -use crate::{WStompClient, WStompConfig, WStompConnectError}; +use crate::{WStompClient, WStompConfig, WStompConnectError, config::WStompConfigOpts}; /// Connect to STOMP server without additional parameters +/// +/// Creates and builds the client automatically. pub async fn connect(url: U) -> Result where Uri: TryFrom, @@ -17,6 +22,8 @@ where } /// Connect to STOMP server using authorization token +/// +/// Creates and builds the client automatically. pub async fn connect_with_token( url: U, auth_token: impl Into, @@ -32,6 +39,8 @@ where } /// Connect to STOMP server using password +/// +/// Creates and builds the client automatically. pub async fn connect_with_pass( url: U, login: impl Into, @@ -66,69 +75,122 @@ impl StompConnect for WebsocketsRequest { } } -impl WStompConfig { - pub async fn build_and_connect(self) -> Result +impl WStompConfig +where + Uri: TryFrom, + >::Error: Into, +{ + /// Build the client and connect (once). + pub async fn build_and_connect(self) -> Result { + let (url, opts) = self.into_parts(); + + let uri = Uri::try_from(url).map_err(|e| { + let err: HttpError = e.into(); + WStompConnectError::WsClientError(WsClientError::from(err)) + })?; + + inner_connect(uri, opts).await + } + + /// Build the client and spawns connect procedure with reconnection mechanism. + /// The result from the connection procedure and all subsequent reconnection attempts is passed into the callback. + pub fn build_and_connect_with_reconnection_cb( + self, + cb: F, + ) -> Result<(), WStompConnectError> where - Uri: TryFrom, - >::Error: Into, + F: Fn(Result) -> R + 'static, + R: Future, { - let (url, opts) = self.into_inner(); - - let client = if let Some(client) = opts.client { - client - } else { - #[cfg(feature = "rustls")] - if opts.ssl { - crate::connect_ssl::create_ssl_client() - } else { - awc::Client::default() - } - #[cfg(not(feature = "rustls"))] - awc::Client::default() - }; + let (url, opts) = self.into_parts(); let uri = Uri::try_from(url).map_err(|e| { let err: HttpError = e.into(); WStompConnectError::WsClientError(WsClientError::from(err)) })?; - let (authority, host_name) = uri - .authority() - .map(|a| (a.to_string(), a.host().to_string())) - .unwrap_or_default(); + let mut backoff = ExponentialBackoffBuilder::new() + .with_initial_interval(Duration::from_secs(opts.retry_initial_interval)) + .with_max_interval(Duration::from_secs(opts.retry_max_interval)) + .with_multiplier(opts.retry_multiplier) + .with_max_elapsed_time(opts.retry_max_elapsed_time.map(Duration::from_secs)) + .build(); + + actix_rt::spawn(async move { + loop { + let tx = inner_connect(uri.clone(), opts.clone()).await; + + if tx.is_ok() { + backoff.reset(); + } else if let Some(duration) = backoff.next_backoff() { + sleep(duration).await; + } else { + cb(Err(WStompConnectError::ReconnectionLimit)).await; + break; + } + + cb(tx).await; + } + }); + + Ok(()) + } +} - let mut headers = opts.additional_headers; +pub(crate) fn headers_for_token(auth_token: impl Into) -> Vec<(String, String)> { + vec![("Authorization".to_string(), auth_token.into())] +} - if let Some(auth_token) = opts.auth_token { - headers.extend(headers_for_token(auth_token)); +async fn inner_connect( + uri: Uri, + opts: WStompConfigOpts, +) -> Result { + let client = if let Some(client) = opts.client { + client + } else { + #[cfg(feature = "rustls")] + if opts.ssl { + crate::connect_ssl::create_ssl_client() + } else { + awc::Client::default() } + #[cfg(not(feature = "rustls"))] + awc::Client::default() + }; - let stomp_client = client.ws::(uri).stomp_connect().await?; + let (authority, host_name) = uri + .authority() + .map(|a| (a.to_string(), a.host().to_string())) + .unwrap_or_default(); - let connect_msg = Connector::builder() - .server(authority.clone()) - .virtualhost(authority) - .headers(headers) - .use_tls(true) - .tls_server_name(host_name); + let mut headers = opts.additional_headers; - let connect_msg = if let Some(login) = opts.login - && let Some(passcode) = opts.passcode - { - connect_msg.login(login).passcode(passcode).msg() - } else { - connect_msg.msg() - }; + if let Some(auth_token) = opts.auth_token { + headers.extend(headers_for_token(auth_token)); + } - stomp_client - .send(connect_msg) - .await - .map_err(WStompConnectError::ConnectMessageFailed)?; + let stomp_client = client.ws::(uri).stomp_connect().await?; - Ok(stomp_client) - } -} + let connect_msg = Connector::builder() + .server(authority.clone()) + .virtualhost(authority) + .headers(headers) + .use_tls(true) + .tls_server_name(host_name); -pub(crate) fn headers_for_token(auth_token: impl Into) -> Vec<(String, String)> { - vec![("Authorization".to_string(), auth_token.into())] + let connect_msg = if let Some(login) = opts.login + && let Some(passcode) = opts.passcode + { + connect_msg.login(login).passcode(passcode).msg() + } else { + connect_msg.msg() + }; + + stomp_client + .send(connect_msg) + .await + .map_err(Box::new) + .map_err(WStompConnectError::ConnectMessageFailed)?; + + Ok(stomp_client) } diff --git a/src/connect_ssl.rs b/src/connect_ssl.rs index af57a3b..1f6650f 100644 --- a/src/connect_ssl.rs +++ b/src/connect_ssl.rs @@ -14,7 +14,9 @@ where WStompConfig::new(url).ssl().build_and_connect().await } -/// Connect to STOMP server through SSL using authorization token +/// Connect to STOMP server through SSL using authorization token. +/// +/// Creates and builds the client automatically. pub async fn connect_ssl_with_token( url: U, auth_token: impl Into, @@ -30,7 +32,9 @@ where .await } -/// Connect to STOMP server through SSL using password +/// Connect to STOMP server through SSL using password. +/// +/// Creates and builds the client automatically. pub async fn connect_ssl_with_pass( url: U, login: String, diff --git a/src/stomp_handler.rs b/src/stomp_handler.rs index 590f39c..0ab660e 100644 --- a/src/stomp_handler.rs +++ b/src/stomp_handler.rs @@ -30,6 +30,9 @@ pub(crate) async fn stomp_handler_task( let mut interval = actix_rt::time::interval(Duration::from_secs(20)); + let mut pings_sent = 0; + let mut pongs_received = 0; + loop { select! { // Received a message from the WebSocket server @@ -51,10 +54,10 @@ pub(crate) async fn stomp_handler_task( finished_reading = true; } WsFrame::Close(reason) => { - let _ = stomp_tx.send(WStompEvent::WebsocketClosed(reason)).await; + let _ = stomp_tx.send(WStompEvent::Error(WStompError::WebsocketClosed(reason))).await; break; } - WsFrame::Pong(_) => {} + WsFrame::Pong(_) => pongs_received += 1, WsFrame::Continuation(item) => { match item { WsItem::FirstText(bytes) => { @@ -121,7 +124,17 @@ pub(crate) async fn stomp_handler_task( } _ = interval.tick() => { - let _ = ws_sink.send(WsMessage::Ping(Bytes::from_static(b"wstomp"))).await; + if pongs_received < pings_sent { + let _ = stomp_tx.send(WStompEvent::Error(WStompError::PingTimeout)).await; + break; + } + match ws_sink.send(WsMessage::Ping(Bytes::from_static(b"wstomp"))).await { + Ok(_) => pings_sent += 1, + Err(err) => { + let _ = stomp_tx.send(WStompEvent::Error(WStompError::PingFailed(err.into()))).await; + break; + } + } } // 3. Both streams closed, exit loop diff --git a/src/wstomp_event.rs b/src/wstomp_event.rs index d87e60f..29b7409 100644 --- a/src/wstomp_event.rs +++ b/src/wstomp_event.rs @@ -8,7 +8,8 @@ use tokio::sync::mpsc::error::SendError; #[derive(Debug)] pub enum WStompConnectError { WsClientError(WsClientError), - ConnectMessageFailed(SendError>), + ConnectMessageFailed(Box>>), + ReconnectionLimit, } /// Custom enum combine events in WebSocket and STOMP @@ -16,8 +17,6 @@ pub enum WStompConnectError { pub enum WStompEvent { /// Regular message from STOMP protocol Message(Message), - /// Websocket closed connection (with reason) - WebsocketClosed(Option), /// WebSocket or STOMP error combined Error(WStompError), } @@ -44,6 +43,12 @@ pub enum WStompError { /// This is a warning that WebSocket protocol finished receiving data, but STOMP protocol /// doesn't recognize it as a full STOMP message. Should not happen, can be ignored in most cases. IncompleteStompFrame, + /// Websocket closed connection (with reason) + WebsocketClosed(Option), + /// Can't send ping, probably network problems + PingFailed(anyhow::Error), + /// Haven't received pong from last ping + PingTimeout, } impl std::fmt::Display for WStompConnectError { @@ -51,6 +56,7 @@ impl std::fmt::Display for WStompConnectError { match self { Self::WsClientError(err) => write!(f, "WebSocket receive error: {}", err), Self::ConnectMessageFailed(msg) => write!(f, "WebSocket receive error: {}", msg), + Self::ReconnectionLimit => write!(f, "Reconnection retry limit reached"), } } } @@ -61,12 +67,22 @@ impl std::fmt::Display for WStompError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::WsReceive(err) => write!(f, "WebSocket receive error: {}", err), + Self::WsSend(err) => write!(f, "WebSocket send error: {}", err), Self::StompDecoding(err) => write!(f, "STOMP decoding error: {}", err), Self::StompEncoding(err) => write!(f, "STOMP encoding error: {}", err), Self::IncompleteStompFrame => { write!(f, "STOMP decoding warning: Dropped incomplete frame") } - Self::WsSend(err) => write!(f, "WebSocket send error: {}", err), + Self::WebsocketClosed(reason) => write!( + f, + "Websocket closed {}", + reason + .as_ref() + .map(|r| r.description.clone().unwrap_or_default()) + .unwrap_or_default() + ), + Self::PingFailed(err) => write!(f, "Websocket ping failed: {err}"), + Self::PingTimeout => write!(f, "Websocket ping timeout"), } } }