diff --git a/benchmark/opperf/nd_operations/nn_activation_operators.py b/benchmark/opperf/nd_operations/nn_activation_operators.py index 161dfe72123e..7c59065aabbf 100644 --- a/benchmark/opperf/nd_operations/nn_activation_operators.py +++ b/benchmark/opperf/nd_operations/nn_activation_operators.py @@ -36,9 +36,10 @@ 8. Activation 8.1 relu 8.2 sigmoid - 8.3 softrelu - 8.4 softsign - 8.5 tanh + 8.3 log_sigmoid + 8.4 softrelu + 8.5 softsign + 8.6 tanh """ diff --git a/benchmark/opperf/rules/default_params.py b/benchmark/opperf/rules/default_params.py index 4e8bb6b6cc6f..94181932be83 100644 --- a/benchmark/opperf/rules/default_params.py +++ b/benchmark/opperf/rules/default_params.py @@ -375,7 +375,7 @@ # For NN operators DEFAULT_ACT_TYPE_LR = ['leaky', 'elu', 'selu', 'gelu'] -DEFAULT_ACT_TYPE_ACTIVATION = ['relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'] +DEFAULT_ACT_TYPE_ACTIVATION = ['relu', 'sigmoid', 'log_sigmoid', 'softrelu', 'softsign', 'tanh'] DEFAULT_LABEL_SOFTMAX = [(1024, 1024), (10000, 1), (10000, 100)] DEFAULT_LABEL_SOFTMAX_LARGE_TENSOR = [(2**32, 1)] diff --git a/python/mxnet/amp/lists/symbol_bf16.py b/python/mxnet/amp/lists/symbol_bf16.py index e7f14fa3f79d..b7cb85327143 100644 --- a/python/mxnet/amp/lists/symbol_bf16.py +++ b/python/mxnet/amp/lists/symbol_bf16.py @@ -288,6 +288,7 @@ 'hard_sigmoid', 'identity', 'logical_not', + 'log_sigmoid' 'max_axis', 'max', 'min', diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index 6359384489c4..3c927b253af1 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -395,6 +395,7 @@ 'lamb_update_phase1', 'lamb_update_phase2', 'logical_not', + 'log_sigmoid', 'max', 'min', 'mp_lamb_update_phase1', diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index f5d3c26c23a4..aa0ab51bb4a6 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2171,6 +2171,14 @@ def log1p(self, *args, **kwargs): """ return op.log1p(self, *args, **kwargs) + def log_sigmoid(self, *args, **kwargs): + """Convenience fluent method for :py:func:`log_sigmoid`. + + The arguments are the same as for :py:func:`log_sigmoid`, with + this array as data. + """ + return op.log_sigmoid(self, *args, **kwargs) + def sqrt(self, *args, **kwargs): """Convenience fluent method for :py:func:`sqrt`. diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index b64c170cd7df..7dbd3f2d3e29 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -2260,6 +2260,14 @@ def log1p(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute log1p') + def log_sigmoid(self, *args, **kwargs): + """Convenience fluent method for :py:func:`log_sigmoid`. + + The arguments are the same as for :py:func:`log_sigmoid`, with + this array as data. + """ + raise AttributeError('mxnet.numpy.ndarray object has no attribute log_sigmoid') + def sqrt(self, *args, **kwargs): """Convenience fluent method for :py:func:`sqrt`. diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 3ef6281faea5..496265605b44 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -2519,6 +2519,14 @@ def log1p(self, *args, **kwargs): """ return op.log1p(self, *args, **kwargs) + def log_sigmoid(self, *args, **kwargs): + """Convenience fluent method for :py:func:`log_sigmoid`. + + The arguments are the same as for :py:func:`log_sigmoid`, with + this array as data. + """ + return op.log_sigmoid(self, *args, **kwargs) + def sqrt(self, *args, **kwargs): """Convenience fluent method for :py:func:`sqrt`. diff --git a/src/api/operator/numpy_extension/npx_activation_op.cc b/src/api/operator/numpy_extension/npx_activation_op.cc index c072f6e9fc70..ad8cc3cb17a8 100644 --- a/src/api/operator/numpy_extension/npx_activation_op.cc +++ b/src/api/operator/numpy_extension/npx_activation_op.cc @@ -34,6 +34,8 @@ inline int String2MXNetActType(const std::string& s) { return activation::kReLU; } else if (s == "sigmoid") { return activation::kSigmoid; + } else if (s == "log_sigmoid") { + return activation::kLogSigmoid; } else if (s == "tanh") { return activation::kTanh; } else if (s == "softrelu") { diff --git a/src/common/cuda/rtc/backward_functions-inl.h b/src/common/cuda/rtc/backward_functions-inl.h index 50f0c671bf48..64ec2515f44c 100644 --- a/src/common/cuda/rtc/backward_functions-inl.h +++ b/src/common/cuda/rtc/backward_functions-inl.h @@ -40,8 +40,14 @@ backward_relu(const DTypeGrad grad, const DType val) { template __device__ inline mixed_type -backward_sigmoid(const DTypeGrad grad, const DType out) { - return grad * out * (1 - out); +backward_sigmoid(const DTypeGrad grad, const DType val) { + return grad * val * (1 - val); +} + +template +__device__ inline mixed_type +backward_log_sigmoid(const DTypeGrad grad, const DType val) { + return grad * 1 / (1 + op::exp(val)); } template diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h index f4d08e6d1a60..9018a5d435d6 100644 --- a/src/common/cuda/rtc/forward_functions-inl.h +++ b/src/common/cuda/rtc/forward_functions-inl.h @@ -685,6 +685,15 @@ __device__ inline DType sigmoid(const DType val) { } } +template +__device__ inline DType log_sigmoid(const DType val) { + if (type_util::has_double_or_integral::value) { + return ::log(1./(1 + ::exp(-val))); + } else { + return ::logf(1.f/(1 + expf(-val))); + } +} + template __device__ inline DType softrelu(const DType val) { if (type_util::has_double_or_integral::value) { diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h index 0add7eaa99da..df6d67e5fda5 100644 --- a/src/operator/fusion/fused_op-inl.h +++ b/src/operator/fusion/fused_op-inl.h @@ -56,6 +56,7 @@ const std::map>> ops_desc = { {"_backward_amp_cast" , {{"op::identity(%)", "_0"}}}, {"relu" , {{"op::relu(%)", "_0"}}}, {"sigmoid" , {{"op::sigmoid(%)", "_0"}}}, + {"log_sigmoid" , {{"op::log_sigmoid(%)", "_0"}}}, {"softsign" , {{"op::softsign(%)", "_0"}}}, {"exp" , {{"op::exp(%)", "_0"}}}, {"expm1" , {{"op::expm1(%)", "_0"}}}, @@ -135,6 +136,7 @@ const std::map>> ops_desc = { {"logical_not" , {{"op::logical_not(%)", "_0"}}}, {"_backward_relu" , {{"op::backward_relu(%, %)", "_0", "_1"}}}, {"_backward_sigmoid" , {{"op::backward_sigmoid(%, %)", "_0", "_1"}}}, + {"_backward_log_sigmoid" , {{"op::backward_log_sigmoid(%, %)", "_0", "_1"}}}, {"_backward_expm1" , {{"op::backward_expm1(%, %)", "_0", "_1"}}}, {"_backward_log" , {{"op::backward_log(%, %)", "_0", "_1"}}}, {"_backward_log10" , {{"op::backward_log10(%, %)", "_0", "_1"}}}, diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 7c7c18f39c3c..c33dad4601d7 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -411,6 +411,10 @@ MXNET_UNARY_MATH_OP(sigmoid, 1.0f / (1.0f + math::exp(-a))); MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a))); +MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a)))); + +MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a))); + MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a))); MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a))); diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h index 1111464b9697..647debf32fc3 100644 --- a/src/operator/nn/activation-inl.h +++ b/src/operator/nn/activation-inl.h @@ -47,7 +47,7 @@ namespace activation { enum ActivationOpInputs {kData}; enum ActivationOpOutputs {kOut}; enum ActivationOpResource {kTempSpace}; -enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftSign}; +enum ActivationOpType {kReLU, kSigmoid, kLogSigmoid, kTanh, kSoftReLU, kSoftSign}; // Get the number of inputs to the gradient depending on the activation type int GradNumInputs(int act_type); @@ -60,6 +60,7 @@ struct ActivationParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(act_type) .add_enum("relu", activation::kReLU) .add_enum("sigmoid", activation::kSigmoid) + .add_enum("log_sigmoid", activation::kLogSigmoid) .add_enum("tanh", activation::kTanh) .add_enum("softrelu", activation::kSoftReLU) .add_enum("softsign", activation::kSoftSign) @@ -75,6 +76,8 @@ struct ActivationParam : public dmlc::Parameter { return "relu"; case activation::kSigmoid: return "sigmoid"; + case activation::kLogSigmoid: + return "log_sigmoid"; case activation::kTanh: return "tanh"; case activation::kSoftReLU: @@ -159,6 +162,10 @@ void ActivationComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ctx, ActivationForward( ctx, inputs[0], req[0], outputs[0]); break; + case activation::kLogSigmoid: + ActivationForward( + ctx, inputs[0], req[0], outputs[0]); + break; case activation::kTanh: ActivationForward( ctx, inputs[0], req[0], outputs[0]); @@ -190,6 +197,10 @@ void ActivationGradComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ct ActivationBackward( ctx, inputs[0], inputs[1], req[0], outputs[0]); break; + case activation::kLogSigmoid: + ActivationBackward( + ctx, inputs[0], inputs[1], req[0], outputs[0]); + break; case activation::kTanh: ActivationBackward( ctx, inputs[0], inputs[1], req[0], outputs[0]); diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index e9c5251404bd..12a80843cbc5 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -51,6 +51,7 @@ int GradNumInputs(int act_type) { case kSoftSign: case kTanh: case kSigmoid: + case kLogSigmoid: return 3; default: CHECK(false) << "missing activation type"; @@ -91,6 +92,7 @@ struct ActivationGrad { case kSoftSign: case kTanh: case kSigmoid: + case kLogSigmoid: heads.push_back(n->inputs[activation::kData]); break; default: @@ -168,6 +170,7 @@ The following activation functions are supported: - `relu`: Rectified Linear Unit, :math:`y = max(x, 0)` - `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}` +- `log_sigmoid`: :math:`y = log(\frac{1}{1 + exp(-x)})` - `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}` - `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))` - `softsign`: :math:`y = \frac{x}{1 + abs(x)}` diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu index 1116cf20165b..18962f5740de 100644 --- a/src/operator/nn/activation.cu +++ b/src/operator/nn/activation.cu @@ -115,6 +115,9 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, } else if (act_type == activation::kSigmoid) { ActivationBackward( ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); + } else if (act_type == activation::kLogSigmoid) { + ActivationBackward( + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); } else { LOG(FATAL) << "unknown activation type"; } diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc b/src/operator/nn/mkldnn/mkldnn_act.cc index 761ab86a2c7f..a4fe78025535 100644 --- a/src/operator/nn/mkldnn/mkldnn_act.cc +++ b/src/operator/nn/mkldnn/mkldnn_act.cc @@ -43,6 +43,7 @@ namespace op { bool SupportMKLDNNAct(const ActivationParam& param) { return param.act_type == activation::kReLU || param.act_type == activation::kSigmoid + || param.act_type == activation::kLogSigmoid || param.act_type == activation::kSoftReLU || param.act_type == activation::kTanh; } @@ -83,6 +84,8 @@ mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) { return mkldnn::algorithm::eltwise_relu; case activation::kSigmoid: return mkldnn::algorithm::eltwise_logistic; + case activation::kLogSigmoid: + return mkldnn::algorithm::eltwise_logsigmoid; case activation::kTanh: return mkldnn::algorithm::eltwise_tanh; case activation::kSoftReLU: diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 557338e2b408..5a4b27a646d4 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -236,6 +236,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log_sigmoid); // NOLINT() +IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_sigmoid_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad); // NOLINT() IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu); // NOLINT() diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 5107de8b161d..7a951e2819c6 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -149,6 +149,23 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid, return ret; }); +// log_sigmoid +MXNET_OPERATOR_REGISTER_UNARY(log_sigmoid) +MXNET_ADD_SPARSE_OP_ALIAS(log_sigmoid) +.describe(R"code(Computes log_sigmoid of x element-wise. + +.. math:: + y = log(1 / (1 + exp(-x))) + +The storage type of ``log_sigmoid`` output is always dense + +)code" ADD_FILELINE) +.set_attr("FCompute", UnaryOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_log_sigmoid"}); + +MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_log_sigmoid, + unary_bwd); + DMLC_REGISTER_PARAMETER(HardSigmoidParam); diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu b/src/operator/tensor/elemwise_unary_op_basic.cu index 074f7ac69a26..e9f52f1c6a0b 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cu +++ b/src/operator/tensor/elemwise_unary_op_basic.cu @@ -39,6 +39,12 @@ NNVM_REGISTER_OP(sigmoid) NNVM_REGISTER_OP(_backward_sigmoid) .set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_sigmoid"}); +NNVM_REGISTER_OP(log_sigmoid) +.set_attr("FCompute", UnaryRTCCompute{"log_sigmoid"}); + +NNVM_REGISTER_OP(_backward_log_sigmoid) +.set_attr("FCompute", ElemwiseBinaryRTCCompute{"backward_log_sigmoid"}); + NNVM_REGISTER_OP(hard_sigmoid) .set_attr("FCompute", HardSigmoidForward); diff --git a/tests/cpp/operator/activation_perf.cc b/tests/cpp/operator/activation_perf.cc index 29deda92e01b..61b96267f84b 100644 --- a/tests/cpp/operator/activation_perf.cc +++ b/tests/cpp/operator/activation_perf.cc @@ -43,6 +43,7 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) { vector activations = { "relu", "sigmoid", + "log_sigmoid", "tanh", "softrelu", "softsign" diff --git a/tests/python/mkl/subgraphs/test_conv_subgraph.py b/tests/python/mkl/subgraphs/test_conv_subgraph.py index e8116044b5f1..c38c75cf7424 100644 --- a/tests/python/mkl/subgraphs/test_conv_subgraph.py +++ b/tests/python/mkl/subgraphs/test_conv_subgraph.py @@ -107,6 +107,7 @@ def hybrid_forward(self, F, x): @pytest.mark.parametrize('alg,quantize', [ ("relu", False), #TODO(bgawrych): investigate ("sigmoid", True), + ("log_sigmoid", False), ("tanh", False), #TODO(bgawrych): investigate #("softrelu", True), #TODO(bgawrych): bug in oneDNN with AVX ("relu6", False), #TODO(bgawrych): investigate @@ -147,6 +148,7 @@ def hybrid_forward(self, F, x): @pytest.mark.parametrize('alg,quantize', [ ("relu", True), ("sigmoid", True), + ("log_sigmoid", True), ("tanh", True), ("softrelu", True), ("relu6", True), @@ -183,6 +185,7 @@ def hybrid_forward(self, F, x): @pytest.mark.parametrize('alg,quantize', [ ("relu", True), ("sigmoid", True), + ("log_sigmoid", True), ("tanh", True), #("softrelu", True), #TODO(bgawrych): failing fusion check - difference in random single element ("relu6", True), @@ -289,6 +292,7 @@ def hybrid_forward(self, F, x, shared_weight): @pytest.mark.parametrize('alg,quantize', [ ("relu", True), ("sigmoid", True), + ("log_sigmoid", True), ("tanh", True), ("softrelu", True), ("relu6", True), diff --git a/tests/python/mkl/subgraphs/test_fc_subgraph.py b/tests/python/mkl/subgraphs/test_fc_subgraph.py index 07151ad22227..5b4c61d7756d 100644 --- a/tests/python/mkl/subgraphs/test_fc_subgraph.py +++ b/tests/python/mkl/subgraphs/test_fc_subgraph.py @@ -23,7 +23,7 @@ from mxnet.gluon import nn from mxnet.test_utils import assert_almost_equal_with_err -fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu', 'gelu', 'elu', 'leaky', +fc_post_ops_list=['relu', 'sigmoid', 'log_sigmoid', 'tanh', 'softrelu', 'gelu', 'elu', 'leaky', 'square', 'square_root', 'abs', 'exp', 'bounded_relu'] def test_float64_fallback(): @@ -69,7 +69,7 @@ def __init__(self, use_bias, flatten, alg, **kwargs): def hybrid_forward(self, F, x): fc_out = self.fc(x) - if self.alg in ['relu', 'sigmoid', 'tanh', 'softrelu']: + if self.alg in ['relu', 'sigmoid', 'log_sigmoid', 'tanh', 'softrelu']: out = F.Activation(fc_out, act_type=self.alg) elif self.alg in ['gelu', 'elu', 'leaky']: out = F.LeakyReLU(fc_out, act_type=self.alg) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 07487579d775..b2995ec9c38a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -677,6 +677,21 @@ def fsigmoid(a): check_symbolic_forward(y, [xa], [ya]) check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)]) +def test_log_sigmoid(): + def flog_sigmoid(a): + return np.log(np.divide(1.0, np.add(1.0, np.exp(-a)))) + def flog_sigmoid_grad(a): + return np.divide(1.0, np.add(1.0, np.exp(a))) + shape = (3, 4) + x = mx.symbol.Variable("x") + y = mx.sym.log_sigmoid(x) + xa = np.random.uniform(low=-1.0,high=1.0,size=shape) + ya = flog_sigmoid(xa) + ya_grad = flog_sigmoid_grad(xa) + check_numeric_gradient(y, [xa], numeric_eps=1E-3) + check_symbolic_forward(y, [xa], [ya]) + check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad]) + def test_shape_array(): for i in range(1,6): shape = rand_shape_nd(i) @@ -8697,7 +8712,7 @@ def test_get_operator_arguments(): assert isinstance(operator_arguments, OperatorArguments) assert operator_arguments.names == ['data', 'act_type'] assert operator_arguments.types \ - == ['NDArray-or-Symbol', "{'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required"] + == ['NDArray-or-Symbol', "{'log_sigmoid', 'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required"] assert operator_arguments.narg == 2