Skip to content

[fix]: permute_021 copy fix #1993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: 2.5_perf_fix
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ bool is_permute_021(TensorIteratorBase &iter) {
is_permute &= input.stride(0) == input.size(2) * input.stride(2);
is_permute &= input.stride(1) == 1;
is_permute &= input.stride(2) >= input.size(1);
is_permute &= output.is_contiguous();
is_permute &= output.is_contiguous() && !input.is_contiguous();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why shouldn't input be contiguous?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The is_permute_021 function should detect whether the input tensor has its last two dimensions in transposed order. But for input shapes of (n, 1, 1), these are inherently contiguous cases that need to be excluded from consideration

}
return is_permute;
}
Expand Down Expand Up @@ -221,8 +221,8 @@ __global__ void transpose_tile_big_kernel(const void* __restrict a, void* __rest
// Copy full tile with large loads
constexpr uint32_t row_bytes_wr = BIG_TILE_SIZE_N * sizeof(T);
constexpr uint32_t vmem_per_row_wr = row_bytes_wr / sizeof(__uint128_t);
constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr;
constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr;
constexpr uint32_t rows_per_wg_wr = BLOCK_SIZE / vmem_per_row_wr;
constexpr uint32_t wr_per_row = BIG_TILE_SIZE_K / rows_per_wg_wr;
// Make sure WG isn't too large
static_assert(wr_per_row >= 1);
const uint8_t* pc = (const uint8_t*)c + tj * BIG_TILE_SIZE_K * stride_k + ti * row_bytes_wr + current_m * out_stride_nk;
Expand Down Expand Up @@ -325,8 +325,6 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
} else if (is_permute_021(iter) && (dtype == kBFloat16 || dtype == kHalf)) {
transpose_last2dim(iter);
} else {
AT_DISPATCH_V2(
dtype, "copy_", AT_WRAP([&] {
Expand Down