From a0623d3bc8855259b95885bd3e0ffa04cabcc45b Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Tue, 26 Apr 2022 15:36:46 +0200 Subject: [PATCH 1/2] [wgsl-in] implement firstTrailingBit/firstLeadingBit u32 overloads --- src/back/glsl/mod.rs | 30 ++++++++++++++++- src/front/glsl/builtins.rs | 40 +++++++++++++++++++++++ src/proc/typifier.rs | 10 +++--- tests/in/bits.wgsl | 4 +-- tests/out/glsl/bits.main.Compute.glsl | 4 +-- tests/out/msl/bits.msl | 4 +-- tests/out/spv/bits.spvasm | 8 ++--- tests/out/wgsl/bits.wgsl | 4 +-- tests/out/wgsl/bits_glsl-frag.wgsl | 46 +++++++++++++-------------- 9 files changed, 109 insertions(+), 41 deletions(-) diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 699348f058..7ede19e4a7 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2806,6 +2806,30 @@ impl<'a, W: Write> Writer<'a, W> { let extract_bits = fun == Mf::ExtractBits; let insert_bits = fun == Mf::InsertBits; + // we might need to cast to unsigned integers since + // GLSL's findLSB / findMSB always return signed integers + let need_extra_paren = { + (fun == Mf::FindLsb || fun == Mf::FindMsb) + && match *ctx.info[arg].ty.inner_with(&self.module.types) { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => { + write!(self.out, "uint(")?; + true + } + crate::TypeInner::Vector { + kind: crate::ScalarKind::Uint, + size, + .. + } => { + write!(self.out, "uvec{}(", size as u8)?; + true + } + _ => false, + } + }; + write!(self.out, "{}(", fun_name)?; self.write_expr(arg, ctx)?; if let Some(arg) = arg1 { @@ -2838,7 +2862,11 @@ impl<'a, W: Write> Writer<'a, W> { self.write_expr(arg, ctx)?; } } - write!(self.out, ")")? + write!(self.out, ")")?; + + if need_extra_paren { + write!(self.out, ")")? + } } // `As` is always a call. // If `convert` is true the function name is the type diff --git a/src/front/glsl/builtins.rs b/src/front/glsl/builtins.rs index 6e61481382..5bee3948f9 100644 --- a/src/front/glsl/builtins.rs +++ b/src/front/glsl/builtins.rs @@ -727,6 +727,17 @@ fn inject_standard_builtins( _ => {} } + // we need to cast the return type of findLsb / findMsb + let mc = if kind == Sk::Uint { + match mc { + MacroCall::MathFunction(MathFunction::FindLsb) => MacroCall::FindLsbUint, + MacroCall::MathFunction(MathFunction::FindMsb) => MacroCall::FindMsbUint, + mc => mc, + } + } else { + mc + }; + declaration.overloads.push(module.add_builtin(args, mc)) } } @@ -1580,6 +1591,8 @@ pub enum MacroCall { }, ImageStore, MathFunction(MathFunction), + FindLsbUint, + FindMsbUint, BitfieldExtract, BitfieldInsert, Relational(RelationalFunction), @@ -1848,6 +1861,33 @@ impl MacroCall { Span::default(), body, ), + mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => { + let fun = match mc { + MacroCall::FindLsbUint => MathFunction::FindLsb, + MacroCall::FindMsbUint => MathFunction::FindMsb, + _ => unreachable!(), + }; + let res = ctx.add_expression( + Expression::Math { + fun, + arg: args[0], + arg1: None, + arg2: None, + arg3: None, + }, + Span::default(), + body, + ); + ctx.add_expression( + Expression::As { + expr: res, + kind: Sk::Sint, + convert: Some(4), + }, + Span::default(), + body, + ) + } MacroCall::BitfieldInsert => { let conv_arg_2 = ctx.add_expression( Expression::As { diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 9f3f2a4947..2d48d2ae4b 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -821,13 +821,13 @@ impl<'a> ResolveContext<'a> { Mf::CountOneBits | Mf::ReverseBits | Mf::ExtractBits | - Mf::InsertBits => res_arg.clone(), + Mf::InsertBits | Mf::FindLsb | Mf::FindMsb => match *res_arg.inner_with(types) { - Ti::Scalar { kind: _, width } => - TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Sint, width }), - Ti::Vector { size, kind: _, width } => - TypeResolution::Value(Ti::Vector { size, kind: crate::ScalarKind::Sint, width }), + Ti::Scalar { kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } => + TypeResolution::Value(Ti::Scalar { kind, width }), + Ti::Vector { size, kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } => + TypeResolution::Value(Ti::Vector { size, kind, width }), ref other => return Err(ResolveError::IncompatibleOperands( format!("{:?}({:?})", fun, other) )), diff --git a/tests/in/bits.wgsl b/tests/in/bits.wgsl index 24101d9933..1c78dae201 100644 --- a/tests/in/bits.wgsl +++ b/tests/in/bits.wgsl @@ -37,7 +37,7 @@ fn main() { u3 = extractBits(u3, 5u, 10u); u4 = extractBits(u4, 5u, 10u); i = firstTrailingBit(i); - i2 = firstTrailingBit(u2); + u2 = firstTrailingBit(u2); i3 = firstLeadingBit(i3); - i = firstLeadingBit(u); + u = firstLeadingBit(u); } diff --git a/tests/out/glsl/bits.main.Compute.glsl b/tests/out/glsl/bits.main.Compute.glsl index 9cb276a9c2..3166d01363 100644 --- a/tests/out/glsl/bits.main.Compute.glsl +++ b/tests/out/glsl/bits.main.Compute.glsl @@ -88,11 +88,11 @@ void main() { int _e120 = i; i = findLSB(_e120); uvec2 _e122 = u2_; - i2_ = findLSB(_e122); + u2_ = uvec2(findLSB(_e122)); ivec3 _e124 = i3_; i3_ = findMSB(_e124); uint _e126 = u; - i = findMSB(_e126); + u = uint(findMSB(_e126)); return; } diff --git a/tests/out/msl/bits.msl b/tests/out/msl/bits.msl index cc3a5660ae..a87b6a70dd 100644 --- a/tests/out/msl/bits.msl +++ b/tests/out/msl/bits.msl @@ -88,10 +88,10 @@ kernel void main_( int _e120 = i; i = (((1 + int(metal::ctz(_e120))) % 33) - 1); metal::uint2 _e122 = u2_; - i2_ = (((1 + int2(metal::ctz(_e122))) % 33) - 1); + u2_ = (((1 + int2(metal::ctz(_e122))) % 33) - 1); metal::int3 _e124 = i3_; i3_ = (((1 + int3(metal::clz(_e124))) % 33) - 1); uint _e126 = u; - i = (((1 + int(metal::clz(_e126))) % 33) - 1); + u = (((1 + int(metal::clz(_e126))) % 33) - 1); return; } diff --git a/tests/out/spv/bits.spvasm b/tests/out/spv/bits.spvasm index b224d04ab3..6d224eee20 100644 --- a/tests/out/spv/bits.spvasm +++ b/tests/out/spv/bits.spvasm @@ -155,13 +155,13 @@ OpStore %33 %110 %112 = OpExtInst %4 %1 FindILsb %111 OpStore %19 %112 %113 = OpLoad %14 %29 -%114 = OpExtInst %11 %1 FindILsb %113 -OpStore %21 %114 +%114 = OpExtInst %14 %1 FindILsb %113 +OpStore %29 %114 %115 = OpLoad %12 %23 %116 = OpExtInst %12 %1 FindSMsb %115 OpStore %23 %116 %117 = OpLoad %6 %27 -%118 = OpExtInst %4 %1 FindUMsb %117 -OpStore %19 %118 +%118 = OpExtInst %6 %1 FindUMsb %117 +OpStore %27 %118 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/bits.wgsl b/tests/out/wgsl/bits.wgsl index a7c3b01e4a..2bdaf6e9ff 100644 --- a/tests/out/wgsl/bits.wgsl +++ b/tests/out/wgsl/bits.wgsl @@ -82,10 +82,10 @@ fn main() { let _e120 = i; i = firstTrailingBit(_e120); let _e122 = u2_; - i2_ = firstTrailingBit(_e122); + u2_ = firstTrailingBit(_e122); let _e124 = i3_; i3_ = firstLeadingBit(_e124); let _e126 = u; - i = firstLeadingBit(_e126); + u = firstLeadingBit(_e126); return; } diff --git a/tests/out/wgsl/bits_glsl-frag.wgsl b/tests/out/wgsl/bits_glsl-frag.wgsl index 977e1ba5bc..012f51f460 100644 --- a/tests/out/wgsl/bits_glsl-frag.wgsl +++ b/tests/out/wgsl/bits_glsl-frag.wgsl @@ -79,29 +79,29 @@ fn main_1() { let _e232 = i4_; i4_ = firstTrailingBit(_e232); let _e235 = u; - i = firstTrailingBit(_e235); - let _e238 = u2_; - i2_ = firstTrailingBit(_e238); - let _e241 = u3_; - i3_ = firstTrailingBit(_e241); - let _e244 = u4_; - i4_ = firstTrailingBit(_e244); - let _e247 = i; - i = firstLeadingBit(_e247); - let _e250 = i2_; - i2_ = firstLeadingBit(_e250); - let _e253 = i3_; - i3_ = firstLeadingBit(_e253); - let _e256 = i4_; - i4_ = firstLeadingBit(_e256); - let _e259 = u; - i = firstLeadingBit(_e259); - let _e262 = u2_; - i2_ = firstLeadingBit(_e262); - let _e265 = u3_; - i3_ = firstLeadingBit(_e265); - let _e268 = u4_; - i4_ = firstLeadingBit(_e268); + i = i32(firstTrailingBit(_e235)); + let _e239 = u2_; + i2_ = vec2(firstTrailingBit(_e239)); + let _e243 = u3_; + i3_ = vec3(firstTrailingBit(_e243)); + let _e247 = u4_; + i4_ = vec4(firstTrailingBit(_e247)); + let _e251 = i; + i = firstLeadingBit(_e251); + let _e254 = i2_; + i2_ = firstLeadingBit(_e254); + let _e257 = i3_; + i3_ = firstLeadingBit(_e257); + let _e260 = i4_; + i4_ = firstLeadingBit(_e260); + let _e263 = u; + i = i32(firstLeadingBit(_e263)); + let _e267 = u2_; + i2_ = vec2(firstLeadingBit(_e267)); + let _e271 = u3_; + i3_ = vec3(firstLeadingBit(_e271)); + let _e275 = u4_; + i4_ = vec4(firstLeadingBit(_e275)); return; } From de59845fb7fc4b6292d5357a6bc66f6ae5f77394 Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Tue, 26 Apr 2022 15:46:32 +0200 Subject: [PATCH 2/2] fix MSL type issue reverts https://github.com/gfx-rs/naga/pull/1473/commits/b9162e443dd58cf95c2d7930e9c42b0e0e00ede7 --- src/back/msl/writer.rs | 31 ++++--------------------------- tests/out/msl/bits.msl | 8 ++++---- 2 files changed, 8 insertions(+), 31 deletions(-) diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index 836985846c..1e36b5a9b1 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -1591,21 +1591,6 @@ impl Writer { crate::TypeInner::Scalar { .. } => true, _ => false, }; - let argument_size_suffix = match *context.resolve_type(arg) { - crate::TypeInner::Vector { - size: crate::VectorSize::Bi, - .. - } => "2", - crate::TypeInner::Vector { - size: crate::VectorSize::Tri, - .. - } => "3", - crate::TypeInner::Vector { - size: crate::VectorSize::Quad, - .. - } => "4", - _ => "", - }; let fun_name = match fun { // comparison @@ -1705,21 +1690,13 @@ impl Writer { self.put_expression(arg1.unwrap(), context, false)?; write!(self.out, ")")?; } else if fun == Mf::FindLsb { - write!( - self.out, - "(((1 + int{}({}::ctz(", - argument_size_suffix, NAMESPACE - )?; + write!(self.out, "((({}::ctz(", NAMESPACE)?; self.put_expression(arg, context, true)?; - write!(self.out, "))) % 33) - 1)")?; + write!(self.out, ") + 1) % 33) - 1)")?; } else if fun == Mf::FindMsb { - write!( - self.out, - "(((1 + int{}({}::clz(", - argument_size_suffix, NAMESPACE - )?; + write!(self.out, "((({}::clz(", NAMESPACE)?; self.put_expression(arg, context, true)?; - write!(self.out, "))) % 33) - 1)")?; + write!(self.out, ") + 1) % 33) - 1)")? } else if fun == Mf::Unpack2x16float { write!(self.out, "float2(as_type(")?; self.put_expression(arg, context, false)?; diff --git a/tests/out/msl/bits.msl b/tests/out/msl/bits.msl index a87b6a70dd..88503320bb 100644 --- a/tests/out/msl/bits.msl +++ b/tests/out/msl/bits.msl @@ -86,12 +86,12 @@ kernel void main_( metal::uint4 _e116 = u4_; u4_ = metal::extract_bits(_e116, 5u, 10u); int _e120 = i; - i = (((1 + int(metal::ctz(_e120))) % 33) - 1); + i = (((metal::ctz(_e120) + 1) % 33) - 1); metal::uint2 _e122 = u2_; - u2_ = (((1 + int2(metal::ctz(_e122))) % 33) - 1); + u2_ = (((metal::ctz(_e122) + 1) % 33) - 1); metal::int3 _e124 = i3_; - i3_ = (((1 + int3(metal::clz(_e124))) % 33) - 1); + i3_ = (((metal::clz(_e124) + 1) % 33) - 1); uint _e126 = u; - u = (((1 + int(metal::clz(_e126))) % 33) - 1); + u = (((metal::clz(_e126) + 1) % 33) - 1); return; }