Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "reqwest-websocket"
version = "0.5.1"
version = "0.6.0"
edition = "2021"
authors = ["Janosch Gräf <janosch.graef@gmail.com>"]
description = "WebSocket connections with reqwest"
Expand All @@ -20,7 +20,10 @@ all-features = true
rustdoc-args = ["--cfg", "docsrs"]

[features]
default = []
full = ["json", "middleware"]
json = ["dep:serde", "dep:serde_json"]
middleware = ["dep:reqwest-middleware"]

[dependencies]
# pin version, see https://github.com/jgraef/reqwest-websocket/pull/33
Expand All @@ -35,6 +38,7 @@ serde_json = { version = "1.0", default-features = false, optional = true, featu
"alloc",
] }
bytes = "1.10.1"
reqwest-middleware = { version = "0.5.0", optional = true }

[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
async-tungstenite = { version = "0.32", default-features = false, features = [
Expand Down Expand Up @@ -68,6 +72,8 @@ futures-util = { version = "0.3", default-features = false, features = [
"sink",
"alloc",
] }
async-trait = "0.1.89"
http = "1.4.0"

[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
tokio = { version = "1", features = ["macros", "rt"] }
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

Extension for [`reqwest`][2] to allow [websocket][1] connections.

This crate contains the extension trait [`RequestBuilderExt`][4], which adds an
This crate contains the extension trait [`Upgrade`][4], which adds an
`upgrade` method to `reqwest::RequestBuilder` that prepares the HTTP request to
upgrade the connection to a WebSocket. After you call `upgrade()`, you can send
your upgraded request as usual with `send()`, which will return an
Expand All @@ -22,7 +22,7 @@ For a full example take a look at [`hello_world.rs`](examples/hello_world.rs).

```rust
// Extends the `reqwest::RequestBuilder` to allow WebSocket upgrades.
use reqwest_websocket::RequestBuilderExt;
use reqwest_websocket::Upgrade;

// Creates a GET request, upgrades and sends it.
let response = Client::default()
Expand Down Expand Up @@ -56,4 +56,4 @@ request.
[1]: https://en.wikipedia.org/wiki/WebSocket
[2]: https://docs.rs/reqwest/latest/reqwest/index.html
[3]: https://docs.rs/web-sys/latest/web_sys/struct.WebSocket.html
[4]: https://docs.rs/reqwest-websocket/0.1.0/reqwest_websocket/trait.RequestBuilderExt.html
[4]: https://docs.rs/reqwest-websocket/latest/reqwest_websocket/trait.Upgrade.html
6 changes: 5 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Upcoming Version: 0.5.2
# Upcoming Version: 0.6

Document new features here. Document whether your changes are *breaking* semver-compatibility.

- Update `reqwest` to 0.13
- Update `tungstenite` to 0.28
- Update `async-tungstenite` to 0.32
- tests: use local test server for native tests
- Add `Client` and `RequestBuilder` traits that abstract over the specific implementation of these types.
Rename `RequestBuilderExt` to `Upgrade`. `Upgrade` is a blanket extension for anything that implements our `RequestBuilder` trait.
- Add support for `request_middleware` behind `middleware` flag.

# 0.5.1

Expand Down
2 changes: 1 addition & 1 deletion examples/hello_world.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use reqwest::Client;
use reqwest_websocket::{Error, Message, RequestBuilderExt};
use reqwest_websocket::{Error, Message, Upgrade};

#[tokio::main(flavor = "current_thread")]
async fn main() -> Result<(), Error> {
Expand Down
95 changes: 76 additions & 19 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! #
//! # async fn run() -> Result<(), Error> {
//! // Extends the `reqwest::RequestBuilder` to allow WebSocket upgrades.
//! use reqwest_websocket::RequestBuilderExt;
//! use reqwest_websocket::Upgrade;
//!
//! // Creates a GET request, upgrades and sends it.
//! let response = Client::default()
Expand Down Expand Up @@ -49,13 +49,16 @@

#[cfg(feature = "json")]
mod json;
#[cfg(feature = "middleware")]
mod middleware;
#[cfg(not(target_arch = "wasm32"))]
mod native;
mod protocol;
#[cfg(target_arch = "wasm32")]
mod wasm;

use std::{
future::Future,
pin::Pin,
task::{ready, Context, Poll},
};
Expand All @@ -66,7 +69,7 @@ pub use crate::native::HandshakeError;
pub use crate::protocol::{CloseCode, Message};
pub use bytes::Bytes;
use futures_util::{Sink, SinkExt, Stream, StreamExt};
use reqwest::{Client, ClientBuilder, IntoUrl, RequestBuilder};
use reqwest::IntoUrl;

/// Errors returned by `reqwest_websocket`.
#[derive(Debug, thiserror::Error)]
Expand All @@ -91,10 +94,14 @@ pub enum Error {
WebSys(#[from] wasm::WebSysError),

/// Error during serialization/deserialization.
#[error("serde_json error")]
#[cfg(feature = "json")]
#[cfg_attr(docsrs, doc(cfg(feature = "json")))]
#[error("serde_json error")]
Json(#[from] serde_json::Error),

#[cfg(feature = "middleware")]
#[error("reqwest_middleware error")]
ReqwestMiddleware(#[from] reqwest_middleware::Error),
}

/// Opens a `WebSocket` connection at the specified `URL`.
Expand All @@ -105,7 +112,7 @@ pub enum Error {
/// [`Request`]: reqwest::Request
/// [`Response`]: reqwest::Response
pub async fn websocket(url: impl IntoUrl) -> Result<WebSocket, Error> {
builder_http1_only(Client::builder())
builder_http1_only(reqwest::Client::builder())
.build()?
.get(url)
.upgrade()
Expand All @@ -117,42 +124,89 @@ pub async fn websocket(url: impl IntoUrl) -> Result<WebSocket, Error> {

#[inline]
#[cfg(not(target_arch = "wasm32"))]
fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
fn builder_http1_only(builder: reqwest::ClientBuilder) -> reqwest::ClientBuilder {
builder.http1_only()
}

#[inline]
#[cfg(target_arch = "wasm32")]
fn builder_http1_only(builder: ClientBuilder) -> ClientBuilder {
fn builder_http1_only(builder: reqwest::ClientBuilder) -> reqwest::ClientBuilder {
builder
}

/// Trait that extends [`reqwest::RequestBuilder`] with an `upgrade` method.
pub trait RequestBuilderExt {
/// A generic client.
///
/// This is needed by [`RequestBuilder`] to be generic over the specific implementation of a client.
/// Its only requirement is to be able to execute [`reqwest::Request`]s.
///
/// This is implemented for [`reqwest::Client`] and [`reqwest_middleware::ClientWithMiddleware`] (with `middleware` feature).
/// It provides a single interface for executing a [`reqwest::Request`].
pub trait Client {
fn execute(
&self,
request: reqwest::Request,
) -> impl Future<Output = Result<reqwest::Response, Error>> + '_;
}

impl Client for reqwest::Client {
async fn execute(&self, request: reqwest::Request) -> Result<reqwest::Response, Error> {
self.execute(request).await.map_err(Into::into)
}
}

/// A generic request builder.
///
/// This is needed by [`Upgraded`] to be generic over the specific implementation of a request (and client).
/// Its only requirements are that it provides the specific client type, and can build itself into a client and a [`reqwest::Request`].
pub trait RequestBuilder {
type Client: Client;

fn build_split(self) -> (Self::Client, Result<reqwest::Request, Error>);
}

impl RequestBuilder for reqwest::RequestBuilder {
type Client = reqwest::Client;

fn build_split(self) -> (Self::Client, Result<reqwest::Request, Error>) {
let (client, request) = reqwest::RequestBuilder::build_split(self);
(client, request.map_err(Into::into))
}
}

/// Extension trait for requests builders that can be upgraded to a websocket connection.
///
/// This is automatically implemented for anything that implements our [`RequestBuilder`] trait.
pub trait Upgrade: Sized {
/// Upgrades the [`RequestBuilder`] to perform a `WebSocket` handshake.
///
/// This returns a wrapped type, so you must do this after you set up
/// your request, and just before sending the request.
fn upgrade(self) -> UpgradedRequestBuilder;
fn upgrade(self) -> Upgraded<Self>;
}

impl RequestBuilderExt for RequestBuilder {
fn upgrade(self) -> UpgradedRequestBuilder {
UpgradedRequestBuilder::new(self)
impl<R> Upgrade for R
where
R: RequestBuilder,
{
fn upgrade(self) -> Upgraded<Self> {
Upgraded::new(self)
}
}

/// Wrapper for a [`reqwest::RequestBuilder`] that performs the
/// `WebSocket` handshake when sent.
pub struct UpgradedRequestBuilder {
inner: RequestBuilder,
pub struct Upgraded<R> {
inner: R,
protocols: Vec<String>,
#[cfg(not(target_arch = "wasm32"))]
web_socket_config: Option<tungstenite::protocol::WebSocketConfig>,
}

impl UpgradedRequestBuilder {
pub(crate) fn new(inner: RequestBuilder) -> Self {
impl<R> Upgraded<R>
where
R: RequestBuilder,
{
pub(crate) fn new(inner: R) -> Self {
Self {
inner,
protocols: vec![],
Expand Down Expand Up @@ -181,7 +235,10 @@ impl UpgradedRequestBuilder {
let inner = native::send_request(self.inner, &self.protocols).await?;

#[cfg(target_arch = "wasm32")]
let inner = wasm::WebSysWebSocketStream::new(self.inner.build()?, &self.protocols).await?;
let inner = {
let request = self.inner.build_split().1?;
wasm::WebSysWebSocketStream::new(request, &self.protocols).await?
};

Ok(UpgradeResponse {
inner,
Expand Down Expand Up @@ -344,7 +401,7 @@ mod tests {
#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::wasm_bindgen_test;

use crate::{websocket, CloseCode, Message, RequestBuilderExt, WebSocket};
use crate::{websocket, CloseCode, Message, Upgrade, WebSocket};

#[cfg(target_arch = "wasm32")]
wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
Expand Down Expand Up @@ -463,7 +520,7 @@ mod tests {
}
}

async fn test_websocket(mut websocket: WebSocket) {
pub async fn test_websocket(mut websocket: WebSocket) {
let text = "Hello, World!";
websocket.send(Message::Text(text.into())).await.unwrap();

Expand Down
82 changes: 82 additions & 0 deletions src/middleware.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use crate::{Client, Error, RequestBuilder};

impl Client for reqwest_middleware::ClientWithMiddleware {
async fn execute(&self, request: reqwest::Request) -> Result<reqwest::Response, Error> {
self.execute(request).await.map_err(Into::into)
}
}

impl RequestBuilder for reqwest_middleware::RequestBuilder {
type Client = reqwest_middleware::ClientWithMiddleware;

fn build_split(self) -> (Self::Client, Result<reqwest::Request, Error>) {
let (client, request) = reqwest_middleware::RequestBuilder::build_split(self);
(client, request.map_err(Into::into))
}
}

#[cfg(test)]
#[cfg(not(target_arch = "wasm32"))]
mod tests {
use crate::{
tests::{test_websocket, TestServer},
Upgrade,
};
use std::sync::{Arc, Mutex};

#[derive(Debug)]
struct TestMiddleware {
did_run: Arc<Mutex<bool>>,
}

#[async_trait::async_trait]
impl reqwest_middleware::Middleware for TestMiddleware {
async fn handle(
&self,
req: reqwest::Request,
extensions: &mut http::Extensions,
next: reqwest_middleware::Next<'_>,
) -> Result<reqwest::Response, reqwest_middleware::Error> {
{
let mut did_run = self.did_run.lock().unwrap();
*did_run = true;
}
next.run(req, extensions).await
}
}

//#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
//#[cfg_attr(not(target_arch = "wasm32"), tokio::test)]
#[tokio::test]
async fn websocket_with_middleware() {
let echo = TestServer::new().await;

let did_run = Arc::new(Mutex::new(false));
let middleware = TestMiddleware {
did_run: did_run.clone(),
};

let client = reqwest::Client::builder().http1_only().build().unwrap();
let client = reqwest_middleware::ClientBuilder::new(client)
.with(middleware)
.build();

let websocket = client
.get(echo.http_url())
.upgrade()
.send()
.await
.unwrap()
.into_websocket()
.await
.unwrap();

test_websocket(websocket).await;

let did_run = {
let did_run = did_run.lock().unwrap();
*did_run
};
assert!(did_run);
}
}
13 changes: 8 additions & 5 deletions src/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,21 @@ use std::borrow::Cow;

use crate::{
protocol::{CloseCode, Message},
Error,
Client, Error, RequestBuilder,
};
use reqwest::{
header::{HeaderName, HeaderValue},
RequestBuilder, Response, StatusCode, Version,
Response, StatusCode, Version,
};
use tungstenite::protocol::WebSocketConfig;

pub async fn send_request(
request_builder: RequestBuilder,
pub async fn send_request<R>(
request_builder: R,
protocols: &[String],
) -> Result<WebSocketResponse, Error> {
) -> Result<WebSocketResponse, Error>
where
R: RequestBuilder,
{
let (client, request_result) = request_builder.build_split();
let mut request = request_result?;

Expand Down