Skip to content

Commit af72757

Browse files
author
coderfeli
committed
fix register spill for thread_reduce
1 parent 78543e6 commit af72757

File tree

1 file changed

+6
-17
lines changed

1 file changed

+6
-17
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -569,14 +569,11 @@ struct ReduceOp {
569569
using load_t = at::native::memory::aligned_vector<scalar_t, output_vec_size>;
570570

571571
// Multiple accumulators to remove dependency between unrolled loops.
572-
arg_vec_t value_list[vt0];
572+
arg_vec_t value_list;
573573

574574
#pragma unroll
575-
for (int i = 0; i < vt0; i++) {
576-
#pragma unroll
577-
for (int j = 0; j < output_vec_size; j++) {
578-
value_list[i][j] = ident;
579-
}
575+
for (int j = 0; j < output_vec_size; j++) {
576+
value_list[j] = ident;
580577
}
581578

582579
load_t values[vt0];
@@ -591,7 +588,7 @@ struct ReduceOp {
591588
for (index_t i = 0; i < vt0; i++) {
592589
#pragma unroll
593590
for (index_t j = 0; j < output_vec_size; j++) {
594-
value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx + i * stride);
591+
value_list[j] = ops.reduce(value_list[j], values[i].val[j], idx + i * stride);
595592
}
596593
}
597594
idx += stride * vt0;
@@ -616,20 +613,12 @@ struct ReduceOp {
616613
}
617614
#pragma unroll
618615
for (index_t j = 0; j < output_vec_size; j++) {
619-
value_list[i][j] = ops.reduce(value_list[i][j], values[i].val[j], idx);
616+
value_list[j] = ops.reduce(value_list[j], values[i].val[j], idx);
620617
}
621618
idx += stride;
622619
}
623620

624-
// combine accumulators
625-
#pragma unroll
626-
for (int i = 1; i < vt0; i++) {
627-
#pragma unroll
628-
for (index_t j = 0; j < output_vec_size; j++) {
629-
value_list[0][j] = ops.combine(value_list[0][j], value_list[i][j]);
630-
}
631-
}
632-
return value_list[0];
621+
return value_list;
633622
}
634623

635624
template <int output_vec_size>

0 commit comments

Comments
 (0)