Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hlsl-out] Fix return type for firstbitlow/high #2315

Merged
merged 2 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2574,6 +2574,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Unpack2x16float,
Regular(&'static str),
MissingIntOverload(&'static str),
MissingIntReturnType(&'static str),
CountTrailingZeros,
CountLeadingZeros,
}
Expand Down Expand Up @@ -2642,8 +2643,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::CountLeadingZeros => Function::CountLeadingZeros,
Mf::CountOneBits => Function::MissingIntOverload("countbits"),
Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
Mf::FindLsb => Function::Regular("firstbitlow"),
Mf::FindMsb => Function::Regular("firstbithigh"),
Mf::FindLsb => Function::MissingIntReturnType("firstbitlow"),
Mf::FindMsb => Function::MissingIntReturnType("firstbithigh"),
Mf::Unpack2x16float => Function::Unpack2x16float,
_ => return Err(Error::Unimplemented(format!("write_expr_math {fun:?}"))),
};
Expand Down Expand Up @@ -2707,6 +2708,21 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ")")?;
}
}
Function::MissingIntReturnType(fun_name) => {
let scalar_kind = &func_ctx.info[arg]
.ty
.inner_with(&module.types)
.scalar_kind();
if let Some(ScalarKind::Sint) = *scalar_kind {
write!(self.out, "asint({fun_name}(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
write!(self.out, "{fun_name}(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")")?;
}
}
Function::CountTrailingZeros => {
match *func_ctx.info[arg].ty.inner_with(&module.types) {
TypeInner::Vector { size, kind, .. } => {
Expand All @@ -2721,9 +2737,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
write!(self.out, "asint(min((32u){s}, asuint(firstbitlow(")?;
write!(self.out, "asint(min((32u){s}, firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))))")?;
write!(self.out, ")))")?;
}
}
TypeInner::Scalar { kind, .. } => {
Expand All @@ -2732,9 +2748,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
write!(self.out, "asint(min(32u, asuint(firstbitlow(")?;
write!(self.out, "asint(min(32u, firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))))")?;
write!(self.out, ")))")?;
}
}
_ => unreachable!(),
Expand All @@ -2752,31 +2768,36 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
};

if let ScalarKind::Uint = kind {
write!(self.out, "asuint((31){s} - firstbithigh(")?;
write!(self.out, "((31u){s} - firstbithigh(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
write!(
self.out,
" < (0){s} ? (0){s} : (31){s} - firstbithigh("
" < (0){s} ? (0){s} : (31){s} - asint(firstbithigh("
)?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
}
}
TypeInner::Scalar { kind, .. } => {
if let ScalarKind::Uint = kind {
write!(self.out, "asuint(31 - firstbithigh(")?;
write!(self.out, "(31u - firstbithigh(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " < 0 ? 0 : 31 - firstbithigh(")?;
write!(self.out, " < 0 ? 0 : 31 - asint(firstbithigh(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ")))")?;
}
}
_ => unreachable!(),
}

self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;

return Ok(());
}
}
Expand Down
7 changes: 7 additions & 0 deletions tests/in/math-functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ fn main() {
let g = refract(v, v, f);
let const_dot = dot(vec2<i32>(), vec2<i32>());
let first_leading_bit_abs = firstLeadingBit(abs(0u));
let flb_a = firstLeadingBit(-1);
let flb_b = firstLeadingBit(vec2(-1));
let flb_c = firstLeadingBit(vec2(1u));
let ftb_a = firstTrailingBit(-1);
let ftb_b = firstTrailingBit(1u);
let ftb_c = firstTrailingBit(vec2(-1));
let ftb_d = firstTrailingBit(vec2(1u));
let ctz_a = countTrailingZeros(0u);
let ctz_b = countTrailingZeros(0);
let ctz_c = countTrailingZeros(0xFFFFFFFFu);
Expand Down
11 changes: 9 additions & 2 deletions tests/out/glsl/math-functions.main.Fragment.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ void main() {
vec4 g = refract(v, v, 1.0);
int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y);
uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u)))));
int flb_a = findMSB(-1);
ivec2 flb_b = findMSB(ivec2(-1));
uvec2 flb_c = uvec2(findMSB(uvec2(1u)));
int ftb_a = findLSB(-1);
uint ftb_b = uint(findLSB(1u));
ivec2 ftb_c = findLSB(ivec2(-1));
uvec2 ftb_d = uvec2(findLSB(uvec2(1u)));
uint ctz_a = min(uint(findLSB(0u)), 32u);
int ctz_b = int(min(uint(findLSB(0)), 32u));
uint ctz_c = min(uint(findLSB(4294967295u)), 32u);
Expand All @@ -24,8 +31,8 @@ void main() {
ivec2 ctz_h = ivec2(min(uvec2(findLSB(ivec2(1))), uvec2(32u)));
int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1));
uint clz_b = uint(31 - findMSB(1u));
ivec2 _e40 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e40), ivec2(0), lessThan(_e40, ivec2(0)));
ivec2 _e58 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e58), ivec2(0), lessThan(_e58, ivec2(0)));
uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u)));
}

25 changes: 16 additions & 9 deletions tests/out/hlsl/math-functions.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@ void main()
float4 g = refract(v, v, 1.0);
int const_dot = dot(int2(0, 0), int2(0, 0));
uint first_leading_bit_abs = firstbithigh(abs(0u));
int flb_a = asint(firstbithigh(-1));
int2 flb_b = asint(firstbithigh((-1).xx));
uint2 flb_c = firstbithigh((1u).xx);
int ftb_a = asint(firstbitlow(-1));
uint ftb_b = firstbitlow(1u);
int2 ftb_c = asint(firstbitlow((-1).xx));
uint2 ftb_d = firstbitlow((1u).xx);
uint ctz_a = min(32u, firstbitlow(0u));
int ctz_b = asint(min(32u, asuint(firstbitlow(0))));
int ctz_b = asint(min(32u, firstbitlow(0)));
uint ctz_c = min(32u, firstbitlow(4294967295u));
int ctz_d = asint(min(32u, asuint(firstbitlow(-1))));
int ctz_d = asint(min(32u, firstbitlow(-1)));
uint2 ctz_e = min((32u).xx, firstbitlow((0u).xx));
int2 ctz_f = asint(min((32u).xx, asuint(firstbitlow((0).xx))));
int2 ctz_f = asint(min((32u).xx, firstbitlow((0).xx)));
uint2 ctz_g = min((32u).xx, firstbitlow((1u).xx));
int2 ctz_h = asint(min((32u).xx, asuint(firstbitlow((1).xx))));
int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1));
uint clz_b = asuint(31 - firstbithigh(1u));
int2 _expr40 = (-1).xx;
int2 clz_c = (_expr40 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr40));
uint2 clz_d = asuint((31).xx - firstbithigh((1u).xx));
int2 ctz_h = asint(min((32u).xx, firstbitlow((1).xx)));
int clz_a = (-1 < 0 ? 0 : 31 - asint(firstbithigh(-1)));
uint clz_b = (31u - firstbithigh(1u));
int2 _expr58 = (-1).xx;
int2 clz_c = (_expr58 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr58)));
uint2 clz_d = ((31u).xx - firstbithigh((1u).xx));
}
9 changes: 9 additions & 0 deletions tests/out/msl/math-functions.msl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@ fragment void main_(
int const_dot = ( + const_type_1_.x * const_type_1_.x + const_type_1_.y * const_type_1_.y);
uint _e13 = metal::abs(0u);
uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1);
int flb_a = metal::select(31 - metal::clz(metal::select(-1, ~-1, -1 < 0)), int(-1), -1 == 0 || -1 == -1);
metal::int2 _e18 = metal::int2(-1);
metal::int2 flb_b = metal::select(31 - metal::clz(metal::select(_e18, ~_e18, _e18 < 0)), int2(-1), _e18 == 0 || _e18 == -1);
metal::uint2 _e21 = metal::uint2(1u);
metal::uint2 flb_c = metal::select(31 - metal::clz(_e21), uint2(-1), _e21 == 0 || _e21 == -1);
int ftb_a = (((metal::ctz(-1) + 1) % 33) - 1);
uint ftb_b = (((metal::ctz(1u) + 1) % 33) - 1);
metal::int2 ftb_c = (((metal::ctz(metal::int2(-1)) + 1) % 33) - 1);
metal::uint2 ftb_d = (((metal::ctz(metal::uint2(1u)) + 1) % 33) - 1);
uint ctz_a = metal::ctz(0u);
int ctz_b = metal::ctz(0);
uint ctz_c = metal::ctz(4294967295u);
Expand Down
89 changes: 50 additions & 39 deletions tests/out/spv/math-functions.spvasm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 76
; Bound: 87
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
Expand All @@ -15,9 +15,9 @@ OpExecutionMode %18 OriginUpperLeft
%6 = OpConstant %7 0
%9 = OpTypeInt 32 0
%8 = OpConstant %9 0
%10 = OpConstant %9 4294967295
%11 = OpConstant %7 -1
%12 = OpConstant %9 1
%10 = OpConstant %7 -1
%11 = OpConstant %9 1
%12 = OpConstant %9 4294967295
%13 = OpConstant %7 1
%14 = OpTypeVector %4 4
%15 = OpTypeVector %7 2
Expand All @@ -26,11 +26,11 @@ OpExecutionMode %18 OriginUpperLeft
%27 = OpConstantComposite %14 %5 %5 %5 %5
%28 = OpConstantComposite %14 %3 %3 %3 %3
%31 = OpConstantNull %7
%42 = OpConstant %9 32
%50 = OpTypeVector %9 2
%53 = OpConstantComposite %50 %42 %42
%65 = OpConstant %7 31
%71 = OpConstantComposite %15 %65 %65
%44 = OpTypeVector %9 2
%54 = OpConstant %9 32
%64 = OpConstantComposite %44 %54 %54
%76 = OpConstant %7 31
%82 = OpConstantComposite %15 %76 %76
%18 = OpFunction %2 None %19
%17 = OpLabel
OpBranch %20
Expand All @@ -52,35 +52,46 @@ OpBranch %20
%30 = OpIAdd %7 %35 %38
%39 = OpCopyObject %9 %8
%40 = OpExtInst %9 %1 FindUMsb %39
%43 = OpExtInst %9 %1 FindILsb %8
%41 = OpExtInst %9 %1 UMin %42 %43
%45 = OpExtInst %7 %1 FindILsb %6
%44 = OpExtInst %7 %1 UMin %42 %45
%47 = OpExtInst %9 %1 FindILsb %10
%46 = OpExtInst %9 %1 UMin %42 %47
%49 = OpExtInst %7 %1 FindILsb %11
%48 = OpExtInst %7 %1 UMin %42 %49
%51 = OpCompositeConstruct %50 %8 %8
%54 = OpExtInst %50 %1 FindILsb %51
%52 = OpExtInst %50 %1 UMin %53 %54
%55 = OpCompositeConstruct %15 %6 %6
%57 = OpExtInst %15 %1 FindILsb %55
%56 = OpExtInst %15 %1 UMin %53 %57
%58 = OpCompositeConstruct %50 %12 %12
%60 = OpExtInst %50 %1 FindILsb %58
%59 = OpExtInst %50 %1 UMin %53 %60
%61 = OpCompositeConstruct %15 %13 %13
%63 = OpExtInst %15 %1 FindILsb %61
%62 = OpExtInst %15 %1 UMin %53 %63
%66 = OpExtInst %7 %1 FindUMsb %11
%64 = OpISub %7 %65 %66
%68 = OpExtInst %7 %1 FindUMsb %12
%67 = OpISub %9 %65 %68
%69 = OpCompositeConstruct %15 %11 %11
%72 = OpExtInst %15 %1 FindUMsb %69
%70 = OpISub %15 %71 %72
%73 = OpCompositeConstruct %50 %12 %12
%75 = OpExtInst %15 %1 FindUMsb %73
%74 = OpISub %50 %71 %75
%41 = OpExtInst %7 %1 FindSMsb %10
%42 = OpCompositeConstruct %15 %10 %10
%43 = OpExtInst %15 %1 FindSMsb %42
%45 = OpCompositeConstruct %44 %11 %11
%46 = OpExtInst %44 %1 FindUMsb %45
%47 = OpExtInst %7 %1 FindILsb %10
%48 = OpExtInst %9 %1 FindILsb %11
%49 = OpCompositeConstruct %15 %10 %10
%50 = OpExtInst %15 %1 FindILsb %49
%51 = OpCompositeConstruct %44 %11 %11
%52 = OpExtInst %44 %1 FindILsb %51
%55 = OpExtInst %9 %1 FindILsb %8
%53 = OpExtInst %9 %1 UMin %54 %55
%57 = OpExtInst %7 %1 FindILsb %6
%56 = OpExtInst %7 %1 UMin %54 %57
%59 = OpExtInst %9 %1 FindILsb %12
%58 = OpExtInst %9 %1 UMin %54 %59
%61 = OpExtInst %7 %1 FindILsb %10
%60 = OpExtInst %7 %1 UMin %54 %61
%62 = OpCompositeConstruct %44 %8 %8
%65 = OpExtInst %44 %1 FindILsb %62
%63 = OpExtInst %44 %1 UMin %64 %65
%66 = OpCompositeConstruct %15 %6 %6
%68 = OpExtInst %15 %1 FindILsb %66
%67 = OpExtInst %15 %1 UMin %64 %68
%69 = OpCompositeConstruct %44 %11 %11
%71 = OpExtInst %44 %1 FindILsb %69
%70 = OpExtInst %44 %1 UMin %64 %71
%72 = OpCompositeConstruct %15 %13 %13
%74 = OpExtInst %15 %1 FindILsb %72
%73 = OpExtInst %15 %1 UMin %64 %74
%77 = OpExtInst %7 %1 FindUMsb %10
%75 = OpISub %7 %76 %77
%79 = OpExtInst %7 %1 FindUMsb %11
%78 = OpISub %9 %76 %79
%80 = OpCompositeConstruct %15 %10 %10
%83 = OpExtInst %15 %1 FindUMsb %80
%81 = OpISub %15 %82 %83
%84 = OpCompositeConstruct %44 %11 %11
%86 = OpExtInst %15 %1 FindUMsb %84
%85 = OpISub %44 %82 %86
OpReturn
OpFunctionEnd
7 changes: 7 additions & 0 deletions tests/out/wgsl/math-functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ fn main() {
let g = refract(v, v, 1.0);
let const_dot = dot(vec2<i32>(0, 0), vec2<i32>(0, 0));
let first_leading_bit_abs = firstLeadingBit(abs(0u));
let flb_a = firstLeadingBit(-1);
let flb_b = firstLeadingBit(vec2<i32>(-1));
let flb_c = firstLeadingBit(vec2<u32>(1u));
let ftb_a = firstTrailingBit(-1);
let ftb_b = firstTrailingBit(1u);
let ftb_c = firstTrailingBit(vec2<i32>(-1));
let ftb_d = firstTrailingBit(vec2<u32>(1u));
let ctz_a = countTrailingZeros(0u);
let ctz_b = countTrailingZeros(0);
let ctz_c = countTrailingZeros(4294967295u);
Expand Down