diff --git a/packages/blitz-net/src/lib.rs b/packages/blitz-net/src/lib.rs index 857b8f6f9..764782a97 100644 --- a/packages/blitz-net/src/lib.rs +++ b/packages/blitz-net/src/lib.rs @@ -2,9 +2,10 @@ //! //! Provides an implementation of the [`blitz_traits::net::NetProvider`] trait. -use blitz_traits::net::{Body, Bytes, NetHandler, NetProvider, NetWaker, Request}; +// use blitz_traits::net::{Body, Bytes, NetHandler, NetProvider, NetWaker, Request}; +use blitz_traits::net::{AbortSignal, Body, Bytes, NetHandler, NetProvider, NetWaker, Request}; use data_url::DataUrl; -use std::sync::Arc; +use std::{marker::PhantomData, pin::Pin, sync::Arc, task::Poll}; use tokio::runtime::Handle; #[cfg(feature = "cache")] @@ -102,16 +103,6 @@ impl Provider { }) } - async fn fetch_with_handler( - client: Client, - request: Request, - handler: Box, - ) -> Result<(), ProviderError> { - let (response_url, bytes) = Self::fetch_inner(client, request).await?; - handler.bytes(response_url, bytes); - Ok(()) - } - #[allow(clippy::type_complexity)] pub fn fetch_with_callback( &self, @@ -155,7 +146,7 @@ impl Provider { } impl NetProvider for Provider { - fn fetch(&self, doc_id: usize, request: Request, handler: Box) { + fn fetch(&self, doc_id: usize, mut request: Request, handler: Box) { let client = self.client.clone(); #[cfg(feature = "debug_log")] @@ -166,23 +157,80 @@ impl NetProvider for Provider { #[cfg(feature = "debug_log")] let url = request.url.to_string(); - let _res = Self::fetch_with_handler(client, request, handler).await; - - #[cfg(feature = "debug_log")] - if let Err(e) = _res { - eprintln!("Error fetching {url}: {e:?}"); + let signal = request.signal.take(); + let result = if let Some(signal) = signal { + AbortFetch::new( + signal, + Box::pin(async move { Self::fetch_inner(client, request).await }), + ) + .await } else { - println!("Success {url}"); - } + Self::fetch_inner(client, request).await + }; // Call the waker to notify of completed network request - waker.wake(doc_id) + waker.wake(doc_id); + + match result { + Ok((response_url, bytes)) => { + handler.bytes(response_url, bytes); + #[cfg(feature = "debug_log")] + println!("Success {url}"); + } + Err(e) => { + #[cfg(feature = "debug_log")] + eprintln!("Error fetching {url}: {e:?}"); + #[cfg(not(feature = "debug_log"))] + let _ = e; + } + }; }); } } +/// A future that is cancellable using an AbortSignal +struct AbortFetch { + signal: AbortSignal, + future: F, + _rt: PhantomData, +} + +impl AbortFetch { + fn new(signal: AbortSignal, future: F) -> Self { + Self { + signal, + future, + _rt: PhantomData, + } + } +} + +impl Future for AbortFetch +where + F: Future + Unpin + Send + 'static, + F::Output: Send + Into> + 'static, + T: Unpin, +{ + type Output = Result; + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + if self.signal.aborted() { + return Poll::Ready(Err(ProviderError::Abort)); + } + + match Pin::new(&mut self.future).poll(cx) { + Poll::Ready(output) => Poll::Ready(output.into()), + Poll::Pending => Poll::Pending, + } + } +} + #[derive(Debug)] pub enum ProviderError { + Abort, Io(std::io::Error), DataUrl(data_url::DataUrlError), DataUrlBase64(data_url::forgiving_base64::InvalidBase64), diff --git a/packages/blitz-traits/src/navigation.rs b/packages/blitz-traits/src/navigation.rs index cab906a09..84317b857 100644 --- a/packages/blitz-traits/src/navigation.rs +++ b/packages/blitz-traits/src/navigation.rs @@ -62,6 +62,7 @@ impl NavigationOptions { content_type: self.content_type, headers: HeaderMap::new(), body: self.document_resource, + signal: None, } } } diff --git a/packages/blitz-traits/src/net.rs b/packages/blitz-traits/src/net.rs index 7f05c4246..cf2321a4c 100644 --- a/packages/blitz-traits/src/net.rs +++ b/packages/blitz-traits/src/net.rs @@ -6,6 +6,10 @@ use serde::{ Serialize, ser::{SerializeSeq, SerializeTuple}, }; +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, +}; use std::{ops::Deref, path::PathBuf}; pub use url::Url; @@ -43,6 +47,7 @@ pub struct Request { pub content_type: String, pub headers: HeaderMap, pub body: Body, + pub signal: Option, } impl Request { /// A get request to the specified Url and an empty body @@ -53,8 +58,14 @@ impl Request { content_type: String::new(), headers: HeaderMap::new(), body: Body::Empty, + signal: None, } } + + pub fn signal(mut self, signal: AbortSignal) -> Self { + self.signal = Some(signal); + self + } } #[derive(Debug, Clone)] @@ -148,3 +159,42 @@ pub struct DummyNetProvider; impl NetProvider for DummyNetProvider { fn fetch(&self, _doc_id: usize, _request: Request, _handler: Box) {} } + +/// The AbortController interface represents a controller object that +/// allows you to abort one or more Web requests as and when desired. +/// +/// https://developer.mozilla.org/en-US/docs/Web/API/AbortController +#[derive(Debug, Default)] +pub struct AbortController { + pub signal: AbortSignal, +} + +impl AbortController { + /// The abort() method of the AbortController interface aborts + /// an asynchronous operation before it has completed. + /// This is able to abort fetch requests. + /// + /// https://developer.mozilla.org/en-US/docs/Web/API/AbortController/abort + pub fn abort(self) { + self.signal.0.store(true, Ordering::SeqCst); + } +} + +/// The AbortSignal interface represents a signal object that allows you to +/// communicate with an asynchronous operation (such as a fetch request) and +/// abort it if required via an AbortController object. +/// +/// https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal +#[derive(Debug, Default, Clone)] +pub struct AbortSignal(Arc); + +impl AbortSignal { + /// The aborted read-only property returns a value that indicates whether + /// the asynchronous operations the signal is communicating with are + /// aborted (true) or not (false). + /// + /// https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal/aborted + pub fn aborted(&self) -> bool { + self.0.load(Ordering::SeqCst) + } +}