From a6164dc68cb9a2911eb2971b4a7c9b750f14385a Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 26 Oct 2018 15:14:10 -0700 Subject: [PATCH 01/10] add env variable to choose deterministic cudnn alg --- src/operator/nn/cudnn/cudnn_convolution-inl.h | 2 ++ src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 2 ++ src/operator/nn/cudnn/cudnn_pooling-inl.h | 3 ++- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index d63d46821edc..72c9a39199e8 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -889,6 +889,7 @@ class CuDNNConvolutionOp { size_t workspace_byte, CuDNNAlgo *algo) { // Determine the fastest acceptable algo that matches the algo_preference (-1 = any), // regardless of mathType. + bool enforce_determinism = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0); for (decltype(perf_results.size()) i = 0; i != perf_results.size(); ++i) { const auto &result = perf_results[i]; bool algo_is_tensor_core = false; @@ -896,6 +897,7 @@ class CuDNNConvolutionOp { algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; #endif if (result.status == CUDNN_STATUS_SUCCESS && + (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) { algo->Set(result.algo, algo_is_tensor_core); return; diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index c0c56507bbf3..7a9cefcd5e56 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -829,6 +829,7 @@ class CuDNNDeconvolutionOp { void AlgoFinalSelect(const std::vector &perf_results, std::string kernel_name, size_t workspace_byte, CuDNNAlgo *algo) { // Determine the fastest acceptable algo regardless of mathType. + bool enforce_determinism = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0); for (decltype(perf_results.size()) i = 0; i != perf_results.size(); ++i) { const auto &result = perf_results[i]; bool algo_is_tensor_core = false; @@ -836,6 +837,7 @@ class CuDNNDeconvolutionOp { algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; #endif if (result.status == CUDNN_STATUS_SUCCESS && + (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) { algo->Set(result.algo, algo_is_tensor_core); return; diff --git a/src/operator/nn/cudnn/cudnn_pooling-inl.h b/src/operator/nn/cudnn/cudnn_pooling-inl.h index bc3ee366007c..bb3e9c7e3342 100644 --- a/src/operator/nn/cudnn/cudnn_pooling-inl.h +++ b/src/operator/nn/cudnn/cudnn_pooling-inl.h @@ -48,7 +48,8 @@ class CuDNNPoolingOp { param_ = p; switch (param_.pool_type) { case pool_enum::kMaxPooling: - mode_ = CUDNN_POOLING_MAX; + mode_ = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0) ? + CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX; break; case pool_enum::kAvgPooling: if (param_.count_include_pad.has_value() && !param_.count_include_pad.value()) { From 3da0f7de60a1b82123bc50184e95a7440954ffd8 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 29 Oct 2018 13:34:09 -0700 Subject: [PATCH 02/10] set default value to false --- src/operator/nn/cudnn/cudnn_convolution-inl.h | 2 +- src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 2 +- src/operator/nn/cudnn/cudnn_pooling-inl.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 72c9a39199e8..642bb1c049be 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -889,7 +889,7 @@ class CuDNNConvolutionOp { size_t workspace_byte, CuDNNAlgo *algo) { // Determine the fastest acceptable algo that matches the algo_preference (-1 = any), // regardless of mathType. - bool enforce_determinism = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0); + bool enforce_determinism = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", false); for (decltype(perf_results.size()) i = 0; i != perf_results.size(); ++i) { const auto &result = perf_results[i]; bool algo_is_tensor_core = false; diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 7a9cefcd5e56..2f7d1ac65dab 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -829,7 +829,7 @@ class CuDNNDeconvolutionOp { void AlgoFinalSelect(const std::vector &perf_results, std::string kernel_name, size_t workspace_byte, CuDNNAlgo *algo) { // Determine the fastest acceptable algo regardless of mathType. - bool enforce_determinism = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0); + bool enforce_determinism = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", false); for (decltype(perf_results.size()) i = 0; i != perf_results.size(); ++i) { const auto &result = perf_results[i]; bool algo_is_tensor_core = false; diff --git a/src/operator/nn/cudnn/cudnn_pooling-inl.h b/src/operator/nn/cudnn/cudnn_pooling-inl.h index bb3e9c7e3342..29192b197ab3 100644 --- a/src/operator/nn/cudnn/cudnn_pooling-inl.h +++ b/src/operator/nn/cudnn/cudnn_pooling-inl.h @@ -48,7 +48,7 @@ class CuDNNPoolingOp { param_ = p; switch (param_.pool_type) { case pool_enum::kMaxPooling: - mode_ = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0) ? + mode_ = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", false) ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX; break; case pool_enum::kAvgPooling: From 65d64d20ef7d51c8aa5899ae716d7ba10b40ce9b Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 29 Oct 2018 15:10:49 -0700 Subject: [PATCH 03/10] fix build failure in Windows GPU --- src/operator/nn/cudnn/cudnn_convolution-inl.h | 2 +- src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 642bb1c049be..ab4fefa10d53 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -897,7 +897,7 @@ class CuDNNConvolutionOp { algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; #endif if (result.status == CUDNN_STATUS_SUCCESS && - (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && + (!enforce_determinism || result.determinism == PerfType::CUDNN_DETERMINISTIC) && (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) { algo->Set(result.algo, algo_is_tensor_core); return; diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 2f7d1ac65dab..9db30ef4971e 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -837,7 +837,7 @@ class CuDNNDeconvolutionOp { algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; #endif if (result.status == CUDNN_STATUS_SUCCESS && - (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && + (!enforce_determinism || result.determinism == PerfType::CUDNN_DETERMINISTIC) && (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) { algo->Set(result.algo, algo_is_tensor_core); return; From 7bc1d4d60e6cbd796031dd754d60659da12054ac Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 29 Oct 2018 15:16:55 -0700 Subject: [PATCH 04/10] revert the previous change --- src/operator/nn/cudnn/cudnn_convolution-inl.h | 2 +- src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index ab4fefa10d53..642bb1c049be 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -897,7 +897,7 @@ class CuDNNConvolutionOp { algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; #endif if (result.status == CUDNN_STATUS_SUCCESS && - (!enforce_determinism || result.determinism == PerfType::CUDNN_DETERMINISTIC) && + (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) { algo->Set(result.algo, algo_is_tensor_core); return; diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 9db30ef4971e..2f7d1ac65dab 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -837,7 +837,7 @@ class CuDNNDeconvolutionOp { algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; #endif if (result.status == CUDNN_STATUS_SUCCESS && - (!enforce_determinism || result.determinism == PerfType::CUDNN_DETERMINISTIC) && + (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) { algo->Set(result.algo, algo_is_tensor_core); return; From ca98f5f52f6c8a9d127d4593ff6266eccc26f9e6 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 29 Oct 2018 15:41:46 -0700 Subject: [PATCH 05/10] only check determinism in CUDNN 7.x release --- src/operator/nn/cudnn/cudnn_convolution-inl.h | 2 ++ src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 642bb1c049be..3bd6c5a3826b 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -897,7 +897,9 @@ class CuDNNConvolutionOp { algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; #endif if (result.status == CUDNN_STATUS_SUCCESS && + #if CUDNN_MAJOR >= 7 (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && + #endif (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) { algo->Set(result.algo, algo_is_tensor_core); return; diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 2f7d1ac65dab..72ba2c95fc6a 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -837,7 +837,9 @@ class CuDNNDeconvolutionOp { algo_is_tensor_core = result.mathType == CUDNN_TENSOR_OP_MATH; #endif if (result.status == CUDNN_STATUS_SUCCESS && + #if CUDNN_MAJOR >= 7 (!enforce_determinism || result.determinism == cudnnDeterminism_t::CUDNN_DETERMINISTIC) && + #endif (param_.cudnn_tune.value() != conv::kLimited || result.memory <= workspace_byte)) { algo->Set(result.algo, algo_is_tensor_core); return; From ad3a484b72da930be31c181676127923b48af6eb Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 29 Oct 2018 16:17:14 -0700 Subject: [PATCH 06/10] Add cudnn version check --- src/operator/nn/cudnn/cudnn_pooling-inl.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/operator/nn/cudnn/cudnn_pooling-inl.h b/src/operator/nn/cudnn/cudnn_pooling-inl.h index 29192b197ab3..5c88bd929ae4 100644 --- a/src/operator/nn/cudnn/cudnn_pooling-inl.h +++ b/src/operator/nn/cudnn/cudnn_pooling-inl.h @@ -48,8 +48,11 @@ class CuDNNPoolingOp { param_ = p; switch (param_.pool_type) { case pool_enum::kMaxPooling: + #if CUDNN_MAJOR >= 7 mode_ = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", false) ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX; + #else + mode_ = CUDNN_POOLING_MAX; break; case pool_enum::kAvgPooling: if (param_.count_include_pad.has_value() && !param_.count_include_pad.value()) { From 2a108efc7e829bb7ef6501ed901836f16feb6326 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 29 Oct 2018 16:36:35 -0700 Subject: [PATCH 07/10] fix lint error --- src/operator/nn/cudnn/cudnn_pooling-inl.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/cudnn/cudnn_pooling-inl.h b/src/operator/nn/cudnn/cudnn_pooling-inl.h index 5c88bd929ae4..89fa73ef5471 100644 --- a/src/operator/nn/cudnn/cudnn_pooling-inl.h +++ b/src/operator/nn/cudnn/cudnn_pooling-inl.h @@ -52,7 +52,8 @@ class CuDNNPoolingOp { mode_ = dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", false) ? CUDNN_POOLING_MAX_DETERMINISTIC : CUDNN_POOLING_MAX; #else - mode_ = CUDNN_POOLING_MAX; + mode_ = CUDNN_POOLING_MAX; + #endif break; case pool_enum::kAvgPooling: if (param_.count_include_pad.has_value() && !param_.count_include_pad.value()) { From 57403b1d3a5a18b08765e54014a2f4b9a70663d9 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 30 Oct 2018 15:02:37 -0700 Subject: [PATCH 08/10] document env variable MXNET_ENFORCE_DETERMINISM --- docs/faq/env_var.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 92cf4931b04b..e373377ee8de 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -196,6 +196,12 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca - Flag to enable or disable MKLDNN accelerator. On by default. - Only applies to mxnet that has been compiled with MKLDNN (```pip install mxnet-mkl``` or built from source with ```USE_MKLDNN=1```) +* MXNET_ENFORCE_DETERMINISM + - Values: 0(false) or 1(true) ```(default=0)``` + - If set to true, MXNet will only use deterministic algorithms in forward and backward computation. + If no such algorithm exists given other constraints, MXNet will error out. This variable affects the choice + of CUDNN convolution algorithms. Please see [CUDNN developer guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html) for more details. + Settings for Minimum Memory Usage --------------------------------- - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1``` From d1bdf0f38f50b8c499f22ae1d50770b819f27678 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Sat, 3 Nov 2018 16:33:25 -0700 Subject: [PATCH 09/10] use cudnnGet instead of cudnnFind when determinism required --- src/operator/nn/cudnn/cudnn_convolution-inl.h | 8 ++++---- src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index 3bd6c5a3826b..a2dce33e91fb 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -635,8 +635,8 @@ class CuDNNConvolutionOp { std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); int actual_fwd_algos = 0; auto fwd_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 - : cudnnFindConvolutionForwardAlgorithm; + (param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ? + cudnnGetConvolutionForwardAlgorithm_v7 : cudnnFindConvolutionForwardAlgorithm; CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, in_desc_, filter_desc_, @@ -657,8 +657,8 @@ class CuDNNConvolutionOp { // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we // were summing into the output (i.e. beta != 0). Get() returned OK algos though. auto bwd_filter_algo_discoverer = - param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 - : cudnnFindConvolutionBackwardFilterAlgorithm; + (param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ? + cudnnGetConvolutionBackwardFilterAlgorithm_v7 : cudnnFindConvolutionBackwardFilterAlgorithm; CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, in_desc_, out_desc_, diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 72ba2c95fc6a..7616acd6f687 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -562,8 +562,8 @@ class CuDNNDeconvolutionOp { std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); int actual_fwd_algos = 0; auto fwd_algo_discoverer = - param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 - : cudnnFindConvolutionForwardAlgorithm; + (param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ? + cudnnGetConvolutionForwardAlgorithm_v7 : cudnnFindConvolutionForwardAlgorithm; CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, out_desc_, filter_desc_, @@ -584,8 +584,8 @@ class CuDNNDeconvolutionOp { // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we // were summing into the output (i.e. beta != 0). Get() returned OK algos though. auto bwd_filter_algo_discoverer = - param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 - : cudnnFindConvolutionBackwardFilterAlgorithm; + (param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ? + cudnnGetConvolutionBackwardFilterAlgorithm_v7 : cudnnFindConvolutionBackwardFilterAlgorithm; CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, out_desc_, in_desc_, @@ -603,8 +603,8 @@ class CuDNNDeconvolutionOp { std::vector bwd_data_results(max_bwd_data_algos); int actual_bwd_data_algos = 0; auto bwd_data_algo_discoverer = - param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7 - : cudnnFindConvolutionBackwardDataAlgorithm; + (param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ? + cudnnGetConvolutionBackwardDataAlgorithm_v7 : cudnnFindConvolutionBackwardDataAlgorithm; CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, filter_desc_, in_desc_, From 71498ab1e6ef598ca4034472d4ab185d049a85ea Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Tue, 6 Nov 2018 09:56:04 -0800 Subject: [PATCH 10/10] Revert "use cudnnGet instead of cudnnFind when determinism required" This reverts commit d1bdf0f38f50b8c499f22ae1d50770b819f27678. --- src/operator/nn/cudnn/cudnn_convolution-inl.h | 8 ++++---- src/operator/nn/cudnn/cudnn_deconvolution-inl.h | 12 ++++++------ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/operator/nn/cudnn/cudnn_convolution-inl.h b/src/operator/nn/cudnn/cudnn_convolution-inl.h index a2dce33e91fb..3bd6c5a3826b 100644 --- a/src/operator/nn/cudnn/cudnn_convolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_convolution-inl.h @@ -635,8 +635,8 @@ class CuDNNConvolutionOp { std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); int actual_fwd_algos = 0; auto fwd_algo_discoverer = - (param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ? - cudnnGetConvolutionForwardAlgorithm_v7 : cudnnFindConvolutionForwardAlgorithm; + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 + : cudnnFindConvolutionForwardAlgorithm; CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, in_desc_, filter_desc_, @@ -657,8 +657,8 @@ class CuDNNConvolutionOp { // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we // were summing into the output (i.e. beta != 0). Get() returned OK algos though. auto bwd_filter_algo_discoverer = - (param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ? - cudnnGetConvolutionBackwardFilterAlgorithm_v7 : cudnnFindConvolutionBackwardFilterAlgorithm; + param_.cudnn_tune.value() == conv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 + : cudnnFindConvolutionBackwardFilterAlgorithm; CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, in_desc_, out_desc_, diff --git a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h index 7616acd6f687..72ba2c95fc6a 100644 --- a/src/operator/nn/cudnn/cudnn_deconvolution-inl.h +++ b/src/operator/nn/cudnn/cudnn_deconvolution-inl.h @@ -562,8 +562,8 @@ class CuDNNDeconvolutionOp { std::vector fwd_results(MaxForwardAlgos(s->dnn_handle_)); int actual_fwd_algos = 0; auto fwd_algo_discoverer = - (param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ? - cudnnGetConvolutionForwardAlgorithm_v7 : cudnnFindConvolutionForwardAlgorithm; + param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionForwardAlgorithm_v7 + : cudnnFindConvolutionForwardAlgorithm; CUDNN_CALL((*fwd_algo_discoverer)(s->dnn_handle_, out_desc_, filter_desc_, @@ -584,8 +584,8 @@ class CuDNNDeconvolutionOp { // In cudnn v7.1.4, find() returned wgrad algos that could fail for large c if we // were summing into the output (i.e. beta != 0). Get() returned OK algos though. auto bwd_filter_algo_discoverer = - (param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ? - cudnnGetConvolutionBackwardFilterAlgorithm_v7 : cudnnFindConvolutionBackwardFilterAlgorithm; + param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionBackwardFilterAlgorithm_v7 + : cudnnFindConvolutionBackwardFilterAlgorithm; CUDNN_CALL((*bwd_filter_algo_discoverer)(s->dnn_handle_, out_desc_, in_desc_, @@ -603,8 +603,8 @@ class CuDNNDeconvolutionOp { std::vector bwd_data_results(max_bwd_data_algos); int actual_bwd_data_algos = 0; auto bwd_data_algo_discoverer = - (param_.cudnn_tune.value() == conv::kOff || dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", 0)) ? - cudnnGetConvolutionBackwardDataAlgorithm_v7 : cudnnFindConvolutionBackwardDataAlgorithm; + param_.cudnn_tune.value() == deconv::kOff ? cudnnGetConvolutionBackwardDataAlgorithm_v7 + : cudnnFindConvolutionBackwardDataAlgorithm; CUDNN_CALL((*bwd_data_algo_discoverer)(s->dnn_handle_, filter_desc_, in_desc_,