From 89dbbf784ea93f239665a06ebfd61d8389bff97d Mon Sep 17 00:00:00 2001 From: Rajeev Joshi <{ID}+{username}@users.noreply.github.com> Date: Sat, 2 Mar 2024 06:07:48 -0800 Subject: [PATCH] Add support for typed task Labels. [PR98](https://github.com/awslabs/shuttle/pull/98) added support for associated a single untyped `Tag` with each task and thread. As we've gained experience with Tags, we've increasingly felt a need to have a mechanism that is both better typed, and allows more than one tag to be associated with tasks. This commit introduces `Labels`, which are inspired by `Extensions` in [the http crate](https://docs.rs/http/latest/http/struct.Extensions.html). Users can attach any set of Labels to a task or thread, with the only caveat being that there can be at most one Label for a given type T. This is not too onerous a restriction, since one can use the common [newtype idiom](https://doc.rust-lang.org/rust-by-example/generics/new_types.html) to easily work around this. For tracing, we also provide a newtype `TaskName` that can be converted to and from a `String`. If the `TaskName` label is set for a task, tracing output will show the `TaskName` (in addition to the `TaskId`) to make logs easier to read. Since the current functionality provided by `Tag` is superseded by `Labels`, we also mark `Tag` as deprecated. They will be removed in a future release. --- src/current.rs | 72 +++++++- src/runtime/execution.rs | 55 +++++- src/runtime/task/labels.rs | 333 ++++++++++++++++++++++++++++++++++++ src/runtime/task/mod.rs | 90 +++++++++- tests/basic/labels.rs | 337 +++++++++++++++++++++++++++++++++++++ tests/basic/mod.rs | 1 + tests/basic/tag.rs | 3 + tests/basic/thread.rs | 3 + 8 files changed, 878 insertions(+), 16 deletions(-) create mode 100644 src/runtime/task/labels.rs create mode 100644 tests/basic/labels.rs diff --git a/src/current.rs b/src/current.rs index a23213be..bb38247d 100644 --- a/src/current.rs +++ b/src/current.rs @@ -5,9 +5,15 @@ //! example, a tool that wants to check linearizability might want access to a global timestamp for //! events, which the [`context_switches`] function provides. -use crate::runtime::execution::{ExecutionState, TASK_ID_TO_TAGS}; +#[allow(deprecated)] +use crate::runtime::execution::TASK_ID_TO_TAGS; +use crate::runtime::execution::{ExecutionState, LABELS}; use crate::runtime::task::clock::VectorClock; -pub use crate::runtime::task::{Tag, Taggable, TaskId}; +pub use crate::runtime::task::labels::Labels; +pub use crate::runtime::task::{ChildLabelFn, TaskId, TaskName}; +#[allow(deprecated)] +pub use crate::runtime::task::{Tag, Taggable}; +use std::fmt::Debug; use std::sync::Arc; /// The number of context switches that happened so far in the current Shuttle execution. @@ -34,23 +40,74 @@ pub fn clock_for(task_id: TaskId) -> VectorClock { ExecutionState::with(|state| state.get_clock(task_id).clone()) } +/// Apply the given function to the Labels for the specified task +pub fn with_labels_for_task(task_id: TaskId, f: F) -> T +where + F: FnOnce(&mut Labels) -> T, +{ + LABELS.with(|cell| { + let mut map = cell.borrow_mut(); + let m = map.entry(task_id).or_default(); + f(m) + }) +} + +/// Get a label of the given type for the specified task, if any +#[inline] +pub fn get_label_for_task(task_id: TaskId) -> Option { + with_labels_for_task(task_id, |labels| labels.get().cloned()) +} + +/// Add the given label to the specified task, returning the old label for the type, if any +#[inline] +pub fn set_label_for_task(task_id: TaskId, value: T) -> Option { + with_labels_for_task(task_id, |labels| labels.insert(value)) +} + +/// Remove a label of the given type for the specified task, returning the old label for the type, if any +#[inline] +pub fn remove_label_for_task(task_id: TaskId) -> Option { + with_labels_for_task(task_id, |labels| labels.remove()) +} + +/// Get the debug name for a task +#[inline] +pub fn get_name_for_task(task_id: TaskId) -> Option { + get_label_for_task::(task_id) +} + +/// Set the debug name for a task, returning the old name, if any +#[inline] +pub fn set_name_for_task(task_id: TaskId, task_name: impl Into) -> Option { + set_label_for_task::(task_id, task_name.into()) +} + +/// Gets the `TaskId` of the current task, or `None` if there is no current task. +pub fn get_current_task() -> Option { + ExecutionState::with(|s| Some(s.try_current()?.id())) +} + +/// Get the `TaskId` of the current task. Panics if there is no current task. +#[inline] +pub fn me() -> TaskId { + get_current_task().unwrap() +} + /// Sets the `tag` field of the current task. /// Returns the `tag` which was there previously. +#[allow(deprecated)] pub fn set_tag_for_current_task(tag: Arc) -> Option> { ExecutionState::set_tag_for_current_task(tag) } /// Gets the `tag` field of the current task. +#[allow(deprecated)] pub fn get_tag_for_current_task() -> Option> { ExecutionState::get_tag_for_current_task() } -/// Gets the `TaskId` of the current task, or `None` if there is no current task. -pub fn get_current_task() -> Option { - ExecutionState::with(|s| Some(s.try_current()?.id())) -} - /// Gets the `tag` field of the specified task. +#[allow(deprecated)] pub fn get_tag_for_task(task_id: TaskId) -> Option> { TASK_ID_TO_TAGS.with(|cell| { let map = cell.borrow(); @@ -59,6 +116,7 @@ pub fn get_tag_for_task(task_id: TaskId) -> Option> { } /// Sets the `tag` field of the specified task. +#[allow(deprecated)] pub fn set_tag_for_task(task: TaskId, tag: Arc) -> Option> { ExecutionState::set_tag_for_task(task, tag) } diff --git a/src/runtime/execution.rs b/src/runtime/execution.rs index 0071b884..ae74fec0 100644 --- a/src/runtime/execution.rs +++ b/src/runtime/execution.rs @@ -1,7 +1,8 @@ use crate::runtime::failure::{init_panic_hook, persist_failure, persist_task_failure}; use crate::runtime::storage::{StorageKey, StorageMap}; use crate::runtime::task::clock::VectorClock; -use crate::runtime::task::{Task, TaskId, DEFAULT_INLINE_TASKS}; +use crate::runtime::task::labels::Labels; +use crate::runtime::task::{ChildLabelFn, Task, TaskId, TaskName, DEFAULT_INLINE_TASKS}; use crate::runtime::thread::continuation::PooledContinuation; use crate::scheduler::{Schedule, Scheduler}; use crate::thread::thread_fn; @@ -18,6 +19,7 @@ use std::rc::Rc; use std::sync::Arc; use tracing::{trace, Span}; +#[allow(deprecated)] use super::task::Tag; // We use this scoped TLS to smuggle the ExecutionState, which is not 'static, across tasks that @@ -28,9 +30,14 @@ scoped_thread_local! { thread_local! { #[allow(clippy::complexity)] + #[allow(deprecated)] pub(crate) static TASK_ID_TO_TAGS: RefCell>> = RefCell::new(HashMap::new()); } +thread_local! { + pub(crate) static LABELS: RefCell> = RefCell::new(HashMap::new()); +} + /// An `Execution` encapsulates a single run of a function under test against a chosen scheduler. /// Its only useful method is `Execution::run`, which executes the function to completion. /// @@ -68,6 +75,11 @@ impl Execution { self.initial_schedule.clone(), )); + // Clear all Labels at the beginning of each execution + LABELS.with(|cell| { + cell.borrow_mut().clear(); + }); + let _guard = init_panic_hook(config.clone()); EXECUTION_STATE.set(&state, move || { @@ -327,6 +339,33 @@ impl ExecutionState { Self::with(|s| s.current().id()) } + fn set_labels_for_new_task(state: &ExecutionState, task_id: TaskId, name: Option) { + LABELS.with(|cell| { + let mut map = cell.borrow_mut(); + + // If parent has labels, inherit them + if let Some(parent_task_id) = state.try_current().map(|t| t.id()) { + let parent_map = map.get(&parent_task_id); + if let Some(parent_map) = parent_map { + let mut child_map = parent_map.clone(); + + // If the parent has a `ChildLabelFn` set, use that to update the child's Labels + if let Some(gen) = parent_map.get::() { + (gen.0)(&mut child_map); + } + + map.insert(task_id, child_map); + } + } + + // Add any name assigned to the task to its set of Labels + if let Some(name) = name { + let m = map.entry(task_id).or_default(); + m.insert(TaskName::from(name)); + } + }); + } + /// Spawn a new task for a future. This doesn't create a yield point; the caller should do that /// if it wants to give the new task a chance to run immediately. pub(crate) fn spawn_future(future: F, stack_size: usize, name: Option) -> TaskId @@ -339,6 +378,9 @@ impl ExecutionState { let task_id = TaskId(state.tasks.len()); let tag = state.get_tag_or_default_for_current_task(); + + Self::set_labels_for_new_task(state, task_id, name.clone()); + let clock = state.increment_clock_mut(); // Increment the parent's clock clock.extend(task_id); // and extend it with an entry for the new task @@ -372,6 +414,9 @@ impl ExecutionState { let parent_span_id = state.top_level_span.id(); let task_id = TaskId(state.tasks.len()); let tag = state.get_tag_or_default_for_current_task(); + + Self::set_labels_for_new_task(state, task_id, name.clone()); + let clock = if let Some(ref mut clock) = initial_clock { clock } else { @@ -424,6 +469,7 @@ impl ExecutionState { while Self::with(|state| state.storage.pop()).is_some() {} TASK_ID_TO_TAGS.with(|cell| cell.borrow_mut().clear()); + LABELS.with(|cell| cell.borrow_mut().clear()); #[cfg(debug_assertions)] Self::with(|state| state.has_cleaned_up = true); @@ -651,6 +697,9 @@ impl ExecutionState { // 2) It creates a visual separation of scheduling decisions and `Task`-induced tracing. // Note that there is a case to be made for not `in_scope`-ing it, as that makes seeing the context // of the context switch clearer. + // + // Note also that changing this trace! statement requires changing the test `basic::labels::test_tracing_with_label_fn` + // which relies on this trace reporting the `runnable` tasks. self.top_level_span .in_scope(|| trace!(?runnable, next_task=?self.next_task)); @@ -669,18 +718,22 @@ impl ExecutionState { // Sets the `tag` field of the current task. // Returns the `tag` which was there previously. + #[allow(deprecated)] pub(crate) fn set_tag_for_current_task(tag: Arc) -> Option> { ExecutionState::with(|s| s.current_mut().set_tag(tag)) } + #[allow(deprecated)] fn get_tag_or_default_for_current_task(&self) -> Option> { self.try_current().and_then(|current| current.get_tag()) } + #[allow(deprecated)] pub(crate) fn get_tag_for_current_task() -> Option> { ExecutionState::with(|s| s.get_tag_or_default_for_current_task()) } + #[allow(deprecated)] pub(crate) fn set_tag_for_task(task: TaskId, tag: Arc) -> Option> { ExecutionState::with(|s| s.get_mut(task).set_tag(tag)) } diff --git a/src/runtime/task/labels.rs b/src/runtime/task/labels.rs new file mode 100644 index 00000000..22905a58 --- /dev/null +++ b/src/runtime/task/labels.rs @@ -0,0 +1,333 @@ +/* +** Code directly copied from https://github.com/hyperium/http/blob/master/src/extensions.rs +** but renaming 'Extensions' to 'Labels' +** +** The key idea is to keep a HashMap (named `AnyMap`) that maps the `TypeId` for a type +** to its associated value, so a `get::()` is translated to `get(TypeId::of::())`. +*/ + +use std::any::{Any, TypeId}; +use std::collections::HashMap; +use std::fmt; +use std::hash::{BuildHasherDefault, Hasher}; + +type AnyMap = HashMap, BuildHasherDefault>; + +// With TypeIds as keys, there's no need to hash them. They are already hashes +// themselves, coming from the compiler. The IdHasher just holds the u64 of +// the TypeId, and then returns it, instead of doing any bit fiddling. +#[derive(Default)] +struct IdHasher(u64); + +impl Hasher for IdHasher { + fn write(&mut self, _: &[u8]) { + unreachable!("TypeId calls write_u64"); + } + + #[inline] + fn write_u64(&mut self, id: u64) { + self.0 = id; + } + + #[inline] + fn finish(&self) -> u64 { + self.0 + } +} + +/// A collections of assigned Labels +/// +/// `Labels` can be used to store extra data associated with running tasks. +#[derive(Clone, Default)] +pub struct Labels { + // If Labels are never used, no need to carry around an empty HashMap. + // That's 3 words. Instead, this is only 1 word. + map: Option>, +} + +impl Labels { + /// Create an empty `Labels`. + #[inline] + pub fn new() -> Labels { + Labels { map: None } + } + + /// Insert a type into this `Labels`. + /// + /// If a label of this type already existed, it will + /// be returned. + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext = Labels::new(); + /// assert!(ext.insert(5i32).is_none()); + /// assert!(ext.insert(4u8).is_none()); + /// assert_eq!(ext.insert(9i32), Some(5i32)); + /// ``` + pub fn insert(&mut self, val: T) -> Option { + self.map + .get_or_insert_with(Box::default) + .insert(TypeId::of::(), Box::new(val)) + .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed)) + } + + /// Get a reference to a type previously inserted on this `Labels`. + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext = Labels::new(); + /// assert!(ext.get::().is_none()); + /// ext.insert(5i32); + /// + /// assert_eq!(ext.get::(), Some(&5i32)); + /// ``` + pub fn get(&self) -> Option<&T> { + self.map + .as_ref() + .and_then(|map| map.get(&TypeId::of::())) + .and_then(|boxed| (**boxed).as_any().downcast_ref()) + } + + /// Get a mutable reference to a type previously inserted on this `Labels`. + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext = Labels::new(); + /// ext.insert(String::from("Hello")); + /// ext.get_mut::().unwrap().push_str(" World"); + /// + /// assert_eq!(ext.get::().unwrap(), "Hello World"); + /// ``` + pub fn get_mut(&mut self) -> Option<&mut T> { + self.map + .as_mut() + .and_then(|map| map.get_mut(&TypeId::of::())) + .and_then(|boxed| (**boxed).as_any_mut().downcast_mut()) + } + + /// Get a mutable reference to a type, inserting `value` if not already present on this + /// `Labels`. + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext = Labels::new(); + /// *ext.get_or_insert(1i32) += 2; + /// + /// assert_eq!(*ext.get::().unwrap(), 3); + /// ``` + pub fn get_or_insert(&mut self, value: T) -> &mut T { + self.get_or_insert_with(|| value) + } + + /// Get a mutable reference to a type, inserting the value created by `f` if not already present + /// on this `Labels`. + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext = Labels::new(); + /// *ext.get_or_insert_with(|| 1i32) += 2; + /// + /// assert_eq!(*ext.get::().unwrap(), 3); + /// ``` + pub fn get_or_insert_with T>(&mut self, f: F) -> &mut T { + let out = self + .map + .get_or_insert_with(Box::default) + .entry(TypeId::of::()) + .or_insert_with(|| Box::new(f())); + (**out).as_any_mut().downcast_mut().unwrap() + } + + /// Get a mutable reference to a type, inserting the type's default value if not already present + /// on this `Labels`. + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext = Labels::new(); + /// *ext.get_or_insert_default::() += 2; + /// + /// assert_eq!(*ext.get::().unwrap(), 2); + /// ``` + pub fn get_or_insert_default(&mut self) -> &mut T { + self.get_or_insert_with(T::default) + } + + /// Remove a type from this `Labels`. + /// + /// If a label of this type existed, it will be returned. + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext = Labels::new(); + /// ext.insert(5i32); + /// assert_eq!(ext.remove::(), Some(5i32)); + /// assert!(ext.get::().is_none()); + /// ``` + pub fn remove(&mut self) -> Option { + self.map + .as_mut() + .and_then(|map| map.remove(&TypeId::of::())) + .and_then(|boxed| boxed.into_any().downcast().ok().map(|boxed| *boxed)) + } + + /// Clear all inserted labels + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext = Labels::new(); + /// ext.insert(5i32); + /// ext.clear(); + /// + /// assert!(ext.get::().is_none()); + /// ``` + #[inline] + pub fn clear(&mut self) { + if let Some(ref mut map) = self.map { + map.clear(); + } + } + + /// Check whether the label set is empty or not. + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext = Labels::new(); + /// assert!(ext.is_empty()); + /// ext.insert(5i32); + /// assert!(!ext.is_empty()); + /// ``` + #[inline] + pub fn is_empty(&self) -> bool { + self.map.as_ref().map_or(true, |map| map.is_empty()) + } + + /// Get the number of Labels available. + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext = Labels::new(); + /// assert_eq!(ext.len(), 0); + /// ext.insert(5i32); + /// assert_eq!(ext.len(), 1); + /// ``` + #[inline] + pub fn len(&self) -> usize { + self.map.as_ref().map_or(0, |map| map.len()) + } + + /// Extends `self` with another `Labels`. + /// + /// If an instance of a specific type exists in both, the one in `self` is overwritten with the + /// one from `other`. + /// + /// # Example + /// + /// ``` + /// # use shuttle::current::Labels; + /// let mut ext_a = Labels::new(); + /// ext_a.insert(8u8); + /// ext_a.insert(16u16); + /// + /// let mut ext_b = Labels::new(); + /// ext_b.insert(4u8); + /// ext_b.insert("hello"); + /// + /// ext_a.extend(ext_b); + /// assert_eq!(ext_a.len(), 3); + /// assert_eq!(ext_a.get::(), Some(&4u8)); + /// assert_eq!(ext_a.get::(), Some(&16u16)); + /// assert_eq!(ext_a.get::<&'static str>().copied(), Some("hello")); + /// ``` + pub fn extend(&mut self, other: Self) { + if let Some(other) = other.map { + if let Some(map) = &mut self.map { + map.extend(*other); + } else { + self.map = Some(other); + } + } + } +} + +impl fmt::Debug for Labels { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Labels").finish() + } +} + +trait AnyClone: Any { + fn clone_box(&self) -> Box; + fn as_any(&self) -> &dyn Any; + fn as_any_mut(&mut self) -> &mut dyn Any; + fn into_any(self: Box) -> Box; +} + +impl AnyClone for T { + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + fn into_any(self: Box) -> Box { + self + } +} + +impl Clone for Box { + fn clone(&self) -> Self { + (**self).clone_box() + } +} + +#[test] +fn test_labels() { + #[derive(Clone, Debug, PartialEq)] + struct MyType(i32); + + let mut labels = Labels::new(); + + labels.insert(5i32); + labels.insert(MyType(10)); + + assert_eq!(labels.get(), Some(&5i32)); + assert_eq!(labels.get_mut(), Some(&mut 5i32)); + + let ext2 = labels.clone(); + + assert_eq!(labels.remove::(), Some(5i32)); + assert!(labels.get::().is_none()); + + // clone still has it + assert_eq!(ext2.get(), Some(&5i32)); + assert_eq!(ext2.get(), Some(&MyType(10))); + + assert_eq!(labels.get::(), None); + assert_eq!(labels.get(), Some(&MyType(10))); +} diff --git a/src/runtime/task/mod.rs b/src/runtime/task/mod.rs index 7a385a66..e85851f4 100644 --- a/src/runtime/task/mod.rs +++ b/src/runtime/task/mod.rs @@ -1,6 +1,8 @@ +use crate::current::get_name_for_task; use crate::runtime::execution::{ExecutionState, TASK_ID_TO_TAGS}; use crate::runtime::storage::{AlreadyDestructedError, StorageKey, StorageMap}; use crate::runtime::task::clock::VectorClock; +use crate::runtime::task::labels::Labels; use crate::runtime::thread; use crate::runtime::thread::continuation::{ContinuationPool, PooledContinuation}; use crate::thread::LocalKey; @@ -15,6 +17,7 @@ use std::task::{Context, Waker}; use tracing::{error_span, event, field, Level, Span}; pub(crate) mod clock; +pub(crate) mod labels; pub(crate) mod waker; use waker::make_waker; @@ -38,6 +41,73 @@ use waker::make_waker; pub(crate) const DEFAULT_INLINE_TASKS: usize = 16; +/// To make debugging easier, if a task is assigned a `TaskName(s)` Label, +/// Shuttle will display the String `s` in addition to the `TaskId` in debug output. +#[derive(Clone, PartialEq, Eq)] +pub struct TaskName(String); + +impl From for TaskName { + fn from(s: String) -> Self { + Self(s) + } +} + +impl From<&str> for TaskName { + fn from(s: &str) -> Self { + Self(String::from(s)) + } +} + +impl std::fmt::Debug for TaskName { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From for String { + fn from(task_name: TaskName) -> Self { + task_name.0 + } +} + +impl<'a> From<&'a TaskName> for &'a String { + fn from(task_name: &'a TaskName) -> Self { + &task_name.0 + } +} + +/// By default, when a task or thread T is spawned, it inherits all labels from its parent. +/// It's often useful to modify or add new Labels to T. One approach is to put label changes +/// at the beginning of the closure that is passed to `spawn`, but this approach has the drawback +/// that the changes are applied only when T is first selected for execution, and the closure +/// is invoked. To overcome this drawback, we introduce the `ChildLabelFn` label. If a parent +/// task or thread has a `ChildLabelFn` set when it spawns a new child task or thread, the +/// child's label set at spawn time will be modified by applying the function inside the `ChildLabelFn`. +/// +/// # Example +/// The following example shows how a `ChildLabelFn` can be used to set up names for the next child(ren) +/// that will be spawned by a parent task. +/// ``` +/// # use shuttle::current::{me, set_label_for_task, get_name_for_task, ChildLabelFn, TaskName}; +/// # use std::sync::Arc; +/// // In the parent, set up a `ChildLabelFn` that assigns a name to the child task +/// shuttle::check_dfs(|| { +/// set_label_for_task(me(), ChildLabelFn(Arc::new(|labels| { labels.insert(TaskName::from("ChildTask")); }))); +/// shuttle::thread::spawn(|| { +/// assert_eq!(get_name_for_task(me()).unwrap(), TaskName::from("ChildTask")); // child task already has the name +/// // ... rest of child +/// }).join().unwrap(); +/// }, None); +/// ``` +#[derive(Clone)] +pub struct ChildLabelFn(pub Arc); + +impl Debug for ChildLabelFn { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "ChildLabelFn") + } +} + /// A `Tag` is an optional piece of metadata associated with a task (a thread or spawned future) to /// aid debugging. /// @@ -47,18 +117,21 @@ pub(crate) const DEFAULT_INLINE_TASKS: usize = 16; /// identify tasks in failing Shuttle tests. A task's [Tag] can be set with the /// [set_tag_for_current_task](crate::current::set_tag_for_current_task) function. Newly spawned /// threads and futures inherit the tag of their parent at spawn time. +#[deprecated] +#[allow(deprecated)] pub trait Tag: Taggable { /// Return the tag as `Any`, typically so that it can be downcast to a known concrete type fn as_any(&self) -> &dyn Any; } - /// `Taggable` is a marker trait which types implementing `Tag` have to implement. /// It exists since we both want to provide a blanket implementation of `as_any`, and have users /// opt in to a type being able to be used as a tag. If we did not have this trait, then `Tag` /// would be automatically implemented for most types (as most types are `Debug + Any`), which /// opens up for accidentally using a type which was not intended to be used as a tag as a tag. +#[deprecated] pub trait Taggable: Debug {} +#[allow(deprecated)] impl Tag for T where T: Taggable + Any, @@ -108,9 +181,11 @@ pub(crate) struct Task { pub(super) span_stack: Vec, // Arbitrarily settable tag which is inherited from the parent. + #[allow(deprecated)] tag: Option>, } +#[allow(deprecated)] impl Task { /// Create a task from a continuation #[allow(clippy::too_many_arguments)] @@ -464,14 +539,13 @@ pub(crate) struct ParkState { pub struct TaskId(pub(super) usize); impl Debug for TaskId { + // If the `TaskName` label is set, use that when generating the Debug string fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - TASK_ID_TO_TAGS.with(|cell| { - let map = cell.borrow(); - match map.get(self) { - Some(tag) => f.write_str(&format!("{:?}({})", tag, self.0)), - None => f.debug_tuple("TaskId").field(&self.0).finish(), - } - }) + if let Some(name) = get_name_for_task(*self) { + f.write_str(&format!("{:?}({})", name, self.0)) + } else { + f.debug_tuple("TaskId").field(&self.0).finish() + } } } diff --git a/tests/basic/labels.rs b/tests/basic/labels.rs new file mode 100644 index 00000000..210a4c55 --- /dev/null +++ b/tests/basic/labels.rs @@ -0,0 +1,337 @@ +use shuttle::{ + check_dfs, check_random, + current::{get_label_for_task, me, set_label_for_task, set_name_for_task, ChildLabelFn, TaskName}, + future, thread, +}; +use std::collections::HashSet; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use test_log::test; +use tracing::field::{Field, Visit}; +use tracing::span::{Attributes, Record}; +use tracing::{Event, Id, Metadata, Subscriber}; + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct Ident(usize); + +#[derive(Clone, Debug, PartialEq, Eq, Hash)] +struct Parent(usize); + +// This test does the following: spawns a tree of tasks with integer ids +// Each task has labels `Parent(i)` and `Ident(j)` where i is the parent's id, and j is the task's id. +// Here is the tree of tasks spawned: +// (1) +// / \ +// (2) (3) +// / \ / \ +// (4) (5) (6) (7) +// +// Each leaf task increments a global AtomicUsize (initially 0) by 1 +// Each task returns its (Parent, Ident) as a usize pair when it completes. +// The root waits for all tasks to complete, and collects the returned usize pairs. It then checks +// - that the final AtomicUsize value is 4 +// - that each (parent, child) pair in the tree above is reported exactly once +// +// This test checks the following properties +// - tasks inherit the Labels of their parent +// - changing child Labels doesn't affect parent Labels +// - child Labels are independent of each other +async fn spawn_tasks(counter: Arc) -> HashSet<(usize, usize)> { + set_label_for_task(me(), Ident(1)); + #[allow(clippy::type_complexity)] + let handles: Vec>)>> = (0..2) + .map(|i| { + let counter2 = counter.clone(); + future::spawn(async move { + // Set this task's Ident to be (2b + i) where (b) is the Ident of the parent + let Ident(parent) = get_label_for_task(me()).unwrap(); + set_label_for_task(me(), Parent(parent)); + set_label_for_task(me(), Ident(2 * parent + i)); + + let handles: Vec> = (0..2) + .map(|j| { + let counter3 = counter2.clone(); + future::spawn(async move { + let Ident(parent) = get_label_for_task(me()).unwrap(); + set_label_for_task(me(), Parent(parent)); + set_label_for_task(me(), Ident(2 * parent + j)); + // Increment global counter + counter3.fetch_add(1usize, Ordering::SeqCst); + // Return (Parent, Ident) for this task + let Parent(p) = get_label_for_task::(me()).unwrap(); + let Ident(c) = get_label_for_task::(me()).unwrap(); + (p, c) + }) + }) + .collect::<_>(); + // Read labels again after children have been spawned + let Parent(p) = get_label_for_task::(me()).unwrap(); + let Ident(c) = get_label_for_task::(me()).unwrap(); + (p, c, handles) + }) + }) + .collect::>(); + + let mut values = HashSet::new(); + for h in handles.into_iter() { + let (a, b, handles) = h.await.unwrap(); + for h2 in handles.into_iter() { + let v2 = h2.await.unwrap(); + assert!(values.insert(v2)); + } + assert!(values.insert((a, b))); + } + + // Validate that root Labels didn't change + let Ident(c) = get_label_for_task(me()).unwrap(); + assert_eq!(c, 1); + + assert_eq!(get_label_for_task::(me()), None); + + values +} + +#[test] +fn task_inheritance() { + check_random( + || { + let counter = Arc::new(AtomicUsize::new(0)); + let counter2 = counter.clone(); + + let seen_values = future::block_on(async move { spawn_tasks(counter2).await }); + + // Check final counter value + assert_eq!(counter.load(Ordering::SeqCst), 2usize.pow(2)); + + // Check that we saw all labels + let expected_values = HashSet::from([(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)]); + assert_eq!(seen_values, expected_values); + }, + 10_000, + ); +} + +// Same test as above, except with threads +fn spawn_threads(counter: Arc) -> HashSet<(usize, usize)> { + set_label_for_task(me(), Ident(1)); + #[allow(clippy::type_complexity)] + let handles: Vec>)>> = (0..2) + .map(|i| { + let counter2 = counter.clone(); + thread::spawn(move || { + // Set this task's Ident to be (2b + i) where (b) is the Ident of the parent + let Ident(parent) = get_label_for_task(me()).unwrap(); + set_label_for_task(me(), Parent(parent)); + set_label_for_task(me(), Ident(2 * parent + i)); + + let handles: Vec> = (0..2) + .map(|j| { + let counter3 = counter2.clone(); + thread::spawn(move || { + let Ident(parent) = get_label_for_task(me()).unwrap(); + set_label_for_task(me(), Parent(parent)); + set_label_for_task(me(), Ident(2 * parent + j)); + // Increment global counter + counter3.fetch_add(1usize, Ordering::SeqCst); + // Return (Parent, Ident) for this task + let Parent(p) = get_label_for_task::(me()).unwrap(); + let Ident(c) = get_label_for_task::(me()).unwrap(); + (p, c) + }) + }) + .collect::<_>(); + // Read labels again after children have been spawned + let Parent(p) = get_label_for_task::(me()).unwrap(); + let Ident(c) = get_label_for_task::(me()).unwrap(); + (p, c, handles) + }) + }) + .collect::<_>(); + + let mut values = HashSet::new(); + for h in handles.into_iter() { + let (a, b, handles) = h.join().unwrap(); + for h2 in handles.into_iter() { + let (c, d) = h2.join().unwrap(); + assert!(values.insert((c, d))); + } + assert!(values.insert((a, b))); + } + + // Validate that root Labels didn't change + let Ident(c) = get_label_for_task(me()).unwrap(); + assert_eq!(c, 1); + + assert_eq!(get_label_for_task::(me()), None); + + values +} + +#[test] +fn thread_inheritance() { + check_random( + || { + let counter = Arc::new(AtomicUsize::new(0)); + let counter2 = counter.clone(); + + let seen_values = spawn_threads(counter2); + + // Check final counter value + assert_eq!(counter.load(Ordering::SeqCst), 2usize.pow(2)); + + // Check that we saw all labels + let expected_values = HashSet::from([(1, 2), (1, 3), (2, 4), (2, 5), (3, 6), (3, 7)]); + assert_eq!(seen_values, expected_values); + }, + 10_000, + ); +} + +// Check that a task can modify another task's Label; in this example, +// the spawned task modifies its parent's Label. +#[test] +fn label_modify() { + check_dfs( + || { + // Start with a known label for current task + set_label_for_task(me(), Ident(0)); + + let parent_id = me(); + + let child = thread::spawn(move || { + // Set the label for the other thread + set_label_for_task(parent_id, Ident(1)); + assert_ne!(me(), parent_id); + // Return my id + get_label_for_task::(me()).unwrap() + }); + + let child_id = child.join().unwrap(); + let my_label = get_label_for_task::(me()).unwrap(); + + assert_eq!(my_label, Ident(1)); // parent id has changed + assert_eq!(child_id, Ident(0)); // child id is the parent's original id + }, + None, + ); +} + +// The following tests exercise the functionality provided by `ChildLabelFn`. +// The main task (which is always named "main-thread(0)" by Shuttle) creates 3 child +// tasks. We test two scenarios: +// (1) the parent sets up a `ChildLabelFn` to assign the child names at creation time +// (so that the names are in effect as soon as the child is created) +// (2) the child tasks assign names to themselves by calling `set_task_name` as the +// first statement in their code, +// +// We use a custom tracing subscriber to monitor the list of `enabled` tasks that +// is logged by Shuttle at each scheduling point. (This also allows us to check +// that Shuttle uses the user-assigned `TaskName` when generating debug output.) +// +// The two tests below check that: +// - In scenario (1), the only names seen in the `runnable` list are `main-thread(0)` +// or names starting with `Child(` +// - In scenario (2) the `runnable` list contains other names + +async fn label_fn_inner(set_name_before_spawn: bool) { + // Spawn 3 children + let handles = (0..3).map(|_| { + if set_name_before_spawn { + // Test scenario (2) + set_label_for_task( + me(), + ChildLabelFn(Arc::new(|labels| { + labels.insert(TaskName::from("Child")); + })), + ); + } + future::spawn(async move { + if !set_name_before_spawn { + // Test scenario (1) + set_name_for_task(me(), TaskName::from("Child")); + } + shuttle::future::yield_now().await; + }) + }); + + for h in handles { + h.await.unwrap(); + } +} + +#[test] +fn test_tracing_with_label_fn() { + let metrics = RunnableSubscriber {}; + let _guard = tracing::subscriber::set_default(metrics); + + check_random( + || { + future::block_on(async { label_fn_inner(true).await }); + }, + 10, + ); +} + +#[test] +#[should_panic(expected = "assertion failed")] +fn test_tracing_without_label_fn() { + let metrics = RunnableSubscriber {}; + let _guard = tracing::subscriber::set_default(metrics); + + check_random( + || { + future::block_on(async { label_fn_inner(false).await }); + }, + 1, // even one execution is enough to fail the assertion + ); +} + +// Custom Subscriber implementation to monitor and check debug output generated by Shuttle. +struct RunnableSubscriber; + +impl Subscriber for RunnableSubscriber { + fn enabled(&self, _metadata: &Metadata<'_>) -> bool { + true + } + + fn new_span(&self, _span: &Attributes<'_>) -> Id { + // We don't care about span equality so just use the same identity for everything + Id::from_u64(1) + } + + fn record(&self, _span: &Id, _values: &Record<'_>) {} + fn record_follows_from(&self, _span: &Id, _follows: &Id) {} + + fn event(&self, event: &Event<'_>) { + let metadata = event.metadata(); + let target = metadata.target(); + if target.contains("shuttle") && target.ends_with("::runtime::execution") { + let fields: &tracing::field::FieldSet = metadata.fields(); + if fields.iter().any(|f| f.name() == "runnable") { + struct CheckRunnableSubscriber; + impl Visit for CheckRunnableSubscriber { + fn record_debug(&mut self, field: &Field, value: &dyn std::fmt::Debug) { + if field.name() == "runnable" { + // The following code relies on the fact that the list of runnable tasks is a SmallVec which is reported + // in debug output in the format "[main-thread(0), first-task-name(3), other-task-name(17)]" etc. + let value = format!("{:?}", value).replace('[', "").replace(']', ""); + let v1 = value.split(','); + // The following assertion fails if a `ChildLabelFn` is not used to set child task names. + assert!(v1 + .map(|s| s.trim()) + .all(|s| (s == "main-thread(0)") || s.starts_with("Child("))); + } + } + } + + let mut visitor = CheckRunnableSubscriber {}; + event.record(&mut visitor); + } + } + } + + fn enter(&self, _span: &Id) {} + fn exit(&self, _span: &Id) {} +} diff --git a/tests/basic/mod.rs b/tests/basic/mod.rs index 14b2c44e..faae7d9a 100644 --- a/tests/basic/mod.rs +++ b/tests/basic/mod.rs @@ -5,6 +5,7 @@ mod condvar; mod config; mod dfs; mod execution; +mod labels; mod lazy_static; mod metrics; mod mpsc; diff --git a/tests/basic/tag.rs b/tests/basic/tag.rs index 884eab40..4117106b 100644 --- a/tests/basic/tag.rs +++ b/tests/basic/tag.rs @@ -1,3 +1,5 @@ +#![allow(deprecated)] + use futures::future::join_all; use shuttle::{ check_dfs, check_random, @@ -279,6 +281,7 @@ impl Subscriber for RunnableSubscriber { fn exit(&self, _span: &Id) {} } +#[ignore] // This test doesn't work anymore, since we don't use tags for tracing output anymore #[test] fn tracing_tags() { let metrics = RunnableSubscriber::new(); diff --git a/tests/basic/thread.rs b/tests/basic/thread.rs index 24179f98..d1c93197 100644 --- a/tests/basic/thread.rs +++ b/tests/basic/thread.rs @@ -1,3 +1,4 @@ +use shuttle::current::{get_name_for_task, me}; use shuttle::sync::{Barrier, Condvar, Mutex}; use shuttle::{check_dfs, check_random, thread}; use std::collections::HashSet; @@ -75,6 +76,8 @@ fn thread_builder_name() { let builder = thread::Builder::new().name("producer".into()); let handle = builder .spawn(|| { + let name = String::from(get_name_for_task(me()).unwrap()); + assert_eq!(name, "producer"); thread::yield_now(); }) .unwrap();