diff --git a/src/runtime/execution.rs b/src/runtime/execution.rs index 23aa12b..7a99895 100644 --- a/src/runtime/execution.rs +++ b/src/runtime/execution.rs @@ -1,4 +1,5 @@ use crate::runtime::failure::{init_panic_hook, persist_failure, persist_task_failure}; +use crate::runtime::storage::{StorageKey, StorageMap}; use crate::runtime::task::clock::VectorClock; use crate::runtime::task::{Task, TaskId, TaskState, DEFAULT_INLINE_TASKS}; use crate::runtime::thread::continuation::PooledContinuation; @@ -171,6 +172,9 @@ pub(crate) struct ExecutionState { // the number of scheduling decisions made so far context_switches: usize, + // static values for the current execution + storage: StorageMap, + scheduler: Rc>, current_schedule: Schedule, @@ -213,6 +217,7 @@ impl ExecutionState { next_task: ScheduledTask::None, has_yielded: false, context_switches: 0, + storage: StorageMap::new(), scheduler, current_schedule: initial_schedule, current_span_entered: None, @@ -320,6 +325,8 @@ impl ExecutionState { .expect("couldn't cleanup a future"); } + while Self::with(|state| state.storage.pop()).is_some() {} + #[cfg(debug_assertions)] Self::with(|state| state.has_cleaned_up = true); } @@ -429,6 +436,16 @@ impl ExecutionState { Self::with(|state| state.context_switches) } + pub(crate) fn get_storage, T: 'static>(&self, key: K) -> Option<&T> { + self.storage + .get(key.into()) + .map(|result| result.expect("global storage is never destructed")) + } + + pub(crate) fn init_storage, T: 'static>(&mut self, key: K, value: T) { + self.storage.init(key.into(), value); + } + pub(crate) fn get_clock(&self, id: TaskId) -> &VectorClock { &self.tasks.get(id.0).unwrap().clock } @@ -437,21 +454,21 @@ impl ExecutionState { &mut self.tasks.get_mut(id.0).unwrap().clock } - // Increment the current thread's clock entry and update its clock with the one provided. + /// Increment the current thread's clock entry and update its clock with the one provided. pub(crate) fn update_clock(&mut self, clock: &VectorClock) { let task = self.current_mut(); task.clock.increment(task.id); task.clock.update(clock); } - // Increment the current thread's clock and return a shared reference to it + /// Increment the current thread's clock and return a shared reference to it pub(crate) fn increment_clock(&mut self) -> &VectorClock { let task = self.current_mut(); task.clock.increment(task.id); &task.clock } - // Increment the current thread's clock and return a mutable reference to it + /// Increment the current thread's clock and return a mutable reference to it pub(crate) fn increment_clock_mut(&mut self) -> &mut VectorClock { let task = self.current_mut(); task.clock.increment(task.id); diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index e91a372..2e4ed15 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -1,5 +1,6 @@ pub(crate) mod execution; mod failure; pub(crate) mod runner; +pub(crate) mod storage; pub(crate) mod task; pub(crate) mod thread; diff --git a/src/runtime/storage.rs b/src/runtime/storage.rs new file mode 100644 index 0000000..0683336 --- /dev/null +++ b/src/runtime/storage.rs @@ -0,0 +1,60 @@ +use std::any::Any; +use std::collections::{HashMap, VecDeque}; + +/// A unique identifier for a storage slot +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) struct StorageKey(pub usize, pub usize); // (identifier, type) + +/// A map of storage values. +/// +/// We remember the insertion order into the storage HashMap so that destruction is deterministic. +/// Values are Option<_> because we need to be able to incrementally destruct them, as it's valid +/// for TLS destructors to initialize new TLS slots. When a slot is destructed, its key is removed +/// from `order` and its value is replaced with None. +pub(crate) struct StorageMap { + locals: HashMap>>, + order: VecDeque, +} + +impl StorageMap { + pub fn new() -> Self { + Self { + locals: HashMap::new(), + order: VecDeque::new(), + } + } + + pub fn get(&self, key: StorageKey) -> Option> { + self.locals.get(&key).map(|val| { + val.as_ref() + .map(|val| { + Ok(val + .downcast_ref::() + .expect("local value must downcast to expected type")) + }) + .unwrap_or(Err(AlreadyDestructedError)) + }) + } + + pub fn init(&mut self, key: StorageKey, value: T) { + let result = self.locals.insert(key, Some(Box::new(value))); + assert!(result.is_none(), "cannot reinitialize a storage slot"); + self.order.push_back(key); + } + + /// Return ownership of the next still-initialized storage slot. + pub fn pop(&mut self) -> Option> { + let key = self.order.pop_front()?; + let value = self + .locals + .get_mut(&key) + .expect("keys in `order` must exist") + .take() + .expect("keys in `order` must not yet be destructed"); + Some(value) + } +} + +#[derive(Debug)] +#[non_exhaustive] +pub(crate) struct AlreadyDestructedError; diff --git a/src/runtime/task/mod.rs b/src/runtime/task/mod.rs index d0323ad..ff47479 100644 --- a/src/runtime/task/mod.rs +++ b/src/runtime/task/mod.rs @@ -1,4 +1,5 @@ use crate::runtime::execution::ExecutionState; +use crate::runtime::storage::{AlreadyDestructedError, StorageKey, StorageMap}; use crate::runtime::task::clock::VectorClock; use crate::runtime::thread; use crate::runtime::thread::continuation::{ContinuationPool, PooledContinuation}; @@ -8,7 +9,6 @@ use bitvec::vec::BitVec; use futures::{task::Waker, Future}; use std::any::Any; use std::cell::RefCell; -use std::collections::{HashMap, VecDeque}; use std::fmt::Debug; use std::rc::Rc; use std::task::Context; @@ -56,7 +56,7 @@ pub(crate) struct Task { name: Option, - local_storage: LocalMap, + local_storage: StorageMap, } impl Task { @@ -80,7 +80,7 @@ impl Task { woken_by_self: false, detached: false, name, - local_storage: LocalMap::new(), + local_storage: StorageMap::new(), } } @@ -210,14 +210,14 @@ impl Task { /// Returns Some(Err(_)) if the slot has already been destructed. Returns None if the slot has /// not yet been initialized. pub(crate) fn local(&self, key: &'static LocalKey) -> Option> { - self.local_storage.get(key) + self.local_storage.get(key.into()) } /// Initialize the given thread-local storage slot with a new value. /// /// Panics if the slot has already been initialized. pub(crate) fn init_local(&mut self, key: &'static LocalKey, value: T) { - self.local_storage.init(key, value) + self.local_storage.init(key.into(), value) } /// Return ownership of the next still-initialized thread-local storage slot, to be used when @@ -316,67 +316,8 @@ impl Debug for TaskSet { } } -/// A unique identifier for a [`LocalKey`](crate::thread::LocalKey) -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -struct LocalKeyId(usize); - -impl LocalKeyId { - fn new(key: &'static LocalKey) -> Self { - Self(key as *const _ as usize) +impl From<&'static LocalKey> for StorageKey { + fn from(key: &'static LocalKey) -> Self { + Self(key as *const _ as usize, 0x1) } } - -/// A map of thread-local storage values. -/// -/// We remember the insertion order into the storage HashMap so that destruction is deterministic. -/// Values are Option<_> because we need to be able to incrementally destruct them, as it's valid -/// for TLS destructors to initialize new TLS slots. When a slot is destructed, its key is removed -/// from `order` and its value is replaced with None. -struct LocalMap { - locals: HashMap>>, - order: VecDeque, -} - -impl LocalMap { - fn new() -> Self { - Self { - locals: HashMap::new(), - order: VecDeque::new(), - } - } - - fn get(&self, key: &'static LocalKey) -> Option> { - self.locals.get(&LocalKeyId::new(key)).map(|val| { - val.as_ref() - .map(|val| { - Ok(val - .downcast_ref::() - .expect("local value must downcast to expected type")) - }) - .unwrap_or(Err(AlreadyDestructedError)) - }) - } - - fn init(&mut self, key: &'static LocalKey, value: T) { - let key = LocalKeyId::new(key); - let result = self.locals.insert(key, Some(Box::new(value))); - assert!(result.is_none(), "cannot reinitialize a TLS slot"); - self.order.push_back(key); - } - - /// Return ownership of the next still-initialized TLS slot. - fn pop(&mut self) -> Option> { - let key = self.order.pop_front()?; - let value = self - .locals - .get_mut(&key) - .expect("keys in `order` must exist") - .take() - .expect("keys in `order` must not yet be destructed"); - Some(value) - } -} - -#[derive(Debug)] -#[non_exhaustive] -pub(crate) struct AlreadyDestructedError; diff --git a/src/sync/mod.rs b/src/sync/mod.rs index de6be5d..f607002 100644 --- a/src/sync/mod.rs +++ b/src/sync/mod.rs @@ -5,6 +5,7 @@ mod barrier; mod condvar; pub mod mpsc; mod mutex; +mod once; mod rwlock; pub use barrier::{Barrier, BarrierWaitResult}; @@ -13,6 +14,9 @@ pub use condvar::{Condvar, WaitTimeoutResult}; pub use mutex::Mutex; pub use mutex::MutexGuard; +pub use once::Once; +pub use once::OnceState; + pub use rwlock::RwLock; pub use rwlock::RwLockReadGuard; pub use rwlock::RwLockWriteGuard; diff --git a/src/sync/once.rs b/src/sync/once.rs new file mode 100644 index 0000000..6ad9f5c --- /dev/null +++ b/src/sync/once.rs @@ -0,0 +1,170 @@ +use crate::runtime::execution::ExecutionState; +use crate::runtime::storage::StorageKey; +use crate::runtime::task::clock::VectorClock; +use crate::sync::Mutex; +use std::cell::RefCell; +use std::rc::Rc; +use tracing::trace; + +/// A synchronization primitive which can be used to run a one-time global initialization. Useful +/// for one-time initialization for FFI or related functionality. This type can only be constructed +/// with [`Once::new()`]. +#[derive(Debug)] +pub struct Once { + // We use the address of the `Once` as an identifier, so it can't be zero-sized even though all + // its state is stored in ExecutionState storage + _dummy: usize, +} + +/// A `Once` cell can either be `Running`, in which case a `Mutex` mediates racing threads trying to +/// invoke `call_once`, or `Complete` once an initializer has completed, in which case the `Mutex` +/// is no longer necessary. +enum OnceInitState { + Running(Rc>), + Complete(VectorClock), +} + +impl std::fmt::Debug for OnceInitState { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Running(_) => write!(f, "Running"), + Self::Complete(_) => write!(f, "Complete"), + } + } +} + +impl Once { + /// Creates a new `Once` value. + pub const fn new() -> Self { + Self { _dummy: 0 } + } + + /// Performs an initialization routine once and only once. The given closure will be executed + /// if this is the first time `call_once` has been called, and otherwise the routine will *not* + /// be invoked. + /// + /// This method will block the calling thread if another initialization routine is currently + /// running. + /// + /// When this function returns, it is guaranteed that some initialization has run and completed + /// (it may not be the closure specified). + pub fn call_once(&self, f: F) + where + F: FnOnce(), + { + self.call_once_inner(|_state| f(), false); + } + + /// Performs the same function as [`Once::call_once()`] except ignores poisoning. + /// + /// If the cell has previously been poisoned, this function will still attempt to call the given + /// closure. If the closure does not panic, the cell will no longer be poisoned. + pub fn call_once_force(&self, f: F) + where + F: FnOnce(&OnceState), + { + self.call_once_inner(f, true); + } + + /// Returns `true` if some [`Once::call_once()`] call has completed successfully. + pub fn is_completed(&self) -> bool { + ExecutionState::with(|state| { + let init = match self.get_state(state) { + Some(init) => init, + None => return false, + }; + let init_state = init.borrow(); + match &*init_state { + OnceInitState::Complete(clock) => { + let clock = clock.clone(); + drop(init_state); + state.update_clock(&clock); + true + } + _ => false, + } + }) + } + + fn call_once_inner(&self, f: F, ignore_poisoning: bool) + where + F: FnOnce(&OnceState), + { + let lock = ExecutionState::with(|state| { + // Initialize the state of the `Once` cell if we're the first thread to try + if self.get_state(state).is_none() { + self.init_state(state, OnceInitState::Running(Rc::new(Mutex::new(false)))); + } + + let init = self.get_state(state).expect("must be initialized by this point"); + let init_state = init.borrow(); + trace!(state=?init_state, "call_once on cell {:p}", self); + match &*init_state { + OnceInitState::Complete(clock) => { + // If already complete, just update the clock from the thread that inited + let clock = clock.clone(); + drop(init_state); + state.update_clock(&clock); + None + } + OnceInitState::Running(lock) => Some(Rc::clone(lock)), + } + }); + + // If there's a lock, then we need to try racing on it to decide who gets to run their + // initialization closure. + if let Some(lock) = lock { + let (mut flag, is_poisoned) = match lock.lock() { + Ok(flag) => (flag, false), + Err(_) if !ignore_poisoning => panic!("Once instance has previously been poisoned"), + Err(err) => (err.into_inner(), true), + }; + if *flag { + return; + } + + trace!("won the call_once race for cell {:p}", self); + f(&OnceState(is_poisoned)); + + *flag = true; + // We were the thread that won the race, so remember our current clock to establish + // causality with future threads that try (and fail) to run `call_once`. The threads + // that were racing with us will get causality through acquiring the `Mutex`. + ExecutionState::with(|state| { + let clock = state.increment_clock().clone(); + *self + .get_state(state) + .expect("must be initialized by this point") + .borrow_mut() = OnceInitState::Complete(clock); + }); + } + } + + fn get_state<'a>(&self, from: &'a ExecutionState) -> Option<&'a RefCell> { + from.get_storage::<_, RefCell>(self) + } + + fn init_state(&self, into: &mut ExecutionState, new_state: OnceInitState) { + into.init_storage::<_, RefCell>(self, RefCell::new(new_state)); + } +} + +/// State yielded to [`Once::call_once_force()`]'s closure parameter. The state can be used to query +/// the poison status of the [`Once`]. +#[derive(Debug)] +#[non_exhaustive] +pub struct OnceState(bool); + +impl OnceState { + /// Returns `true` if the associated [`Once`] was poisoned prior to the invocation of the + /// closure passed to [`Once::call_once_force()`]. + pub fn is_poisoned(&self) -> bool { + self.0 + } +} + +impl From<&Once> for StorageKey { + fn from(once: &Once) -> Self { + StorageKey(once as *const _ as usize, 0x2) + } +} diff --git a/tests/basic/clocks.rs b/tests/basic/clocks.rs index 4c386ea..a4a59dd 100644 --- a/tests/basic/clocks.rs +++ b/tests/basic/clocks.rs @@ -1,11 +1,9 @@ -use shuttle::sync::{ - atomic::{AtomicBool, AtomicU32}, - mpsc::{channel, sync_channel}, - Barrier, Condvar, Mutex, RwLock, -}; +use shuttle::sync::atomic::{AtomicBool, AtomicU32, Ordering}; +use shuttle::sync::mpsc::{channel, sync_channel}; +use shuttle::sync::{Barrier, Condvar, Mutex, Once, RwLock}; use shuttle::{check_dfs, check_pct, thread}; use std::collections::HashSet; -use std::sync::{atomic::Ordering, Arc}; +use std::sync::Arc; use test_env_log::test; pub fn me() -> usize { @@ -15,7 +13,12 @@ pub fn me() -> usize { // TODO Maybe make this a macro so backtraces are more informative pub fn check_clock(f: impl Fn(usize, u32) -> bool) { for (i, &c) in shuttle::my_clock().iter().enumerate() { - assert!(f(i, c)); + assert!( + f(i, c), + "clock {:?} doesn't satisfy predicate at {}", + shuttle::my_clock(), + i + ); } } @@ -437,3 +440,37 @@ fn clock_fetch_update() { None, ); } + +fn clock_once(num_threads: usize) { + let once = Arc::new(Once::new()); + let init = Arc::new(std::sync::atomic::AtomicUsize::new(0)); + + let threads = (0..num_threads) + .map(|_| { + let once = Arc::clone(&once); + let init = Arc::clone(&init); + thread::spawn(move || { + check_clock(|i, c| (c > 0) == (i == 0)); + once.call_once(|| init.store(me(), std::sync::atomic::Ordering::SeqCst)); + let who_inited = init.load(std::sync::atomic::Ordering::SeqCst); + // should have inhaled the clock of the thread that inited the Once, but might also + // have inhaled the clocks of threads that we were racing with for initialization + check_clock(|i, c| !(i == who_inited || i == 0 || i == me()) || c > 0); + }) + }) + .collect::>(); + + for thd in threads { + thd.join().unwrap(); + } +} + +#[test] +fn clock_once_dfs() { + check_dfs(|| clock_once(2), None); +} + +#[test] +fn clock_once_pct() { + check_pct(|| clock_once(20), 10_000, 3); +} diff --git a/tests/basic/mod.rs b/tests/basic/mod.rs index 5c5e1e3..46fab73 100644 --- a/tests/basic/mod.rs +++ b/tests/basic/mod.rs @@ -8,6 +8,7 @@ mod execution; mod metrics; mod mpsc; mod mutex; +mod once; mod panic; mod pct; mod portfolio; diff --git a/tests/basic/once.rs b/tests/basic/once.rs new file mode 100644 index 0000000..7c3e6e6 --- /dev/null +++ b/tests/basic/once.rs @@ -0,0 +1,251 @@ +use shuttle::scheduler::DfsScheduler; +use shuttle::sync::Once; +use shuttle::{check_dfs, check_pct, thread, Runner}; +use std::collections::HashSet; +use std::panic; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use test_env_log::test; + +fn basic(num_threads: usize, checker: F) +where + F: FnOnce(Box), +{ + let initializer = Arc::new(std::sync::Mutex::new(HashSet::new())); + let initializer_clone = Arc::clone(&initializer); + + checker(Box::new(move || { + let once = Arc::new(Once::new()); + let counter = Arc::new(AtomicUsize::new(0)); + + assert!(!once.is_completed()); + + let threads = (0..num_threads) + .map(|_| { + let once = Arc::clone(&once); + let counter = Arc::clone(&counter); + let initializer = Arc::clone(&initializer); + thread::spawn(move || { + once.call_once(|| { + counter.fetch_add(1, Ordering::SeqCst); + initializer.lock().unwrap().insert(thread::current().id()); + }); + + assert!(once.is_completed()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + }) + }) + .collect::>(); + + for thread in threads { + thread.join().unwrap(); + } + + assert!(once.is_completed()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + })); + + assert_eq!(initializer_clone.lock().unwrap().len(), num_threads); +} + +#[test] +fn basic_dfs() { + basic(2, |f| check_dfs(f, None)); +} + +#[test] +fn basic_pct() { + basic(10, |f| check_pct(f, 1000, 3)); +} + +// Same as `basic`, but with a static Once. Static synchronization primitives should be reset across +// executions, so this test should work exactly the same way. +fn basic_static(num_threads: usize, checker: F) +where + F: FnOnce(Box), +{ + static O: Once = Once::new(); + + let initializer = Arc::new(std::sync::Mutex::new(HashSet::new())); + let initializer_clone = Arc::clone(&initializer); + + checker(Box::new(move || { + let counter = Arc::new(AtomicUsize::new(0)); + + assert!(!O.is_completed()); + + let threads = (0..num_threads) + .map(|_| { + let counter = Arc::clone(&counter); + let initializer = Arc::clone(&initializer); + thread::spawn(move || { + O.call_once(|| { + counter.fetch_add(1, Ordering::SeqCst); + initializer.lock().unwrap().insert(thread::current().id()); + }); + + assert!(O.is_completed()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + }) + }) + .collect::>(); + + for thread in threads { + thread.join().unwrap(); + } + + assert!(O.is_completed()); + assert_eq!(counter.load(Ordering::SeqCst), 1); + })); + + assert_eq!(initializer_clone.lock().unwrap().len(), num_threads); +} + +#[test] +fn basic_static_dfs() { + basic_static(2, |f| check_dfs(f, None)); +} + +#[test] +fn basic_static_pct() { + basic_static(10, |f| check_pct(f, 1000, 3)); +} + +// Test that multiple Once cells race for initialization independently +#[test] +fn multiple() { + static O1: Once = Once::new(); + static O2: Once = Once::new(); + + let initializer = Arc::new(std::sync::Mutex::new(HashSet::new())); + let initializer_clone = Arc::clone(&initializer); + + check_dfs( + move || { + let counter = Arc::new(AtomicUsize::new(0)); + + let thd = { + let counter = Arc::clone(&counter); + thread::spawn(move || { + O1.call_once(|| { + counter.fetch_add(1, Ordering::SeqCst); + }); + O2.call_once(|| { + counter.fetch_add(4, Ordering::SeqCst); + }); + }) + }; + + O1.call_once(|| { + counter.fetch_add(2, Ordering::SeqCst); + }); + O2.call_once(|| { + counter.fetch_add(8, Ordering::SeqCst); + }); + + thd.join().unwrap(); + + let counter = counter.load(Ordering::SeqCst); + // lower two bits for C1, upper two bits for C2 + assert!( + counter & (1 + 2) == 1 || counter & (1 + 2) == 2, + "exactly one of the O1 calls should have run" + ); + assert!( + counter & (4 + 8) == 4 || counter & (4 + 8) == 8, + "exactly one of the O2 calls should have run" + ); + initializer.lock().unwrap().insert(counter); + }, + None, + ); + + let initializer = Arc::try_unwrap(initializer_clone).unwrap().into_inner().unwrap(); + assert_eq!(initializer.len(), 4); + assert!(initializer.contains(&5)); + assert!(initializer.contains(&6)); + assert!(initializer.contains(&9)); + assert!(initializer.contains(&10)); +} + +// Ensure that concurrent Shuttle tests see an isolated version of a static Once cell. This test is +// best effort, as it spawns OS threads and hopes they race. +#[test] +fn shared_static() { + static O: Once = Once::new(); + + let counter = Arc::new(AtomicUsize::new(0)); + let mut total_executions = 0; + + // Try a bunch of times to provoke the race + for _ in 0..50 { + #[allow(clippy::needless_collect)] // https://github.com/rust-lang/rust-clippy/issues/7207 + let threads = (0..3) + .map(|_| { + let counter = Arc::clone(&counter); + std::thread::spawn(move || { + let scheduler = DfsScheduler::new(None, false); + let runner = Runner::new(scheduler, Default::default()); + runner.run(move || { + let thds = (0..2) + .map(|_| { + let counter = Arc::clone(&counter); + thread::spawn(move || { + O.call_once(|| { + counter.fetch_add(1, Ordering::SeqCst); + }); + }) + }) + .collect::>(); + + for thd in thds { + thd.join().unwrap(); + } + }) + }) + }) + .collect::>(); + + total_executions += threads.into_iter().map(|handle| handle.join().unwrap()).sum::(); + } + + // The Once cell should be initialized exactly once per test execution, otherwise the tests are + // incorrectly sharing the Once cell + assert_eq!(total_executions, counter.load(Ordering::SeqCst)); +} + +#[test] +fn poison() { + static O: Once = Once::new(); + + check_dfs( + || { + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + O.call_once(|| { + panic!("expected panic"); + }) + })); + assert!(result.is_err(), "`call_once` should panic"); + + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + O.call_once(|| { + // no-op + }); + })); + assert!(result.is_err(), "cell should be poisoned"); + + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + O.call_once_force(|state| { + assert!(state.is_poisoned()); + }); + })); + assert!(result.is_ok(), "`call_once_force` ignores poison"); + + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + O.call_once(|| unreachable!("previous call should have initialized the cell")); + })); + assert!(result.is_ok(), "cell should no longer be poisoned"); + }, + None, + ); +}