From 9e48c783712640cb06f4fa0ced9ce340c71695b5 Mon Sep 17 00:00:00 2001 From: Jerryzcn Date: Thu, 28 Dec 2017 20:41:39 -0800 Subject: [PATCH 01/22] fix autograd import path --- example/gluon/style_transfer/main.py | 4 ++-- example/gluon/style_transfer/net.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/example/gluon/style_transfer/main.py b/example/gluon/style_transfer/main.py index 0ed2d831bbc0..fa21a3695de6 100644 --- a/example/gluon/style_transfer/main.py +++ b/example/gluon/style_transfer/main.py @@ -23,8 +23,8 @@ np.set_printoptions(precision=2) from PIL import Image -from mxnet import gluon -from mxnet.gluon import nn, autograd, Block, HybridBlock, Parameter, ParameterDict +from mxnet import autograd, gluon +from mxnet.gluon import nn, Block, HybridBlock, Parameter, ParameterDict import mxnet.ndarray as F import net diff --git a/example/gluon/style_transfer/net.py b/example/gluon/style_transfer/net.py index 353a52c66c0b..a33ce427fbdf 100644 --- a/example/gluon/style_transfer/net.py +++ b/example/gluon/style_transfer/net.py @@ -17,8 +17,8 @@ import numpy as np import mxnet as mx -from mxnet import gluon -from mxnet.gluon import nn, autograd, Block, HybridBlock, Parameter +from mxnet import autograd, gluon +from mxnet.gluon import nn, Block, HybridBlock, Parameter from mxnet.base import numeric_types import mxnet.ndarray as F From 1e8a7e7e076ad0644fd84d8080df276cb3dc6b31 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 2 Mar 2018 02:31:35 +0000 Subject: [PATCH 02/22] cpu lstm working --- python/mxnet/gluon/rnn/rnn_layer.py | 17 ++- src/operator/rnn-inl.h | 194 +++++++++++++++++++++++++++- 2 files changed, 203 insertions(+), 8 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 204f3c9bd507..3492a9caf356 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -185,15 +185,17 @@ def forward(self, inputs, states=None): for i in range(self._dir): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) self.i2h_weight[i]._finish_deferred_init() - if inputs.context.device_type == 'gpu': - out = self._forward_gpu(inputs, states) + import mxnet + if inputs.context.device_type == 'gpu' or not mxnet.autograd.is_training(): + out = self._forward_kernel(inputs, states) else: - out = self._forward_cpu(inputs, states) + out = self._forward(inputs, states) # out is (output, state) return out[0] if skip_states else out - def _forward_cpu(self, inputs, states): + def _forward(self, inputs, states): + """forward using gluon cell""" ns = len(states) axis = self._layout.find('T') states = sum(zip(*((j for j in i) for i in states)), ()) @@ -206,8 +208,9 @@ def _forward_cpu(self, inputs, states): new_states.append(state) return outputs, new_states - - def _forward_gpu(self, inputs, states): + + def _forward_kernel(self, inputs, states): + """ forward using CUDNN or CPU kenrel""" if self._layout == 'NTC': inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1) ctx = inputs.context @@ -215,7 +218,7 @@ def _forward_gpu(self, inputs, states): params += sum(zip(self.i2h_bias, self.h2h_bias), ()) params = (i.data(ctx).reshape((-1,)) for i in params) params = ndarray.concat(*params, dim=0) - + rnn = ndarray.RNN(inputs, params, *states, state_size=self._hidden_size, num_layers=self._num_layers, bidirectional=self._dir == 2, p=self._dropout, state_outputs=True, mode=self._mode) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index b4735b8eec64..f6d2be61fe6b 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -34,7 +34,11 @@ #include #include #include +#include "./math.h" +#include "./math_functions-inl.h" #include "./operator_common.h" +#include "./mshadow_op.h" +#include "./linalg.h" namespace mxnet { namespace op { @@ -120,7 +124,7 @@ struct RNNParam : public dmlc::Parameter { .describe("Whether to have the states as symbol outputs."); } }; - + template class RNNOp : public Operator { public: @@ -153,6 +157,194 @@ class RNNOp : public Operator { RNNParam param_; }; // class RNNOp +template +class RNNOp : public Operator { + public: + explicit RNNOp(RNNParam param) { + this->param_ = param; + // RNN Mode + switch (param_.mode) { + case rnn_enum::kLstm: + break; + default: + LOG(FATAL) << "Not implmented"; + } + if (param_.mode == rnn_enum::kLstm) + param_.lstm_q_ = true; + else + param_.lstm_q_ = false; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + // Layout TNC + using namespace mshadow; + using namespace mshadow::expr; + + size_t in_expected = param_.lstm_q_ ? 4 : 3; + size_t out_expected = param_.lstm_q_ ? 3 : 2; + + if (!param_.state_outputs) + LOG(FATAL) << "no state outputs is currently not supported for cpu."; + + CHECK_EQ(req[rnn_enum::kOut], kWriteTo); + CHECK_EQ(in_data.size(), in_expected); + CHECK_EQ(out_data.size(), out_expected); + + mshadow::Stream *s = ctx.get_stream(); + // get input + output tensors + // w layout i2h_w, h2h_w, i2h_b, h2h_b + Tensor x = in_data[rnn_enum::kData].get(s); // TNC + Tensor w = in_data[rnn_enum::kParams].get(s); + Tensor hx = in_data[rnn_enum::kState].get(s); // LNC + Tensor y = out_data[rnn_enum::kOut].get(s); // TNC + size_t seq_len = x.shape_[0]; + size_t num_layers = hx.shape_[0]; + size_t batch_size = x.shape_[1]; + size_t h_channel = hx.shape_[2]; + size_t in_channel = x.shape_[2]; + + + CHECK_EQ(x.CheckContiguous(), true); + CHECK_EQ(w.CheckContiguous(), true); + CHECK_EQ(hx.CheckContiguous(), true); + CHECK_EQ(y.CheckContiguous(), true); + + if (ctx.is_train) + LOG(FATAL) << "only inference mode is available for cpu at the moment."; + if (param_.lstm_q_) { + const size_t kNumMat = 4; + size_t fused_h_ch = kNumMat * h_channel; + size_t h_size = batch_size * fused_h_ch; + size_t num_dir = 1 + param_.bidirectional; + size_t h2h_w_size = h_channel * fused_h_ch; + + Tensor cx = in_data[rnn_enum::kStateCell].get(s); + CHECK_EQ(cx.CheckContiguous(), true); + + Tensor cy = out_data[rnn_enum::kStateCellOut].get(s); + Tensor hy = out_data[rnn_enum::kStateOut].get(s); + CHECK_EQ(cy.CheckContiguous(), true); + CHECK_EQ(hy.CheckContiguous(), true); + LOG(INFO) << "w size: " << w.shape_; + LOG(INFO) << "dropout: " << param_.p; + + DType* workspace_addr = + static_cast(ctx.requested[rnn_enum::kTempSpace] + .get_host_space_internal(sizeof(DType) * + (seq_len * h_size + h_size + + y.shape_[0] * y.shape_[1] * y.shape_[2]))); + Tensor i2h_y(workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch)); + Tensor h2h_y(workspace_addr + seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch)); + Tensor y_tmp(workspace_addr + (seq_len + 1) * h_size, y.shape_); + CHECK_EQ(i2h_y.CheckContiguous(), true); + CHECK_EQ(h2h_y.CheckContiguous(), true); + CHECK_EQ(y_tmp.CheckContiguous(), true); + + for (size_t layer = 0; layer < num_layers; layer++) { + int reverse_dir = 0; + int out_tmp = 0; + if (param_.bidirectional && layer % 2) + reverse_dir = 1; + if (layer / num_dir % 2 == 0) + out_tmp = 1; + mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch, (layer < num_dir) ? in_channel : num_dir * h_channel); + mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel); + size_t start = layer < num_dir ? + (layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer + (num_dir * (in_channel * fused_h_ch + h2h_w_size) + (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size)); + Tensor i2h_w(w.Slice(start, start + (layer < num_dir ? (in_channel * fused_h_ch) : num_dir * h2h_w_size)).dptr_, + i2h_w_shape); + start += layer < num_dir ? in_channel * fused_h_ch : h2h_w_size * num_dir; + Tensor h2h_w(w.Slice(start, start + h2h_w_size).dptr_, h2h_w_shape); + start = num_dir * (in_channel * fused_h_ch + h2h_w_size) + (num_layers - num_dir) * (h2h_w_size * (num_dir + 1)) // weight offset + + layer * fused_h_ch * 2; + Tensor i2h_b = w.Slice(start, start + fused_h_ch); + start += fused_h_ch; + Tensor h2h_b = w.Slice(start, start + fused_h_ch); + + for (size_t t = 0; t < seq_len; t++) { + size_t timestep = t; + if (reverse_dir) + timestep = seq_len - 1 - t; + if (out_tmp) { + linalg_gemm(layer < num_dir ? x[timestep]:y[timestep], i2h_w, i2h_y[timestep], false, true, s); + } else { + linalg_gemm(layer < num_dir ? x[timestep]:y_tmp[timestep], i2h_w, i2h_y[timestep], false, true, s); + } + linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y, false, true, s); + h2h_y += repmat(h2h_b, batch_size); + i2h_y[timestep] += repmat(i2h_b, batch_size); + // fused element-wise ops + LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y, y[timestep], out_tmp ? y_tmp[timestep]: y[timestep], + hy[layer], cy[layer], batch_size, h_channel, t, + reverse_dir, out_tmp && (layer == num_layers - 1)); + } + } + } else { + LOG(FATAL) << "only LSTM is available for cpu at the moment."; + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + // TODO: add implementation + LOG(FATAL) << "LSTM backward is not available for cpu at the moment."; + } + + private: + RNNParam param_; + + virtual void LSTMFusedElementWiseCPUOps(const Tensor &i2h_y, + const Tensor &cx, + const Tensor &h2h_y, + const Tensor &y, + const Tensor &tmp, // for holding intermediate layer output + const Tensor &hy, + const Tensor &cy, + const size_t batch_size, + const size_t h_channel, + const size_t t, + const int reverse_dir, + const int copy_tmp2y) { + size_t ji; + // #pragma omp parallel for private(ji) + for (ji = 0; ji < batch_size * h_channel; ji++) { + size_t j = ji / h_channel; // batch dim + size_t i = ji % h_channel; + size_t f = i + h_channel; + size_t c = i + h_channel * 2; + size_t o = i + h_channel * 3; + h2h_y[j][i] += i2h_y[j][i]; + h2h_y[j][f] += i2h_y[j][f]; + h2h_y[j][o] += i2h_y[j][o]; + h2h_y[j][c] += i2h_y[j][c]; + h2h_y[j][i] = 1.0f / (1.0f + math::exp(-h2h_y[j][i])); + h2h_y[j][f] = 1.0f / (1.0f + math::exp(-h2h_y[j][f])); + h2h_y[j][o] = 1.0f / (1.0f + math::exp(-h2h_y[j][o])); + h2h_y[j][c] = tanh(h2h_y[j][c]); + cy[j][i] = h2h_y[j][f] * (t == 0 ? cx[j][i]:cy[j][i]) + h2h_y[j][i] * h2h_y[j][c]; + hy[j][i] = h2h_y[j][o] * tanh(cy[j][i]); + tmp[j][i + h_channel * reverse_dir] = hy[j][i]; + if (copy_tmp2y) { + y[j][i] = tmp[j][i]; + if (reverse_dir) + y[j][i + h_channel] = tmp[j][i + h_channel]; + } + } + } +}; // class RNNOp + template Operator* CreateOp(RNNParam param, int dtype); From 6f209d2a712075ae3767d1f49c6eec0d57c9f9ce Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 2 Mar 2018 02:40:41 +0000 Subject: [PATCH 03/22] remove fatal log --- src/operator/rnn.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 908428b383ca..a60adbcd2fbc 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -30,7 +30,6 @@ namespace mxnet { namespace op { template<> Operator *CreateOp(RNNParam param, int dtype) { - LOG(FATAL) << "RNN is only available for gpu at the moment."; Operator *op = NULL; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { op = new RNNOp(param); From 7e930cbf7b81d0d21bc685275798440a436e24bb Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 2 Mar 2018 20:09:59 +0000 Subject: [PATCH 04/22] add simple unittest remove redundant log enable openmp --- src/operator/rnn-inl.h | 6 ++---- tests/python/unittest/test_gluon_rnn.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index f6d2be61fe6b..6f8fc46319ec 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -229,9 +229,7 @@ class RNNOp : public Operator { Tensor hy = out_data[rnn_enum::kStateOut].get(s); CHECK_EQ(cy.CheckContiguous(), true); CHECK_EQ(hy.CheckContiguous(), true); - LOG(INFO) << "w size: " << w.shape_; - LOG(INFO) << "dropout: " << param_.p; - + DType* workspace_addr = static_cast(ctx.requested[rnn_enum::kTempSpace] .get_host_space_internal(sizeof(DType) * @@ -318,7 +316,7 @@ class RNNOp : public Operator { const int reverse_dir, const int copy_tmp2y) { size_t ji; - // #pragma omp parallel for private(ji) + #pragma omp parallel for private(ji) for (ji = 0; ji < batch_size * h_channel; ji++) { size_t j = ji / h_channel; // batch dim size_t i = ji % h_channel; diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 228884219258..ec6d9c1cb57a 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -67,6 +67,22 @@ def test_lstm_forget_bias(): forget_bias * np.ones(100, ), np.zeros((2 * 100,))]) assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias) +def test_lstm_cpu_inference(): + # should behave the same as lstm cell + atol = 1e-6 + x = mx.nd.ones(shape=(2, 2, 2)) + model = mx.gluon.nn.Sequential() + with model.name_scope(): + model.add(mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True)) + model.initialize(mx.init.One()) + y = model(x).asnumpy() + + mx.test_utils.assert_almost_equal(y, np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], + [0.72045636, 0.72045636, 0.95215213, 0.95215213]], + [[0.95215213, 0.95215213, 0.72045636, 0.72045636], + [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]), + rtol=1e-3, atol=1e-5) + def test_gru(): cell = gluon.rnn.GRUCell(100, prefix='rnn_') From 19c85aa769d85b2bd3ee0e7d097a84ae71c0dde5 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 2 Mar 2018 23:55:21 +0000 Subject: [PATCH 05/22] fused input2hidden gemm --- src/operator/rnn-inl.h | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 6f8fc46319ec..46b503252bc4 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -206,6 +206,10 @@ class RNNOp : public Operator { size_t batch_size = x.shape_[1]; size_t h_channel = hx.shape_[2]; size_t in_channel = x.shape_[2]; + Tensor x_flatten = in_data[rnn_enum::kData] + .get_with_shape(mshadow::Shape2(seq_len * batch_size, in_channel), s); // (T*N)C + Tensor y_flatten = out_data[rnn_enum::kOut] + .get_with_shape(mshadow::Shape2(y.shape_[0] * y.shape_[1], y.shape_[2]), s); // (T*N)C CHECK_EQ(x.CheckContiguous(), true); @@ -236,8 +240,10 @@ class RNNOp : public Operator { (seq_len * h_size + h_size + y.shape_[0] * y.shape_[1] * y.shape_[2]))); Tensor i2h_y(workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch)); + Tensor i2h_y_flatten(workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch)); Tensor h2h_y(workspace_addr + seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch)); Tensor y_tmp(workspace_addr + (seq_len + 1) * h_size, y.shape_); + Tensor y_flatten_tmp(workspace_addr + (seq_len + 1) * h_size, y_flatten.shape_); CHECK_EQ(i2h_y.CheckContiguous(), true); CHECK_EQ(h2h_y.CheckContiguous(), true); CHECK_EQ(y_tmp.CheckContiguous(), true); @@ -263,19 +269,18 @@ class RNNOp : public Operator { Tensor i2h_b = w.Slice(start, start + fused_h_ch); start += fused_h_ch; Tensor h2h_b = w.Slice(start, start + fused_h_ch); - + if (out_tmp) { + linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w, i2h_y_flatten, false, true, s); + } else { + linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w, i2h_y_flatten, false, true, s); + } + i2h_y_flatten += repmat(i2h_b, seq_len * batch_size); for (size_t t = 0; t < seq_len; t++) { size_t timestep = t; if (reverse_dir) - timestep = seq_len - 1 - t; - if (out_tmp) { - linalg_gemm(layer < num_dir ? x[timestep]:y[timestep], i2h_w, i2h_y[timestep], false, true, s); - } else { - linalg_gemm(layer < num_dir ? x[timestep]:y_tmp[timestep], i2h_w, i2h_y[timestep], false, true, s); - } + timestep = seq_len - 1 - t; linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y, false, true, s); h2h_y += repmat(h2h_b, batch_size); - i2h_y[timestep] += repmat(i2h_b, batch_size); // fused element-wise ops LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y, y[timestep], out_tmp ? y_tmp[timestep]: y[timestep], hy[layer], cy[layer], batch_size, h_channel, t, From 7c84239c379d6acc89442569392f20c70a7f116d Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 5 Mar 2018 21:50:51 +0000 Subject: [PATCH 06/22] fix lint --- src/operator/rnn-inl.h | 228 +++++++++++++++++++++++------------------ 1 file changed, 127 insertions(+), 101 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 46b503252bc4..95239f2f7bc9 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -82,10 +82,12 @@ inline int rnn_param_size(int layerNum, int size = rnn_single_param_size(inputSize, hiddenSize, mode); // get size of remaining layers if (bidirectional) { - size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize, hiddenSize, mode); + size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize, + hiddenSize, mode); size *= 2; } else { - size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, mode); + size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, + mode); } return size; } @@ -118,13 +120,14 @@ struct RNNParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(p).set_default(0.) .set_range(0, 1) - .describe("Dropout probability, fraction of the input that gets dropped out at training time"); + .describe("Dropout probability, fraction of the input that gets dropped" + "out at training time"); DMLC_DECLARE_FIELD(state_outputs).set_default(false) .describe("Whether to have the states as symbol outputs."); } }; - + template class RNNOp : public Operator { public: @@ -136,8 +139,6 @@ class RNNOp : public Operator { const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; // TODO(sbodenstein): add MShadow implementation } @@ -148,8 +149,6 @@ class RNNOp : public Operator { const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; // TODO(sbodenstein): add MShadow implementation } @@ -181,9 +180,7 @@ class RNNOp : public Operator { const std::vector &out_data, const std::vector &aux_args) { // Layout TNC - using namespace mshadow; - using namespace mshadow::expr; - + size_t in_expected = param_.lstm_q_ ? 4 : 3; size_t out_expected = param_.lstm_q_ ? 3 : 2; @@ -193,30 +190,35 @@ class RNNOp : public Operator { CHECK_EQ(req[rnn_enum::kOut], kWriteTo); CHECK_EQ(in_data.size(), in_expected); CHECK_EQ(out_data.size(), out_expected); - + mshadow::Stream *s = ctx.get_stream(); // get input + output tensors // w layout i2h_w, h2h_w, i2h_b, h2h_b - Tensor x = in_data[rnn_enum::kData].get(s); // TNC + Tensor x = + in_data[rnn_enum::kData].get(s); // TNC Tensor w = in_data[rnn_enum::kParams].get(s); - Tensor hx = in_data[rnn_enum::kState].get(s); // LNC - Tensor y = out_data[rnn_enum::kOut].get(s); // TNC + Tensor hx = + in_data[rnn_enum::kState].get(s); // LNC + Tensor y = + out_data[rnn_enum::kOut].get(s); // TNC size_t seq_len = x.shape_[0]; size_t num_layers = hx.shape_[0]; size_t batch_size = x.shape_[1]; size_t h_channel = hx.shape_[2]; size_t in_channel = x.shape_[2]; Tensor x_flatten = in_data[rnn_enum::kData] - .get_with_shape(mshadow::Shape2(seq_len * batch_size, in_channel), s); // (T*N)C + .get_with_shape( + mshadow::Shape2(seq_len * batch_size, in_channel), s); // (T*N)C Tensor y_flatten = out_data[rnn_enum::kOut] - .get_with_shape(mshadow::Shape2(y.shape_[0] * y.shape_[1], y.shape_[2]), s); // (T*N)C - + .get_with_shape( + mshadow::Shape2( + y.shape_[0] * y.shape_[1], y.shape_[2]), s); // (T*N)C CHECK_EQ(x.CheckContiguous(), true); CHECK_EQ(w.CheckContiguous(), true); CHECK_EQ(hx.CheckContiguous(), true); CHECK_EQ(y.CheckContiguous(), true); - + if (ctx.is_train) LOG(FATAL) << "only inference mode is available for cpu at the moment."; if (param_.lstm_q_) { @@ -225,67 +227,85 @@ class RNNOp : public Operator { size_t h_size = batch_size * fused_h_ch; size_t num_dir = 1 + param_.bidirectional; size_t h2h_w_size = h_channel * fused_h_ch; - - Tensor cx = in_data[rnn_enum::kStateCell].get(s); + + Tensor cx = + in_data[rnn_enum::kStateCell].get(s); CHECK_EQ(cx.CheckContiguous(), true); - Tensor cy = out_data[rnn_enum::kStateCellOut].get(s); - Tensor hy = out_data[rnn_enum::kStateOut].get(s); + Tensor cy = + out_data[rnn_enum::kStateCellOut].get(s); + Tensor hy = + out_data[rnn_enum::kStateOut].get(s); CHECK_EQ(cy.CheckContiguous(), true); CHECK_EQ(hy.CheckContiguous(), true); - + DType* workspace_addr = - static_cast(ctx.requested[rnn_enum::kTempSpace] - .get_host_space_internal(sizeof(DType) * - (seq_len * h_size + h_size - + y.shape_[0] * y.shape_[1] * y.shape_[2]))); - Tensor i2h_y(workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch)); - Tensor i2h_y_flatten(workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch)); - Tensor h2h_y(workspace_addr + seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch)); - Tensor y_tmp(workspace_addr + (seq_len + 1) * h_size, y.shape_); - Tensor y_flatten_tmp(workspace_addr + (seq_len + 1) * h_size, y_flatten.shape_); + static_cast(ctx.requested[rnn_enum::kTempSpace] + .get_host_space_internal(sizeof(DType) * + (seq_len * h_size + h_size + + y.shape_[0] * y.shape_[1] * y.shape_[2]))); + Tensor i2h_y( + workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch)); + Tensor i2h_y_flatten( + workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch)); + Tensor h2h_y(workspace_addr + + seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch)); + Tensor y_tmp(workspace_addr + + (seq_len + 1) * h_size, y.shape_); + Tensor y_flatten_tmp(workspace_addr + + (seq_len + 1) * h_size, y_flatten.shape_); CHECK_EQ(i2h_y.CheckContiguous(), true); CHECK_EQ(h2h_y.CheckContiguous(), true); CHECK_EQ(y_tmp.CheckContiguous(), true); for (size_t layer = 0; layer < num_layers; layer++) { - int reverse_dir = 0; - int out_tmp = 0; - if (param_.bidirectional && layer % 2) - reverse_dir = 1; - if (layer / num_dir % 2 == 0) - out_tmp = 1; - mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch, (layer < num_dir) ? in_channel : num_dir * h_channel); - mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel); - size_t start = layer < num_dir ? - (layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer - (num_dir * (in_channel * fused_h_ch + h2h_w_size) + (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size)); - Tensor i2h_w(w.Slice(start, start + (layer < num_dir ? (in_channel * fused_h_ch) : num_dir * h2h_w_size)).dptr_, - i2h_w_shape); - start += layer < num_dir ? in_channel * fused_h_ch : h2h_w_size * num_dir; - Tensor h2h_w(w.Slice(start, start + h2h_w_size).dptr_, h2h_w_shape); - start = num_dir * (in_channel * fused_h_ch + h2h_w_size) + (num_layers - num_dir) * (h2h_w_size * (num_dir + 1)) // weight offset - + layer * fused_h_ch * 2; - Tensor i2h_b = w.Slice(start, start + fused_h_ch); - start += fused_h_ch; - Tensor h2h_b = w.Slice(start, start + fused_h_ch); - if (out_tmp) { - linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w, i2h_y_flatten, false, true, s); - } else { - linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w, i2h_y_flatten, false, true, s); - } - i2h_y_flatten += repmat(i2h_b, seq_len * batch_size); - for (size_t t = 0; t < seq_len; t++) { - size_t timestep = t; - if (reverse_dir) - timestep = seq_len - 1 - t; - linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y, false, true, s); - h2h_y += repmat(h2h_b, batch_size); - // fused element-wise ops - LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y, y[timestep], out_tmp ? y_tmp[timestep]: y[timestep], - hy[layer], cy[layer], batch_size, h_channel, t, - reverse_dir, out_tmp && (layer == num_layers - 1)); - } + int reverse_dir = 0; + int out_tmp = 0; + if (param_.bidirectional && layer % 2) + reverse_dir = 1; + if (layer / num_dir % 2 == 0) + out_tmp = 1; + mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch, + (layer < num_dir) ? in_channel : num_dir * h_channel); + mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel); + size_t start = layer < num_dir ? + (layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer + (num_dir * (in_channel * fused_h_ch + h2h_w_size) + + (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size)); + Tensor i2h_w(w.Slice(start, start + (layer < num_dir ? + (in_channel * fused_h_ch) : num_dir * h2h_w_size)).dptr_, + i2h_w_shape); + start += layer < num_dir ? + in_channel * fused_h_ch : h2h_w_size * num_dir; + Tensor h2h_w(w.Slice(start, start + h2h_w_size).dptr_, + h2h_w_shape); + start = num_dir * (in_channel * fused_h_ch + h2h_w_size) + + (num_layers - num_dir) * (h2h_w_size * (num_dir + 1)) + + layer * fused_h_ch * 2; + Tensor i2h_b = w.Slice(start, start + fused_h_ch); + start += fused_h_ch; + Tensor h2h_b = w.Slice(start, start + fused_h_ch); + if (out_tmp) { + linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w, + i2h_y_flatten, false, true, s); + } else { + linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w, + i2h_y_flatten, false, true, s); + } + i2h_y_flatten += repmat(i2h_b, seq_len * batch_size); + for (size_t t = 0; t < seq_len; t++) { + size_t timestep = t; + if (reverse_dir) + timestep = seq_len - 1 - t; + linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y, + false, true, s); + h2h_y += repmat(h2h_b, batch_size); + // fused element-wise ops + LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y, + y[timestep], out_tmp ? y_tmp[timestep]: y[timestep], + hy[layer], cy[layer], batch_size, h_channel, t, + reverse_dir, out_tmp && (layer == num_layers - 1)); + } } } else { LOG(FATAL) << "only LSTM is available for cpu at the moment."; @@ -295,13 +315,10 @@ class RNNOp : public Operator { virtual void Backward(const OpContext &ctx, const std::vector &out_grad, const std::vector &in_data, - const std::vector &out_data, + const std::vector &out_data, const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - // TODO: add implementation LOG(FATAL) << "LSTM backward is not available for cpu at the moment."; } @@ -309,21 +326,22 @@ class RNNOp : public Operator { RNNParam param_; virtual void LSTMFusedElementWiseCPUOps(const Tensor &i2h_y, - const Tensor &cx, - const Tensor &h2h_y, - const Tensor &y, - const Tensor &tmp, // for holding intermediate layer output - const Tensor &hy, - const Tensor &cy, - const size_t batch_size, - const size_t h_channel, - const size_t t, - const int reverse_dir, - const int copy_tmp2y) { + const Tensor &cx, + const Tensor &h2h_y, + const Tensor &y, + // holding intermediate layer output + const Tensor &tmp, + const Tensor &hy, + const Tensor &cy, + const size_t batch_size, + const size_t h_channel, + const size_t t, + const int reverse_dir, + const int copy_tmp2y) { size_t ji; #pragma omp parallel for private(ji) for (ji = 0; ji < batch_size * h_channel; ji++) { - size_t j = ji / h_channel; // batch dim + size_t j = ji / h_channel; // batch dim size_t i = ji % h_channel; size_t f = i + h_channel; size_t c = i + h_channel * 2; @@ -336,18 +354,19 @@ class RNNOp : public Operator { h2h_y[j][f] = 1.0f / (1.0f + math::exp(-h2h_y[j][f])); h2h_y[j][o] = 1.0f / (1.0f + math::exp(-h2h_y[j][o])); h2h_y[j][c] = tanh(h2h_y[j][c]); - cy[j][i] = h2h_y[j][f] * (t == 0 ? cx[j][i]:cy[j][i]) + h2h_y[j][i] * h2h_y[j][c]; + cy[j][i] = h2h_y[j][f] * (t == 0 ? cx[j][i]:cy[j][i]) + + h2h_y[j][i] * h2h_y[j][c]; hy[j][i] = h2h_y[j][o] * tanh(cy[j][i]); tmp[j][i + h_channel * reverse_dir] = hy[j][i]; if (copy_tmp2y) { - y[j][i] = tmp[j][i]; - if (reverse_dir) - y[j][i + h_channel] = tmp[j][i + h_channel]; + y[j][i] = tmp[j][i]; + if (reverse_dir) + y[j][i + h_channel] = tmp[j][i + h_channel]; } } } }; // class RNNOp - + template Operator* CreateOp(RNNParam param, int dtype); @@ -379,7 +398,8 @@ class RNNProp : public OperatorProperty { return num_outputs; } - void Init(const std::vector >& kwargs) override { + void Init(const std::vector >& kwargs) + override { param_.Init(kwargs); } @@ -390,28 +410,33 @@ class RNNProp : public OperatorProperty { bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { - using namespace mshadow; if (param_.mode == rnn_enum::kLstm) { - CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]"; + CHECK_EQ(in_shape->size(), 4U) << + "Input:[data, parameters, state, cell_state]"; } else { CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]"; } const TShape &dshape = (*in_shape)[rnn_enum::kData]; if (dshape.ndim() == 0) return false; - CHECK_EQ(dshape.ndim(), 3U) \ - << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; + CHECK_EQ(dshape.ndim(), 3U) + << "Input data should be rank-3 tensor of dim [sequence length, " + << "batch size, input size]"; // data: [sequence len, batch, input dimension] int batch_size = dshape[1]; int input_size = dshape[2]; int numDirections = param_.bidirectional ? 2 : 1; - int total_layers = numDirections * param_.num_layers; // double for bidirectional + // double for bidirectional + int total_layers = numDirections * param_.num_layers; + SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kState, - Shape3(total_layers, batch_size, param_.state_size)); + mshadow::Shape3(total_layers, batch_size, + param_.state_size)); if (param_.mode == rnn_enum::kLstm) SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kStateCell, - Shape3(total_layers, batch_size, param_.state_size)); + mshadow::Shape3(total_layers, batch_size, + param_.state_size)); // calculate parameter vector length int param_size = rnn_param_size(param_.num_layers, @@ -482,8 +507,9 @@ class RNNProp : public OperatorProperty { const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { - std::vector dep = {in_data[rnn_enum::kData], in_data[rnn_enum::kParams], - in_data[rnn_enum::kState], out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]}; + std::vector dep = {in_data[rnn_enum::kData], + in_data[rnn_enum::kParams], in_data[rnn_enum::kState], + out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]}; if (param_.state_outputs) { dep.push_back(out_data[rnn_enum::kStateOut]); From 73a632b787b94d0c06a2dd55408790bb45c3576c Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 5 Mar 2018 21:58:38 +0000 Subject: [PATCH 07/22] fix pylint --- python/mxnet/gluon/data/dataloader.py | 1 + python/mxnet/gluon/rnn/rnn_layer.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 02aa5c041433..eb482aa56e3e 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -214,6 +214,7 @@ def __iter__(self): worker.start() workers.append(worker) + idx = -1 for idx, batch in enumerate(self._batch_sampler): key_queue.put((idx, batch)) num_batches = idx + 1 diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 3492a9caf356..da7a6b48582d 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -208,7 +208,7 @@ def _forward(self, inputs, states): new_states.append(state) return outputs, new_states - + def _forward_kernel(self, inputs, states): """ forward using CUDNN or CPU kenrel""" if self._layout == 'NTC': @@ -218,7 +218,7 @@ def _forward_kernel(self, inputs, states): params += sum(zip(self.i2h_bias, self.h2h_bias), ()) params = (i.data(ctx).reshape((-1,)) for i in params) params = ndarray.concat(*params, dim=0) - + rnn = ndarray.RNN(inputs, params, *states, state_size=self._hidden_size, num_layers=self._num_layers, bidirectional=self._dir == 2, p=self._dropout, state_outputs=True, mode=self._mode) From 417552f6ba841c1c8d91f6478f7d0ef139a8f333 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 5 Mar 2018 22:32:43 +0000 Subject: [PATCH 08/22] fix windows build error --- src/operator/rnn-inl.h | 44 +++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 95239f2f7bc9..1479e7a6eaa8 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -201,11 +201,11 @@ class RNNOp : public Operator { in_data[rnn_enum::kState].get(s); // LNC Tensor y = out_data[rnn_enum::kOut].get(s); // TNC - size_t seq_len = x.shape_[0]; - size_t num_layers = hx.shape_[0]; - size_t batch_size = x.shape_[1]; - size_t h_channel = hx.shape_[2]; - size_t in_channel = x.shape_[2]; + int64_t seq_len = x.shape_[0]; + int64_t num_layers = hx.shape_[0]; + int64_t batch_size = x.shape_[1]; + int64_t h_channel = hx.shape_[2]; + int64_t in_channel = x.shape_[2]; Tensor x_flatten = in_data[rnn_enum::kData] .get_with_shape( mshadow::Shape2(seq_len * batch_size, in_channel), s); // (T*N)C @@ -223,10 +223,10 @@ class RNNOp : public Operator { LOG(FATAL) << "only inference mode is available for cpu at the moment."; if (param_.lstm_q_) { const size_t kNumMat = 4; - size_t fused_h_ch = kNumMat * h_channel; - size_t h_size = batch_size * fused_h_ch; - size_t num_dir = 1 + param_.bidirectional; - size_t h2h_w_size = h_channel * fused_h_ch; + int64_t fused_h_ch = kNumMat * h_channel; + int64_t h_size = batch_size * fused_h_ch; + int64_t num_dir = 1 + param_.bidirectional; + int64_t h2h_w_size = h_channel * fused_h_ch; Tensor cx = in_data[rnn_enum::kStateCell].get(s); @@ -258,7 +258,7 @@ class RNNOp : public Operator { CHECK_EQ(h2h_y.CheckContiguous(), true); CHECK_EQ(y_tmp.CheckContiguous(), true); - for (size_t layer = 0; layer < num_layers; layer++) { + for (int64_t layer = 0; layer < num_layers; layer++) { int reverse_dir = 0; int out_tmp = 0; if (param_.bidirectional && layer % 2) @@ -268,7 +268,7 @@ class RNNOp : public Operator { mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch, (layer < num_dir) ? in_channel : num_dir * h_channel); mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel); - size_t start = layer < num_dir ? + int64_t start = layer < num_dir ? (layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer (num_dir * (in_channel * fused_h_ch + h2h_w_size) + (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size)); @@ -293,8 +293,8 @@ class RNNOp : public Operator { i2h_y_flatten, false, true, s); } i2h_y_flatten += repmat(i2h_b, seq_len * batch_size); - for (size_t t = 0; t < seq_len; t++) { - size_t timestep = t; + for (int64_t t = 0; t < seq_len; t++) { + int64_t timestep = t; if (reverse_dir) timestep = seq_len - 1 - t; linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y, @@ -333,19 +333,19 @@ class RNNOp : public Operator { const Tensor &tmp, const Tensor &hy, const Tensor &cy, - const size_t batch_size, - const size_t h_channel, - const size_t t, + const int64_t batch_size, + const int64_t h_channel, + const int64_t t, const int reverse_dir, const int copy_tmp2y) { - size_t ji; + int64_t ji; #pragma omp parallel for private(ji) for (ji = 0; ji < batch_size * h_channel; ji++) { - size_t j = ji / h_channel; // batch dim - size_t i = ji % h_channel; - size_t f = i + h_channel; - size_t c = i + h_channel * 2; - size_t o = i + h_channel * 3; + int64_t j = ji / h_channel; // batch dim + int64_t i = ji % h_channel; + int64_t f = i + h_channel; + int64_t c = i + h_channel * 2; + int64_t o = i + h_channel * 3; h2h_y[j][i] += i2h_y[j][i]; h2h_y[j][f] += i2h_y[j][f]; h2h_y[j][o] += i2h_y[j][o]; From 3f0661843b6ea39998998811521ffd73cd5ce9a4 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 5 Mar 2018 23:40:33 +0000 Subject: [PATCH 09/22] fix gluon rnn interface --- python/mxnet/gluon/rnn/rnn_layer.py | 3 ++- src/operator/rnn-inl.h | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index da7a6b48582d..e56c80d0a6c8 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -186,7 +186,8 @@ def forward(self, inputs, states=None): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) self.i2h_weight[i]._finish_deferred_init() import mxnet - if inputs.context.device_type == 'gpu' or not mxnet.autograd.is_training(): + if inputs.context.device_type == 'gpu' or \ + (not mxnet.autograd.is_training() and self._mode == 'lstm'): out = self._forward_kernel(inputs, states) else: out = self._forward(inputs, states) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 1479e7a6eaa8..5c6b81d495ff 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -166,7 +166,7 @@ class RNNOp : public Operator { case rnn_enum::kLstm: break; default: - LOG(FATAL) << "Not implmented"; + LOG(FATAL) << "only LSTM is implmented on CPU"; } if (param_.mode == rnn_enum::kLstm) param_.lstm_q_ = true; From e7e67af693a376862dbcf72b01714d333945c5ec Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 6 Mar 2018 17:17:18 -0800 Subject: [PATCH 10/22] Update dataloader.py --- python/mxnet/gluon/data/dataloader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index eb482aa56e3e..02aa5c041433 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -214,7 +214,6 @@ def __iter__(self): worker.start() workers.append(worker) - idx = -1 for idx, batch in enumerate(self._batch_sampler): key_queue.put((idx, batch)) num_batches = idx + 1 From df2f836edd6bb39b3250f426f3f3b993ae3813b8 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 7 Mar 2018 21:05:07 +0000 Subject: [PATCH 11/22] address cr --- src/operator/rnn-inl.h | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 5c6b81d495ff..225fff6fb9e3 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -82,12 +82,12 @@ inline int rnn_param_size(int layerNum, int size = rnn_single_param_size(inputSize, hiddenSize, mode); // get size of remaining layers if (bidirectional) { - size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize, - hiddenSize, mode); + size += (layerNum - 1) * rnn_single_param_size( + 2 * hiddenSize, hiddenSize, mode); size *= 2; } else { - size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, - mode); + size += (layerNum - 1) * rnn_single_param_size( + hiddenSize, hiddenSize, mode); } return size; } @@ -219,8 +219,9 @@ class RNNOp : public Operator { CHECK_EQ(hx.CheckContiguous(), true); CHECK_EQ(y.CheckContiguous(), true); - if (ctx.is_train) - LOG(FATAL) << "only inference mode is available for cpu at the moment."; + CHECK(!ctx.is_train) << "only inference mode is available" + "for cpu at the moment."; + if (param_.lstm_q_) { const size_t kNumMat = 4; int64_t fused_h_ch = kNumMat * h_channel; @@ -508,8 +509,8 @@ class RNNProp : public OperatorProperty { const std::vector &in_data, const std::vector &out_data) const override { std::vector dep = {in_data[rnn_enum::kData], - in_data[rnn_enum::kParams], in_data[rnn_enum::kState], - out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]}; + in_data[rnn_enum::kParams], in_data[rnn_enum::kState], + out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]}; if (param_.state_outputs) { dep.push_back(out_data[rnn_enum::kStateOut]); From b8ca9c8479cf30ffe7c7eacc9795abd2e1c52ba7 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 7 Mar 2018 21:15:09 +0000 Subject: [PATCH 12/22] address cr --- src/operator/rnn-inl.h | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 225fff6fb9e3..d56a288d811f 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -162,16 +162,14 @@ class RNNOp : public Operator { explicit RNNOp(RNNParam param) { this->param_ = param; // RNN Mode + param_.lstm_q_ = false; switch (param_.mode) { case rnn_enum::kLstm: + param_.lstm_q_ = true; break; default: LOG(FATAL) << "only LSTM is implmented on CPU"; } - if (param_.mode == rnn_enum::kLstm) - param_.lstm_q_ = true; - else - param_.lstm_q_ = false; } virtual void Forward(const OpContext &ctx, From 9b919afdb05e572e75c57bb8347b1f956bb0f8fc Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Thu, 8 Mar 2018 23:22:43 +0000 Subject: [PATCH 13/22] fix import --- python/mxnet/gluon/rnn/rnn_layer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index e56c80d0a6c8..b11e767ebfb2 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -23,6 +23,8 @@ from __future__ import print_function __all__ = ['RNN', 'LSTM', 'GRU'] + +from ...autograd import is_training from ... import ndarray from .. import Block from . import rnn_cell @@ -185,9 +187,9 @@ def forward(self, inputs, states=None): for i in range(self._dir): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) self.i2h_weight[i]._finish_deferred_init() - import mxnet + if inputs.context.device_type == 'gpu' or \ - (not mxnet.autograd.is_training() and self._mode == 'lstm'): + (not is_training() and self._mode == 'lstm'): out = self._forward_kernel(inputs, states) else: out = self._forward(inputs, states) From 110010d2e5c6c6c941a875ebce88512050a36f5f Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 9 Mar 2018 06:31:29 +0000 Subject: [PATCH 14/22] revert some cosmetic change --- src/operator/rnn-inl.h | 47 +++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index d56a288d811f..eac7ba44efcb 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -120,8 +120,7 @@ struct RNNParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(p).set_default(0.) .set_range(0, 1) - .describe("Dropout probability, fraction of the input that gets dropped" - "out at training time"); + .describe("Dropout probability, fraction of the input that gets dropped out at training time"); DMLC_DECLARE_FIELD(state_outputs).set_default(false) .describe("Whether to have the states as symbol outputs."); @@ -139,6 +138,8 @@ class RNNOp : public Operator { const std::vector &req, const std::vector &out_data, const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; // TODO(sbodenstein): add MShadow implementation } @@ -149,6 +150,8 @@ class RNNOp : public Operator { const std::vector &req, const std::vector &in_grad, const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; // TODO(sbodenstein): add MShadow implementation } @@ -178,7 +181,8 @@ class RNNOp : public Operator { const std::vector &out_data, const std::vector &aux_args) { // Layout TNC - + CHECK(!ctx.is_train) << "only inference mode is available" + "for cpu at the moment."; size_t in_expected = param_.lstm_q_ ? 4 : 3; size_t out_expected = param_.lstm_q_ ? 3 : 2; @@ -212,13 +216,10 @@ class RNNOp : public Operator { mshadow::Shape2( y.shape_[0] * y.shape_[1], y.shape_[2]), s); // (T*N)C - CHECK_EQ(x.CheckContiguous(), true); - CHECK_EQ(w.CheckContiguous(), true); - CHECK_EQ(hx.CheckContiguous(), true); - CHECK_EQ(y.CheckContiguous(), true); - - CHECK(!ctx.is_train) << "only inference mode is available" - "for cpu at the moment."; + CHECK(x.CheckContiguous()); + CHECK(w.CheckContiguous()); + CHECK(hx.CheckContiguous()); + CHECK(y.CheckContiguous()); if (param_.lstm_q_) { const size_t kNumMat = 4; @@ -397,8 +398,7 @@ class RNNProp : public OperatorProperty { return num_outputs; } - void Init(const std::vector >& kwargs) - override { + void Init(const std::vector >& kwargs) override { param_.Init(kwargs); } @@ -409,33 +409,29 @@ class RNNProp : public OperatorProperty { bool InferShape(std::vector *in_shape, std::vector *out_shape, std::vector *aux_shape) const override { + using namespace mshadow; if (param_.mode == rnn_enum::kLstm) { - CHECK_EQ(in_shape->size(), 4U) << - "Input:[data, parameters, state, cell_state]"; + CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]"; } else { CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]"; } const TShape &dshape = (*in_shape)[rnn_enum::kData]; if (dshape.ndim() == 0) return false; - CHECK_EQ(dshape.ndim(), 3U) - << "Input data should be rank-3 tensor of dim [sequence length, " - << "batch size, input size]"; + CHECK_EQ(dshape.ndim(), 3U) \ + << "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]"; // data: [sequence len, batch, input dimension] int batch_size = dshape[1]; int input_size = dshape[2]; int numDirections = param_.bidirectional ? 2 : 1; - // double for bidirectional - int total_layers = numDirections * param_.num_layers; + int total_layers = numDirections * param_.num_layers; // double for bibdirectional SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kState, - mshadow::Shape3(total_layers, batch_size, - param_.state_size)); + Shape3(total_layers, batch_size, param_.state_size)); if (param_.mode == rnn_enum::kLstm) SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kStateCell, - mshadow::Shape3(total_layers, batch_size, - param_.state_size)); + Shape3(total_layers, batch_size, param_.state_size)); // calculate parameter vector length int param_size = rnn_param_size(param_.num_layers, @@ -506,9 +502,8 @@ class RNNProp : public OperatorProperty { const std::vector &out_grad, const std::vector &in_data, const std::vector &out_data) const override { - std::vector dep = {in_data[rnn_enum::kData], - in_data[rnn_enum::kParams], in_data[rnn_enum::kState], - out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]}; + std::vector dep = {in_data[rnn_enum::kData], in_data[rnn_enum::kParams], + in_data[rnn_enum::kState], out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]}; if (param_.state_outputs) { dep.push_back(out_data[rnn_enum::kStateOut]); From f346598bfc069f94e24dbd6224e4f40105b9bea3 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 9 Mar 2018 06:46:16 +0000 Subject: [PATCH 15/22] fix typo --- src/operator/rnn-inl.h | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index eac7ba44efcb..eac7af67cf7e 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -82,12 +82,10 @@ inline int rnn_param_size(int layerNum, int size = rnn_single_param_size(inputSize, hiddenSize, mode); // get size of remaining layers if (bidirectional) { - size += (layerNum - 1) * rnn_single_param_size( - 2 * hiddenSize, hiddenSize, mode); + size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize, hiddenSize, mode); size *= 2; } else { - size += (layerNum - 1) * rnn_single_param_size( - hiddenSize, hiddenSize, mode); + size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, mode); } return size; } @@ -423,7 +421,7 @@ class RNNProp : public OperatorProperty { int batch_size = dshape[1]; int input_size = dshape[2]; int numDirections = param_.bidirectional ? 2 : 1; - int total_layers = numDirections * param_.num_layers; // double for bibdirectional + int total_layers = numDirections * param_.num_layers; // double for bidirectional SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kState, From 9de6bf63e6fdda5e088928b98c213521dd29dd3a Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 9 Mar 2018 06:49:40 +0000 Subject: [PATCH 16/22] remove newline --- python/mxnet/gluon/rnn/rnn_layer.py | 2 -- src/operator/rnn-inl.h | 1 - 2 files changed, 3 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index b11e767ebfb2..2fac39923c7d 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -23,7 +23,6 @@ from __future__ import print_function __all__ = ['RNN', 'LSTM', 'GRU'] - from ...autograd import is_training from ... import ndarray from .. import Block @@ -187,7 +186,6 @@ def forward(self, inputs, states=None): for i in range(self._dir): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) self.i2h_weight[i]._finish_deferred_init() - if inputs.context.device_type == 'gpu' or \ (not is_training() and self._mode == 'lstm'): out = self._forward_kernel(inputs, states) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index eac7af67cf7e..99fc0c9f5831 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -422,7 +422,6 @@ class RNNProp : public OperatorProperty { int input_size = dshape[2]; int numDirections = param_.bidirectional ? 2 : 1; int total_layers = numDirections * param_.num_layers; // double for bidirectional - SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kState, Shape3(total_layers, batch_size, param_.state_size)); From f41d8efbd4d3f6adac89f62b166a8b919d82e233 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 9 Mar 2018 09:12:03 +0000 Subject: [PATCH 17/22] rm virtual mv hardcoded number to constant --- src/operator/rnn-inl.h | 26 ++++++++++++------------- tests/python/unittest/test_gluon_rnn.py | 9 +++++---- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 99fc0c9f5831..a9509094241a 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -323,19 +323,19 @@ class RNNOp : public Operator { private: RNNParam param_; - virtual void LSTMFusedElementWiseCPUOps(const Tensor &i2h_y, - const Tensor &cx, - const Tensor &h2h_y, - const Tensor &y, - // holding intermediate layer output - const Tensor &tmp, - const Tensor &hy, - const Tensor &cy, - const int64_t batch_size, - const int64_t h_channel, - const int64_t t, - const int reverse_dir, - const int copy_tmp2y) { + void LSTMFusedElementWiseCPUOps(const Tensor &i2h_y, + const Tensor &cx, + const Tensor &h2h_y, + const Tensor &y, + // holding intermediate layer output + const Tensor &tmp, + const Tensor &hy, + const Tensor &cy, + const int64_t batch_size, + const int64_t h_channel, + const int64_t t, + const int reverse_dir, + const int copy_tmp2y) { int64_t ji; #pragma omp parallel for private(ji) for (ji = 0; ji < batch_size * h_channel; ji++) { diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index ec6d9c1cb57a..3e37a3ac7547 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -67,6 +67,10 @@ def test_lstm_forget_bias(): forget_bias * np.ones(100, ), np.zeros((2 * 100,))]) assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias) +EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], + [0.72045636, 0.72045636, 0.95215213, 0.95215213]], + [[0.95215213, 0.95215213, 0.72045636, 0.72045636], + [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]) def test_lstm_cpu_inference(): # should behave the same as lstm cell atol = 1e-6 @@ -77,10 +81,7 @@ def test_lstm_cpu_inference(): model.initialize(mx.init.One()) y = model(x).asnumpy() - mx.test_utils.assert_almost_equal(y, np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], - [0.72045636, 0.72045636, 0.95215213, 0.95215213]], - [[0.95215213, 0.95215213, 0.72045636, 0.72045636], - [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]), + mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT, rtol=1e-3, atol=1e-5) From a8cda1ae4375ef9affc9827900cf100c1b816127 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 9 Mar 2018 21:57:14 +0000 Subject: [PATCH 18/22] address cr add tests --- src/operator/rnn-inl.h | 48 +++++++++++++-------------- tests/python/gpu/test_operator_gpu.py | 18 +++++++++- 2 files changed, 40 insertions(+), 26 deletions(-) diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index a9509094241a..13c077dd9e35 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -228,14 +228,14 @@ class RNNOp : public Operator { Tensor cx = in_data[rnn_enum::kStateCell].get(s); - CHECK_EQ(cx.CheckContiguous(), true); + CHECK(cx.CheckContiguous()); Tensor cy = out_data[rnn_enum::kStateCellOut].get(s); Tensor hy = out_data[rnn_enum::kStateOut].get(s); - CHECK_EQ(cy.CheckContiguous(), true); - CHECK_EQ(hy.CheckContiguous(), true); + CHECK(cy.CheckContiguous()); + CHECK(hy.CheckContiguous()); DType* workspace_addr = static_cast(ctx.requested[rnn_enum::kTempSpace] @@ -252,9 +252,9 @@ class RNNOp : public Operator { + (seq_len + 1) * h_size, y.shape_); Tensor y_flatten_tmp(workspace_addr + (seq_len + 1) * h_size, y_flatten.shape_); - CHECK_EQ(i2h_y.CheckContiguous(), true); - CHECK_EQ(h2h_y.CheckContiguous(), true); - CHECK_EQ(y_tmp.CheckContiguous(), true); + CHECK(i2h_y.CheckContiguous()); + CHECK(h2h_y.CheckContiguous()); + CHECK(y_tmp.CheckContiguous()); for (int64_t layer = 0; layer < num_layers; layer++) { int reverse_dir = 0; @@ -270,13 +270,10 @@ class RNNOp : public Operator { (layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer (num_dir * (in_channel * fused_h_ch + h2h_w_size) + (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size)); - Tensor i2h_w(w.Slice(start, start + (layer < num_dir ? - (in_channel * fused_h_ch) : num_dir * h2h_w_size)).dptr_, - i2h_w_shape); + Tensor i2h_w(w.dptr_ + start, i2h_w_shape); start += layer < num_dir ? in_channel * fused_h_ch : h2h_w_size * num_dir; - Tensor h2h_w(w.Slice(start, start + h2h_w_size).dptr_, - h2h_w_shape); + Tensor h2h_w(w.dptr_ + start, h2h_w_shape); start = num_dir * (in_channel * fused_h_ch + h2h_w_size) + (num_layers - num_dir) * (h2h_w_size * (num_dir + 1)) + layer * fused_h_ch * 2; @@ -336,25 +333,26 @@ class RNNOp : public Operator { const int64_t t, const int reverse_dir, const int copy_tmp2y) { - int64_t ji; - #pragma omp parallel for private(ji) - for (ji = 0; ji < batch_size * h_channel; ji++) { + int64_t length = batch_size * h_channel; + #pragma omp parallel for + for (int64_t ji = 0; ji < length; ++ji) { int64_t j = ji / h_channel; // batch dim int64_t i = ji % h_channel; int64_t f = i + h_channel; int64_t c = i + h_channel * 2; int64_t o = i + h_channel * 3; - h2h_y[j][i] += i2h_y[j][i]; - h2h_y[j][f] += i2h_y[j][f]; - h2h_y[j][o] += i2h_y[j][o]; - h2h_y[j][c] += i2h_y[j][c]; - h2h_y[j][i] = 1.0f / (1.0f + math::exp(-h2h_y[j][i])); - h2h_y[j][f] = 1.0f / (1.0f + math::exp(-h2h_y[j][f])); - h2h_y[j][o] = 1.0f / (1.0f + math::exp(-h2h_y[j][o])); - h2h_y[j][c] = tanh(h2h_y[j][c]); - cy[j][i] = h2h_y[j][f] * (t == 0 ? cx[j][i]:cy[j][i]) - + h2h_y[j][i] * h2h_y[j][c]; - hy[j][i] = h2h_y[j][o] * tanh(cy[j][i]); + int64_t j_pos = j * h_channel * 4; + h2h_y.dptr_[j_pos + i] += i2h_y.dptr_[j_pos + i]; + h2h_y.dptr_[j_pos + f] += i2h_y.dptr_[j_pos + f]; + h2h_y.dptr_[j_pos + o] += i2h_y.dptr_[j_pos + o]; + h2h_y.dptr_[j_pos + c] += i2h_y.dptr_[j_pos + c]; + h2h_y.dptr_[j_pos + i] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + i])); + h2h_y.dptr_[j_pos + f] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + f])); + h2h_y.dptr_[j_pos + o] = 1.0f / (1.0f + math::exp(-h2h_y.dptr_[j_pos + o])); + h2h_y.dptr_[j_pos + c] = tanh(h2h_y.dptr_[j_pos + c]); + cy[j][i] = h2h_y.dptr_[j_pos + f] * (t == 0 ? cx[j][i]:cy[j][i]) + + h2h_y.dptr_[j_pos + i] * h2h_y.dptr_[j_pos + c]; + hy[j][i] = h2h_y.dptr_[j_pos + o] * tanh(cy[j][i]); tmp[j][i + h_channel * reverse_dir] = hy[j][i]; if (copy_tmp2y) { y[j][i] = tmp[j][i]; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 55bb30cc7d6a..320a9b2ef573 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1371,6 +1371,22 @@ def check_rnn_layer(layer): for g, c in zip(gs, cs): assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-8) +def check_rnn_layer_w_rand_inputs(layer): + layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) + with mx.gpu(0): + x = mx.nd.uniform((10, 16, 30)) + states = layer.begin_state(16) + go, gs = layer(x, states) + + with mx.cpu(0): + x = x.copyto(mx.cpu(0)) + states = layer.begin_state(16) + co, cs = layer(x, states) + + assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-8) + for g, c in zip(gs, cs): + assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-8) + def test_rnn_layer(): check_rnn_layer(gluon.rnn.RNN(100, num_layers=3)) @@ -1379,7 +1395,7 @@ def test_rnn_layer(): check_rnn_layer(gluon.rnn.GRU(100, num_layers=3)) check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) - + check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) def test_sequence_reverse(): check_sequence_reverse(mx.gpu(0)) From f05129377c048147521cf736d0ff1792e4b218db Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 9 Mar 2018 22:08:49 +0000 Subject: [PATCH 19/22] simplify test --- tests/python/unittest/test_gluon_rnn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 3e37a3ac7547..098259416cf1 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -75,9 +75,7 @@ def test_lstm_cpu_inference(): # should behave the same as lstm cell atol = 1e-6 x = mx.nd.ones(shape=(2, 2, 2)) - model = mx.gluon.nn.Sequential() - with model.name_scope(): - model.add(mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True)) + model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True) model.initialize(mx.init.One()) y = model(x).asnumpy() From 6e2134eb5e2641be50ac6f17e9378796b202c678 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Fri, 9 Mar 2018 23:24:34 +0000 Subject: [PATCH 20/22] fix test --- tests/python/gpu/test_operator_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 192c4d49653d..1a4da4ce13a6 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1526,7 +1526,7 @@ def check_rnn_layer(layer): def check_rnn_layer_w_rand_inputs(layer): layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) with mx.gpu(0): - x = mx.nd.uniform((10, 16, 30)) + x = mx.nd.uniform(shape=(10, 16, 30)) states = layer.begin_state(16) go, gs = layer(x, states) From d065b101326f6e6ede33f4ec2c0d167e875daf5d Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 10 Mar 2018 00:24:20 +0000 Subject: [PATCH 21/22] fix tests --- tests/python/gpu/test_operator_gpu.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 1a4da4ce13a6..0d1496858736 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1525,8 +1525,9 @@ def check_rnn_layer(layer): def check_rnn_layer_w_rand_inputs(layer): layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) + x = mx.nd.uniform(shape=(10, 16, 30)) with mx.gpu(0): - x = mx.nd.uniform(shape=(10, 16, 30)) + x = x.copyto(mx.gpu(0)) states = layer.begin_state(16) go, gs = layer(x, states) From ef1e19d5446e7d210d8987af190e535a55944712 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 10 Mar 2018 01:48:25 +0000 Subject: [PATCH 22/22] change magic number scope --- tests/python/unittest/test_gluon_rnn.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index ef1571bd241a..f22b13d65752 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -67,13 +67,12 @@ def test_lstm_forget_bias(): forget_bias * np.ones(100, ), np.zeros((2 * 100,))]) assert_allclose(mod.get_params()[0][bias_argument].asnumpy(), expected_bias) -EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], +def test_lstm_cpu_inference(): + # should behave the same as lstm cell + EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], [0.72045636, 0.72045636, 0.95215213, 0.95215213]], [[0.95215213, 0.95215213, 0.72045636, 0.72045636], [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]) -def test_lstm_cpu_inference(): - # should behave the same as lstm cell - atol = 1e-6 x = mx.nd.ones(shape=(2, 2, 2)) model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True) model.initialize(mx.init.One())