From 1db6ab2fb43218238d4c497cdebb59a14770113a Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Thu, 11 May 2023 18:09:05 +0200 Subject: [PATCH] Introduce `GlobalCtx` --- src/back/hlsl/conv.rs | 21 +++++-------- src/back/hlsl/help.rs | 7 +---- src/back/hlsl/storage.rs | 4 +-- src/back/hlsl/writer.rs | 2 +- src/back/msl/writer.rs | 40 ++++++++++++------------- src/front/glsl/functions.rs | 2 +- src/front/glsl/parser/types.rs | 5 +--- src/front/spv/function.rs | 7 +++++ src/front/spv/mod.rs | 14 +++------ src/front/wgsl/lower/construction.rs | 8 ++--- src/front/wgsl/lower/mod.rs | 24 +++++++-------- src/front/wgsl/mod.rs | 45 +++++++++++----------------- src/proc/layouter.rs | 12 +++----- src/proc/mod.rs | 23 +++++++++++--- src/valid/compose.rs | 28 ++++++++--------- src/valid/expression.rs | 3 +- src/valid/function.rs | 9 +++--- src/valid/interface.rs | 6 ++-- src/valid/mod.rs | 28 ++++++++--------- src/valid/type.rs | 20 +++++-------- 20 files changed, 138 insertions(+), 170 deletions(-) diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 5eb24962f6..69334ecb1d 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -40,11 +40,7 @@ impl crate::TypeInner { } } - pub(super) fn size_hlsl( - &self, - types: &crate::UniqueArena, - constants: &crate::Arena, - ) -> u32 { + pub(super) fn size_hlsl(&self, gctx: crate::proc::GlobalCtx) -> u32 { match *self { Self::Matrix { columns, @@ -58,26 +54,25 @@ impl crate::TypeInner { Self::Array { base, size, stride } => { let count = match size { crate::ArraySize::Constant(handle) => { - constants[handle].to_array_length().unwrap_or(1) + gctx.constants[handle].to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element crate::ArraySize::Dynamic => 1, }; - let last_el_size = types[base].inner.size_hlsl(types, constants); + let last_el_size = gctx.types[base].inner.size_hlsl(gctx); ((count - 1) * stride) + last_el_size } - _ => self.size(constants), + _ => self.size(gctx), } } /// Used to generate the name of the wrapped type constructor pub(super) fn hlsl_type_id<'a>( base: crate::Handle, - types: &crate::UniqueArena, - constants: &crate::Arena, + gctx: crate::proc::GlobalCtx, names: &'a crate::FastHashMap, ) -> Result, Error> { - Ok(match types[base].inner { + Ok(match gctx.types[base].inner { crate::TypeInner::Scalar { kind, width } => Cow::Borrowed(kind.to_hlsl_str(width)?), crate::TypeInner::Vector { size, kind, width } => Cow::Owned(format!( "{}{}", @@ -100,8 +95,8 @@ impl crate::TypeInner { .. } => Cow::Owned(format!( "array{}_{}_", - constants[size].to_array_length().unwrap(), - Self::hlsl_type_id(base, types, constants, names)? + gctx.constants[size].to_array_length().unwrap(), + Self::hlsl_type_id(base, gctx, names)? )), crate::TypeInner::Struct { .. } => { Cow::Borrowed(&names[&crate::proc::NameKey::Type(base)]) diff --git a/src/back/hlsl/help.rs b/src/back/hlsl/help.rs index 8ae9baf62d..51ec5ef4d1 100644 --- a/src/back/hlsl/help.rs +++ b/src/back/hlsl/help.rs @@ -347,12 +347,7 @@ impl<'a, W: Write> super::Writer<'a, W> { module: &crate::Module, constructor: WrappedConstructor, ) -> BackendResult { - let name = crate::TypeInner::hlsl_type_id( - constructor.ty, - &module.types, - &module.constants, - &self.names, - )?; + let name = crate::TypeInner::hlsl_type_id(constructor.ty, module.to_ctx(), &self.names)?; write!(self.out, "Construct{name}")?; Ok(()) } diff --git a/src/back/hlsl/storage.rs b/src/back/hlsl/storage.rs index 813dd73649..833a6dce96 100644 --- a/src/back/hlsl/storage.rs +++ b/src/back/hlsl/storage.rs @@ -218,7 +218,7 @@ impl super::Writer<'_, W> { self.write_wrapped_constructor_function_name(module, constructor)?; write!(self.out, "(")?; let count = module.constants[const_handle].to_array_length().unwrap(); - let stride = module.types[base].inner.size(&module.constants); + let stride = module.types[base].inner.size(module.to_ctx()); let iter = (0..count).map(|i| (TypeResolution::Handle(base), stride * i)); self.write_storage_load_sequence(module, var_handle, iter, func_ctx)?; write!(self.out, ")")?; @@ -381,7 +381,7 @@ impl super::Writer<'_, W> { writeln!(self.out, ";")?; // then iterate the stores let count = module.constants[const_handle].to_array_length().unwrap(); - let stride = module.types[base].inner.size(&module.constants); + let stride = module.types[base].inner.size(module.to_ctx()); for i in 0..count { self.temp_access_chain.push(SubAccess::Offset(i * stride)); let sv = StoreValue::TempIndex { diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index 3d0ed9cd3e..a97b5eb3b3 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -877,7 +877,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } let ty_inner = &module.types[member.ty].inner; - last_offset = member.offset + ty_inner.size_hlsl(&module.types, &module.constants); + last_offset = member.offset + ty_inner.size_hlsl(module.to_ctx()); // The indentation is only for readability write!(self.out, "{}", back::INDENT)?; diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 7e41e8a7e1..0621bbdf9e 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -77,7 +77,7 @@ const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e"; struct TypeContext<'a> { handle: Handle, - module: &'a crate::Module, + gctx: proc::GlobalCtx<'a>, names: &'a FastHashMap, access: crate::StorageAccess, binding: Option<&'a super::ResolvedBinding>, @@ -86,7 +86,7 @@ struct TypeContext<'a> { impl<'a> Display for TypeContext<'a> { fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> { - let ty = &self.module.types[self.handle]; + let ty = &self.gctx.types[self.handle]; if ty.needs_alias() && !self.first_time { let name = &self.names[&NameKey::Type(self.handle)]; return write!(out, "{name}"); @@ -223,7 +223,7 @@ impl<'a> Display for TypeContext<'a> { } else if let crate::ArraySize::Constant(size) = size { let constant_ctx = ConstantContext { handle: size, - arena: &self.module.constants, + arena: self.gctx.constants, names: self.names, first_time: false, }; @@ -271,7 +271,7 @@ impl<'a> TypedGlobalVariable<'a> { }; let ty_name = TypeContext { handle: var.ty, - module: self.module, + gctx: self.module.to_ctx(), names: self.names, access: storage_access, binding: self.binding, @@ -399,7 +399,7 @@ fn should_pack_struct_member( } let ty_inner = &module.types[member.ty].inner; - let last_offset = member.offset + ty_inner.size(&module.constants); + let last_offset = member.offset + ty_inner.size(module.to_ctx()); let next_offset = match members.get(index + 1) { Some(next) => next.offset, None => span, @@ -1153,7 +1153,7 @@ impl Writer { crate::TypeInner::Array { base, stride, .. } => ( context.module.types[base] .inner - .size(&context.module.constants), + .size(context.module.to_ctx()), stride, ), _ => return Err(Error::Validation), @@ -1336,7 +1336,7 @@ impl Writer { crate::Expression::ZeroValue(ty) => { let ty_name = TypeContext { handle: ty, - module: context.module, + gctx: context.module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -2459,7 +2459,7 @@ impl Writer { TypeResolution::Handle(ty_handle) => { let ty_name = TypeContext { handle: ty_handle, - module: context.module, + gctx: context.module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3155,7 +3155,7 @@ impl Writer { } => { let base_name = TypeContext { handle: base, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3200,7 +3200,7 @@ impl Writer { writeln!(self.out, "{}char _pad{}[{}];", back::INDENT, index, pad)?; } let ty_inner = &module.types[member.ty].inner; - last_offset = member.offset + ty_inner.size(&module.constants); + last_offset = member.offset + ty_inner.size(module.to_ctx()); let member_name = &self.names[&NameKey::StructMember(handle, index as u32)]; @@ -3219,7 +3219,7 @@ impl Writer { None => { let base_name = TypeContext { handle: member.ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3250,7 +3250,7 @@ impl Writer { _ => { let ty_name = TypeContext { handle, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3310,7 +3310,7 @@ impl Writer { let name = &self.names[&NameKey::Constant(handle)]; let ty_name = TypeContext { handle: ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3450,7 +3450,7 @@ impl Writer { Some(ref result) => { let ty_name = TypeContext { handle: result.ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3468,7 +3468,7 @@ impl Writer { let name = &self.names[&NameKey::FunctionArgument(fun_handle, index as u32)]; let param_type_name = TypeContext { handle: arg.ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3517,7 +3517,7 @@ impl Writer { for (local_handle, local) in fun.local_variables.iter() { let ty_name = TypeContext { handle: local.ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3714,7 +3714,7 @@ impl Writer { let name = &self.names[name_key]; let ty_name = TypeContext { handle: ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3758,7 +3758,7 @@ impl Writer { for (name, ty, binding) in result_members { let ty_name = TypeContext { handle: ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -3853,7 +3853,7 @@ impl Writer { let ty_name = TypeContext { handle: ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, @@ -4090,7 +4090,7 @@ impl Writer { let name = &self.names[&NameKey::EntryPointLocal(ep_index as _, local_handle)]; let ty_name = TypeContext { handle: local.ty, - module, + gctx: module.to_ctx(), names: &self.names, access: crate::StorageAccess::empty(), binding: None, diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 10c964b5e0..355b1453b6 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -1501,7 +1501,7 @@ impl Frontend { offset: span, }); - span += self.module.types[ty].inner.size(&self.module.constants); + span += self.module.types[ty].inner.size(self.module.to_ctx()); let len = expressions.len(); let load = expressions.append(Expression::Load { pointer }, Default::default()); diff --git a/src/front/glsl/parser/types.rs b/src/front/glsl/parser/types.rs index 08a70669a0..5c92b4f3d5 100644 --- a/src/front/glsl/parser/types.rs +++ b/src/front/glsl/parser/types.rs @@ -53,10 +53,7 @@ impl<'source> ParsingContext<'source> { ArraySize::Constant(constant) }; - frontend - .layouter - .update(&frontend.module.types, &frontend.module.constants) - .unwrap(); + frontend.layouter.update(frontend.module.to_ctx()).unwrap(); let stride = frontend.layouter[*ty].to_stride(); *ty = frontend.module.types.insert( Type { diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index 5dc781504e..89ad4a0e16 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -568,6 +568,13 @@ impl> super::Frontend { } impl<'function> BlockContext<'function> { + pub(super) fn gctx(&self) -> crate::proc::GlobalCtx { + crate::proc::GlobalCtx { + types: self.type_arena, + constants: self.const_arena, + } + } + /// Consumes the `BlockContext` producing a Ir [`Block`](crate::Block) fn lower(mut self) -> crate::Block { fn lower_impl( diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index e920db1451..14bae2a9cd 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -2089,7 +2089,7 @@ impl> Frontend { let result_ty = self.lookup_type.lookup(result_type_id)?; let inner = &ctx.type_arena[result_ty.handle].inner; let kind = inner.scalar_kind().unwrap(); - let size = inner.size(ctx.const_arena) as u8; + let size = inner.size(ctx.gctx()) as u8; let left_cast = ctx.expressions.append( crate::Expression::As { @@ -4387,9 +4387,7 @@ impl> Frontend { let decor = self.future_decor.remove(&id).unwrap_or_default(); let base = self.lookup_type.lookup(type_id)?.handle; - self.layouter - .update(&module.types, &module.constants) - .unwrap(); + self.layouter.update(module.to_ctx()).unwrap(); // HACK if the underlying type is an image or a sampler, let's assume // that we're dealing with a binding-array @@ -4470,9 +4468,7 @@ impl> Frontend { let decor = self.future_decor.remove(&id).unwrap_or_default(); let base = self.lookup_type.lookup(type_id)?.handle; - self.layouter - .update(&module.types, &module.constants) - .unwrap(); + self.layouter.update(module.to_ctx()).unwrap(); // HACK same case as in `parse_type_array()` let inner = if let crate::TypeInner::Image { .. } | crate::TypeInner::Sampler { .. } = @@ -4523,9 +4519,7 @@ impl> Frontend { .as_ref() .map_or(false, |decor| decor.storage_buffer); - self.layouter - .update(&module.types, &module.constants) - .unwrap(); + self.layouter.update(module.to_ctx()).unwrap(); let mut members = Vec::::with_capacity(inst.wc as usize - 2); let mut member_lookups = Vec::with_capacity(members.capacity()); diff --git a/src/front/wgsl/lower/construction.rs b/src/front/wgsl/lower/construction.rs index 0c2f29bd98..c1457b0d92 100644 --- a/src/front/wgsl/lower/construction.rs +++ b/src/front/wgsl/lower/construction.rs @@ -471,9 +471,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ctx.module.constants.fetch_or_append(size, Span::UNDEFINED), ), stride: { - self.layouter - .update(&ctx.module.types, &ctx.module.constants) - .unwrap(); + self.layouter.update(ctx.module.to_ctx()).unwrap(); self.layouter[base].to_stride() }, }; @@ -645,9 +643,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::ArraySize::Dynamic => crate::ArraySize::Dynamic, }; - self.layouter - .update(&ctx.module.types, &ctx.module.constants) - .unwrap(); + self.layouter.update(ctx.module.to_ctx()).unwrap(); let ty = ctx.ensure_type_exists(crate::TypeInner::Array { base, size, diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index e066a29f1b..6e1d2f24e6 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -343,7 +343,7 @@ impl<'a> ExpressionContext<'a, '_, '_> { } fn format_typeinner(&self, inner: &crate::TypeInner) -> String { - inner.to_wgsl(&self.module.types, &self.module.constants) + inner.to_wgsl(self.module.to_ctx()) } fn format_type(&self, handle: Handle) -> String { @@ -624,14 +624,16 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { if let Some(explicit) = explicit_ty { if explicit != inferred_type { let ty = &ctx.module.types[explicit]; - let explicit = ty.name.clone().unwrap_or_else(|| { - ty.inner.to_wgsl(&ctx.module.types, &ctx.module.constants) - }); + let explicit = ty + .name + .clone() + .unwrap_or_else(|| ty.inner.to_wgsl(ctx.module.to_ctx())); let ty = &ctx.module.types[inferred_type]; - let inferred = ty.name.clone().unwrap_or_else(|| { - ty.inner.to_wgsl(&ctx.module.types, &ctx.module.constants) - }); + let inferred = ty + .name + .clone() + .unwrap_or_else(|| ty.inner.to_wgsl(ctx.module.to_ctx())); return Err(Error::InitializationTypeMismatch( c.name.span, @@ -2065,9 +2067,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { for member in s.members.iter() { let ty = self.resolve_ast_type(member.ty, ctx.reborrow())?; - self.layouter - .update(&ctx.module.types, &ctx.module.constants) - .unwrap(); + self.layouter.update(ctx.module.to_ctx()).unwrap(); let member_min_size = self.layouter[ty].size; let member_min_alignment = self.layouter[ty].alignment; @@ -2154,9 +2154,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::Type::Array { base, size } => { let base = self.resolve_ast_type(base, ctx.reborrow())?; - self.layouter - .update(&ctx.module.types, &ctx.module.constants) - .unwrap(); + self.layouter.update(ctx.module.to_ctx()).unwrap(); crate::TypeInner::Array { base, diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index eb21fae6c9..a64f1f8f17 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -11,8 +11,6 @@ mod parse; #[cfg(test)] mod tests; -use crate::arena::{Arena, UniqueArena}; - use crate::front::wgsl::error::Error; use crate::front::wgsl::parse::Parser; use thiserror::Error; @@ -100,11 +98,7 @@ impl crate::TypeInner { /// For example `vec3`. /// /// Note: The names of a `TypeInner::Struct` is not known. Therefore this method will simply return "struct" for them. - fn to_wgsl( - &self, - types: &UniqueArena, - constants: &Arena, - ) -> String { + fn to_wgsl(&self, gctx: crate::proc::GlobalCtx) -> String { use crate::TypeInner as Ti; match *self { @@ -128,7 +122,7 @@ impl crate::TypeInner { format!("atomic<{}>", kind.to_wgsl(width)) } Ti::Pointer { base, .. } => { - let base = &types[base]; + let base = &gctx.types[base]; let name = base.name.as_deref().unwrap_or("unknown"); format!("ptr<{name}>") } @@ -136,11 +130,11 @@ impl crate::TypeInner { format!("ptr<{}>", kind.to_wgsl(width)) } Ti::Array { base, size, .. } => { - let member_type = &types[base]; + let member_type = &gctx.types[base]; let base = member_type.name.as_deref().unwrap_or("unknown"); match size { crate::ArraySize::Constant(size) => { - let constant = &constants[size]; + let constant = &gctx.constants[size]; let size = constant .name .clone() @@ -209,11 +203,11 @@ impl crate::TypeInner { Ti::AccelerationStructure => "acceleration_structure".to_string(), Ti::RayQuery => "ray_query".to_string(), Ti::BindingArray { base, size, .. } => { - let member_type = &types[base]; + let member_type = &gctx.types[base]; let base = member_type.name.as_deref().unwrap_or("unknown"); match size { crate::ArraySize::Constant(size) => { - let size = constants[size].name.as_deref().unwrap_or("unknown"); + let size = gctx.constants[size].name.as_deref().unwrap_or("unknown"); format!("binding_array<{base}, {size}>") } crate::ArraySize::Dynamic => format!("binding_array<{base}>"), @@ -261,19 +255,23 @@ mod type_inner_tests { Default::default(), ); + let gctx = crate::proc::GlobalCtx { + types: &types, + constants: &constants, + }; let array = crate::TypeInner::Array { base: mytype1, stride: 4, size: crate::ArraySize::Constant(c), }; - assert_eq!(array.to_wgsl(&types, &constants), "array"); + assert_eq!(array.to_wgsl(gctx), "array"); let mat = crate::TypeInner::Matrix { rows: crate::VectorSize::Quad, columns: crate::VectorSize::Bi, width: 8, }; - assert_eq!(mat.to_wgsl(&types, &constants), "mat2x4"); + assert_eq!(mat.to_wgsl(gctx), "mat2x4"); let ptr = crate::TypeInner::Pointer { base: mytype2, @@ -281,7 +279,7 @@ mod type_inner_tests { access: crate::StorageAccess::default(), }, }; - assert_eq!(ptr.to_wgsl(&types, &constants), "ptr"); + assert_eq!(ptr.to_wgsl(gctx), "ptr"); let img1 = crate::TypeInner::Image { dim: crate::ImageDimension::D2, @@ -291,36 +289,27 @@ mod type_inner_tests { multi: true, }, }; - assert_eq!( - img1.to_wgsl(&types, &constants), - "texture_multisampled_2d" - ); + assert_eq!(img1.to_wgsl(gctx), "texture_multisampled_2d"); let img2 = crate::TypeInner::Image { dim: crate::ImageDimension::Cube, arrayed: true, class: crate::ImageClass::Depth { multi: false }, }; - assert_eq!(img2.to_wgsl(&types, &constants), "texture_depth_cube_array"); + assert_eq!(img2.to_wgsl(gctx), "texture_depth_cube_array"); let img3 = crate::TypeInner::Image { dim: crate::ImageDimension::D2, arrayed: false, class: crate::ImageClass::Depth { multi: true }, }; - assert_eq!( - img3.to_wgsl(&types, &constants), - "texture_depth_multisampled_2d" - ); + assert_eq!(img3.to_wgsl(gctx), "texture_depth_multisampled_2d"); let array = crate::TypeInner::BindingArray { base: mytype1, size: crate::ArraySize::Constant(c), }; - assert_eq!( - array.to_wgsl(&types, &constants), - "binding_array" - ); + assert_eq!(array.to_wgsl(gctx), "binding_array"); } } diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index 65369d1cc8..11b2250e93 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -1,4 +1,4 @@ -use crate::arena::{Arena, Handle, UniqueArena}; +use crate::arena::Handle; use std::{fmt::Display, num::NonZeroU32, ops}; /// A newtype struct where its only valid values are powers of 2 @@ -165,15 +165,11 @@ impl Layouter { /// constant arenas, and then assume that layouts are available for all /// types. #[allow(clippy::or_fun_call)] - pub fn update( - &mut self, - types: &UniqueArena, - constants: &Arena, - ) -> Result<(), LayoutError> { + pub fn update(&mut self, gctx: super::GlobalCtx) -> Result<(), LayoutError> { use crate::TypeInner as Ti; - for (ty_handle, ty) in types.iter().skip(self.layouts.len()) { - let size = ty.inner.size(constants); + for (ty_handle, ty) in gctx.types.iter().skip(self.layouts.len()) { + let size = ty.inner.size(gctx); let layout = match ty.inner { Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => { let alignment = Alignment::new(width as u32) diff --git a/src/proc/mod.rs b/src/proc/mod.rs index ed501693d1..80ae567ee3 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -199,7 +199,7 @@ impl super::TypeInner { } /// Get the size of this type. - pub fn size(&self, constants: &super::Arena) -> u32 { + pub fn size(&self, gctx: GlobalCtx) -> u32 { match *self { Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32, Self::Vector { @@ -221,7 +221,7 @@ impl super::TypeInner { } => { let count = match size { super::ArraySize::Constant(handle) => { - constants[handle].to_array_length().unwrap_or(1) + gctx.constants[handle].to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element super::ArraySize::Dynamic => 1, @@ -575,16 +575,31 @@ impl super::ImageClass { } } +impl crate::Module { + pub const fn to_ctx(&self) -> GlobalCtx<'_> { + GlobalCtx { + types: &self.types, + constants: &self.constants, + } + } +} + +#[derive(Clone, Copy)] +pub struct GlobalCtx<'a> { + pub types: &'a crate::UniqueArena, + pub constants: &'a crate::Arena, +} + #[test] fn test_matrix_size() { - let constants = crate::Arena::new(); + let module = crate::Module::default(); assert_eq!( crate::TypeInner::Matrix { columns: crate::VectorSize::Tri, rows: crate::VectorSize::Tri, width: 4 } - .size(&constants), + .size(module.to_ctx()), 48, ); } diff --git a/src/valid/compose.rs b/src/valid/compose.rs index e77d538255..a7a9619794 100644 --- a/src/valid/compose.rs +++ b/src/valid/compose.rs @@ -1,8 +1,5 @@ #[cfg(feature = "validate")] -use crate::{ - arena::{Arena, UniqueArena}, - proc::TypeResolution, -}; +use crate::proc::TypeResolution; use crate::arena::Handle; @@ -20,18 +17,17 @@ pub enum ComposeError { #[cfg(feature = "validate")] pub fn validate_compose( self_ty_handle: Handle, - constant_arena: &Arena, - type_arena: &UniqueArena, + gctx: crate::proc::GlobalCtx, component_resolutions: impl ExactSizeIterator, ) -> Result<(), ComposeError> { use crate::TypeInner as Ti; - match type_arena[self_ty_handle].inner { + match gctx.types[self_ty_handle].inner { // vectors are composed from scalars or other vectors Ti::Vector { size, kind, width } => { let mut total = 0; for (index, comp_res) in component_resolutions.enumerate() { - total += match *comp_res.inner_with(type_arena) { + total += match *comp_res.inner_with(gctx.types) { Ti::Scalar { kind: comp_kind, width: comp_width, @@ -74,7 +70,7 @@ pub fn validate_compose( }); } for (index, comp_res) in component_resolutions.enumerate() { - if comp_res.inner_with(type_arena) != &inner { + if comp_res.inner_with(gctx.types) != &inner { log::error!("Matrix component[{}] type {:?}", index, comp_res); return Err(ComposeError::ComponentType { index: index as u32, @@ -87,7 +83,7 @@ pub fn validate_compose( size: crate::ArraySize::Constant(handle), stride: _, } => { - let count = constant_arena[handle].to_array_length().unwrap(); + let count = gctx.constants[handle].to_array_length().unwrap(); if count as usize != component_resolutions.len() { return Err(ComposeError::ComponentCount { expected: count, @@ -95,11 +91,11 @@ pub fn validate_compose( }); } for (index, comp_res) in component_resolutions.enumerate() { - let base_inner = &type_arena[base].inner; - let comp_res_inner = comp_res.inner_with(type_arena); + let base_inner = &gctx.types[base].inner; + let comp_res_inner = comp_res.inner_with(gctx.types); // We don't support arrays of pointers, but it seems best not to // embed that assumption here, so use `TypeInner::equivalent`. - if !base_inner.equivalent(comp_res_inner, type_arena) { + if !base_inner.equivalent(comp_res_inner, gctx.types) { log::error!("Array component[{}] type {:?}", index, comp_res); return Err(ComposeError::ComponentType { index: index as u32, @@ -116,11 +112,11 @@ pub fn validate_compose( } for (index, (member, comp_res)) in members.iter().zip(component_resolutions).enumerate() { - let member_inner = &type_arena[member.ty].inner; - let comp_res_inner = comp_res.inner_with(type_arena); + let member_inner = &gctx.types[member.ty].inner; + let comp_res_inner = comp_res.inner_with(gctx.types); // We don't support pointers in structs, but it seems best not to embed // that assumption here, so use `TypeInner::equivalent`. - if !comp_res_inner.equivalent(member_inner, type_arena) { + if !comp_res_inner.equivalent(member_inner, gctx.types) { log::error!("Struct component[{}] type {:?}", index, comp_res); return Err(ComposeError::ComponentType { index: index as u32, diff --git a/src/valid/expression.rs b/src/valid/expression.rs index ece98d70fd..b5cea3194f 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -301,8 +301,7 @@ impl super::Validator { E::Compose { ref components, ty } => { validate_compose( ty, - &module.constants, - &module.types, + module.to_ctx(), components.iter().map(|&handle| info[handle].ty.clone()), )?; ShaderStages::all() diff --git a/src/valid/function.rs b/src/valid/function.rs index 8ebcddd9e3..c18f86555d 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -894,8 +894,7 @@ impl super::Validator { fn validate_local_var( &self, var: &crate::LocalVariable, - types: &UniqueArena, - constants: &Arena, + gctx: crate::proc::GlobalCtx, ) -> Result<(), LocalVariableError> { log::debug!("var {:?}", var); let type_info = self @@ -910,13 +909,13 @@ impl super::Validator { } if let Some(const_handle) = var.init { - match constants[const_handle].inner { + match gctx.constants[const_handle].inner { crate::ConstantInner::Scalar { width, ref value } => { let ty_inner = crate::TypeInner::Scalar { width, kind: value.scalar_kind(), }; - if types[var.ty].inner != ty_inner { + if gctx.types[var.ty].inner != ty_inner { return Err(LocalVariableError::InitializerType); } } @@ -942,7 +941,7 @@ impl super::Validator { #[cfg(feature = "validate")] for (var_handle, var) in fun.local_variables.iter() { - self.validate_local_var(var, &module.types, &module.constants) + self.validate_local_var(var, module.to_ctx()) .map_err(|source| { FunctionError::LocalVariable { handle: var_handle, diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 1fafef2f66..47e2f201c8 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -395,12 +395,12 @@ impl super::Validator { pub(super) fn validate_global_var( &self, var: &crate::GlobalVariable, - types: &UniqueArena, + gctx: crate::proc::GlobalCtx, ) -> Result<(), GlobalVariableError> { use super::TypeFlags; log::debug!("var {:?}", var); - let inner_ty = match types[var.ty].inner { + let inner_ty = match gctx.types[var.ty].inner { // A binding array is (mostly) supposed to behave the same as a // series of individually bound resources, so we can (mostly) // validate a `binding_array` as if it were just a plain `T`. @@ -444,7 +444,7 @@ impl super::Validator { ) } crate::AddressSpace::Handle => { - match types[inner_ty].inner { + match gctx.types[inner_ty].inner { crate::TypeInner::Image { class, .. } => match class { crate::ImageClass::Storage { format: diff --git a/src/valid/mod.rs b/src/valid/mod.rs index bed067db8f..3e689160b5 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -11,7 +11,7 @@ mod interface; mod r#type; #[cfg(feature = "validate")] -use crate::arena::{Arena, UniqueArena}; +use crate::arena::UniqueArena; use crate::{ arena::Handle, @@ -296,10 +296,9 @@ impl Validator { fn validate_constant( &self, handle: Handle, - constants: &Arena, - types: &UniqueArena, + gctx: crate::proc::GlobalCtx, ) -> Result<(), ConstantError> { - let con = &constants[handle]; + let con = &gctx.constants[handle]; match con.inner { crate::ConstantInner::Scalar { width, ref value } => { if self.check_width(value.scalar_kind(), width).is_err() { @@ -309,11 +308,10 @@ impl Validator { crate::ConstantInner::Composite { ty, ref components } => { compose::validate_compose( ty, - constants, - types, + gctx, components .iter() - .map(|&component| constants[component].inner.resolve_type()), + .map(|&component| gctx.constants[component].inner.resolve_type()), )?; } } @@ -331,17 +329,15 @@ impl Validator { #[cfg(feature = "validate")] Self::validate_module_handles(module).map_err(|e| e.with_span())?; - self.layouter - .update(&module.types, &module.constants) - .map_err(|e| { - let handle = e.ty; - ValidationError::from(e).with_span_handle(handle, &module.types) - })?; + self.layouter.update(module.to_ctx()).map_err(|e| { + let handle = e.ty; + ValidationError::from(e).with_span_handle(handle, &module.types) + })?; #[cfg(feature = "validate")] if self.flags.contains(ValidationFlags::CONSTANTS) { for (handle, constant) in module.constants.iter() { - self.validate_constant(handle, &module.constants, &module.types) + self.validate_constant(handle, module.to_ctx()) .map_err(|source| { ValidationError::Constant { handle, @@ -361,7 +357,7 @@ impl Validator { for (handle, ty) in module.types.iter() { let ty_info = self - .validate_type(handle, &module.types, &module.constants) + .validate_type(handle, module.to_ctx()) .map_err(|source| { ValidationError::Type { handle, @@ -376,7 +372,7 @@ impl Validator { #[cfg(feature = "validate")] for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var(var, &module.types) + self.validate_global_var(var, module.to_ctx()) .map_err(|source| { ValidationError::GlobalVariable { handle: var_handle, diff --git a/src/valid/type.rs b/src/valid/type.rs index 1838ce88ab..e122fec44f 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -1,8 +1,5 @@ use super::Capabilities; -use crate::{ - arena::{Arena, Handle, UniqueArena}, - proc::Alignment, -}; +use crate::{arena::Handle, proc::Alignment}; bitflags::bitflags! { /// Flags associated with [`Type`]s by [`Validator`]. @@ -246,11 +243,10 @@ impl super::Validator { pub(super) fn validate_type( &self, handle: Handle, - types: &UniqueArena, - constants: &Arena, + gctx: crate::proc::GlobalCtx, ) -> Result { use crate::TypeInner as Ti; - Ok(match types[handle].inner { + Ok(match gctx.types[handle].inner { Ti::Scalar { kind, width } => { self.check_width(kind, width)?; let shareable = if kind.is_numeric() { @@ -419,7 +415,7 @@ impl super::Validator { let type_info_mask = match size { crate::ArraySize::Constant(const_handle) => { - let constant = &constants[const_handle]; + let constant = &gctx.constants[const_handle]; let length_is_positive = match *constant { crate::Constant { specialization: Some(_), @@ -535,7 +531,7 @@ impl super::Validator { } } - let base_size = types[member.ty].inner.size(constants); + let base_size = gctx.types[member.ty].inner.size(gctx); min_offset = member.offset + base_size; if min_offset > span { return Err(TypeError::MemberOutOfBounds { @@ -579,14 +575,14 @@ impl super::Validator { } }; - prev_struct_data = match types[member.ty].inner { + prev_struct_data = match gctx.types[member.ty].inner { crate::TypeInner::Struct { span, .. } => Some((span, member.offset)), _ => None, }; // The last field may be an unsized array. if !base_info.flags.contains(TypeFlags::SIZED) { - let is_array = match types[member.ty].inner { + let is_array = match gctx.types[member.ty].inner { crate::TypeInner::Array { .. } => true, _ => false, }; @@ -635,7 +631,7 @@ impl super::Validator { if base_info.flags.contains(TypeFlags::DATA) { // Currently Naga only supports binding arrays of structs for non-handle types. - match types[base].inner { + match gctx.types[base].inner { crate::TypeInner::Struct { .. } => {} _ => return Err(TypeError::BindingArrayBaseTypeNotStruct(base)), };