diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index 39b73163794f..1ec01ae31cbb 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -37,13 +37,28 @@ Status CudnnTensor::Set(gsl::span input_dims, cudnnDataType_t dat TensorPitches pitches(input_dims); InlinedVector dims(rank); InlinedVector strides(rank); - for (int i = 0; i < rank; i++) { - dims[i] = gsl::narrow_cast(input_dims[i]); - strides[i] = gsl::narrow_cast(pitches[i]); - } - if (is_nhwc) { - std::swap(dims[1], dims[rank - 1]); - std::swap(strides[1], strides[rank - 1]); + + if (!is_nhwc) { + for (int i = 0; i < rank; i++) { + dims[i] = gsl::narrow_cast(input_dims[i]); + strides[i] = gsl::narrow_cast(pitches[i]); + } + } else { + // NHWDC <-> NCHWD + + // N + dims[0] = gsl::narrow_cast(input_dims[0]); + strides[0] = gsl::narrow_cast(pitches[0]); + + // HWD + for (int i = 1; i < rank - 1; i++) { + dims[i + 1] = gsl::narrow_cast(input_dims[i]); + strides[i + 1] = gsl::narrow_cast(pitches[i]); + } + + // C + dims[1] = input_dims[rank - 1]; + strides[1] = pitches[rank - 1]; } CUDNN_RETURN_IF_ERROR(cudnnSetTensorNdDescriptor(tensor_, dataType, static_cast(rank), dims.data(), strides.data())); return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/pool.cc b/onnxruntime/core/providers/cuda/nn/pool.cc index 09ee8855eb8b..3bf3e8fd789b 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.cc +++ b/onnxruntime/core/providers/cuda/nn/pool.cc @@ -147,8 +147,8 @@ class CudnnPoolingDescriptor final { cudnnPoolingDescriptor_t desc_; }; -template -Status Pool::ComputeInternal(OpKernelContext* context) const { +template +Status Pool::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -163,12 +163,16 @@ Status Pool::ComputeInternal(OpKernelContext* context) const auto strides = pool_attrs_.strides; if (pool_attrs_.global_pooling) { - kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); + if constexpr (Layout == LAYOUT_NCHW) { + kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); + } else if constexpr (Layout == LAYOUT_NHWC) { + kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 1); + } pads.assign(kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - auto out_channel = NHWC ? x_shape[x_dims.size() - 1] : x_shape[1]; - auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); + auto out_channel = (Layout == LAYOUT_NHWC) ? x_shape[x_dims.size() - 1] : x_shape[1]; + auto y_dims = pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, Layout == LAYOUT_NHWC); TensorShape y_shape(y_dims); Tensor* Y = context->Output(0, y_shape); // special case when there is a dim value of 0 in the shape. @@ -180,20 +184,20 @@ Status Pool::ComputeInternal(OpKernelContext* context) const TensorShapeVector x_dims_cudnn(x_dims.begin(), x_dims.end()); TensorShapeVector y_dims_cudnn(y_dims); if (kernel_shape.size() < 2) { - // cudnn only takes 4D or 5D input, so pad dimensions if needed - if (NHWC) { + // cuDNN only takes 4D or 5D input, so pad dimensions if needed + if (Layout == LAYOUT_NHWC) { x_dims_cudnn.insert(x_dims_cudnn.begin() + 1, 1); y_dims_cudnn.insert(y_dims_cudnn.begin() + 1, 1); - kernel_shape.insert(kernel_shape.begin() + 1, 1); - strides.insert(strides.begin() + 1, 1); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); } else { - x_dims_cudnn.push_back(1); - y_dims_cudnn.push_back(1); - kernel_shape.push_back(1); - strides.push_back(1); + x_dims_cudnn.insert(x_dims_cudnn.begin() + 2, 1); + y_dims_cudnn.insert(y_dims_cudnn.begin() + 2, 1); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); } - pads.insert(pads.begin() + kernel_shape.size(), 0); - pads.insert(pads.end(), 0); } cudnnPoolingMode_t mode = CUDNN_POOLING_MAX; @@ -210,8 +214,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const const auto beta = Consts::Zero; CudnnTensor x_tensor; CudnnTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); + ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); + ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); const auto input_count = x_shape.Size(); const auto output_count = y_shape.Size(); @@ -227,8 +231,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const const auto beta = Consts::Zero; CudnnTensor x_tensor; CudnnTensor y_tensor; - ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); - ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), NHWC)); + ORT_RETURN_IF_ERROR(x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); + ORT_RETURN_IF_ERROR(y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType(), Layout == LAYOUT_NHWC)); CUDNN_RETURN_IF_ERROR( PoolingForwardHelper(GetCudnnHandle(context), pooling_desc, &alpha, x_tensor, x_data, &beta, y_tensor, y_data)); @@ -237,8 +241,8 @@ Status Pool::ComputeInternal(OpKernelContext* context) const return Status::OK(); } -template -Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) const { +template +Status Pool, Layout>::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); @@ -253,12 +257,19 @@ Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) cons auto strides = this->pool_attrs_.strides; if (this->pool_attrs_.global_pooling) { - kernel_shape.assign(x_dims.begin() + 2, x_dims.end()); + // the logic below is most likely broken. Unfortunately no test runs through this case case. + // accessing x_dims.end() should result in a crash since it is OOB. + // i assume the last element is supposed to be accessed and thus used end() -1 / end() - 2. + if constexpr (Layout == LAYOUT_NCHW) { + kernel_shape.assign(x_dims.begin() + 2, x_dims.end() - 1); + } else if constexpr (Layout == LAYOUT_NHWC) { + kernel_shape.assign(x_dims.begin() + 1, x_dims.end() - 2); + } pads.assign(kernel_shape.size(), 0); strides.assign(kernel_shape.size(), 1); } - auto out_channel = NHWC ? x_shape[x_shape.NumDimensions() - 1] : x_shape[1]; - auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, NHWC); + auto out_channel = Layout == LAYOUT_NHWC ? x_shape[x_shape.NumDimensions() - 1] : x_shape[1]; + auto y_dims = this->pool_attrs_.SetOutputSize(x_shape, out_channel, &pads, Layout == LAYOUT_NHWC); Tensor* Y = context->Output(0, TensorShape(y_dims)); // special case when there is a dim value of 0 in the shape. @@ -269,17 +280,20 @@ Status Pool, NHWC>::ComputeInternal(OpKernelContext* context) cons // I is in NCHW format and the contained indices use NCHW math to compute the index auto i_dims = y_dims; - if (NHWC) { - std::swap(i_dims[1], i_dims[x_shape.NumDimensions() - 1]); + if constexpr (Layout == LAYOUT_NHWC) { + // y_dims in NHWDC format, i_dims has to be in NCHWD format. + i_dims.insert(i_dims.begin() + 1, i_dims.back()); // N*C*HWDC + i_dims.pop_back(); // NCHW } Tensor* I = context->Output(1, TensorShape(i_dims)); if (nullptr != I || !this->pool_attrs_.default_dilations) { auto i_data = nullptr == I ? nullptr : I->MutableData(); - MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, strides, pads, - this->pool_attrs_.dilations, this->pool_attrs_.storage_order, x_data, y_data, i_data); + MaxPoolWithIndex(this->Stream(context), x_shape, TensorShape(y_dims), kernel_shape, + strides, pads, this->pool_attrs_.dilations, + this->pool_attrs_.storage_order, x_data, y_data, i_data); } else { - ORT_RETURN_IF_ERROR((Pool, NHWC>::ComputeInternal(context))); + ORT_RETURN_IF_ERROR((Pool, Layout == LAYOUT_NHWC>::ComputeInternal(context))); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/nn/pool.h b/onnxruntime/core/providers/cuda/nn/pool.h index 8b5152a1565a..97f7c8b8762d 100644 --- a/onnxruntime/core/providers/cuda/nn/pool.h +++ b/onnxruntime/core/providers/cuda/nn/pool.h @@ -19,10 +19,10 @@ class Pool : public CudaKernel, public PoolBase { Status ComputeInternal(OpKernelContext* context) const override; }; -template -class Pool, NHWC> final : public Pool, NHWC> { +template +class Pool, Layout> final : public Pool, Layout> { public: - explicit Pool(const OpKernelInfo& info) : Pool, NHWC>(info) {} + explicit Pool(const OpKernelInfo& info) : Pool, Layout>(info) {} Status ComputeInternal(OpKernelContext* context) const override; }; diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 820eb3163784..5e59abc0c995 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -185,7 +185,7 @@ TEST(PoolTest, MaxPool_8_With_Index) { MaxPool_8_WithIndexTest(true, 1 /*storage_order*/); // col major } -TEST(PoolTest, MaxPool1D) { +TEST(PoolTest, MaxPool1D_case1) { OpTester test("MaxPool"); test.AddAttribute("auto_pad", ""); @@ -200,7 +200,46 @@ TEST(PoolTest, MaxPool1D) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); +} + +TEST(PoolTest, MaxPool1D_case2) { + OpTester test("MaxPool"); + // no padding + test.AddAttribute("auto_pad", "VALID"); + test.AddAttribute("strides", std::vector{1}); + test.AddAttribute("pads", vector{0, 0}); + test.AddAttribute("kernel_shape", vector{2}); + + std::vector x_vals = {1, 2, 3, 4, 5}; + std::vector x_dims = {1, 1, 5}; + // The last dim is (5-2+1)/1 = 4 + std::vector expected_dims = {1, 1, 4}; + std::vector expected_vals = {2, 3, 4, 5}; + + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); +} + +TEST(PoolTest, MaxPool1D_case3) { + OpTester test("MaxPool"); + test.AddAttribute("auto_pad", ""); + test.AddAttribute("strides", std::vector{1}); + // Pad one element + test.AddAttribute("pads", vector{0, 1}); + test.AddAttribute("kernel_shape", vector{2}); + + std::vector x_vals = {1, 2, 3, 4, 5}; + std::vector x_dims = {1, 1, 5}; + // Since we padded it, the last dim is larger compared to the case above + std::vector expected_dims = {1, 1, 5}; + std::vector expected_vals = {2, 3, 4, 5, 5}; + + test.AddInput("X", x_dims, x_vals); + test.AddOutput("Y", expected_dims, expected_vals); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); } static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { @@ -707,7 +746,7 @@ TEST(PoolTest, GlobalMaxPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, GlobalMaxPool3D) { @@ -783,7 +822,7 @@ TEST(PoolTest, GlobalMaxPool3D) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool) { @@ -864,7 +903,7 @@ TEST(PoolTest, AveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool_IncludePadPixel) { @@ -888,7 +927,7 @@ TEST(PoolTest, AveragePool_IncludePadPixel) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // test 'strides' attribute not specified @@ -907,7 +946,7 @@ TEST(PoolTest, AveragePool_DefaultStrides) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool_10_ceil1_2d) { @@ -931,7 +970,7 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); + {kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, AveragePool_19_dilation_2d) { @@ -955,7 +994,9 @@ TEST(PoolTest, AveragePool_19_dilation_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, + kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider}); } TEST(PoolTest, GlobalAveragePool) { @@ -1031,7 +1072,7 @@ TEST(PoolTest, GlobalAveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, GlobalAveragePool_Large_128) { @@ -1044,7 +1085,7 @@ TEST(PoolTest, GlobalAveragePool_Large_128) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals, /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, GlobalAveragePool_Large_256) { @@ -1057,7 +1098,7 @@ TEST(PoolTest, GlobalAveragePool_Large_256) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals, /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}); } TEST(PoolTest, LpPool) { @@ -1364,7 +1405,7 @@ TEST(PoolTest, LpPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); } // test data generated with lp_pool_test_generator.py @@ -1396,7 +1437,8 @@ TEST(PoolTest, LpPool1d) { // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060 // TensorRT does not support 1d pooling - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); y_count++; } } @@ -1428,7 +1470,7 @@ TEST(PoolTest, LpPool2d) { test.AddAttribute("kernel_shape", kernel_sizes[kernel_size_count]); test.AddOutput("Y", y_sizes[y_count], ys[y_count]); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); y_count++; } } @@ -1446,7 +1488,8 @@ TEST(PoolTest, LpPoolCeilMode) { // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060 // TensorRT does not support 1d pooling - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, GlobalLpPool) { @@ -1701,7 +1744,7 @@ TEST(PoolTest, GlobalLpPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); } TEST(PoolTest, MaxPoolDimWithZeroForN) { @@ -1719,7 +1762,7 @@ TEST(PoolTest, MaxPoolDimWithZeroForN) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}); + {kTensorrtExecutionProvider, kQnnExecutionProvider}); } } // namespace test