Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Thumbs.db
/Cargo.lock
/target
/.cargo
/tmp

# wasm example
/examples/wasm/dist
Expand Down
45 changes: 36 additions & 9 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,56 @@ json = ["dep:serde", "dep:serde_json"]

[dependencies]
# pin version, see https://github.com/jgraef/reqwest-websocket/pull/33
futures-util = { version = ">=0.3.31", default-features = false, features = ["sink"] }
futures-util = { version = ">=0.3.31", default-features = false, features = [
"sink",
] }
reqwest = { version = "0.12", default-features = false }
thiserror = "2"
tracing = "0.1"
serde = { version = "1.0", default-features = false, optional = true }
serde_json = { version = "1.0", default-features = false, optional = true, features = ["alloc"] }
serde_json = { version = "1.0", default-features = false, optional = true, features = [
"alloc",
] }
bytes = "1.10.1"

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
async-tungstenite = { version = "0.31.0", default-features = false, features = ["futures-03-sink"] }
tokio-util = { version = "0.7", default-features = false, features = ["compat"] }
tungstenite = { version = "0.27", default-features = false, features = ["handshake"] }
async-tungstenite = { version = "0.32.0", default-features = false, features = [
"futures-03-sink",
] }
tokio-util = { version = "0.7", default-features = false, features = [
"compat",
] }
tungstenite = { version = "0.28", default-features = false, features = [
"handshake",
] }

[target.'cfg(target_arch = "wasm32")'.dependencies]
web-sys = { version = "0.3", features = ["WebSocket", "CloseEvent", "ErrorEvent", "Event", "MessageEvent", "BinaryType"] }
tokio = { version = "1", default-features = false, features = ["sync", "macros"] }
web-sys = { version = "0.3", features = [
"WebSocket",
"CloseEvent",
"ErrorEvent",
"Event",
"MessageEvent",
"BinaryType",
] }
tokio = { version = "1", default-features = false, features = [
"sync",
"macros",
] }

[dev-dependencies]
tokio = { version = "1", features = ["macros", "rt"] }
reqwest = { version = "0.12", features = ["default-tls"] }
serde = { version = "1.0", features = ["derive"] }
futures-util = { version = "0.3", default-features = false, features = ["sink", "alloc"] }
futures-util = { version = "0.3", default-features = false, features = [
"sink",
"alloc",
] }

[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
tokio = { version = "1", features = ["macros", "rt"] }
wasm-bindgen-test = "0.3"
wasm-bindgen-futures = "0.4"

[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
axum = { version = "0.8.7", features = ["ws"] }
tokio = { version = "1", features = ["macros", "rt", "net"] }
158 changes: 147 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -338,16 +338,130 @@ impl Sink<Message> for WebSocket {
}

#[cfg(test)]
pub mod tests {
mod tests {
use futures_util::{SinkExt, TryStreamExt};
use reqwest::Client;
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test;

use crate::{websocket, CloseCode, Message, RequestBuilderExt, WebSocket};

#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);

use super::{websocket, CloseCode, Message, RequestBuilderExt, WebSocket};
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug)]
pub struct TestServer {
shutdown_sender: Option<tokio::sync::oneshot::Sender<()>>,
http_url: String,
ws_url: String,
}

#[cfg(not(target_arch = "wasm32"))]
impl TestServer {
pub async fn new() -> Self {
async fn handle_connection(mut socket: axum::extract::ws::WebSocket) {
if let Some(protocol) = socket.protocol() {
if let Ok(protocol) = protocol.to_str() {
println!("server/protocol: {protocol:?}");
if let Err(error) = socket
.send(axum::extract::ws::Message::Text(
format!("protocol: {protocol}").into(),
))
.await
{
eprintln!("server/send: {error}");
return;
}
} else {
println!("server/protocol: could not convert to utf-8");
}
}

while let Some(message) = socket.recv().await {
match message {
Ok(message) => match &message {
axum::extract::ws::Message::Text(_)
| axum::extract::ws::Message::Binary(_) => {
if let Err(error) = socket.send(message).await {
eprintln!("server/send: {error}");
break;
}
}
_ => {}
},
Err(error) => {
eprintln!("server/recv: {error}");
break;
}
}
}
}

let (shutdown_sender, shutdown_receiver) = tokio::sync::oneshot::channel();
let listener = tokio::net::TcpListener::bind(("localhost", 0))
.await
.unwrap();
let port = listener.local_addr().unwrap().port();
let app = axum::Router::new().route(
"/",
axum::routing::any(|ws: axum::extract::ws::WebSocketUpgrade| async move {
ws.protocols(["chat"]).on_upgrade(handle_connection)
}),
);

// todo: I think we'll need to spawn this on a proper thread (for which we create a separate runtime) for this to be shared across multiple tests
let _join_handle = tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move {
let _ = shutdown_receiver.await;
})
.await
.unwrap();
});
Self {
shutdown_sender: Some(shutdown_sender),
http_url: format!("http://localhost:{port}/"),
ws_url: format!("ws://localhost:{port}/"),
}
}

pub fn http_url(&self) -> &str {
&self.http_url
}

pub fn ws_url(&self) -> &str {
&self.ws_url
}
}

#[cfg(not(target_arch = "wasm32"))]
impl Drop for TestServer {
fn drop(&mut self) {
if let Some(shutdown_sender) = self.shutdown_sender.take() {
println!("Shutting down server");
let _ = shutdown_sender.send(());
}
}
}

#[cfg(target_arch = "wasm32")]
pub struct TestServer;

#[cfg(target_arch = "wasm32")]
impl TestServer {
pub async fn new() -> Self {
Self
}

pub fn http_url(&self) -> &str {
"https://echo.websocket.org/"
}

pub fn ws_url(&self) -> &str {
"wss://echo.websocket.org/"
}
}

async fn test_websocket(mut websocket: WebSocket) {
let text = "Hello, World!";
Expand All @@ -370,8 +484,10 @@ pub mod tests {
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_with_request_builder() {
let echo = TestServer::new().await;

let websocket = Client::default()
.get("https://echo.websocket.org/")
.get(echo.http_url())
.upgrade()
.send()
.await
Expand All @@ -386,22 +502,27 @@ pub mod tests {
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_shorthand() {
let websocket = websocket("https://echo.websocket.org/").await.unwrap();
let echo = TestServer::new().await;

let websocket = websocket(echo.http_url()).await.unwrap();
test_websocket(websocket).await;
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_with_ws_scheme() {
let websocket = websocket("wss://echo.websocket.org/").await.unwrap();
let echo = TestServer::new().await;
let websocket = websocket(echo.ws_url()).await.unwrap();

test_websocket(websocket).await;
}

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_close() {
let websocket = websocket("https://echo.websocket.org/").await.unwrap();
let echo = TestServer::new().await;

let websocket = websocket(echo.http_url()).await.unwrap();
websocket
.close(CloseCode::Normal, Some("test"))
.await
Expand All @@ -411,7 +532,9 @@ pub mod tests {
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
async fn test_send_close_frame() {
let mut websocket = websocket("https://echo.websocket.org/").await.unwrap();
let echo = TestServer::new().await;

let mut websocket = websocket(echo.http_url()).await.unwrap();
websocket
.send(Message::Close {
code: CloseCode::Normal,
Expand All @@ -436,10 +559,15 @@ pub mod tests {

#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
#[ignore = "https://echo.websocket.org/ ignores subprotocols"]
#[cfg_attr(
target_arch = "wasm32",
ignore = "echo.websocket.org ignores subprotocols"
)]
async fn test_with_subprotocol() {
let websocket = Client::default()
.get("https://echo.websocket.org/")
let echo = TestServer::new().await;

let mut websocket = Client::default()
.get(echo.http_url())
.upgrade()
.protocols(["chat"])
.send()
Expand All @@ -451,7 +579,15 @@ pub mod tests {

assert_eq!(websocket.protocol(), Some("chat"));

test_websocket(websocket).await;
let message = websocket.try_next().await.unwrap().unwrap();
match message {
Message::Text(s) => {
assert_eq!(s, "protocol: chat");
}
_ => {
panic!("Expected text message with selected protocol");
}
}
}

#[test]
Expand Down