diff --git a/src/driver.rs b/src/driver.rs index 168d55d..76f23eb 100644 --- a/src/driver.rs +++ b/src/driver.rs @@ -1,13 +1,15 @@ -use std::cell::Cell; +use std::cell::{Cell, RefCell}; use std::future::Future; use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; +use std::task::Waker; use std::task::{Context, Poll}; use std::thread; use std::time::{Duration, Instant}; use async_lock::OnceCell; use futures_lite::pin; +use parking::Parker; use waker_fn::waker_fn; use crate::reactor::Reactor; @@ -120,112 +122,145 @@ pub fn block_on(future: impl Future) -> T { unparker().unpark(); }); - // Parker and unparker for notifying the current thread. - let (p, u) = parking::pair(); - // This boolean is set to `true` when the current thread is blocked on I/O. - let io_blocked = Arc::new(AtomicBool::new(false)); + // Creates a parker and an associated waker that unparks it. + fn parker_and_waker() -> (Parker, Waker, Arc) { + // Parker and unparker for notifying the current thread. + let (p, u) = parking::pair(); + + // This boolean is set to `true` when the current thread is blocked on I/O. + let io_blocked = Arc::new(AtomicBool::new(false)); + + // Prepare the waker. + let waker = waker_fn({ + let io_blocked = io_blocked.clone(); + move || { + if u.unpark() { + // Check if waking from another thread and if currently blocked on I/O. + if !IO_POLLING.with(Cell::get) && io_blocked.load(Ordering::SeqCst) { + Reactor::get().notify(); + } + } + } + }); + + (p, waker, io_blocked) + } thread_local! { + // Cached parker and waker for efficiency. + static CACHE: RefCell<(Parker, Waker, Arc)> = RefCell::new(parker_and_waker()); + // Indicates that the current thread is polling I/O, but not necessarily blocked on it. static IO_POLLING: Cell = Cell::new(false); } - // Prepare the waker. - let waker = waker_fn({ - let io_blocked = io_blocked.clone(); - move || { - if u.unpark() { - // Check if waking from another thread and if currently blocked on I/O. - if !IO_POLLING.with(Cell::get) && io_blocked.load(Ordering::SeqCst) { - Reactor::get().notify(); - } + CACHE.with(|cache| { + // Try grabbing the cached parker and waker. + let tmp_cached; + let tmp_fresh; + let (p, waker, io_blocked) = match cache.try_borrow_mut() { + Ok(cache) => { + // Use the cached parker and waker. + tmp_cached = cache; + &*tmp_cached } - } - }); - let cx = &mut Context::from_waker(&waker); - pin!(future); + Err(_) => { + // Looks like this is a recursive `block_on()` call. + // Create a fresh parker and waker. + tmp_fresh = parker_and_waker(); + &tmp_fresh + } + }; - loop { - // Poll the future. - if let Poll::Ready(t) = future.as_mut().poll(cx) { - tracing::trace!("completed"); - return t; - } + pin!(future); - // Check if a notification was received. - if p.park_timeout(Duration::from_secs(0)) { - tracing::trace!("notified"); + let cx = &mut Context::from_waker(waker); - // Try grabbing a lock on the reactor to process I/O events. - if let Some(mut reactor_lock) = Reactor::get().try_lock() { - // First let wakers know this parker is processing I/O events. - IO_POLLING.with(|io| io.set(true)); - let _guard = CallOnDrop(|| { - IO_POLLING.with(|io| io.set(false)); - }); - - // Process available I/O events. - reactor_lock.react(Some(Duration::from_secs(0))).ok(); + loop { + // Poll the future. + if let Poll::Ready(t) = future.as_mut().poll(cx) { + // Ensure the cached parker is reset to the unnotified state for future block_on calls, + // in case this future called wake and then immediately returned Poll::Ready. + p.park_timeout(Duration::from_secs(0)); + tracing::trace!("completed"); + return t; } - continue; - } - // Try grabbing a lock on the reactor to wait on I/O. - if let Some(mut reactor_lock) = Reactor::get().try_lock() { - // Record the instant at which the lock was grabbed. - let start = Instant::now(); - - loop { - // First let wakers know this parker is blocked on I/O. - IO_POLLING.with(|io| io.set(true)); - io_blocked.store(true, Ordering::SeqCst); - let _guard = CallOnDrop(|| { - IO_POLLING.with(|io| io.set(false)); - io_blocked.store(false, Ordering::SeqCst); - }); - - // Check if a notification has been received before `io_blocked` was updated - // because in that case the reactor won't receive a wakeup. - if p.park_timeout(Duration::from_secs(0)) { - tracing::trace!("notified"); - break; - } + // Check if a notification was received. + if p.park_timeout(Duration::from_secs(0)) { + tracing::trace!("notified"); - // Wait for I/O events. - tracing::trace!("waiting on I/O"); - reactor_lock.react(None).ok(); + // Try grabbing a lock on the reactor to process I/O events. + if let Some(mut reactor_lock) = Reactor::get().try_lock() { + // First let wakers know this parker is processing I/O events. + IO_POLLING.with(|io| io.set(true)); + let _guard = CallOnDrop(|| { + IO_POLLING.with(|io| io.set(false)); + }); - // Check if a notification has been received. - if p.park_timeout(Duration::from_secs(0)) { - tracing::trace!("notified"); - break; + // Process available I/O events. + reactor_lock.react(Some(Duration::from_secs(0))).ok(); } + continue; + } - // Check if this thread been handling I/O events for a long time. - if start.elapsed() > Duration::from_micros(500) { - tracing::trace!("stops hogging the reactor"); - - // This thread is clearly processing I/O events for some other threads - // because it didn't get a notification yet. It's best to stop hogging the - // reactor and give other threads a chance to process I/O events for - // themselves. - drop(reactor_lock); - - // Unpark the "async-io" thread in case no other thread is ready to start - // processing I/O events. This way we prevent a potential latency spike. - unparker().unpark(); - - // Wait for a notification. - p.park(); - break; + // Try grabbing a lock on the reactor to wait on I/O. + if let Some(mut reactor_lock) = Reactor::get().try_lock() { + // Record the instant at which the lock was grabbed. + let start = Instant::now(); + + loop { + // First let wakers know this parker is blocked on I/O. + IO_POLLING.with(|io| io.set(true)); + io_blocked.store(true, Ordering::SeqCst); + let _guard = CallOnDrop(|| { + IO_POLLING.with(|io| io.set(false)); + io_blocked.store(false, Ordering::SeqCst); + }); + + // Check if a notification has been received before `io_blocked` was updated + // because in that case the reactor won't receive a wakeup. + if p.park_timeout(Duration::from_secs(0)) { + tracing::trace!("notified"); + break; + } + + // Wait for I/O events. + tracing::trace!("waiting on I/O"); + reactor_lock.react(None).ok(); + + // Check if a notification has been received. + if p.park_timeout(Duration::from_secs(0)) { + tracing::trace!("notified"); + break; + } + + // Check if this thread been handling I/O events for a long time. + if start.elapsed() > Duration::from_micros(500) { + tracing::trace!("stops hogging the reactor"); + + // This thread is clearly processing I/O events for some other threads + // because it didn't get a notification yet. It's best to stop hogging the + // reactor and give other threads a chance to process I/O events for + // themselves. + drop(reactor_lock); + + // Unpark the "async-io" thread in case no other thread is ready to start + // processing I/O events. This way we prevent a potential latency spike. + unparker().unpark(); + + // Wait for a notification. + p.park(); + break; + } } + } else { + // Wait for an actual notification. + tracing::trace!("sleep until notification"); + p.park(); } - } else { - // Wait for an actual notification. - tracing::trace!("sleep until notification"); - p.park(); } - } + }) } /// Runs a closure when dropped. diff --git a/tests/block_on.rs b/tests/block_on.rs new file mode 100644 index 0000000..70241f0 --- /dev/null +++ b/tests/block_on.rs @@ -0,0 +1,178 @@ +use async_io::block_on; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll, Waker}, + time::{Duration, Instant}, +}; + +#[test] +fn doesnt_poll_after_ready() { + #[derive(Default)] + struct Bomb { + returned_ready: bool, + } + impl Future for Bomb { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + if self.returned_ready { + panic!("Future was polled again after returning Poll::Ready"); + } else { + self.returned_ready = true; + Poll::Ready(()) + } + } + } + + block_on(Bomb::default()) +} + +#[test] +fn recursive_wakers_are_different() { + struct Outer; + impl Future for Outer { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let outer_waker = cx.waker(); + block_on(Inner { outer_waker }); + Poll::Ready(()) + } + } + + struct Inner<'a> { + pub outer_waker: &'a Waker, + } + impl Future for Inner<'_> { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let inner_waker = cx.waker(); + assert!(!inner_waker.will_wake(self.outer_waker)); + Poll::Ready(()) + } + } + + block_on(Outer); +} + +#[test] +fn inner_cannot_wake_outer() { + #[derive(Default)] + struct Outer { + elapsed: Option, + } + impl Future for Outer { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(elapsed) = self.elapsed { + assert!(elapsed.elapsed() >= Duration::from_secs(1)); + Poll::Ready(()) + } else { + let outer_waker = cx.waker().clone(); + block_on(Inner); + std::thread::spawn(|| { + std::thread::sleep(Duration::from_secs(1)); + outer_waker.wake(); + }); + self.elapsed = Some(Instant::now()); + Poll::Pending + } + } + } + + struct Inner; + impl Future for Inner { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let inner_waker = cx.waker(); + inner_waker.wake_by_ref(); + Poll::Ready(()) + } + } + + block_on(Outer::default()); +} + +#[test] +fn outer_cannot_wake_inner() { + struct Outer; + impl Future for Outer { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let outer_waker = cx.waker(); + outer_waker.wake_by_ref(); + block_on(Inner::default()); + Poll::Ready(()) + } + } + + #[derive(Default)] + struct Inner { + elapsed: Option, + } + impl Future for Inner { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(elapsed) = self.elapsed { + assert!(elapsed.elapsed() >= Duration::from_secs(1)); + Poll::Ready(()) + } else { + let inner_waker = cx.waker().clone(); + std::thread::spawn(|| { + std::thread::sleep(Duration::from_secs(1)); + inner_waker.wake(); + }); + self.elapsed = Some(Instant::now()); + Poll::Pending + } + } + } + + block_on(Outer); +} + +#[test] +fn first_cannot_wake_second() { + struct First; + impl Future for First { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let first_waker = cx.waker(); + first_waker.wake_by_ref(); + Poll::Ready(()) + } + } + + #[derive(Default)] + struct Second { + elapsed: Option, + } + impl Future for Second { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if let Some(elapsed) = self.elapsed { + assert!(elapsed.elapsed() >= Duration::from_secs(1)); + Poll::Ready(()) + } else { + let second_waker = cx.waker().clone(); + std::thread::spawn(|| { + std::thread::sleep(Duration::from_secs(1)); + second_waker.wake(); + }); + self.elapsed = Some(Instant::now()); + Poll::Pending + } + } + } + + block_on(First); + block_on(Second::default()); +}