Skip to content

Commit

Permalink
[operator] Add logsigmoid activation function (apache#20268)
Browse files Browse the repository at this point in the history
* [operator] Add logsigmoid activation function

* Improve GPU path for logsigmoid

* Fix website check

* Add log sigmoid to subgraph tests
  • Loading branch information
bartekkuncer authored May 21, 2021
1 parent 4400f42 commit 9308174
Show file tree
Hide file tree
Showing 23 changed files with 125 additions and 10 deletions.
7 changes: 4 additions & 3 deletions benchmark/opperf/nd_operations/nn_activation_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
8. Activation
8.1 relu
8.2 sigmoid
8.3 softrelu
8.4 softsign
8.5 tanh
8.3 log_sigmoid
8.4 softrelu
8.5 softsign
8.6 tanh
"""

Expand Down
2 changes: 1 addition & 1 deletion benchmark/opperf/rules/default_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@

# For NN operators
DEFAULT_ACT_TYPE_LR = ['leaky', 'elu', 'selu', 'gelu']
DEFAULT_ACT_TYPE_ACTIVATION = ['relu', 'sigmoid', 'softrelu', 'softsign', 'tanh']
DEFAULT_ACT_TYPE_ACTIVATION = ['relu', 'sigmoid', 'log_sigmoid', 'softrelu', 'softsign', 'tanh']
DEFAULT_LABEL_SOFTMAX = [(1024, 1024), (10000, 1), (10000, 100)]

DEFAULT_LABEL_SOFTMAX_LARGE_TENSOR = [(2**32, 1)]
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/amp/lists/symbol_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@
'hard_sigmoid',
'identity',
'logical_not',
'log_sigmoid'
'max_axis',
'max',
'min',
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/amp/lists/symbol_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@
'lamb_update_phase1',
'lamb_update_phase2',
'logical_not',
'log_sigmoid',
'max',
'min',
'mp_lamb_update_phase1',
Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2171,6 +2171,14 @@ def log1p(self, *args, **kwargs):
"""
return op.log1p(self, *args, **kwargs)

def log_sigmoid(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log_sigmoid`.
The arguments are the same as for :py:func:`log_sigmoid`, with
this array as data.
"""
return op.log_sigmoid(self, *args, **kwargs)

def sqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sqrt`.
Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2260,6 +2260,14 @@ def log1p(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute log1p')

def log_sigmoid(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log_sigmoid`.
The arguments are the same as for :py:func:`log_sigmoid`, with
this array as data.
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute log_sigmoid')

def sqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sqrt`.
Expand Down
8 changes: 8 additions & 0 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2519,6 +2519,14 @@ def log1p(self, *args, **kwargs):
"""
return op.log1p(self, *args, **kwargs)

def log_sigmoid(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`log_sigmoid`.
The arguments are the same as for :py:func:`log_sigmoid`, with
this array as data.
"""
return op.log_sigmoid(self, *args, **kwargs)

def sqrt(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`sqrt`.
Expand Down
2 changes: 2 additions & 0 deletions src/api/operator/numpy_extension/npx_activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ inline int String2MXNetActType(const std::string& s) {
return activation::kReLU;
} else if (s == "sigmoid") {
return activation::kSigmoid;
} else if (s == "log_sigmoid") {
return activation::kLogSigmoid;
} else if (s == "tanh") {
return activation::kTanh;
} else if (s == "softrelu") {
Expand Down
10 changes: 8 additions & 2 deletions src/common/cuda/rtc/backward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,14 @@ backward_relu(const DTypeGrad grad, const DType val) {
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_sigmoid(const DTypeGrad grad, const DType out) {
return grad * out * (1 - out);
backward_sigmoid(const DTypeGrad grad, const DType val) {
return grad * val * (1 - val);
}
template <typename DType, typename DTypeGrad>
__device__ inline mixed_type<DTypeGrad, DType>
backward_log_sigmoid(const DTypeGrad grad, const DType val) {
return grad * 1 / (1 + op::exp(val));
}
template <typename DType, typename DTypeGrad>
Expand Down
9 changes: 9 additions & 0 deletions src/common/cuda/rtc/forward_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,15 @@ __device__ inline DType sigmoid(const DType val) {
}
}
template <typename DType>
__device__ inline DType log_sigmoid(const DType val) {
if (type_util::has_double_or_integral<DType>::value) {
return ::log(1./(1 + ::exp(-val)));
} else {
return ::logf(1.f/(1 + expf(-val)));
}
}
template <typename DType>
__device__ inline DType softrelu(const DType val) {
if (type_util::has_double_or_integral<DType>::value) {
Expand Down
2 changes: 2 additions & 0 deletions src/operator/fusion/fused_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = {
{"_backward_amp_cast" , {{"op::identity(%)", "_0"}}},
{"relu" , {{"op::relu(%)", "_0"}}},
{"sigmoid" , {{"op::sigmoid(%)", "_0"}}},
{"log_sigmoid" , {{"op::log_sigmoid(%)", "_0"}}},
{"softsign" , {{"op::softsign(%)", "_0"}}},
{"exp" , {{"op::exp(%)", "_0"}}},
{"expm1" , {{"op::expm1(%)", "_0"}}},
Expand Down Expand Up @@ -135,6 +136,7 @@ const std::map<std::string, std::vector<std::vector<std::string>>> ops_desc = {
{"logical_not" , {{"op::logical_not(%)", "_0"}}},
{"_backward_relu" , {{"op::backward_relu(%, %)", "_0", "_1"}}},
{"_backward_sigmoid" , {{"op::backward_sigmoid(%, %)", "_0", "_1"}}},
{"_backward_log_sigmoid" , {{"op::backward_log_sigmoid(%, %)", "_0", "_1"}}},
{"_backward_expm1" , {{"op::backward_expm1(%, %)", "_0", "_1"}}},
{"_backward_log" , {{"op::backward_log(%, %)", "_0", "_1"}}},
{"_backward_log10" , {{"op::backward_log10(%, %)", "_0", "_1"}}},
Expand Down
4 changes: 4 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,10 @@ MXNET_UNARY_MATH_OP(sigmoid, 1.0f / (1.0f + math::exp(-a)));

MXNET_UNARY_MATH_OP(sigmoid_grad, math::id(a) * (1.0f - math::id(a)));

MXNET_UNARY_MATH_OP(log_sigmoid, math::log(1.0f / (1.0f + math::exp(-a))));

MXNET_UNARY_MATH_OP(log_sigmoid_grad, 1.0f / (1.0f + math::exp(a)));

MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));

MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));
Expand Down
13 changes: 12 additions & 1 deletion src/operator/nn/activation-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ namespace activation {
enum ActivationOpInputs {kData};
enum ActivationOpOutputs {kOut};
enum ActivationOpResource {kTempSpace};
enum ActivationOpType {kReLU, kSigmoid, kTanh, kSoftReLU, kSoftSign};
enum ActivationOpType {kReLU, kSigmoid, kLogSigmoid, kTanh, kSoftReLU, kSoftSign};

// Get the number of inputs to the gradient depending on the activation type
int GradNumInputs(int act_type);
Expand All @@ -60,6 +60,7 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
DMLC_DECLARE_FIELD(act_type)
.add_enum("relu", activation::kReLU)
.add_enum("sigmoid", activation::kSigmoid)
.add_enum("log_sigmoid", activation::kLogSigmoid)
.add_enum("tanh", activation::kTanh)
.add_enum("softrelu", activation::kSoftReLU)
.add_enum("softsign", activation::kSoftSign)
Expand All @@ -75,6 +76,8 @@ struct ActivationParam : public dmlc::Parameter<ActivationParam> {
return "relu";
case activation::kSigmoid:
return "sigmoid";
case activation::kLogSigmoid:
return "log_sigmoid";
case activation::kTanh:
return "tanh";
case activation::kSoftReLU:
Expand Down Expand Up @@ -159,6 +162,10 @@ void ActivationComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
ActivationForward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kLogSigmoid:
ActivationForward<xpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs[0], req[0], outputs[0]);
break;
case activation::kTanh:
ActivationForward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, inputs[0], req[0], outputs[0]);
Expand Down Expand Up @@ -190,6 +197,10 @@ void ActivationGradComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ct
ActivationBackward<xpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kLogSigmoid:
ActivationBackward<xpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
break;
case activation::kTanh:
ActivationBackward<xpu, mshadow_op::tanh, mshadow_op::tanh_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
Expand Down
3 changes: 3 additions & 0 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ int GradNumInputs(int act_type) {
case kSoftSign:
case kTanh:
case kSigmoid:
case kLogSigmoid:
return 3;
default:
CHECK(false) << "missing activation type";
Expand Down Expand Up @@ -91,6 +92,7 @@ struct ActivationGrad {
case kSoftSign:
case kTanh:
case kSigmoid:
case kLogSigmoid:
heads.push_back(n->inputs[activation::kData]);
break;
default:
Expand Down Expand Up @@ -168,6 +170,7 @@ The following activation functions are supported:
- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`
- `sigmoid`: :math:`y = \frac{1}{1 + exp(-x)}`
- `log_sigmoid`: :math:`y = log(\frac{1}{1 + exp(-x)})`
- `tanh`: Hyperbolic tangent, :math:`y = \frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`
- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`
- `softsign`: :math:`y = \frac{x}{1 + abs(x)}`
Expand Down
3 changes: 3 additions & 0 deletions src/operator/nn/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
} else if (act_type == activation::kSigmoid) {
ActivationBackward<gpu, mshadow_op::sigmoid, mshadow_op::sigmoid_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kLogSigmoid) {
ActivationBackward<gpu, mshadow_op::log_sigmoid, mshadow_op::log_sigmoid_grad>(
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else {
LOG(FATAL) << "unknown activation type";
}
Expand Down
3 changes: 3 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace op {
bool SupportMKLDNNAct(const ActivationParam& param) {
return param.act_type == activation::kReLU
|| param.act_type == activation::kSigmoid
|| param.act_type == activation::kLogSigmoid
|| param.act_type == activation::kSoftReLU
|| param.act_type == activation::kTanh;
}
Expand Down Expand Up @@ -83,6 +84,8 @@ mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
return mkldnn::algorithm::eltwise_relu;
case activation::kSigmoid:
return mkldnn::algorithm::eltwise_logistic;
case activation::kLogSigmoid:
return mkldnn::algorithm::eltwise_logsigmoid;
case activation::kTanh:
return mkldnn::algorithm::eltwise_tanh;
case activation::kSoftReLU:
Expand Down
2 changes: 2 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,8 @@ IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::reciprocal); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::reciprocal_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::sigmoid); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::sigmoid_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::log_sigmoid); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::log_sigmoid_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::softsign); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::softsign_grad); // NOLINT()
IMPLEMENT_UNARY_WORKLOAD_FWD(mxnet::op::mshadow_op::relu); // NOLINT()
Expand Down
17 changes: 17 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,23 @@ MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_sigmoid,
return ret;
});

// log_sigmoid
MXNET_OPERATOR_REGISTER_UNARY(log_sigmoid)
MXNET_ADD_SPARSE_OP_ALIAS(log_sigmoid)
.describe(R"code(Computes log_sigmoid of x element-wise.
.. math::
y = log(1 / (1 + exp(-x)))
The storage type of ``log_sigmoid`` output is always dense
)code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", UnaryOp::Compute<cpu, mshadow_op::log_sigmoid>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_log_sigmoid"});

MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_log_sigmoid,
unary_bwd<mshadow_op::log_sigmoid_grad>);



DMLC_REGISTER_PARAMETER(HardSigmoidParam);
Expand Down
6 changes: 6 additions & 0 deletions src/operator/tensor/elemwise_unary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ NNVM_REGISTER_OP(sigmoid)
NNVM_REGISTER_OP(_backward_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryRTCCompute{"backward_sigmoid"});

NNVM_REGISTER_OP(log_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", UnaryRTCCompute{"log_sigmoid"});

NNVM_REGISTER_OP(_backward_log_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryRTCCompute{"backward_log_sigmoid"});

NNVM_REGISTER_OP(hard_sigmoid)
.set_attr<FCompute>("FCompute<gpu>", HardSigmoidForward<gpu>);

Expand Down
1 change: 1 addition & 0 deletions tests/cpp/operator/activation_perf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) {
vector<string> activations = {
"relu",
"sigmoid",
"log_sigmoid",
"tanh",
"softrelu",
"softsign"
Expand Down
4 changes: 4 additions & 0 deletions tests/python/mkl/subgraphs/test_conv_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def hybrid_forward(self, F, x):
@pytest.mark.parametrize('alg,quantize', [
("relu", False), #TODO(bgawrych): investigate
("sigmoid", True),
("log_sigmoid", False),
("tanh", False), #TODO(bgawrych): investigate
#("softrelu", True), #TODO(bgawrych): bug in oneDNN with AVX
("relu6", False), #TODO(bgawrych): investigate
Expand Down Expand Up @@ -147,6 +148,7 @@ def hybrid_forward(self, F, x):
@pytest.mark.parametrize('alg,quantize', [
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
("tanh", True),
("softrelu", True),
("relu6", True),
Expand Down Expand Up @@ -183,6 +185,7 @@ def hybrid_forward(self, F, x):
@pytest.mark.parametrize('alg,quantize', [
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
("tanh", True),
#("softrelu", True), #TODO(bgawrych): failing fusion check - difference in random single element
("relu6", True),
Expand Down Expand Up @@ -289,6 +292,7 @@ def hybrid_forward(self, F, x, shared_weight):
@pytest.mark.parametrize('alg,quantize', [
("relu", True),
("sigmoid", True),
("log_sigmoid", True),
("tanh", True),
("softrelu", True),
("relu6", True),
Expand Down
4 changes: 2 additions & 2 deletions tests/python/mkl/subgraphs/test_fc_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from mxnet.gluon import nn
from mxnet.test_utils import assert_almost_equal_with_err

fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu', 'gelu', 'elu', 'leaky',
fc_post_ops_list=['relu', 'sigmoid', 'log_sigmoid', 'tanh', 'softrelu', 'gelu', 'elu', 'leaky',
'square', 'square_root', 'abs', 'exp', 'bounded_relu']

def test_float64_fallback():
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, use_bias, flatten, alg, **kwargs):

def hybrid_forward(self, F, x):
fc_out = self.fc(x)
if self.alg in ['relu', 'sigmoid', 'tanh', 'softrelu']:
if self.alg in ['relu', 'sigmoid', 'log_sigmoid', 'tanh', 'softrelu']:
out = F.Activation(fc_out, act_type=self.alg)
elif self.alg in ['gelu', 'elu', 'leaky']:
out = F.LeakyReLU(fc_out, act_type=self.alg)
Expand Down
17 changes: 16 additions & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,21 @@ def fsigmoid(a):
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ya * (1 - ya)])

def test_log_sigmoid():
def flog_sigmoid(a):
return np.log(np.divide(1.0, np.add(1.0, np.exp(-a))))
def flog_sigmoid_grad(a):
return np.divide(1.0, np.add(1.0, np.exp(a)))
shape = (3, 4)
x = mx.symbol.Variable("x")
y = mx.sym.log_sigmoid(x)
xa = np.random.uniform(low=-1.0,high=1.0,size=shape)
ya = flog_sigmoid(xa)
ya_grad = flog_sigmoid_grad(xa)
check_numeric_gradient(y, [xa], numeric_eps=1E-3)
check_symbolic_forward(y, [xa], [ya])
check_symbolic_backward(y, [xa], [np.ones(shape)], [ya_grad])

def test_shape_array():
for i in range(1,6):
shape = rand_shape_nd(i)
Expand Down Expand Up @@ -8697,7 +8712,7 @@ def test_get_operator_arguments():
assert isinstance(operator_arguments, OperatorArguments)
assert operator_arguments.names == ['data', 'act_type']
assert operator_arguments.types \
== ['NDArray-or-Symbol', "{'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required"]
== ['NDArray-or-Symbol', "{'log_sigmoid', 'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'}, required"]
assert operator_arguments.narg == 2


Expand Down

0 comments on commit 9308174

Please sign in to comment.