Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP][MXNET-107] Fused LSTM implementation for CPU #10104

Merged
merged 39 commits into from
May 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fd24ed2
register RNN fused-API with nnvm, finish single-layer && undirection …
Mar 8, 2018
ba0fe6d
fix coding style and lint complains
TaoLv Mar 8, 2018
a3c34ab
add single-layer && undirectional LSTM backward function
Mar 8, 2018
b5c1ef7
make interface universal for other RNN mode
Mar 9, 2018
73ed6dd
share intermediate result between forward and backward in a trick way
Mar 9, 2018
d72fe17
add comments for important parameters
Mar 12, 2018
d6811b5
modify testcase
Mar 14, 2018
d0306e5
Fix coding style and error message
TaoLv Mar 14, 2018
c2e7c8f
fix openmp collapse error
Mar 15, 2018
154aa3b
fix const
Mar 15, 2018
7c0cc29
remove rnn.cu and skip related testcases temporarily for building on GPU
Mar 15, 2018
b59f009
support multi-layer and bidirectional for lstm inference
Mar 17, 2018
26d32d2
remove some testcaseS in test_gluon_rnn.py to build on GPU
Mar 18, 2018
1b89cff
remove testcase between fp32 and fp64 temporarily
Mar 22, 2018
afd831d
retrigger ci
TaoLv Mar 22, 2018
ce818d3
fix some logs
Mar 26, 2018
f24ee4b
use a better way to share memory
Mar 26, 2018
d51dafd
fix cudnn registration
Mar 26, 2018
cdaadf7
fix invariant calculations and enable some gpu testcases
Mar 26, 2018
4161f3b
add thread local cache for cudnn rnn op
TaoLv Mar 26, 2018
f3dcb07
add thread local cache for rnn op
Mar 28, 2018
09f6e9a
fix bugs
Mar 28, 2018
c28bbc8
remove some testcases to check segmentfault
Mar 29, 2018
3370cb4
remove cudnn registeration to check segmentfault
Mar 29, 2018
46af847
support multi-layer for LSTM Training
Mar 30, 2018
e42e7f9
modify lstm testcase
Apr 2, 2018
e5b8b51
add bidirectional support for lstm
Apr 3, 2018
8a67315
fix gluon and coding style
Apr 4, 2018
78edb41
fix bugs
Apr 4, 2018
f50f5c0
remove nnvm registration
Apr 8, 2018
35a4a4b
enable gpu testcases
Apr 9, 2018
19ef217
add detailed descriptions
Apr 9, 2018
b0cfcf8
add dropout check
Apr 10, 2018
b6b567e
fix workspace size
Apr 27, 2018
1471836
Merge remote-tracking branch 'upstream/master' into lstm
TaoLv May 8, 2018
a52b5ef
dropout is not supported, add unit test for it
TaoLv May 9, 2018
a60de72
Merge remote-tracking branch 'upstream/master' into lstm
TaoLv May 9, 2018
3c61b84
fix review comments
TaoLv May 12, 2018
aeb8e9d
Merge remote-tracking branch 'upstream/master' into lstm
TaoLv May 12, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from __future__ import print_function
__all__ = ['RNN', 'LSTM', 'GRU']

from ...autograd import is_training
from ... import ndarray
from .. import Block
from . import rnn_cell
Expand Down Expand Up @@ -186,8 +185,7 @@ def forward(self, inputs, states=None):
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
if inputs.context.device_type == 'gpu' or \
(not is_training() and self._mode == 'lstm'):
if inputs.context.device_type == 'gpu' or self._mode == 'lstm':
out = self._forward_kernel(inputs, states)
else:
out = self._forward(inputs, states)
Expand Down
3 changes: 2 additions & 1 deletion src/operator/cudnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace mxnet {
namespace op {
#if defined(__CUDACC__) && MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
template<typename DType>
class CuDNNRNNOp : public Operator {
class CuDNNRNNOp : public Operator{
public:
explicit CuDNNRNNOp(RNNParam param) {
this->param_ = param;
Expand Down Expand Up @@ -101,6 +101,7 @@ class CuDNNRNNOp : public Operator {
CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
Storage::Get()->Free(dropout_states_);
Storage::Get()->Free(reserve_space_);
init_cudnn_ = false;
}
}

Expand Down
Loading