@@ -569,14 +569,11 @@ struct ReduceOp {
569
569
using load_t = at::native::memory::aligned_vector<scalar_t , output_vec_size>;
570
570
571
571
// Multiple accumulators to remove dependency between unrolled loops.
572
- arg_vec_t value_list[vt0] ;
572
+ arg_vec_t value_list;
573
573
574
574
#pragma unroll
575
- for (int i = 0 ; i < vt0; i++) {
576
- #pragma unroll
577
- for (int j = 0 ; j < output_vec_size; j++) {
578
- value_list[i][j] = ident;
579
- }
575
+ for (int j = 0 ; j < output_vec_size; j++) {
576
+ value_list[j] = ident;
580
577
}
581
578
582
579
load_t values[vt0];
@@ -591,7 +588,7 @@ struct ReduceOp {
591
588
for (index_t i = 0 ; i < vt0; i++) {
592
589
#pragma unroll
593
590
for (index_t j = 0 ; j < output_vec_size; j++) {
594
- value_list[i][ j] = ops.reduce (value_list[i] [j], values[i].val [j], idx + i * stride);
591
+ value_list[j] = ops.reduce (value_list[j], values[i].val [j], idx + i * stride);
595
592
}
596
593
}
597
594
idx += stride * vt0;
@@ -616,20 +613,12 @@ struct ReduceOp {
616
613
}
617
614
#pragma unroll
618
615
for (index_t j = 0 ; j < output_vec_size; j++) {
619
- value_list[i][ j] = ops.reduce (value_list[i] [j], values[i].val [j], idx);
616
+ value_list[j] = ops.reduce (value_list[j], values[i].val [j], idx);
620
617
}
621
618
idx += stride;
622
619
}
623
620
624
- // combine accumulators
625
- #pragma unroll
626
- for (int i = 1 ; i < vt0; i++) {
627
- #pragma unroll
628
- for (index_t j = 0 ; j < output_vec_size; j++) {
629
- value_list[0 ][j] = ops.combine (value_list[0 ][j], value_list[i][j]);
630
- }
631
- }
632
- return value_list[0 ];
621
+ return value_list;
633
622
}
634
623
635
624
template <int output_vec_size>
0 commit comments