Skip to content

Explicit shutdown support #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 48 additions & 11 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,13 @@ mod driver {
#[derive(Debug)]
pub(crate) enum Event {
SleepUntil(NodeDesc),
Shutdown,
}

pub(crate) fn execute<const D: usize>(this: Builder, rx: channel::Receiver<Event>) {
let mut nodes = DaryHeap::<Node, D>::new();
let pivot = Instant::now();
let to_usec = |x: Instant| x.duration_since(pivot).as_micros() as u64;
let to_usec = |x: Instant| x.saturating_duration_since(pivot).as_micros() as u64;
let resolution_usec = this.schedule_resolution.as_micros() as u64;

// As each node always increment the `gc_counter` by 1 when dropped, and the worker
Expand All @@ -156,10 +157,10 @@ mod driver {
let now = to_usec(now_ts);
let mut event = if let Some(node) = nodes.peek() {
let remain = node.timeout_usec.saturating_sub(now);
if remain > resolution_usec {
let system_sleep_for = remain - resolution_usec;
let timeout = Duration::from_micros(system_sleep_for);
let deadline = now_ts + timeout;
if let Some(system_sleep_for) = remain.checked_sub(resolution_usec) {
// Would only panic if node timout was representable, but this intermediate
// point that's sooner than the node timeout was not
let deadline = now_ts + Duration::from_micros(system_sleep_for);

let Ok(event) = rx.recv_deadline(deadline).map_err(|e| match e {
channel::RecvTimeoutError::Timeout => (),
Expand All @@ -181,10 +182,12 @@ mod driver {
'busy_wait: loop {
let now = to_usec(Instant::now());
if now >= node.timeout_usec {
let node = nodes.pop().unwrap();
let node = nodes.pop().expect("node presence checked via peek()");

if let Some(waker) = node.weak_waker.upgrade() {
waker.value.lock().take().expect("logic error").wake();
if let Some(waker) =
node.weak_waker.upgrade().and_then(|wn| wn.value.lock().take())
{
waker.wake();
}

let n_garbage = gc_counter.fetch_sub(1, Ordering::Release);
Expand Down Expand Up @@ -227,6 +230,21 @@ mod driver {
match event {
Event::SleepUntil(desc) => nodes
.push(Node { timeout_usec: to_usec(desc.timeout), weak_waker: desc.waker }),
Event::Shutdown => {
// Wake remaining timers so they aren't stuck forever. This should cause
// CompletedEarly for the affected timers.

for n in nodes {
if let Some(waker) =
n.weak_waker.upgrade().and_then(|wn| wn.value.lock().take())
{
waker.wake();
}
}

// Return, not break, to avoid hitting asserts below that may not be true in this case
return;
}
};

event = match rx.try_recv() {
Expand Down Expand Up @@ -328,6 +346,14 @@ impl Handle {
pub fn interval(&self, interval: Duration) -> util::Interval {
util::Interval { handle: self.clone(), wakeup_time: Instant::now() + interval, interval }
}

/// Signal the driver to shut down, causing the thread it is running in to exit gracefully.
///
/// Existing timers will fire immediately.
pub fn shutdown(self) {
// if rx is already dropped, we're done anyway
let _ = self.tx.send(driver::Event::Shutdown);
}
}

pub mod util {
Expand Down Expand Up @@ -523,7 +549,9 @@ pub enum Report {
/// Timer has not been requested as the timeout is already expired.
ExpiredTimer(Duration),

/// We woke up a bit earlier than required. It is usually hundreads of nanoseconds.
/// We woke up a bit earlier than required. It is usually hundreds of nanoseconds.
///
/// This is also produced by extant timers when the driver is shut down.
CompletedEarly(Duration),
}

Expand Down Expand Up @@ -588,13 +616,22 @@ impl std::future::Future for SleepFuture {
waker: Arc::downgrade(&waker),
});

tx.send(event).expect("timer driver instance dropped!");
if tx.send(event).is_err() {
// Driver has gone away, so this task will never be woken later.
// `timeout` <= now, otherwise it would have matched the check above.
self.state = SleepState::Woken;
return Poll::Ready(Report::CompletedEarly(
self.timeout.saturating_duration_since(now),
));
}
self.state = SleepState::Sleeping(waker);
} else if let SleepState::Sleeping(node) = &self.state {
// We woke up too early. Check if it is due to broken clock monotonicity.
if node.is_expired() {
self.state = SleepState::Woken;
return Poll::Ready(Report::CompletedEarly(self.timeout - now));
return Poll::Ready(Report::CompletedEarly(
self.timeout.saturating_duration_since(now),
));
} else {
// If not, this is a spurious wakeup. We should sleep again.
// - XXX: Should we re-register wakeup timer here?
Expand Down
62 changes: 62 additions & 0 deletions tests/shutdown.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
use async_spin_sleep::Report;
use futures::future::join_all;
use std::time;

#[tokio::test]
async fn shutdown_with_no_timers_works() {
let (handle, driver) = async_spin_sleep::create();
let driver_handle = std::thread::spawn(driver);

let start = time::Instant::now();

handle.shutdown();

driver_handle.join().unwrap();

let elapsed = start.elapsed();
assert!(elapsed < time::Duration::from_millis(100), "elapsed: {elapsed:?}");
}

#[tokio::test]
async fn shutdown_fires_existing_later_timers() {
let (handle, driver) = async_spin_sleep::create();
let driver_handle = std::thread::spawn(driver);

// schedule some tasks waiting for timers in the future, after the driver is shut down
let handles = (1..=100)
.map(|s| tokio::spawn(handle.sleep_for(time::Duration::from_secs(s))))
.collect::<Vec<_>>();

let before_shutdown = time::Instant::now();

handle.shutdown();

driver_handle.join().unwrap();

let driver_join_elapsed = before_shutdown.elapsed();
assert!(
driver_join_elapsed < time::Duration::from_millis(100),
"elapsed: {driver_join_elapsed:?}"
);

// all existing timers should fire quickly
let reports =
join_all(handles.into_iter()).await.into_iter().map(|res| res.unwrap()).collect::<Vec<_>>();
let timers_complete_elapsed = before_shutdown.elapsed();
assert!(
timers_complete_elapsed < time::Duration::from_millis(100),
"elapsed: {timers_complete_elapsed:?}"
);

for rep in reports {
match rep {
Report::CompletedEarly(dur) => {
// the soonest timer was for 1s, so all should have fired early
assert!(dur > time::Duration::from_millis(200))
}
Report::Completed(_) | Report::ExpiredTimer(_) => {
panic!("Unexpected report: {rep:?}")
}
}
}
}