diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 794e0fbcbd4aa..59ba1236c3435 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -150,7 +150,7 @@ bool is_permute_021(TensorIteratorBase &iter) { is_permute &= input.stride(0) == input.size(2) * input.stride(2); is_permute &= input.stride(1) == 1; is_permute &= input.stride(2) >= input.size(1); - is_permute &= output.is_contiguous(); + is_permute &= output.is_contiguous() && !input.is_contiguous(); } return is_permute; } @@ -221,8 +221,8 @@ __global__ void transpose_tile_big_kernel(const void* __restrict a, void* __rest // Copy full tile with large loads constexpr uint32_t row_bytes_wr = BIG_TILE_SIZE_N * sizeof(T); constexpr uint32_t vmem_per_row_wr = row_bytes_wr / sizeof(__uint128_t); - constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr; - constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr; + constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr; + constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr; // Make sure WG isn't too large static_assert(wr_per_row >= 1); const uint8_t* pc = (const uint8_t*)c + tj * BIG_TILE_SIZE_K * stride_k + ti * row_bytes_wr + current_m * out_stride_nk; @@ -325,8 +325,6 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) { AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] { gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); }); - } else if (is_permute_021(iter) && (dtype == kBFloat16 || dtype == kHalf)) { - transpose_last2dim(iter); } else { AT_DISPATCH_V2( dtype, "copy_", AT_WRAP([&] {