From 8ee042a1341af3711469bd4fd0b640fbbe6e0a09 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 13:26:21 -0500 Subject: [PATCH 01/15] IR fuzz target with Arbitrary implementations --- Cargo.toml | 4 ++- fuzz/Cargo.toml | 8 +++++- fuzz/fuzz_targets/ir.rs | 10 ++++++++ fuzz/fuzz_targets/spv_parser.rs | 2 +- src/arena.rs | 42 +++++++++++++++++++++++++++--- src/block.rs | 1 + src/lib.rs | 45 +++++++++++++++++++++++++++++++++ src/span.rs | 1 + 8 files changed, 106 insertions(+), 7 deletions(-) create mode 100644 fuzz/fuzz_targets/ir.rs diff --git a/Cargo.toml b/Cargo.toml index bb93b3b32f..a0dcec5703 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,10 +14,12 @@ resolver = "2" [package.metadata.docs.rs] all-features = true -[dependencies] # MSRV warnings: # - bitflags 1.3 requires Rust-1.46 # - indexmap 1.7 requires Rust-1.49 + +[dependencies] +arbitrary = { version = "1", features = ["derive"], optional = true } bitflags = "1" bit-set = "0.5" codespan-reporting = { version = "0.11.0", optional = true } diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index ceb30ea551..06c3e5a961 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -14,7 +14,7 @@ libfuzzer-sys = "0.4" [dependencies.naga] path = ".." -features = ["spv-in", "wgsl-in", "glsl-in"] +features = ["arbitrary", "spv-in", "wgsl-in", "glsl-in", "validate"] # Prevent this from interfering with workspaces [workspace] @@ -37,3 +37,9 @@ name = "glsl_parser" path = "fuzz_targets/glsl_parser.rs" test = false doc = false + +[[bin]] +name = "ir" +path = "fuzz_targets/ir.rs" +test = false +doc = false diff --git a/fuzz/fuzz_targets/ir.rs b/fuzz/fuzz_targets/ir.rs new file mode 100644 index 0000000000..86b7a6aa6e --- /dev/null +++ b/fuzz/fuzz_targets/ir.rs @@ -0,0 +1,10 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|module: naga::Module| { + use naga::valid as v; + // Check if the module validates without errors. + //TODO: may also fuzz the flags and capabilities + let mut validator = v::Validator::new(v::ValidationFlags::all(), v::Capabilities::empty()); + let _result = validator.validate(&module); +}); diff --git a/fuzz/fuzz_targets/spv_parser.rs b/fuzz/fuzz_targets/spv_parser.rs index 3b8af0fbb4..0fda11f439 100644 --- a/fuzz/fuzz_targets/spv_parser.rs +++ b/fuzz/fuzz_targets/spv_parser.rs @@ -1,6 +1,6 @@ #![no_main] use libfuzzer_sys::fuzz_target; -use naga::front::spv::{Parser, Options}; +use naga::front::spv::{Options, Parser}; fuzz_target!(|data: Vec| { // Ensure the parser can handle potentially malformed data without crashing. diff --git a/src/arena.rs b/src/arena.rs index 4884010d3c..56f70e94d2 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -17,6 +17,7 @@ use indexmap::set::IndexSet; any(feature = "serialize", feature = "deserialize"), serde(transparent) )] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct Handle { index: Index, #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))] @@ -110,6 +111,7 @@ impl Handle { any(feature = "serialize", feature = "deserialize"), serde(transparent) )] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct Range { inner: ops::Range, #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))] @@ -154,6 +156,7 @@ impl Iterator for Range { /// a reference to the stored item. #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "serialize", serde(transparent))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] #[cfg_attr(test, derive(PartialEq))] pub struct Arena { /// Values of this arena. @@ -543,8 +546,7 @@ impl ops::Index> for UniqueArena { #[cfg(feature = "serialize")] impl serde::Serialize for UniqueArena where - T: Eq + hash::Hash, - T: serde::Serialize, + T: Eq + hash::Hash + serde::Serialize, { fn serialize(&self, serializer: S) -> Result where @@ -557,8 +559,7 @@ where #[cfg(feature = "deserialize")] impl<'de, T> serde::Deserialize<'de> for UniqueArena where - T: Eq + hash::Hash, - T: serde::Deserialize<'de>, + T: Eq + hash::Hash + serde::Deserialize<'de>, { fn deserialize(deserializer: D) -> Result where @@ -575,3 +576,36 @@ where }) } } + +//Note: largely borrowed from `HashSet` implementation +#[cfg(feature = "arbitrary")] +impl<'a, T> arbitrary::Arbitrary<'a> for UniqueArena +where + T: Eq + hash::Hash + arbitrary::Arbitrary<'a>, +{ + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result { + let mut arena = Self::default(); + for elem in u.arbitrary_iter()? { + arena.set.insert(elem?); + #[cfg(feature = "span")] + arena.span_info.push(Span::UNDEFINED); + } + Ok(arena) + } + + fn arbitrary_take_rest(u: arbitrary::Unstructured<'a>) -> arbitrary::Result { + let mut arena = Self::default(); + for elem in u.arbitrary_take_rest_iter()? { + arena.set.insert(elem?); + #[cfg(feature = "span")] + arena.span_info.push(Span::UNDEFINED); + } + Ok(arena) + } + + #[inline] + fn size_hint(depth: usize) -> (usize, Option) { + let depth_hint = ::size_hint(depth); + arbitrary::size_hint::and(depth_hint, (0, None)) + } +} diff --git a/src/block.rs b/src/block.rs index 6a31301e11..e70202b8d1 100644 --- a/src/block.rs +++ b/src/block.rs @@ -5,6 +5,7 @@ use std::ops::{Deref, DerefMut, RangeBounds}; #[derive(Debug, Clone, Default)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "serialize", serde(transparent))] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct Block { body: Vec, #[cfg(feature = "span")] diff --git a/src/lib.rs b/src/lib.rs index 6e7083082f..d6e6aacd67 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -217,6 +217,8 @@ pub mod valid; pub use crate::arena::{Arena, Handle, Range, UniqueArena}; pub use crate::span::{Span, SpanContext, WithSpan}; +#[cfg(feature = "arbitrary")] +use arbitrary::Arbitrary; #[cfg(feature = "deserialize")] use serde::Deserialize; #[cfg(feature = "serialize")] @@ -248,6 +250,7 @@ pub(crate) type NamedExpressions = FastHashMap, String>; #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct EarlyDepthTest { conservative: Option, } @@ -264,6 +267,7 @@ pub struct EarlyDepthTest { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ConservativeDepth { /// Shader may rewrite depth only with a value greater than calculated; GreaterEqual, @@ -279,6 +283,7 @@ pub enum ConservativeDepth { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[allow(missing_docs)] // The names are self evident pub enum ShaderStage { Vertex, @@ -290,6 +295,7 @@ pub enum ShaderStage { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum StorageClass { /// Function locals. Function, @@ -311,6 +317,7 @@ pub enum StorageClass { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum BuiltIn { Position, ViewIndex, @@ -345,6 +352,7 @@ pub type Bytes = u8; #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum VectorSize { /// 2D vector Bi = 2, @@ -359,6 +367,7 @@ pub enum VectorSize { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ScalarKind { /// Signed integer type. Sint, @@ -375,6 +384,7 @@ pub enum ScalarKind { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ArraySize { /// The array size is constant. Constant(Handle), @@ -386,6 +396,7 @@ pub enum ArraySize { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Interpolation { /// The value will be interpolated in a perspective-correct fashion. /// Also known as "smooth" in glsl. @@ -402,6 +413,7 @@ pub enum Interpolation { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Sampling { /// Interpolate the value at the center of the pixel. Center, @@ -421,6 +433,7 @@ pub enum Sampling { #[derive(Clone, Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct StructMember { pub name: Option, /// Type of the field. @@ -435,6 +448,7 @@ pub struct StructMember { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ImageDimension { /// 1D image D1, @@ -450,6 +464,7 @@ bitflags::bitflags! { /// Flags describing an image. #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] + #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Default)] pub struct StorageAccess: u32 { /// Storage can be used as a source for load ops. @@ -463,6 +478,7 @@ bitflags::bitflags! { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum StorageFormat { // 8-bit formats R8Unorm, @@ -513,6 +529,7 @@ pub enum StorageFormat { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ImageClass { /// Regular sampled image. Sampled { @@ -540,6 +557,7 @@ pub enum ImageClass { #[derive(Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Type { /// The name of the type, if any. pub name: Option, @@ -551,6 +569,7 @@ pub struct Type { #[derive(Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum TypeInner { /// Number of integral or floating-point kind. Scalar { kind: ScalarKind, width: Bytes }, @@ -680,6 +699,7 @@ pub enum TypeInner { #[derive(Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Constant { pub name: Option, pub specialization: Option, @@ -690,6 +710,7 @@ pub struct Constant { #[derive(Debug, Clone, Copy, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ScalarValue { Sint(i64), Uint(u64), @@ -701,6 +722,7 @@ pub enum ScalarValue { #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ConstantInner { Scalar { width: Bytes, @@ -716,6 +738,7 @@ pub enum ConstantInner { #[derive(Clone, Debug, Eq, PartialEq, Hash)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Binding { /// Built-in shader variable. BuiltIn(BuiltIn), @@ -747,6 +770,7 @@ pub enum Binding { #[derive(Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct ResourceBinding { /// The bind group index. pub group: u32, @@ -758,6 +782,7 @@ pub struct ResourceBinding { #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct GlobalVariable { /// Name of the variable, if any. pub name: Option, @@ -775,6 +800,7 @@ pub struct GlobalVariable { #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct LocalVariable { /// Name of the variable, if any. pub name: Option, @@ -788,6 +814,7 @@ pub struct LocalVariable { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum UnaryOperator { Negate, Not, @@ -797,6 +824,7 @@ pub enum UnaryOperator { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum BinaryOperator { Add, Subtract, @@ -827,6 +855,7 @@ pub enum BinaryOperator { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum AtomicFunction { Add, Subtract, @@ -842,6 +871,7 @@ pub enum AtomicFunction { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum DerivativeAxis { X, Y, @@ -852,6 +882,7 @@ pub enum DerivativeAxis { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum RelationalFunction { All, Any, @@ -865,6 +896,7 @@ pub enum RelationalFunction { #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum MathFunction { // comparison Abs, @@ -948,6 +980,7 @@ pub enum MathFunction { #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum SampleLevel { Auto, Zero, @@ -963,6 +996,7 @@ pub enum SampleLevel { #[derive(Clone, Copy, Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum ImageQuery { /// Get the size at the specified level. Size { @@ -982,6 +1016,7 @@ pub enum ImageQuery { #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum SwizzleComponent { /// X = 0, @@ -997,6 +1032,7 @@ bitflags::bitflags! { /// Memory barrier flags. #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] + #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[derive(Default)] pub struct Barrier: u32 { /// Barrier affects all `StorageClass::Storage` accesses. @@ -1013,6 +1049,7 @@ bitflags::bitflags! { #[cfg_attr(test, derive(PartialEq))] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Expression { /// Array access with a computed index. /// @@ -1301,6 +1338,7 @@ pub use block::Block; #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum SwitchValue { Integer(i32), Default, @@ -1311,6 +1349,7 @@ pub enum SwitchValue { #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct SwitchCase { /// Value, upon which the case is considered true. pub value: SwitchValue, @@ -1327,6 +1366,7 @@ pub struct SwitchCase { #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum Statement { /// Emit a range of expressions, visible to all statements that follow in this block. /// @@ -1470,6 +1510,7 @@ pub enum Statement { #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct FunctionArgument { /// Name of the argument, if any. pub name: Option, @@ -1483,6 +1524,7 @@ pub struct FunctionArgument { #[derive(Clone, Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct FunctionResult { /// Type of the result. pub ty: Handle, @@ -1495,6 +1537,7 @@ pub struct FunctionResult { #[derive(Debug, Default)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Function { /// Name of the function, if any. pub name: Option, @@ -1558,6 +1601,7 @@ pub struct Function { #[derive(Debug)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct EntryPoint { /// Name of this entry point, visible externally. /// @@ -1587,6 +1631,7 @@ pub struct EntryPoint { #[derive(Debug, Default)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub struct Module { /// Storage for the types defined in this module. pub types: UniqueArena, diff --git a/src/span.rs b/src/span.rs index 6d541210b9..51ea481016 100644 --- a/src/span.rs +++ b/src/span.rs @@ -3,6 +3,7 @@ use std::{error::Error, fmt, ops::Range}; /// A source code span, used for error reporting. #[derive(Clone, Copy, Debug, PartialEq, Default)] +#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] pub struct Span { start: u32, end: u32, From 1eddd990da5e4c083bf46cfe0f956e1f66ccb464 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 13:26:44 -0500 Subject: [PATCH 02/15] layouter: handle width=0 gracefully --- src/proc/layouter.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index 5f6559f08e..39e8177eeb 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -74,7 +74,7 @@ impl Layouter { let layout = match ty.inner { Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => TypeLayout { size, - alignment: Alignment::new(width as u32).unwrap(), + alignment: Alignment::new(width as u32).ok_or(InvalidBaseType(ty_handle))?, }, Ti::Vector { size: vec_size, @@ -88,7 +88,7 @@ impl Layouter { } else { 2 }; - Alignment::new((count * width) as u32).unwrap() + Alignment::new((count * width) as u32).ok_or(InvalidBaseType(ty_handle))? }, }, Ti::Matrix { @@ -99,7 +99,7 @@ impl Layouter { size, alignment: { let count = if rows >= crate::VectorSize::Tri { 4 } else { 2 }; - Alignment::new((count * width) as u32).unwrap() + Alignment::new((count * width) as u32).ok_or(InvalidBaseType(ty_handle))? }, }, Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { From b06bf61d875ebd005a574f22547909a6dcbdeb8e Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 13:27:15 -0500 Subject: [PATCH 03/15] Fix multiplication overflow in span() computation --- src/proc/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 13883df909..00ec69c605 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -101,7 +101,7 @@ impl super::TypeInner { size, kind: _, width, - } => (size as u8 * width) as u32, + } => size as u32 * width as u32, // matrices are treated as arrays of aligned columns Self::Matrix { columns, From bb1c5000f25ff62bdc084147d8ae5c12dead89ad Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 14:06:14 -0500 Subject: [PATCH 04/15] validate: check constant composite type --- src/valid/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 6b2c804757..f703bb916c 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -240,7 +240,7 @@ impl Validator { } } crate::ConstantInner::Composite { ty, ref components } => { - match types[ty].inner { + match types.get_handle(ty).ok_or(ConstantError::InvalidType)?.inner { crate::TypeInner::Array { size: crate::ArraySize::Constant(size_handle), .. From 9678d453b501f382233e2e0cd642bf2a83cad57b Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 14:15:22 -0500 Subject: [PATCH 05/15] typifier: check local vars, global vars, and function arguments to exist --- src/proc/typifier.rs | 24 +++++++++++++++++++----- src/valid/function.rs | 8 +------- src/valid/mod.rs | 6 +++++- 3 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 7a1d7a46e0..6e237ee5a1 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -189,10 +189,14 @@ pub enum ResolveError { FunctionNotDefined { name: String }, #[error("Function without return type")] FunctionReturnsVoid, - #[error("Type is not found in the given immutable arena")] - TypeNotFound, #[error("Incompatible operands: {0}")] IncompatibleOperands(String), + #[error("Local var {0:?} doesn't exist")] + LocalVariableNotFound(Handle), + #[error("Global var {0:?} doesn't exist")] + GlobalVariableNotFound(Handle), + #[error("Function argument {0} doesn't exist")] + FunctionArgumentNotFound(u32), } pub struct ResolveContext<'a> { @@ -430,10 +434,17 @@ impl<'a> ResolveContext<'a> { }, crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty), crate::Expression::FunctionArgument(index) => { - TypeResolution::Handle(self.arguments[index as usize].ty) + let arg = self + .arguments + .get(index as usize) + .ok_or(ResolveError::FunctionArgumentNotFound(index))?; + TypeResolution::Handle(arg.ty) } crate::Expression::GlobalVariable(h) => { - let var = &self.global_vars[h]; + let var = self + .global_vars + .try_get(h) + .ok_or(ResolveError::GlobalVariableNotFound(h))?; if var.class == crate::StorageClass::Handle { TypeResolution::Handle(var.ty) } else { @@ -444,7 +455,10 @@ impl<'a> ResolveContext<'a> { } } crate::Expression::LocalVariable(h) => { - let var = &self.local_vars[h]; + let var = self + .local_vars + .try_get(h) + .ok_or(ResolveError::LocalVariableNotFound(h))?; TypeResolution::Value(Ti::Pointer { base: var.ty, class: crate::StorageClass::Function, diff --git a/src/valid/function.rs b/src/valid/function.rs index 0254775ec5..66f5965dbb 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -819,13 +819,7 @@ impl super::Validator { module: &crate::Module, mod_info: &ModuleInfo, ) -> Result> { - #[cfg(feature = "validate")] - let mut info = mod_info - .process_function(fun, module, self.flags) - .map_err(WithSpan::into_other)?; - - #[cfg(not(feature = "validate"))] - let info = mod_info.process_function(fun, module, self.flags)?; + let mut info = mod_info.process_function(fun, module, self.flags)?; #[cfg(feature = "validate")] for (var_handle, var) in fun.local_variables.iter() { diff --git a/src/valid/mod.rs b/src/valid/mod.rs index f703bb916c..a7e57a8667 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -240,7 +240,11 @@ impl Validator { } } crate::ConstantInner::Composite { ty, ref components } => { - match types.get_handle(ty).ok_or(ConstantError::InvalidType)?.inner { + match types + .get_handle(ty) + .ok_or(ConstantError::InvalidType)? + .inner + { crate::TypeInner::Array { size: crate::ArraySize::Constant(size_handle), .. From 1795d1c052e53f80b7231aea1570ba2c8937a008 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 14:27:39 -0500 Subject: [PATCH 06/15] validate: check global var type to be in range --- src/valid/interface.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 72c6561abd..5665afaced 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -332,7 +332,10 @@ impl super::Validator { use super::TypeFlags; log::debug!("var {:?}", var); - let type_info = &self.types[var.ty.index()]; + let type_info = self + .types + .get(var.ty.index()) + .ok_or(GlobalVariableError::InvalidType)?; let (required_type_flags, is_resource) = match var.class { crate::StorageClass::Function => return Err(GlobalVariableError::InvalidUsage), From d85e0e5a7b3fca8d3f3da9b5518289a9314fe938 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 14:30:06 -0500 Subject: [PATCH 07/15] layouter: handle width multiplication overflows --- src/proc/layouter.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index 39e8177eeb..d9b1f2f159 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -88,7 +88,7 @@ impl Layouter { } else { 2 }; - Alignment::new((count * width) as u32).ok_or(InvalidBaseType(ty_handle))? + Alignment::new(count * width as u32).ok_or(InvalidBaseType(ty_handle))? }, }, Ti::Matrix { @@ -99,7 +99,7 @@ impl Layouter { size, alignment: { let count = if rows >= crate::VectorSize::Tri { 4 } else { 2 }; - Alignment::new((count * width) as u32).ok_or(InvalidBaseType(ty_handle))? + Alignment::new(count * width as u32).ok_or(InvalidBaseType(ty_handle))? }, }, Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { From aaea6f71bdbb603ef6459428655524df097ed3f6 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 14:35:34 -0500 Subject: [PATCH 08/15] validate: check function argument type to be in range --- src/valid/function.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/valid/function.rs b/src/valid/function.rs index 66f5965dbb..411ef7ce18 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -837,7 +837,14 @@ impl super::Validator { #[cfg(feature = "validate")] for (index, argument) in fun.arguments.iter().enumerate() { - match module.types[argument.ty].inner.pointer_class() { + let ty = module.types.get_handle(argument.ty).ok_or( + FunctionError::InvalidArgumentType { + index, + name: argument.name.clone().unwrap_or_default(), + } + .with_span_handle(argument.ty, &module.types), + )?; + match ty.inner.pointer_class() { Some(crate::StorageClass::Private) | Some(crate::StorageClass::Function) | Some(crate::StorageClass::WorkGroup) From bd54b0eeace7e72a5634a4fd342c890137383ec7 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 14:39:00 -0500 Subject: [PATCH 09/15] validate: check local var type to be in range --- src/valid/function.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/valid/function.rs b/src/valid/function.rs index 411ef7ce18..d0eeab1b5c 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -786,12 +786,17 @@ impl super::Validator { constants: &Arena, ) -> Result<(), LocalVariableError> { log::debug!("var {:?}", var); - if !self.types[var.ty.index()] + let type_info = self + .types + .get(var.ty.index()) + .ok_or(LocalVariableError::InvalidType(var.ty))?; + if !type_info .flags .contains(super::TypeFlags::DATA | super::TypeFlags::SIZED) { return Err(LocalVariableError::InvalidType(var.ty)); } + if let Some(const_handle) = var.init { match constants[const_handle].inner { crate::ConstantInner::Scalar { width, ref value } => { From 2a745c7e1848a2bac693f2816e119c91a0e7c92f Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 16:27:28 -0500 Subject: [PATCH 10/15] layouter: rich and careful errors --- src/back/hlsl/storage.rs | 4 +-- src/back/msl/writer.rs | 17 ++++++------ src/front/glsl/functions.rs | 5 +++- src/front/glsl/types.rs | 10 +++---- src/front/spv/mod.rs | 4 +-- src/proc/layouter.rs | 54 +++++++++++++++++++++++++++++-------- src/proc/mod.rs | 24 +++++++++-------- src/valid/function.rs | 7 ++--- src/valid/mod.rs | 6 ++--- src/valid/type.rs | 7 +++-- 10 files changed, 90 insertions(+), 48 deletions(-) diff --git a/src/back/hlsl/storage.rs b/src/back/hlsl/storage.rs index b235a061fb..08909802b1 100644 --- a/src/back/hlsl/storage.rs +++ b/src/back/hlsl/storage.rs @@ -148,7 +148,7 @@ impl super::Writer<'_, W> { } => { write!(self.out, "{{")?; let count = module.constants[const_handle].to_array_length().unwrap(); - let stride = module.types[base].inner.span(&module.constants); + let stride = module.types[base].inner.size(&module.constants).unwrap(); let iter = (0..count).map(|i| (TypeResolution::Handle(base), stride * i)); self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; write!(self.out, "}}")?; @@ -311,7 +311,7 @@ impl super::Writer<'_, W> { writeln!(self.out, ";")?; // then iterate the stores let count = module.constants[const_handle].to_array_length().unwrap(); - let stride = module.types[base].inner.span(&module.constants); + let stride = module.types[base].inner.size(&module.constants).unwrap(); for i in 0..count { self.temp_access_chain.push(SubAccess::Offset(i * stride)); let sv = StoreValue::TempIndex { diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index dd4e93b222..9533e5076f 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -349,7 +349,7 @@ fn should_pack_struct_member( } let ty_inner = &module.types[member.ty].inner; - let last_offset = member.offset + ty_inner.span(&module.constants); + let last_offset = member.offset + ty_inner.size(&module.constants).unwrap(); let next_offset = match members.get(index + 1) { Some(next) => next.offset, None => span, @@ -750,22 +750,23 @@ impl Writer { None => return Err(Error::Validation), }; - let (span, stride) = match context.module.types[array_ty].inner { + let (size, stride) = match context.module.types[array_ty].inner { crate::TypeInner::Array { base, stride, .. } => ( context.module.types[base] .inner - .span(&context.module.constants), + .size(&context.module.constants) + .unwrap(), stride, ), _ => return Err(Error::Validation), }; - // When the stride length is larger than the span, the final element's stride of + // When the stride length is larger than the size, the final element's stride of // bytes would have padding following the value. But the buffer size in // `buffer_sizes.sizeN` may not include this padding - it only needs to be large // enough to hold the actual values' bytes. // - // So subtract off the span to get a byte size that falls at the start or within + // So subtract off the size to get a byte size that falls at the start or within // the final element. Then divide by the stride size, to get one less than the // length, and then add one. This works even if the buffer size does include the // stride padding, since division rounds towards zero (MSL 2.4 ยง6.1). It will fail @@ -774,10 +775,10 @@ impl Writer { // prevent that. write!( self.out, - "(_buffer_sizes.size{idx} - {offset} - {span}) / {stride}", + "(_buffer_sizes.size{idx} - {offset} - {size}) / {stride}", idx = handle.index(), offset = offset, - span = span, + size = size, stride = stride, )?; Ok(()) @@ -2379,7 +2380,7 @@ impl Writer { writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?; } let ty_inner = &module.types[member.ty].inner; - last_offset = member.offset + ty_inner.span(&module.constants); + last_offset = member.offset + ty_inner.size(&module.constants).unwrap(); let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 14f30bdfb7..d565b00222 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -1293,7 +1293,10 @@ impl Parser { offset: span, }); - span += self.module.types[ty].inner.span(&self.module.constants); + span += self.module.types[ty] + .inner + .size(&self.module.constants) + .unwrap(); let len = expressions.len(); let load = expressions.append(Expression::Load { pointer }, Default::default()); diff --git a/src/front/glsl/types.rs b/src/front/glsl/types.rs index e008ec3822..106104378e 100644 --- a/src/front/glsl/types.rs +++ b/src/front/glsl/types.rs @@ -262,14 +262,14 @@ impl Parser { array_specifier .map(|(size, size_meta)| { meta.subsume(size_meta); + let stride = self.module.types[base] + .inner + .size(&self.module.constants) + .unwrap(); self.module.types.insert( Type { name: None, - inner: TypeInner::Array { - base, - size, - stride: self.module.types[base].inner.span(&self.module.constants), - }, + inner: TypeInner::Array { base, size, stride }, }, meta, ) diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index e1de8bdaf6..f6f7db14dd 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -4078,7 +4078,7 @@ impl> Parser { size: crate::ArraySize::Constant(length_const.handle), stride: match decor.array_stride { Some(stride) => stride.get(), - None => module.types[base].inner.span(&module.constants), + None => module.types[base].inner.size(&module.constants).unwrap(), }, }; self.lookup_type.insert( @@ -4115,7 +4115,7 @@ impl> Parser { size: crate::ArraySize::Dynamic, stride: match decor.array_stride { Some(stride) => stride.get(), - None => module.types[base].inner.span(&module.constants), + None => module.types[base].inner.size(&module.constants).unwrap(), }, }; self.lookup_type.insert( diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index d9b1f2f159..637257c17e 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -30,8 +30,29 @@ impl ops::Index> for Layouter { } #[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] -#[error("Base type {0:?} is out of bounds")] -pub struct InvalidBaseType(pub Handle); +pub enum TypeLayoutError { + #[error("Array element type {0:?} doesn't exist")] + InvalidArrayElementType(Handle), + #[error("Struct member[{0}] type {1:?} doesn't exist")] + InvalidStructMemberType(u32, Handle), + #[error("Zero width is not supported")] + ZeroWidth, + #[error(transparent)] + Size(#[from] super::ProcError), +} + +#[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] +#[error("Error laying out type {ty:?}: {inner}")] +pub struct LayoutError { + pub ty: Handle, + pub inner: TypeLayoutError, +} + +impl TypeLayoutError { + fn with(self, ty: Handle) -> LayoutError { + LayoutError { ty, inner: self } + } +} impl Layouter { pub fn clear(&mut self) { @@ -62,19 +83,24 @@ impl Layouter { (start..start + span, alignment) } + #[allow(clippy::or_fun_call)] pub fn update( &mut self, types: &UniqueArena, constants: &Arena, - ) -> Result<(), InvalidBaseType> { + ) -> Result<(), LayoutError> { use crate::TypeInner as Ti; for (ty_handle, ty) in types.iter().skip(self.layouts.len()) { - let size = ty.inner.span(constants); + let size = ty + .inner + .size(constants) + .map_err(|error| TypeLayoutError::Size(error).with(ty_handle))?; let layout = match ty.inner { Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => TypeLayout { size, - alignment: Alignment::new(width as u32).ok_or(InvalidBaseType(ty_handle))?, + alignment: Alignment::new(width as u32) + .ok_or(TypeLayoutError::ZeroWidth.with(ty_handle))?, }, Ti::Vector { size: vec_size, @@ -88,7 +114,8 @@ impl Layouter { } else { 2 }; - Alignment::new(count * width as u32).ok_or(InvalidBaseType(ty_handle))? + Alignment::new(count * width as u32) + .ok_or(TypeLayoutError::ZeroWidth.with(ty_handle))? }, }, Ti::Matrix { @@ -99,7 +126,8 @@ impl Layouter { size, alignment: { let count = if rows >= crate::VectorSize::Tri { 4 } else { 2 }; - Alignment::new(count * width as u32).ok_or(InvalidBaseType(ty_handle))? + Alignment::new(count * width as u32) + .ok_or(TypeLayoutError::ZeroWidth.with(ty_handle))? }, }, Ti::Pointer { .. } | Ti::ValuePointer { .. } => TypeLayout { @@ -115,16 +143,20 @@ impl Layouter { alignment: if base < ty_handle { self[base].alignment } else { - return Err(InvalidBaseType(base)); + return Err(TypeLayoutError::InvalidArrayElementType(base).with(ty_handle)); }, }, Ti::Struct { span, ref members } => { let mut alignment = Alignment::new(1).unwrap(); - for member in members { + for (index, member) in members.iter().enumerate() { alignment = if member.ty < ty_handle { alignment.max(self[member.ty].alignment) } else { - return Err(InvalidBaseType(member.ty)); + return Err(TypeLayoutError::InvalidStructMemberType( + index as u32, + member.ty, + ) + .with(ty_handle)); }; } TypeLayout { @@ -137,7 +169,7 @@ impl Layouter { alignment: Alignment::new(1).unwrap(), }, }; - debug_assert!(ty.inner.span(constants) <= layout.size); + debug_assert!(size <= layout.size); self.layouts.push(layout); } diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 00ec69c605..79fd078ecd 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -9,16 +9,16 @@ mod typifier; use std::cmp::PartialEq; pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength}; -pub use layouter::{Alignment, InvalidBaseType, Layouter, TypeLayout}; +pub use layouter::{Alignment, LayoutError, Layouter, TypeLayout, TypeLayoutError}; pub use namer::{EntryPointIndex, NameKey, Namer}; pub use terminator::ensure_block_returns; pub use typifier::{ResolveContext, ResolveError, TypeResolution}; -#[derive(Clone, Debug, thiserror::Error, PartialEq)] +#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)] pub enum ProcError { - #[error("type is not indexable, and has no length (validation error)")] + #[error("Type is not indexable, and has no length (validation error)")] TypeNotIndexable, - #[error("array length is wrong kind of constant (validation error)")] + #[error("Array length {0:?} is wrong kind of constant (validation error)")] InvalidArraySizeConstant(crate::Handle), } @@ -94,8 +94,8 @@ impl super::TypeInner { } } - pub fn span(&self, constants: &super::Arena) -> u32 { - match *self { + pub fn size(&self, constants: &super::Arena) -> Result { + Ok(match *self { Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32, Self::Vector { size, @@ -119,8 +119,10 @@ impl super::TypeInner { } => { let count = match size { super::ArraySize::Constant(handle) => { - // Bad array lengths will be caught during validation. - constants[handle].to_array_length().unwrap_or(1) + let constant = constants + .try_get(handle) + .ok_or(ProcError::InvalidArraySizeConstant(handle))?; + constant.to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element super::ArraySize::Dynamic => 1, @@ -129,7 +131,7 @@ impl super::TypeInner { } Self::Struct { span, .. } => span, Self::Image { .. } | Self::Sampler { .. } => 0, - } + }) } /// Return the canoncal form of `self`, or `None` if it's already in @@ -447,7 +449,7 @@ fn test_matrix_size() { rows: crate::VectorSize::Tri, width: 4 } - .span(&constants), - 48 + .size(&constants), + Ok(48), ); } diff --git a/src/valid/function.rs b/src/valid/function.rs index d0eeab1b5c..491cee54bb 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -824,6 +824,7 @@ impl super::Validator { module: &crate::Module, mod_info: &ModuleInfo, ) -> Result> { + #[cfg_attr(not(feature = "validate"), allow(unused_mut))] let mut info = mod_info.process_function(fun, module, self.flags)?; #[cfg(feature = "validate")] @@ -842,13 +843,13 @@ impl super::Validator { #[cfg(feature = "validate")] for (index, argument) in fun.arguments.iter().enumerate() { - let ty = module.types.get_handle(argument.ty).ok_or( + let ty = module.types.get_handle(argument.ty).ok_or_else(|| { FunctionError::InvalidArgumentType { index, name: argument.name.clone().unwrap_or_default(), } - .with_span_handle(argument.ty, &module.types), - )?; + .with_span_handle(argument.ty, &module.types) + })?; match ty.inner.pointer_class() { Some(crate::StorageClass::Private) | Some(crate::StorageClass::Function) diff --git a/src/valid/mod.rs b/src/valid/mod.rs index a7e57a8667..91807e002e 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -10,7 +10,7 @@ use crate::arena::{Arena, UniqueArena}; use crate::{ arena::Handle, - proc::{InvalidBaseType, Layouter}, + proc::{LayoutError, Layouter}, FastHashSet, }; use bit_set::BitSet; @@ -127,7 +127,7 @@ pub enum ConstantError { #[derive(Clone, Debug, thiserror::Error)] pub enum ValidationError { #[error(transparent)] - Layouter(#[from] InvalidBaseType), + Layouter(#[from] LayoutError), #[error("Type {handle:?} '{name}' is invalid")] Type { handle: Handle, @@ -278,7 +278,7 @@ impl Validator { self.layouter .update(&module.types, &module.constants) .map_err(|e| { - let InvalidBaseType(handle) = e; + let handle = e.ty; ValidationError::from(e).with_span_handle(handle, &module.types) })?; diff --git a/src/valid/type.rs b/src/valid/type.rs index cda1fa64a0..570e8efcc8 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -326,7 +326,8 @@ impl super::Validator { return Err(TypeError::InvalidArrayBaseType(base)); } - let base_size = types[base].inner.span(constants); + //Note: `unwrap()` is fine, since `Layouter` goes first and calls it + let base_size = types[base].inner.size(constants).unwrap(); if stride < base_size { return Err(TypeError::InsufficientArrayStride { stride, base_size }); } @@ -478,7 +479,9 @@ impl super::Validator { }); } } - let base_size = types[member.ty].inner.span(constants); + + //Note: `unwrap()` is fine because `Layouter` goes first and checks this + let base_size = types[member.ty].inner.size(constants).unwrap(); min_offset = member.offset + base_size; if min_offset > span { return Err(TypeError::MemberOutOfBounds { From 884d44817a866595549a440b283275e951b0ae3c Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 16:45:09 -0500 Subject: [PATCH 11/15] analyzer: skip invalid expressions --- src/valid/analyzer.rs | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index ab4d77c421..65d72572ed 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -266,7 +266,10 @@ impl FunctionInfo { handle: Handle, global_use: GlobalUse, ) -> NonUniformResult { - let info = &mut self.expressions[handle.index()]; + //Note: if the expression doesn't exist, this function + // will return `None`, but the later validation of + // expressions should detect this and error properly. + let info = self.expressions.get_mut(handle.index())?; info.ref_count += 1; // mark the used global as read if let Some(global) = info.assignable_global { @@ -290,7 +293,8 @@ impl FunctionInfo { handle: Handle, assignable_global: &mut Option>, ) -> NonUniformResult { - let info = &mut self.expressions[handle.index()]; + //Note: similarly to `add_ref_impl`, this ignores invalid expressions. + let info = self.expressions.get_mut(handle.index())?; info.ref_count += 1; // propagate the assignable global up the chain, till it either hits // a value-type expression, or the assignment statement. @@ -629,7 +633,10 @@ impl FunctionInfo { S::Emit(ref range) => { let mut requirements = UniformityRequirements::empty(); for expr in range.clone() { - let req = self.expressions[expr.index()].uniformity.requirements; + let req = match self.expressions.get(expr.index()) { + Some(expr) => expr.uniformity.requirements, + None => UniformityRequirements::empty(), + }; #[cfg(feature = "validate")] if self .flags From 5d4310d5b98217442872d06e1dc2936b40fbc296 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 18:27:43 -0500 Subject: [PATCH 12/15] typifier: handle forward expression dependencies --- src/front/mod.rs | 6 +- src/proc/typifier.rs | 190 +++++++++++++++++++++--------------------- src/valid/analyzer.rs | 9 +- 3 files changed, 108 insertions(+), 97 deletions(-) diff --git a/src/front/mod.rs b/src/front/mod.rs index 3264f0c352..80dfa87ac0 100644 --- a/src/front/mod.rs +++ b/src/front/mod.rs @@ -94,7 +94,8 @@ impl Typifier { ) -> Result<(), ResolveError> { if self.resolutions.len() <= expr_handle.index() { for (eh, expr) in expressions.iter().skip(self.resolutions.len()) { - let resolution = ctx.resolve(expr, |h| &self.resolutions[h.index()])?; + //Note: the closure can't `Err` by construction + let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?; log::debug!("Resolving {:?} = {:?} : {:?}", eh, expr, resolution); self.resolutions.push(resolution); } @@ -116,7 +117,8 @@ impl Typifier { self.grow(expr_handle, expressions, ctx) } else { let expr = &expressions[expr_handle]; - let resolution = ctx.resolve(expr, |h| &self.resolutions[h.index()])?; + //Note: the closure can't `Err` by construction + let resolution = ctx.resolve(expr, |h| Ok(&self.resolutions[h.index()]))?; self.resolutions[expr_handle.index()] = resolution; Ok(()) } diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 6e237ee5a1..220936df63 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -197,6 +197,8 @@ pub enum ResolveError { GlobalVariableNotFound(Handle), #[error("Function argument {0} doesn't exist")] FunctionArgumentNotFound(u32), + #[error("Expression {0:?} depends on expressions that follow")] + ExpressionForwardDependency(Handle), } pub struct ResolveContext<'a> { @@ -227,12 +229,12 @@ impl<'a> ResolveContext<'a> { pub fn resolve( &self, expr: &crate::Expression, - past: impl Fn(Handle) -> &'a TypeResolution, + past: impl Fn(Handle) -> Result<&'a TypeResolution, ResolveError>, ) -> Result { use crate::TypeInner as Ti; let types = self.types; Ok(match *expr { - crate::Expression::Access { base, .. } => match *past(base).inner_with(types) { + crate::Expression::Access { base, .. } => match *past(base)?.inner_with(types) { // Arrays and matrices can only be indexed dynamically behind a // pointer, but that's a validation error, not a type error, so // go ahead provide a type here. @@ -299,106 +301,108 @@ impl<'a> ResolveContext<'a> { }); } }, - crate::Expression::AccessIndex { base, index } => match *past(base).inner_with(types) { - Ti::Vector { size, kind, width } => { - if index >= size as u32 { - return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); - } - TypeResolution::Value(Ti::Scalar { kind, width }) - } - Ti::Matrix { - columns, - rows, - width, - } => { - if index >= columns as u32 { - return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); - } - TypeResolution::Value(crate::TypeInner::Vector { - size: rows, - kind: crate::ScalarKind::Float, - width, - }) - } - Ti::Array { base, .. } => TypeResolution::Handle(base), - Ti::Struct { ref members, .. } => { - let member = members - .get(index as usize) - .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?; - TypeResolution::Handle(member.ty) - } - Ti::ValuePointer { - size: Some(size), - kind, - width, - class, - } => { - if index >= size as u32 { - return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); - } - TypeResolution::Value(Ti::ValuePointer { - size: None, - kind, - width, - class, - }) - } - Ti::Pointer { - base: ty_base, - class, - } => TypeResolution::Value(match types[ty_base].inner { - Ti::Array { base, .. } => Ti::Pointer { base, class }, + crate::Expression::AccessIndex { base, index } => { + match *past(base)?.inner_with(types) { Ti::Vector { size, kind, width } => { if index >= size as u32 { return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); } - Ti::ValuePointer { - size: None, - kind, - width, - class, - } + TypeResolution::Value(Ti::Scalar { kind, width }) } Ti::Matrix { - rows, columns, + rows, width, } => { if index >= columns as u32 { return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); } - Ti::ValuePointer { - size: Some(rows), + TypeResolution::Value(crate::TypeInner::Vector { + size: rows, kind: crate::ScalarKind::Float, width, - class, - } + }) } + Ti::Array { base, .. } => TypeResolution::Handle(base), Ti::Struct { ref members, .. } => { let member = members .get(index as usize) .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?; - Ti::Pointer { - base: member.ty, - class, + TypeResolution::Handle(member.ty) + } + Ti::ValuePointer { + size: Some(size), + kind, + width, + class, + } => { + if index >= size as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); } + TypeResolution::Value(Ti::ValuePointer { + size: None, + kind, + width, + class, + }) } + Ti::Pointer { + base: ty_base, + class, + } => TypeResolution::Value(match types[ty_base].inner { + Ti::Array { base, .. } => Ti::Pointer { base, class }, + Ti::Vector { size, kind, width } => { + if index >= size as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + Ti::ValuePointer { + size: None, + kind, + width, + class, + } + } + Ti::Matrix { + rows, + columns, + width, + } => { + if index >= columns as u32 { + return Err(ResolveError::OutOfBoundsIndex { expr: base, index }); + } + Ti::ValuePointer { + size: Some(rows), + kind: crate::ScalarKind::Float, + width, + class, + } + } + Ti::Struct { ref members, .. } => { + let member = members + .get(index as usize) + .ok_or(ResolveError::OutOfBoundsIndex { expr: base, index })?; + Ti::Pointer { + base: member.ty, + class, + } + } + ref other => { + log::error!("Access index sub-type {:?}", other); + return Err(ResolveError::InvalidSubAccess { + ty: ty_base, + indexed: true, + }); + } + }), ref other => { - log::error!("Access index sub-type {:?}", other); - return Err(ResolveError::InvalidSubAccess { - ty: ty_base, + log::error!("Access index type {:?}", other); + return Err(ResolveError::InvalidAccess { + expr: base, indexed: true, }); } - }), - ref other => { - log::error!("Access index type {:?}", other); - return Err(ResolveError::InvalidAccess { - expr: base, - indexed: true, - }); } - }, + } crate::Expression::Constant(h) => match self.constants[h].inner { crate::ConstantInner::Scalar { width, ref value } => { TypeResolution::Value(Ti::Scalar { @@ -408,7 +412,7 @@ impl<'a> ResolveContext<'a> { } crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty), }, - crate::Expression::Splat { size, value } => match *past(value).inner_with(types) { + crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) { Ti::Scalar { kind, width } => { TypeResolution::Value(Ti::Vector { size, kind, width }) } @@ -421,7 +425,7 @@ impl<'a> ResolveContext<'a> { size, vector, pattern: _, - } => match *past(vector).inner_with(types) { + } => match *past(vector)?.inner_with(types) { Ti::Vector { size: _, kind, @@ -464,7 +468,7 @@ impl<'a> ResolveContext<'a> { class: crate::StorageClass::Function, }) } - crate::Expression::Load { pointer } => match *past(pointer).inner_with(types) { + crate::Expression::Load { pointer } => match *past(pointer)?.inner_with(types) { Ti::Pointer { base, class: _ } => { if let Ti::Atomic { kind, width } = types[base].inner { TypeResolution::Value(Ti::Scalar { kind, width }) @@ -490,7 +494,7 @@ impl<'a> ResolveContext<'a> { image, gather: Some(_), .. - } => match *past(image).inner_with(types) { + } => match *past(image)?.inner_with(types) { Ti::Image { class, .. } => TypeResolution::Value(Ti::Vector { kind: match class { crate::ImageClass::Sampled { kind, multi: _ } => kind, @@ -505,7 +509,7 @@ impl<'a> ResolveContext<'a> { } }, crate::Expression::ImageSample { image, .. } - | crate::Expression::ImageLoad { image, .. } => match *past(image).inner_with(types) { + | crate::Expression::ImageLoad { image, .. } => match *past(image)?.inner_with(types) { Ti::Image { class, .. } => TypeResolution::Value(match class { crate::ImageClass::Depth { multi: _ } => Ti::Scalar { kind: crate::ScalarKind::Float, @@ -528,7 +532,7 @@ impl<'a> ResolveContext<'a> { } }, crate::Expression::ImageQuery { image, query } => TypeResolution::Value(match query { - crate::ImageQuery::Size { level: _ } => match *past(image).inner_with(types) { + crate::ImageQuery::Size { level: _ } => match *past(image)?.inner_with(types) { Ti::Image { dim, .. } => match dim { crate::ImageDimension::D1 => Ti::Scalar { kind: crate::ScalarKind::Sint, @@ -557,14 +561,14 @@ impl<'a> ResolveContext<'a> { width: 4, }, }), - crate::Expression::Unary { expr, .. } => past(expr).clone(), + crate::Expression::Unary { expr, .. } => past(expr)?.clone(), crate::Expression::Binary { op, left, right } => match op { crate::BinaryOperator::Add | crate::BinaryOperator::Subtract | crate::BinaryOperator::Divide - | crate::BinaryOperator::Modulo => past(left).clone(), + | crate::BinaryOperator::Modulo => past(left)?.clone(), crate::BinaryOperator::Multiply => { - let (res_left, res_right) = (past(left), past(right)); + let (res_left, res_right) = (past(left)?, past(right)?); match (res_left.inner_with(types), res_right.inner_with(types)) { ( &Ti::Matrix { @@ -623,7 +627,7 @@ impl<'a> ResolveContext<'a> { | crate::BinaryOperator::LogicalOr => { let kind = crate::ScalarKind::Bool; let width = crate::BOOL_WIDTH; - let inner = match *past(left).inner_with(types) { + let inner = match *past(left)?.inner_with(types) { Ti::Scalar { .. } => Ti::Scalar { kind, width }, Ti::Vector { size, .. } => Ti::Vector { size, kind, width }, ref other => { @@ -639,7 +643,7 @@ impl<'a> ResolveContext<'a> { | crate::BinaryOperator::ExclusiveOr | crate::BinaryOperator::InclusiveOr | crate::BinaryOperator::ShiftLeft - | crate::BinaryOperator::ShiftRight => past(left).clone(), + | crate::BinaryOperator::ShiftRight => past(left)?.clone(), }, crate::Expression::AtomicResult { kind, @@ -656,8 +660,8 @@ impl<'a> ResolveContext<'a> { TypeResolution::Value(Ti::Scalar { kind, width }) } } - crate::Expression::Select { accept, .. } => past(accept).clone(), - crate::Expression::Derivative { axis: _, expr } => past(expr).clone(), + crate::Expression::Select { accept, .. } => past(accept)?.clone(), + crate::Expression::Derivative { axis: _, expr } => past(expr)?.clone(), crate::Expression::Relational { fun, argument } => match fun { crate::RelationalFunction::All | crate::RelationalFunction::Any => { TypeResolution::Value(Ti::Scalar { @@ -668,7 +672,7 @@ impl<'a> ResolveContext<'a> { crate::RelationalFunction::IsNan | crate::RelationalFunction::IsInf | crate::RelationalFunction::IsFinite - | crate::RelationalFunction::IsNormal => match *past(argument).inner_with(types) { + | crate::RelationalFunction::IsNormal => match *past(argument)?.inner_with(types) { Ti::Scalar { .. } => TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Bool, width: crate::BOOL_WIDTH, @@ -694,7 +698,7 @@ impl<'a> ResolveContext<'a> { arg3: _, } => { use crate::MathFunction as Mf; - let res_arg = past(arg); + let res_arg = past(arg)?; match fun { // comparison Mf::Abs | @@ -748,7 +752,7 @@ impl<'a> ResolveContext<'a> { let arg1 = arg1.ok_or_else(|| ResolveError::IncompatibleOperands( format!("{:?}(_, None)", fun) ))?; - match (res_arg.inner_with(types), past(arg1).inner_with(types)) { + match (res_arg.inner_with(types), past(arg1)?.inner_with(types)) { (&Ti::Vector {kind: _, size: columns,width}, &Ti::Vector{ size: rows, .. }) => TypeResolution::Value(Ti::Matrix { columns, rows, width }), (left, right) => return Err(ResolveError::IncompatibleOperands( @@ -847,7 +851,7 @@ impl<'a> ResolveContext<'a> { expr, kind, convert, - } => match *past(expr).inner_with(types) { + } => match *past(expr)?.inner_with(types) { Ti::Scalar { kind: _, width } => TypeResolution::Value(Ti::Scalar { kind, width: convert.unwrap_or(width), diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index 65d72572ed..ee3ace7fd9 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -10,7 +10,7 @@ use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, use crate::span::{AddSpan as _, WithSpan}; use crate::{ arena::{Arena, Handle}, - proc::{ResolveContext, TypeResolution}, + proc::{ResolveContext, ResolveError, TypeResolution}, }; use std::ops; @@ -598,7 +598,12 @@ impl FunctionInfo { }, }; - let ty = resolve_context.resolve(expression, |h| &self.expressions[h.index()].ty)?; + let ty = resolve_context.resolve(expression, |h| { + self.expressions + .get(h.index()) + .map(|ei| &ei.ty) + .ok_or(ResolveError::ExpressionForwardDependency(h)) + })?; self.expressions[handle.index()] = ExpressionInfo { uniformity, ref_count: 0, From 1c78441947a4f607e5f9681ca7de88c5c8f67ee5 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 18:52:47 -0500 Subject: [PATCH 13/15] typifier: handle non-existing constants --- src/proc/typifier.rs | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 220936df63..1ed0e8fdac 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -197,6 +197,8 @@ pub enum ResolveError { GlobalVariableNotFound(Handle), #[error("Function argument {0} doesn't exist")] FunctionArgumentNotFound(u32), + #[error("Constant {0:?} doesn't exist")] + ConstantNotFound(Handle), #[error("Expression {0:?} depends on expressions that follow")] ExpressionForwardDependency(Handle), } @@ -403,15 +405,23 @@ impl<'a> ResolveContext<'a> { } } } - crate::Expression::Constant(h) => match self.constants[h].inner { - crate::ConstantInner::Scalar { width, ref value } => { - TypeResolution::Value(Ti::Scalar { - kind: value.scalar_kind(), - width, - }) + crate::Expression::Constant(h) => { + let constant = self + .constants + .try_get(h) + .ok_or(ResolveError::ConstantNotFound(h))?; + match constant.inner { + crate::ConstantInner::Scalar { width, ref value } => { + TypeResolution::Value(Ti::Scalar { + kind: value.scalar_kind(), + width, + }) + } + crate::ConstantInner::Composite { ty, components: _ } => { + TypeResolution::Handle(ty) + } } - crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty), - }, + } crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) { Ti::Scalar { kind, width } => { TypeResolution::Value(Ti::Vector { size, kind, width }) From c4e46ed40a5798db955412d499bdd3d5d18bdfcf Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 13 Jan 2022 18:57:57 -0500 Subject: [PATCH 14/15] Fix arbitrary for MSRV --- .github/workflows/pipeline.yml | 30 ++++++++++++++++++------- Cargo.toml | 41 +++++++++++++++++----------------- src/front/glsl/variables.rs | 27 +++++++++------------- src/front/spv/function.rs | 4 ++-- 4 files changed, 55 insertions(+), 47 deletions(-) diff --git a/.github/workflows/pipeline.yml b/.github/workflows/pipeline.yml index c92721c188..d7f84efda5 100644 --- a/.github/workflows/pipeline.yml +++ b/.github/workflows/pipeline.yml @@ -3,31 +3,45 @@ name: pipeline on: [push, pull_request] jobs: - test: - name: Test + test-msrv: + name: Test MSRV runs-on: ubuntu-latest - strategy: - matrix: - rust: ["1.43.0", nightly] steps: - uses: actions/checkout@v2 - uses: actions-rs/toolchain@v1 with: profile: minimal - toolchain: ${{ matrix.rust }} + toolchain: "1.43.0" override: true - uses: actions-rs/cargo@v1 name: Downgrade bitflags to MSRV - if: ${{ matrix.rust }} == "1.43.0" with: command: update args: -p bitflags --precise 1.2.1 - uses: actions-rs/cargo@v1 name: Downgrade indexmap to MSRV - if: ${{ matrix.rust }} == "1.43.0" with: command: update args: -p indexmap --precise 1.6.2 + - uses: actions-rs/cargo@v1 + name: Test all features + with: + # `cli` already enables most features, so let's add the rest, + # except for `arbitrary`, which requires Rust-1.51 + command: test + args: --workspace --features serialize,deserialize + - name: Check snapshots + run: git diff --exit-code -- tests/out + test: + name: Test Nightly + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: nightly + override: true - uses: actions-rs/cargo@v1 name: Default test with: diff --git a/Cargo.toml b/Cargo.toml index a0dcec5703..daaefb6c51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,26 +14,6 @@ resolver = "2" [package.metadata.docs.rs] all-features = true -# MSRV warnings: -# - bitflags 1.3 requires Rust-1.46 -# - indexmap 1.7 requires Rust-1.49 - -[dependencies] -arbitrary = { version = "1", features = ["derive"], optional = true } -bitflags = "1" -bit-set = "0.5" -codespan-reporting = { version = "0.11.0", optional = true } -rustc-hash = "1.1.0" -indexmap = "1.6" -log = "0.4" -num-traits = "0.2" -spirv = { version = "0.2", optional = true } -thiserror = "1.0.21" -serde = { version = "1.0", features = ["derive"], optional = true } -petgraph = { version ="0.6", optional = true } -pp-rs = { version = "0.2.1", optional = true } -hexf-parse = { version = "0.2.1", optional = true } - [profile.release] panic = "abort" @@ -57,6 +37,27 @@ hlsl-out = [] span = ["codespan-reporting"] validate = [] +# MSRV warnings: +# - arbitrary 1.0.3 requires Rust-1.51 +# - bitflags 1.3 requires Rust-1.46 +# - indexmap 1.7 requires Rust-1.49 + +[dependencies] +arbitrary = { version = "1", features = ["derive"], optional = true } +bitflags = "1" +bit-set = "0.5" +codespan-reporting = { version = "0.11.0", optional = true } +rustc-hash = "1.1.0" +indexmap = "1.6" +log = "0.4" +num-traits = "0.2" +spirv = { version = "0.2", optional = true } +thiserror = "1.0.21" +serde = { version = "1.0", features = ["derive"], optional = true } +petgraph = { version ="0.6", optional = true } +pp-rs = { version = "0.2.1", optional = true } +hexf-parse = { version = "0.2.1", optional = true } + [dev-dependencies] diff = "0.1" ron = "0.7" diff --git a/src/front/glsl/variables.rs b/src/front/glsl/variables.rs index d4043c033f..9a226916f1 100644 --- a/src/front/glsl/variables.rs +++ b/src/front/glsl/variables.rs @@ -654,21 +654,14 @@ impl Parser { &mut self, ctx: &mut Context, body: &mut Block, - #[cfg_attr(not(feature = "glsl-validate"), allow(unused_variables))] - VarDeclaration { - qualifiers, - ty, - name, - init, - meta, - }: VarDeclaration, + decl: VarDeclaration, ) -> Result> { #[cfg(feature = "glsl-validate")] - if let Some(ref name) = name { + if let Some(ref name) = decl.name { if ctx.lookup_local_var_current_scope(name).is_some() { self.errors.push(Error { kind: ErrorKind::VariableAlreadyDeclared(name.clone()), - meta, + meta: decl.meta, }) } } @@ -676,7 +669,7 @@ impl Parser { let mut mutable = true; let mut precision = None; - for &(ref qualifier, meta) in qualifiers { + for &(ref qualifier, meta) in decl.qualifiers { match *qualifier { TypeQualifier::StorageQualifier(StorageQualifier::Const) => { if !mutable { @@ -707,15 +700,15 @@ impl Parser { let handle = ctx.locals.append( LocalVariable { - name: name.clone(), - ty, - init, + name: decl.name.clone(), + ty: decl.ty, + init: decl.init, }, - meta, + decl.meta, ); - let expr = ctx.add_expression(Expression::LocalVariable(handle), meta, body); + let expr = ctx.add_expression(Expression::LocalVariable(handle), decl.meta, body); - if let Some(name) = name { + if let Some(name) = decl.name { ctx.add_local_var(name, expr, mutable); } diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index 23c47688e7..efc7853047 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -464,9 +464,9 @@ impl> super::Parser { *component = function.expressions.append(load_expr, span); } - match &members[..] { + match members[..] { [] => {} - [member] => { + [ref member] => { function.body.extend(emitter.finish(&function.expressions)); let span = function.expressions.get_span(components[0]); function.body.push( From 59a0ca930261a682182ad5e99ea7dfcdd9a08584 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 14 Jan 2022 11:33:05 -0500 Subject: [PATCH 15/15] Introduce BadHandle error --- src/arena.rs | 14 ++++++++++-- src/back/spv/index.rs | 10 ++++++--- src/back/spv/mod.rs | 2 -- src/proc/index.rs | 49 +++++++++++++++++++++++------------------ src/proc/layouter.rs | 8 +++---- src/proc/mod.rs | 19 +++++----------- src/proc/typifier.rs | 25 +++++---------------- src/valid/expression.rs | 29 +++++++----------------- src/valid/function.rs | 16 ++++++-------- src/valid/type.rs | 23 ++++++++++--------- 10 files changed, 91 insertions(+), 104 deletions(-) diff --git a/src/arena.rs b/src/arena.rs index 56f70e94d2..8b5ba84b47 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -8,6 +8,13 @@ type Index = NonZeroU32; use crate::Span; use indexmap::set::IndexSet; +#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)] +#[error("Handle {index} of {kind} is either not present, or inaccessible yet")] +pub struct BadHandle { + pub kind: &'static str, + pub index: usize, +} + /// A strongly typed reference to an arena item. /// /// A `Handle` value can be used as an index into an [`Arena`] or [`UniqueArena`]. @@ -265,8 +272,11 @@ impl Arena { self.fetch_if_or_append(value, span, T::eq) } - pub fn try_get(&self, handle: Handle) -> Option<&T> { - self.data.get(handle.index()) + pub fn try_get(&self, handle: Handle) -> Result<&T, BadHandle> { + self.data.get(handle.index()).ok_or_else(|| BadHandle { + kind: std::any::type_name::(), + index: handle.index(), + }) } /// Get a mutable reference to an element in the arena. diff --git a/src/back/spv/index.rs b/src/back/spv/index.rs index c2b198048a..8edde1f5fe 100644 --- a/src/back/spv/index.rs +++ b/src/back/spv/index.rs @@ -79,14 +79,18 @@ impl<'w> BlockContext<'w> { block: &mut Block, ) -> Result, Error> { let sequence_ty = self.fun_info[sequence].ty.inner_with(&self.ir_module.types); - match sequence_ty.indexable_length(self.ir_module)? { - crate::proc::IndexableLength::Known(known_length) => { + match sequence_ty.indexable_length(self.ir_module) { + Ok(crate::proc::IndexableLength::Known(known_length)) => { Ok(MaybeKnown::Known(known_length)) } - crate::proc::IndexableLength::Dynamic => { + Ok(crate::proc::IndexableLength::Dynamic) => { let length_id = self.write_runtime_array_length(sequence, block)?; Ok(MaybeKnown::Computed(length_id)) } + Err(err) => { + log::error!("Sequence length for {:?} failed: {}", sequence, err); + Err(Error::Validation("indexable length")) + } } } diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index 8dc574dfe8..469fa13d67 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -66,8 +66,6 @@ pub enum Error { FeatureNotImplemented(&'static str), #[error("module is not validated properly: {0}")] Validation(&'static str), - #[error(transparent)] - Proc(#[from] crate::proc::ProcError), } #[derive(Default)] diff --git a/src/proc/index.rs b/src/proc/index.rs index ae20cc5d6e..e35537bd3c 100644 --- a/src/proc/index.rs +++ b/src/proc/index.rs @@ -1,8 +1,6 @@ //! Definitions for index bounds checking. -use super::ProcError; -use crate::valid; -use crate::{Handle, UniqueArena}; +use crate::{valid, Handle, UniqueArena}; use bit_set::BitSet; /// How should code generated by Naga do bounds checks? @@ -300,6 +298,14 @@ impl GuardedIndex { } } +#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)] +pub enum IndexableLengthError { + #[error("Type is not indexable, and has no length (validation error)")] + TypeNotIndexable, + #[error("Array length constant {0:?} is invalid")] + InvalidArrayLength(Handle), +} + impl crate::TypeInner { /// Return the length of a subscriptable type. /// @@ -312,7 +318,10 @@ impl crate::TypeInner { /// The value returned is appropriate for bounds checks on subscripting. /// /// Return an error if `self` does not describe a subscriptable type at all. - pub fn indexable_length(&self, module: &crate::Module) -> Result { + pub fn indexable_length( + &self, + module: &crate::Module, + ) -> Result { use crate::TypeInner as Ti; let known_length = match *self { Ti::Vector { size, .. } => size as _, @@ -332,10 +341,10 @@ impl crate::TypeInner { Ti::Vector { size, .. } => size as _, Ti::Matrix { columns, .. } => columns as _, Ti::Array { size, .. } => return size.to_indexable_length(module), - _ => return Err(ProcError::TypeNotIndexable), + _ => return Err(IndexableLengthError::TypeNotIndexable), } } - _ => return Err(ProcError::TypeNotIndexable), + _ => return Err(IndexableLengthError::TypeNotIndexable), }; Ok(IndexableLength::Known(known_length)) } @@ -355,25 +364,23 @@ pub enum IndexableLength { } impl crate::ArraySize { - pub fn to_indexable_length(self, module: &crate::Module) -> Result { - use crate::Constant as K; + pub fn to_indexable_length( + self, + module: &crate::Module, + ) -> Result { Ok(match self { - Self::Constant(k) => match module.constants[k] { - K { - specialization: Some(_), - .. - } => { + Self::Constant(k) => { + let constant = &module.constants[k]; + if constant.specialization.is_some() { // Specializable constants are not supported as array lengths. // See valid::TypeError::UnsupportedSpecializedArrayLength. - return Err(ProcError::InvalidArraySizeConstant(k)); + return Err(IndexableLengthError::InvalidArrayLength(k)); } - ref unspecialized => { - let length = unspecialized - .to_array_length() - .ok_or(ProcError::InvalidArraySizeConstant(k))?; - IndexableLength::Known(length) - } - }, + let length = constant + .to_array_length() + .ok_or(IndexableLengthError::InvalidArrayLength(k))?; + IndexableLength::Known(length) + } Self::Dynamic => IndexableLength::Dynamic, }) } diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index 637257c17e..1f3651d8e3 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -1,4 +1,4 @@ -use crate::arena::{Arena, Handle, UniqueArena}; +use crate::arena::{Arena, BadHandle, Handle, UniqueArena}; use std::{num::NonZeroU32, ops}; pub type Alignment = NonZeroU32; @@ -37,8 +37,8 @@ pub enum TypeLayoutError { InvalidStructMemberType(u32, Handle), #[error("Zero width is not supported")] ZeroWidth, - #[error(transparent)] - Size(#[from] super::ProcError), + #[error("Array size is a bad handle")] + BadHandle(#[from] BadHandle), } #[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] @@ -95,7 +95,7 @@ impl Layouter { let size = ty .inner .size(constants) - .map_err(|error| TypeLayoutError::Size(error).with(ty_handle))?; + .map_err(|error| TypeLayoutError::BadHandle(error).with(ty_handle))?; let layout = match ty.inner { Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => TypeLayout { size, diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 79fd078ecd..015694a37f 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -8,20 +8,12 @@ mod typifier; use std::cmp::PartialEq; -pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength}; +pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; pub use layouter::{Alignment, LayoutError, Layouter, TypeLayout, TypeLayoutError}; pub use namer::{EntryPointIndex, NameKey, Namer}; pub use terminator::ensure_block_returns; pub use typifier::{ResolveContext, ResolveError, TypeResolution}; -#[derive(Clone, Copy, Debug, thiserror::Error, PartialEq)] -pub enum ProcError { - #[error("Type is not indexable, and has no length (validation error)")] - TypeNotIndexable, - #[error("Array length {0:?} is wrong kind of constant (validation error)")] - InvalidArraySizeConstant(crate::Handle), -} - impl From for super::ScalarKind { fn from(format: super::StorageFormat) -> Self { use super::{ScalarKind as Sk, StorageFormat as Sf}; @@ -94,7 +86,10 @@ impl super::TypeInner { } } - pub fn size(&self, constants: &super::Arena) -> Result { + pub fn size( + &self, + constants: &super::Arena, + ) -> Result { Ok(match *self { Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32, Self::Vector { @@ -119,9 +114,7 @@ impl super::TypeInner { } => { let count = match size { super::ArraySize::Constant(handle) => { - let constant = constants - .try_get(handle) - .ok_or(ProcError::InvalidArraySizeConstant(handle))?; + let constant = constants.try_get(handle)?; constant.to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 1ed0e8fdac..0879c1f6fc 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -1,4 +1,4 @@ -use crate::arena::{Arena, Handle, UniqueArena}; +use crate::arena::{Arena, BadHandle, Handle, UniqueArena}; use thiserror::Error; @@ -162,6 +162,8 @@ impl crate::ConstantInner { #[derive(Clone, Debug, Error, PartialEq)] pub enum ResolveError { + #[error(transparent)] + BadHandle(#[from] BadHandle), #[error("Index {index} is out of bounds for expression {expr:?}")] OutOfBoundsIndex { expr: Handle, @@ -191,14 +193,8 @@ pub enum ResolveError { FunctionReturnsVoid, #[error("Incompatible operands: {0}")] IncompatibleOperands(String), - #[error("Local var {0:?} doesn't exist")] - LocalVariableNotFound(Handle), - #[error("Global var {0:?} doesn't exist")] - GlobalVariableNotFound(Handle), #[error("Function argument {0} doesn't exist")] FunctionArgumentNotFound(u32), - #[error("Constant {0:?} doesn't exist")] - ConstantNotFound(Handle), #[error("Expression {0:?} depends on expressions that follow")] ExpressionForwardDependency(Handle), } @@ -406,10 +402,7 @@ impl<'a> ResolveContext<'a> { } } crate::Expression::Constant(h) => { - let constant = self - .constants - .try_get(h) - .ok_or(ResolveError::ConstantNotFound(h))?; + let constant = self.constants.try_get(h)?; match constant.inner { crate::ConstantInner::Scalar { width, ref value } => { TypeResolution::Value(Ti::Scalar { @@ -455,10 +448,7 @@ impl<'a> ResolveContext<'a> { TypeResolution::Handle(arg.ty) } crate::Expression::GlobalVariable(h) => { - let var = self - .global_vars - .try_get(h) - .ok_or(ResolveError::GlobalVariableNotFound(h))?; + let var = self.global_vars.try_get(h)?; if var.class == crate::StorageClass::Handle { TypeResolution::Handle(var.ty) } else { @@ -469,10 +459,7 @@ impl<'a> ResolveContext<'a> { } } crate::Expression::LocalVariable(h) => { - let var = self - .local_vars - .try_get(h) - .ok_or(ResolveError::LocalVariableNotFound(h))?; + let var = self.local_vars.try_get(h)?; TypeResolution::Value(Ti::Pointer { base: var.ty, class: crate::StorageClass::Function, diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 4e59a9ef4b..80d3c82169 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -4,8 +4,8 @@ use super::{compose::validate_compose, FunctionInfo, ShaderStages, TypeFlags}; use crate::arena::UniqueArena; use crate::{ - arena::Handle, - proc::{ProcError, ResolveError}, + arena::{BadHandle, Handle}, + proc::{IndexableLengthError, ResolveError}, }; #[derive(Clone, Debug, thiserror::Error)] @@ -17,6 +17,8 @@ pub enum ExpressionError { NotInScope, #[error("Depends on {0:?}, which has not been processed yet")] ForwardDependency(Handle), + #[error(transparent)] + BadDependency(#[from] BadHandle), #[error("Base type {0:?} is not compatible with this expression")] InvalidBaseType(Handle), #[error("Accessing with index {0:?} can't be done")] @@ -27,12 +29,6 @@ pub enum ExpressionError { IndexMustBeConstant(Handle), #[error("Function argument {0:?} doesn't exist")] FunctionArgumentDoesntExist(u32), - #[error("Constant {0:?} doesn't exist")] - ConstantDoesntExist(Handle), - #[error("Global variable {0:?} doesn't exist")] - GlobalVarDoesntExist(Handle), - #[error("Local variable {0:?} doesn't exist")] - LocalVarDoesntExist(Handle), #[error("Loading of {0:?} can't be done")] InvalidPointerType(Handle), #[error("Array length of {0:?} can't be done")] @@ -46,7 +42,7 @@ pub enum ExpressionError { #[error(transparent)] Compose(#[from] super::ComposeError), #[error(transparent)] - Proc(#[from] ProcError), + IndexableLength(#[from] IndexableLengthError), #[error("Operation {0:?} can't work with {1:?}")] InvalidUnaryOperandType(crate::UnaryOperator, Handle), #[error("Operation {0:?} can't work with {1:?} and {2:?}")] @@ -266,10 +262,7 @@ impl super::Validator { ShaderStages::all() } E::Constant(handle) => { - let _ = module - .constants - .try_get(handle) - .ok_or(ExpressionError::ConstantDoesntExist(handle))?; + let _ = module.constants.try_get(handle)?; ShaderStages::all() } E::Splat { size: _, value } => match *resolver.resolve(value)? { @@ -319,17 +312,11 @@ impl super::Validator { ShaderStages::all() } E::GlobalVariable(handle) => { - let _ = module - .global_variables - .try_get(handle) - .ok_or(ExpressionError::GlobalVarDoesntExist(handle))?; + let _ = module.global_variables.try_get(handle)?; ShaderStages::all() } E::LocalVariable(handle) => { - let _ = function - .local_variables - .try_get(handle) - .ok_or(ExpressionError::LocalVarDoesntExist(handle))?; + let _ = function.local_variables.try_get(handle)?; ShaderStages::all() } E::Load { pointer } => { diff --git a/src/valid/function.rs b/src/valid/function.rs index 491cee54bb..c6524872f0 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -1,6 +1,6 @@ -use crate::arena::Handle; #[cfg(feature = "validate")] use crate::arena::{Arena, UniqueArena}; +use crate::arena::{BadHandle, Handle}; use super::{ analyzer::{UniformityDisruptor, UniformityRequirements}, @@ -16,8 +16,8 @@ use bit_set::BitSet; #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum CallError { - #[error("Bad function")] - InvalidFunction, + #[error(transparent)] + BadHandle(#[from] BadHandle), #[error("The callee is declared after the caller")] ForwardDeclaredFunction, #[error("Argument {index} expression is invalid")] @@ -67,6 +67,8 @@ pub enum LocalVariableError { #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum FunctionError { + #[error(transparent)] + BadHandle(#[from] BadHandle), #[error("Expression {handle:?} is invalid")] Expression { handle: Handle, @@ -121,8 +123,6 @@ pub enum FunctionError { pointer: Handle, value: Handle, }, - #[error("The expression {0:?} is currupted")] - InvalidExpression(Handle), #[error("Image store parameters are invalid")] InvalidImageStore(#[source] ExpressionError), #[error("Call to {function:?} is invalid")] @@ -201,9 +201,7 @@ impl<'a> BlockContext<'a> { &self, handle: Handle, ) -> Result<&'a crate::Expression, FunctionError> { - self.expressions - .try_get(handle) - .ok_or(FunctionError::InvalidExpression(handle)) + Ok(self.expressions.try_get(handle)?) } fn resolve_type_impl( @@ -256,7 +254,7 @@ impl super::Validator { let fun = context .functions .try_get(function) - .ok_or(CallError::InvalidFunction) + .map_err(CallError::BadHandle) .map_err(WithSpan::new)?; if fun.arguments.len() != arguments.len() { return Err(CallError::ArgumentCount { diff --git a/src/valid/type.rs b/src/valid/type.rs index 570e8efcc8..b7240578ce 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -1,6 +1,6 @@ use super::Capabilities; use crate::{ - arena::{Arena, Handle, UniqueArena}, + arena::{Arena, BadHandle, Handle, UniqueArena}, proc::Alignment, }; @@ -72,6 +72,8 @@ pub enum Disalignment { #[derive(Clone, Debug, thiserror::Error)] pub enum TypeError { + #[error(transparent)] + BadHandle(#[from] BadHandle), #[error("The {0:?} scalar width {1} is not supported")] InvalidWidth(crate::ScalarKind, crate::Bytes), #[error("The {0:?} scalar width {1} is not supported for an atomic")] @@ -375,11 +377,12 @@ impl super::Validator { let sized_flag = match size { crate::ArraySize::Constant(const_handle) => { - let length_is_positive = match constants.try_get(const_handle) { - Some(&crate::Constant { + let constant = constants.try_get(const_handle)?; + let length_is_positive = match *constant { + crate::Constant { specialization: Some(_), .. - }) => { + } => { // Many of our back ends don't seem to support // specializable array lengths. If you want to try to make // this work, be sure to address all uses of @@ -389,28 +392,28 @@ impl super::Validator { const_handle, )); } - Some(&crate::Constant { + crate::Constant { inner: crate::ConstantInner::Scalar { width: _, value: crate::ScalarValue::Uint(length), }, .. - }) => length > 0, + } => length > 0, // Accept a signed integer size to avoid // requiring an explicit uint // literal. Type inference should make // this unnecessary. - Some(&crate::Constant { + crate::Constant { inner: crate::ConstantInner::Scalar { width: _, value: crate::ScalarValue::Sint(length), }, .. - }) => length > 0, - other => { - log::warn!("Array size {:?}", other); + } => length > 0, + _ => { + log::warn!("Array size {:?}", constant); return Err(TypeError::InvalidArraySizeConstant(const_handle)); } };