From fe468c689e108e85cbaf62bbba2df32e1186cd4f Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Fri, 29 Jun 2018 15:13:14 -0700 Subject: [PATCH] revert --- python/mxnet/gluon/rnn/rnn_layer.py | 11 ++++++++--- tests/python/unittest/test_gluon_rnn.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 3bc9eccd7c45..6a686f1a4115 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -180,6 +180,10 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): def hybrid_forward(self, F, inputs, states=None, **kwargs): if F is ndarray: batch_size = inputs.shape[self._layout.find('N')] + if self._input_size == 0: + for i in range(self._dir): + self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) + self.i2h_weight[i]._finish_deferred_init() skip_states = states is None if skip_states: if F is ndarray: @@ -199,12 +203,13 @@ def hybrid_forward(self, F, inputs, states=None, **kwargs): # out is (output, state) return out[0] if skip_states else out - def infer_shape(self, inputs, *states): - if self._input_size == 0: + def __call__(self, inputs, *states): + if self._input_size == 0 and isinstance(inputs, ndarray.NDArray): for i in range(self._dir): self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) self.i2h_weight[i]._finish_deferred_init() - return super(_RNNLayer, self).infer_shape(inputs, *states) + states = list(filter(lambda x: x is not None, states)) + return super(_RNNLayer, self).__call__(inputs, *states) def _forward_kernel(self, F, inputs, states, **kwargs): """ forward using CUDNN or CPU kenrel""" diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index d14ad3d54d3c..8329317be028 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -305,7 +305,7 @@ def test_rnn_layers(): mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True) net = gluon.nn.HybridSequential() - net.add(gluon.rnn.LSTM(10, 2, bidirectional=True)) + net.add(gluon.rnn.LSTM(10, bidirectional=True)) net.add(gluon.nn.BatchNorm(axis=2)) net.add(gluon.nn.Flatten()) net.add(gluon.nn.Dense(3, activation='relu'))