Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wgsl-in] implement firstTrailingBit/firstLeadingBit u32 overloads #1865

Merged
merged 2 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2806,6 +2806,30 @@ impl<'a, W: Write> Writer<'a, W> {
let extract_bits = fun == Mf::ExtractBits;
let insert_bits = fun == Mf::InsertBits;

// we might need to cast to unsigned integers since
// GLSL's findLSB / findMSB always return signed integers
let need_extra_paren = {
(fun == Mf::FindLsb || fun == Mf::FindMsb)
&& match *ctx.info[arg].ty.inner_with(&self.module.types) {
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Uint,
..
} => {
write!(self.out, "uint(")?;
true
}
crate::TypeInner::Vector {
kind: crate::ScalarKind::Uint,
size,
..
} => {
write!(self.out, "uvec{}(", size as u8)?;
true
}
_ => false,
}
};

write!(self.out, "{}(", fun_name)?;
self.write_expr(arg, ctx)?;
if let Some(arg) = arg1 {
Expand Down Expand Up @@ -2838,7 +2862,11 @@ impl<'a, W: Write> Writer<'a, W> {
self.write_expr(arg, ctx)?;
}
}
write!(self.out, ")")?
write!(self.out, ")")?;

if need_extra_paren {
write!(self.out, ")")?
}
}
// `As` is always a call.
// If `convert` is true the function name is the type
Expand Down
31 changes: 4 additions & 27 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1591,21 +1591,6 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Scalar { .. } => true,
_ => false,
};
let argument_size_suffix = match *context.resolve_type(arg) {
crate::TypeInner::Vector {
size: crate::VectorSize::Bi,
..
} => "2",
crate::TypeInner::Vector {
size: crate::VectorSize::Tri,
..
} => "3",
crate::TypeInner::Vector {
size: crate::VectorSize::Quad,
..
} => "4",
_ => "",
};

let fun_name = match fun {
// comparison
Expand Down Expand Up @@ -1705,21 +1690,13 @@ impl<W: Write> Writer<W> {
self.put_expression(arg1.unwrap(), context, false)?;
write!(self.out, ")")?;
} else if fun == Mf::FindLsb {
write!(
self.out,
"(((1 + int{}({}::ctz(",
argument_size_suffix, NAMESPACE
)?;
write!(self.out, "((({}::ctz(", NAMESPACE)?;
self.put_expression(arg, context, true)?;
write!(self.out, "))) % 33) - 1)")?;
write!(self.out, ") + 1) % 33) - 1)")?;
} else if fun == Mf::FindMsb {
write!(
self.out,
"(((1 + int{}({}::clz(",
argument_size_suffix, NAMESPACE
)?;
write!(self.out, "((({}::clz(", NAMESPACE)?;
self.put_expression(arg, context, true)?;
write!(self.out, "))) % 33) - 1)")?;
write!(self.out, ") + 1) % 33) - 1)")?
} else if fun == Mf::Unpack2x16float {
write!(self.out, "float2(as_type<half2>(")?;
self.put_expression(arg, context, false)?;
Expand Down
40 changes: 40 additions & 0 deletions src/front/glsl/builtins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,17 @@ fn inject_standard_builtins(
_ => {}
}

// we need to cast the return type of findLsb / findMsb
let mc = if kind == Sk::Uint {
match mc {
MacroCall::MathFunction(MathFunction::FindLsb) => MacroCall::FindLsbUint,
MacroCall::MathFunction(MathFunction::FindMsb) => MacroCall::FindMsbUint,
mc => mc,
}
} else {
mc
};

declaration.overloads.push(module.add_builtin(args, mc))
}
}
Expand Down Expand Up @@ -1580,6 +1591,8 @@ pub enum MacroCall {
},
ImageStore,
MathFunction(MathFunction),
FindLsbUint,
FindMsbUint,
BitfieldExtract,
BitfieldInsert,
Relational(RelationalFunction),
Expand Down Expand Up @@ -1848,6 +1861,33 @@ impl MacroCall {
Span::default(),
body,
),
mc @ (MacroCall::FindLsbUint | MacroCall::FindMsbUint) => {
let fun = match mc {
MacroCall::FindLsbUint => MathFunction::FindLsb,
MacroCall::FindMsbUint => MathFunction::FindMsb,
_ => unreachable!(),
};
let res = ctx.add_expression(
Expression::Math {
fun,
arg: args[0],
arg1: None,
arg2: None,
arg3: None,
},
Span::default(),
body,
);
ctx.add_expression(
Expression::As {
expr: res,
kind: Sk::Sint,
convert: Some(4),
},
Span::default(),
body,
)
}
MacroCall::BitfieldInsert => {
let conv_arg_2 = ctx.add_expression(
Expression::As {
Expand Down
10 changes: 5 additions & 5 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -821,13 +821,13 @@ impl<'a> ResolveContext<'a> {
Mf::CountOneBits |
Mf::ReverseBits |
Mf::ExtractBits |
Mf::InsertBits => res_arg.clone(),
Mf::InsertBits |
Mf::FindLsb |
Mf::FindMsb => match *res_arg.inner_with(types) {
Ti::Scalar { kind: _, width } =>
TypeResolution::Value(Ti::Scalar { kind: crate::ScalarKind::Sint, width }),
Ti::Vector { size, kind: _, width } =>
TypeResolution::Value(Ti::Vector { size, kind: crate::ScalarKind::Sint, width }),
Ti::Scalar { kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } =>
TypeResolution::Value(Ti::Scalar { kind, width }),
Ti::Vector { size, kind: kind @ (crate::ScalarKind::Sint | crate::ScalarKind::Uint), width } =>
TypeResolution::Value(Ti::Vector { size, kind, width }),
ref other => return Err(ResolveError::IncompatibleOperands(
format!("{:?}({:?})", fun, other)
)),
Expand Down
4 changes: 2 additions & 2 deletions tests/in/bits.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn main() {
u3 = extractBits(u3, 5u, 10u);
u4 = extractBits(u4, 5u, 10u);
i = firstTrailingBit(i);
i2 = firstTrailingBit(u2);
u2 = firstTrailingBit(u2);
i3 = firstLeadingBit(i3);
i = firstLeadingBit(u);
u = firstLeadingBit(u);
}
4 changes: 2 additions & 2 deletions tests/out/glsl/bits.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ void main() {
int _e120 = i;
i = findLSB(_e120);
uvec2 _e122 = u2_;
i2_ = findLSB(_e122);
u2_ = uvec2(findLSB(_e122));
ivec3 _e124 = i3_;
i3_ = findMSB(_e124);
uint _e126 = u;
i = findMSB(_e126);
u = uint(findMSB(_e126));
return;
}

8 changes: 4 additions & 4 deletions tests/out/msl/bits.msl
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ kernel void main_(
metal::uint4 _e116 = u4_;
u4_ = metal::extract_bits(_e116, 5u, 10u);
int _e120 = i;
i = (((1 + int(metal::ctz(_e120))) % 33) - 1);
i = (((metal::ctz(_e120) + 1) % 33) - 1);
metal::uint2 _e122 = u2_;
i2_ = (((1 + int2(metal::ctz(_e122))) % 33) - 1);
u2_ = (((metal::ctz(_e122) + 1) % 33) - 1);
metal::int3 _e124 = i3_;
i3_ = (((1 + int3(metal::clz(_e124))) % 33) - 1);
i3_ = (((metal::clz(_e124) + 1) % 33) - 1);
uint _e126 = u;
i = (((1 + int(metal::clz(_e126))) % 33) - 1);
u = (((metal::clz(_e126) + 1) % 33) - 1);
return;
}
8 changes: 4 additions & 4 deletions tests/out/spv/bits.spvasm
Original file line number Diff line number Diff line change
Expand Up @@ -155,13 +155,13 @@ OpStore %33 %110
%112 = OpExtInst %4 %1 FindILsb %111
OpStore %19 %112
%113 = OpLoad %14 %29
%114 = OpExtInst %11 %1 FindILsb %113
OpStore %21 %114
%114 = OpExtInst %14 %1 FindILsb %113
OpStore %29 %114
%115 = OpLoad %12 %23
%116 = OpExtInst %12 %1 FindSMsb %115
OpStore %23 %116
%117 = OpLoad %6 %27
%118 = OpExtInst %4 %1 FindUMsb %117
OpStore %19 %118
%118 = OpExtInst %6 %1 FindUMsb %117
OpStore %27 %118
OpReturn
OpFunctionEnd
4 changes: 2 additions & 2 deletions tests/out/wgsl/bits.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ fn main() {
let _e120 = i;
i = firstTrailingBit(_e120);
let _e122 = u2_;
i2_ = firstTrailingBit(_e122);
u2_ = firstTrailingBit(_e122);
let _e124 = i3_;
i3_ = firstLeadingBit(_e124);
let _e126 = u;
i = firstLeadingBit(_e126);
u = firstLeadingBit(_e126);
return;
}
46 changes: 23 additions & 23 deletions tests/out/wgsl/bits_glsl-frag.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -79,29 +79,29 @@ fn main_1() {
let _e232 = i4_;
i4_ = firstTrailingBit(_e232);
let _e235 = u;
i = firstTrailingBit(_e235);
let _e238 = u2_;
i2_ = firstTrailingBit(_e238);
let _e241 = u3_;
i3_ = firstTrailingBit(_e241);
let _e244 = u4_;
i4_ = firstTrailingBit(_e244);
let _e247 = i;
i = firstLeadingBit(_e247);
let _e250 = i2_;
i2_ = firstLeadingBit(_e250);
let _e253 = i3_;
i3_ = firstLeadingBit(_e253);
let _e256 = i4_;
i4_ = firstLeadingBit(_e256);
let _e259 = u;
i = firstLeadingBit(_e259);
let _e262 = u2_;
i2_ = firstLeadingBit(_e262);
let _e265 = u3_;
i3_ = firstLeadingBit(_e265);
let _e268 = u4_;
i4_ = firstLeadingBit(_e268);
i = i32(firstTrailingBit(_e235));
let _e239 = u2_;
i2_ = vec2<i32>(firstTrailingBit(_e239));
let _e243 = u3_;
i3_ = vec3<i32>(firstTrailingBit(_e243));
let _e247 = u4_;
i4_ = vec4<i32>(firstTrailingBit(_e247));
let _e251 = i;
i = firstLeadingBit(_e251);
let _e254 = i2_;
i2_ = firstLeadingBit(_e254);
let _e257 = i3_;
i3_ = firstLeadingBit(_e257);
let _e260 = i4_;
i4_ = firstLeadingBit(_e260);
let _e263 = u;
i = i32(firstLeadingBit(_e263));
let _e267 = u2_;
i2_ = vec2<i32>(firstLeadingBit(_e267));
let _e271 = u3_;
i3_ = vec3<i32>(firstLeadingBit(_e271));
let _e275 = u4_;
i4_ = vec4<i32>(firstLeadingBit(_e275));
return;
}

Expand Down