diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 204f3c9bd507..2fac39923c7d 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -23,6 +23,7 @@ from __future__ import print_function __all__ = ['RNN', 'LSTM', 'GRU'] +from ...autograd import is_training from ... import ndarray from .. import Block from . import rnn_cell @@ -185,15 +186,17 @@ def forward(self, inputs, states=None): for i in range(self._dir): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) self.i2h_weight[i]._finish_deferred_init() - if inputs.context.device_type == 'gpu': - out = self._forward_gpu(inputs, states) + if inputs.context.device_type == 'gpu' or \ + (not is_training() and self._mode == 'lstm'): + out = self._forward_kernel(inputs, states) else: - out = self._forward_cpu(inputs, states) + out = self._forward(inputs, states) # out is (output, state) return out[0] if skip_states else out - def _forward_cpu(self, inputs, states): + def _forward(self, inputs, states): + """forward using gluon cell""" ns = len(states) axis = self._layout.find('T') states = sum(zip(*((j for j in i) for i in states)), ()) @@ -207,7 +210,8 @@ def _forward_cpu(self, inputs, states): return outputs, new_states - def _forward_gpu(self, inputs, states): + def _forward_kernel(self, inputs, states): + """ forward using CUDNN or CPU kenrel""" if self._layout == 'NTC': inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1) ctx = inputs.context diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index b4735b8eec64..13c077dd9e35 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -34,7 +34,11 @@ #include #include #include +#include "./math.h" +#include "./math_functions-inl.h" #include "./operator_common.h" +#include "./mshadow_op.h" +#include "./linalg.h" namespace mxnet { namespace op { @@ -153,6 +157,212 @@ class RNNOp : public Operator { RNNParam param_; }; // class RNNOp +template +class RNNOp : public Operator { + public: + explicit RNNOp(RNNParam param) { + this->param_ = param; + // RNN Mode + param_.lstm_q_ = false; + switch (param_.mode) { + case rnn_enum::kLstm: + param_.lstm_q_ = true; + break; + default: + LOG(FATAL) << "only LSTM is implmented on CPU"; + } + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + // Layout TNC + CHECK(!ctx.is_train) << "only inference mode is available" + "for cpu at the moment."; + size_t in_expected = param_.lstm_q_ ? 4 : 3; + size_t out_expected = param_.lstm_q_ ? 3 : 2; + + if (!param_.state_outputs) + LOG(FATAL) << "no state outputs is currently not supported for cpu."; + + CHECK_EQ(req[rnn_enum::kOut], kWriteTo); + CHECK_EQ(in_data.size(), in_expected); + CHECK_EQ(out_data.size(), out_expected); + + mshadow::Stream *s = ctx.get_stream(); + // get input + output tensors + // w layout i2h_w, h2h_w, i2h_b, h2h_b + Tensor x = + in_data[rnn_enum::kData].get(s); // TNC + Tensor w = in_data[rnn_enum::kParams].get(s); + Tensor hx = + in_data[rnn_enum::kState].get(s); // LNC + Tensor y = + out_data[rnn_enum::kOut].get(s); // TNC + int64_t seq_len = x.shape_[0]; + int64_t num_layers = hx.shape_[0]; + int64_t batch_size = x.shape_[1]; + int64_t h_channel = hx.shape_[2]; + int64_t in_channel = x.shape_[2]; + Tensor x_flatten = in_data[rnn_enum::kData] + .get_with_shape( + mshadow::Shape2(seq_len * batch_size, in_channel), s); // (T*N)C + Tensor y_flatten = out_data[rnn_enum::kOut] + .get_with_shape( + mshadow::Shape2( + y.shape_[0] * y.shape_[1], y.shape_[2]), s); // (T*N)C + + CHECK(x.CheckContiguous()); + CHECK(w.CheckContiguous()); + CHECK(hx.CheckContiguous()); + CHECK(y.CheckContiguous()); + + if (param_.lstm_q_) { + const size_t kNumMat = 4; + int64_t fused_h_ch = kNumMat * h_channel; + int64_t h_size = batch_size * fused_h_ch; + int64_t num_dir = 1 + param_.bidirectional; + int64_t h2h_w_size = h_channel * fused_h_ch; + + Tensor cx = + in_data[rnn_enum::kStateCell].get(s); + CHECK(cx.CheckContiguous()); + + Tensor cy = + out_data[rnn_enum::kStateCellOut].get(s); + Tensor hy = + out_data[rnn_enum::kStateOut].get(s); + CHECK(cy.CheckContiguous()); + CHECK(hy.CheckContiguous()); + + DType* workspace_addr = + static_cast(ctx.requested[rnn_enum::kTempSpace] + .get_host_space_internal(sizeof(DType) * + (seq_len * h_size + h_size + + y.shape_[0] * y.shape_[1] * y.shape_[2]))); + Tensor i2h_y( + workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch)); + Tensor i2h_y_flatten( + workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch)); + Tensor h2h_y(workspace_addr + + seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch)); + Tensor y_tmp(workspace_addr + + (seq_len + 1) * h_size, y.shape_); + Tensor y_flatten_tmp(workspace_addr + + (seq_len + 1) * h_size, y_flatten.shape_); + CHECK(i2h_y.CheckContiguous()); + CHECK(h2h_y.CheckContiguous()); + CHECK(y_tmp.CheckContiguous()); + + for (int64_t layer = 0; layer < num_layers; layer++) { + int reverse_dir = 0; + int out_tmp = 0; + if (param_.bidirectional && layer % 2) + reverse_dir = 1; + if (layer / num_dir % 2 == 0) + out_tmp = 1; + mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch, + (layer < num_dir) ? in_channel : num_dir * h_channel); + mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel); + int64_t start = layer < num_dir ? + (layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer + (num_dir * (in_channel * fused_h_ch + h2h_w_size) + + (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size)); + Tensor i2h_w(w.dptr_ + start, i2h_w_shape); + start += layer < num_dir ? + in_channel * fused_h_ch : h2h_w_size * num_dir; + Tensor h2h_w(w.dptr_ + start, h2h_w_shape); + start = num_dir * (in_channel * fused_h_ch + h2h_w_size) + + (num_layers - num_dir) * (h2h_w_size * (num_dir + 1)) + + layer * fused_h_ch * 2; + Tensor i2h_b = w.Slice(start, start + fused_h_ch); + start += fused_h_ch; + Tensor h2h_b = w.Slice(start, start + fused_h_ch); + if (out_tmp) { + linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w, + i2h_y_flatten, false, true, s); + } else { + linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w, + i2h_y_flatten, false, true, s); + } + i2h_y_flatten += repmat(i2h_b, seq_len * batch_size); + for (int64_t t = 0; t < seq_len; t++) { + int64_t timestep = t; + if (reverse_dir) + timestep = seq_len - 1 - t; + linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y, + false, true, s); + h2h_y += repmat(h2h_b, batch_size); + // fused element-wise ops + LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y, + y[timestep], out_tmp ? y_tmp[timestep]: y[timestep], + hy[layer], cy[layer], batch_size, h_channel, t, + reverse_dir, out_tmp && (layer == num_layers - 1)); + } + } + } else { + LOG(FATAL) << "only LSTM is available for cpu at the moment."; + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + LOG(FATAL) << "LSTM backward is not available for cpu at the moment."; + } + + private: + RNNParam param_; + + void LSTMFusedElementWiseCPUOps(const Tensor &i2h_y, + const Tensor &cx, + const Tensor &h2h_y, + const Tensor &y, + // holding intermediate layer output + const Tensor &tmp, + const Tensor &hy, + const Tensor &cy, + const int64_t batch_size, + const int64_t h_channel, + const int64_t t, + const int reverse_dir, + const int copy_tmp2y) { + int64_t length = batch_size * h_channel; + #pragma omp parallel for + for (int64_t ji = 0; ji < length; ++ji) { + int64_t j = ji / h_channel; // batch dim + int64_t i = ji % h_channel; + int64_t f = i + h_channel; + int64_t c = i + h_channel * 2; + int64_t o = i + h_channel * 3; + int64_t j_pos = j * h_channel * 4; + h2h_y.dptr_[j_pos + i] += i2h_y.dptr_[j_pos + i]; + h2h_y.dptr_[j_pos + f] += i2h_y.dptr_[j_pos + f]; + h2h_y.dptr_[j_pos + o] += i2h_y.dptr_[j_pos + o]; + h2h_y.dptr_[j_pos + c] += i2h_y.dptr_[j_pos + c]; + h2h_y.dptr_[j_pos + i] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + i])); + h2h_y.dptr_[j_pos + f] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + f])); + h2h_y.dptr_[j_pos + o] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + o])); + h2h_y.dptr_[j_pos + c] = tanh(h2h_y.dptr_[j_pos + c]); + cy[j][i] = h2h_y.dptr_[j_pos + f] * (t == 0 ? cx[j][i]:cy[j][i]) + + h2h_y.dptr_[j_pos + i] * h2h_y.dptr_[j_pos + c]; + hy[j][i] = h2h_y.dptr_[j_pos + o] * tanh(cy[j][i]); + tmp[j][i + h_channel * reverse_dir] = hy[j][i]; + if (copy_tmp2y) { + y[j][i] = tmp[j][i]; + if (reverse_dir) + y[j][i + h_channel] = tmp[j][i + h_channel]; + } + } + } +}; // class RNNOp + template Operator* CreateOp(RNNParam param, int dtype); diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 908428b383ca..a60adbcd2fbc 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -30,7 +30,6 @@ namespace mxnet { namespace op { template<> Operator *CreateOp(RNNParam param, int dtype) { - LOG(FATAL) << "RNN is only available for gpu at the moment."; Operator *op = NULL; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new RNNOp(param); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 2cef29c03fa8..0d1496858736 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1523,6 +1523,23 @@ def check_rnn_layer(layer): for g, c in zip(gs, cs): assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) +def check_rnn_layer_w_rand_inputs(layer): + layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) + x = mx.nd.uniform(shape=(10, 16, 30)) + with mx.gpu(0): + x = x.copyto(mx.gpu(0)) + states = layer.begin_state(16) + go, gs = layer(x, states) + + with mx.cpu(0): + x = x.copyto(mx.cpu(0)) + states = layer.begin_state(16) + co, cs = layer(x, states) + + assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6) + for g, c in zip(gs, cs): + assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) + @with_seed() def test_rnn_layer(): check_rnn_layer(gluon.rnn.RNN(100, num_layers=3)) @@ -1531,7 +1548,7 @@ def test_rnn_layer(): check_rnn_layer(gluon.rnn.GRU(100, num_layers=3)) check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) - + check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) @with_seed() def test_sequence_reverse(): diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 871deeb26c40..f22b13d65752 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -67,6 +67,20 @@ def test_lstm_forget_bias(): forget_bias * np.ones(100, ), np.zeros((2 * 100,))]) assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias) +def test_lstm_cpu_inference(): + # should behave the same as lstm cell + EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], + [0.72045636, 0.72045636, 0.95215213, 0.95215213]], + [[0.95215213, 0.95215213, 0.72045636, 0.72045636], + [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]) + x = mx.nd.ones(shape=(2, 2, 2)) + model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True) + model.initialize(mx.init.One()) + y = model(x).asnumpy() + + mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT, + rtol=1e-3, atol=1e-5) + def test_gru(): cell = gluon.rnn.GRUCell(100, prefix='rnn_')