From 2d3d35ba6982569320a0663e7338986a03bb8594 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Puel?= Date: Mon, 16 Dec 2024 16:29:45 +0000 Subject: [PATCH 1/2] Removed unnecessary memory allocations on Win implementation Two arcs and one Box were all merged into one single Pin> by using a self referential structure. Also, a dynamic dyspatch as removed in favor of using generics --- src/win.rs | 71 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/src/win.rs b/src/win.rs index b0a2bce..7977696 100644 --- a/src/win.rs +++ b/src/win.rs @@ -8,7 +8,6 @@ use std::ffi::c_void; use std::io::{Error, ErrorKind, Result}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; use std::task::{Context, Poll}; use windows::Win32::Foundation::{BOOLEAN, HANDLE}; use windows::Win32::NetworkManagement::IpHelper::{ @@ -43,24 +42,25 @@ pub struct IfWatcher { queue: VecDeque, #[allow(unused)] notif: IpChangeNotification, - waker: Arc, - resync: Arc, + shared: Pin>, } impl IfWatcher { /// Create a watcher. pub fn new() -> Result { - let resync = Arc::new(AtomicBool::new(true)); - let waker = Arc::new(AtomicWaker::new()); + let shared = IfWatcherShared { + resync: true.into(), + waker: Default::default(), + }; + let shared = Box::pin(shared); Ok(Self { addrs: Default::default(), queue: Default::default(), - waker: waker.clone(), - resync: resync.clone(), - notif: IpChangeNotification::new(Box::new(move |_, _| { - resync.store(true, Ordering::Relaxed); - waker.wake(); - }))?, + // Safety: + // Self referential structure, `shared` will be dropped + // after `notif` + notif: unsafe { IpChangeNotification::new(shared.as_ref())? }, + shared, }) } @@ -96,8 +96,8 @@ impl IfWatcher { if let Some(event) = self.queue.pop_front() { return Poll::Ready(Ok(event)); } - if !self.resync.swap(false, Ordering::Relaxed) { - self.waker.register(cx.waker()); + if !self.shared.resync.swap(false, Ordering::Relaxed) { + self.shared.waker.register(cx.waker()); return Poll::Pending; } if let Err(error) = self.resync() { @@ -137,10 +137,22 @@ fn ifaddr_to_ipnet(addr: IfAddr) -> IpNet { } } +#[derive(Debug)] +struct IfWatcherShared { + waker: AtomicWaker, + resync: AtomicBool, +} + +impl IpChangeCallback for IfWatcherShared { + fn callback(&self, _row: &MIB_IPINTERFACE_ROW, _notification_type: MIB_NOTIFICATION_TYPE) { + self.resync.store(true, Ordering::Relaxed); + self.waker.wake(); + } +} + /// IP change notifications struct IpChangeNotification { handle: HANDLE, - callback: *mut IpChangeCallback, } impl std::fmt::Debug for IpChangeNotification { @@ -149,31 +161,37 @@ impl std::fmt::Debug for IpChangeNotification { } } -type IpChangeCallback = Box; - impl IpChangeNotification { /// Register for route change notifications - fn new(cb: IpChangeCallback) -> Result { - unsafe extern "system" fn global_callback( + /// + /// Safety: C must outlive the resulting Self + unsafe fn new(cb: Pin<&C>) -> Result + where + C: IpChangeCallback + Send + Sync, + { + unsafe extern "system" fn global_callback( caller_context: *const c_void, row: *const MIB_IPINTERFACE_ROW, notification_type: MIB_NOTIFICATION_TYPE, - ) { - (**(caller_context as *const IpChangeCallback))(&*row, notification_type) + ) where + C: IpChangeCallback + Send + Sync, + { + let caller_context = &*(caller_context as *const C); + caller_context.callback(&*row, notification_type) } let mut handle = HANDLE::default(); - let callback = Box::into_raw(Box::new(cb)); + let callback = cb.get_ref() as *const C; unsafe { NotifyIpInterfaceChange( AF_UNSPEC, - Some(global_callback), - Some(callback as _), + Some(global_callback::), + Some(callback as *const c_void), BOOLEAN(0), &mut handle as _, ) .map_err(|err| Error::new(ErrorKind::Other, err.to_string()))?; } - Ok(Self { callback, handle }) + Ok(Self { handle }) } } @@ -183,9 +201,12 @@ impl Drop for IpChangeNotification { if let Err(err) = CancelMibChangeNotify2(self.handle) { log::error!("error deregistering notification: {}", err); } - drop(Box::from_raw(self.callback)); } } } unsafe impl Send for IpChangeNotification {} + +trait IpChangeCallback { + fn callback(&self, row: &MIB_IPINTERFACE_ROW, notification_type: MIB_NOTIFICATION_TYPE); +} From ce893b00fa039cb5d7adba96b911cc74c30be4e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Puel?= Date: Mon, 16 Dec 2024 22:06:55 +0000 Subject: [PATCH 2/2] Fixed race condition on Win poll It was possible that a thread would read resync as false before another thread set it to true and subsequently register a waker after the other thread set resync to true and already tried calling AtomicWaker::wake(). To fix this issue, we first register the waker, and only then we check the resync value. If resync is true and we didn't have to register the waker, we just take it back. --- src/win.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/win.rs b/src/win.rs index 7977696..64cfe0e 100644 --- a/src/win.rs +++ b/src/win.rs @@ -96,10 +96,13 @@ impl IfWatcher { if let Some(event) = self.queue.pop_front() { return Poll::Ready(Ok(event)); } - if !self.shared.resync.swap(false, Ordering::Relaxed) { - self.shared.waker.register(cx.waker()); + + self.shared.waker.register(cx.waker()); + if !self.shared.resync.swap(false, Ordering::AcqRel) { return Poll::Pending; } + self.shared.waker.take(); + if let Err(error) = self.resync() { return Poll::Ready(Err(error)); } @@ -145,7 +148,7 @@ struct IfWatcherShared { impl IpChangeCallback for IfWatcherShared { fn callback(&self, _row: &MIB_IPINTERFACE_ROW, _notification_type: MIB_NOTIFICATION_TYPE) { - self.resync.store(true, Ordering::Relaxed); + self.resync.store(true, Ordering::Release); self.waker.wake(); } }