Skip to content

Commit

Permalink
[js] small fix to workaround formatter (#19400)
Browse files Browse the repository at this point in the history
### Description
Rename shader variable names to snake_case naming and also to avoid
formatter behaving inconsistently in win/linux.
  • Loading branch information
fs-eire committed Feb 21, 2024
1 parent 97ff17c commit 3fe2c13
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,28 +85,28 @@ const createLayerNormProgramInfo =
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')}
let offset = global_idx * uniforms.norm_size_vectorized;
var meanVector = ${fillVector('f32', components)};
var meanSquareVector = ${fillVector('f32', components)};
var mean_vector = ${fillVector('f32', components)};
var mean_square_vector = ${fillVector('f32', components)};
for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) {
let value = ${castToF32(dataType, components, 'x[h + offset]')};
meanVector += value;
meanSquareVector += value * value;
mean_vector += value;
mean_square_vector += value * value;
}
let mean = ${sumVector('meanVector', components)} / uniforms.norm_size;
let invStdDev =
inverseSqrt(${sumVector('meanSquareVector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon);
let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size;
let inv_std_dev = inverseSqrt(${
sumVector('mean_square_vector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon);
for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) {
let f32input = ${castToF32(dataType, components, 'x[j + offset]')};
let f32scale = ${castToF32(dataType, components, 'scale[j]')};
output[j + offset] = ${variables[0].type.value}((f32input - mean) * invStdDev * f32scale
output[j + offset] = ${variables[0].type.value}((f32input - mean) * inv_std_dev * f32scale
${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''}
);
}
${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''};
${hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev' : ''};
${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''};
}`;
};
const outputs = [{dims: outputShape, dataType: inputs[0].dataType}];
Expand Down

0 comments on commit 3fe2c13

Please sign in to comment.