Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Cpu lstm inference #9977

Merged
merged 26 commits into from
Mar 10, 2018
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __iter__(self):
worker.start()
workers.append(worker)

idx = -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx might be reference before assignment when i run pylint

for idx, batch in enumerate(self._batch_sampler):
key_queue.put((idx, batch))
num_batches = idx + 1
Expand Down
14 changes: 9 additions & 5 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,18 @@ def forward(self, inputs, states=None):
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
if inputs.context.device_type == 'gpu':
out = self._forward_gpu(inputs, states)
import mxnet
if inputs.context.device_type == 'gpu' or \
(not mxnet.autograd.is_training() and self._mode == 'lstm'):
out = self._forward_kernel(inputs, states)
else:
out = self._forward_cpu(inputs, states)
out = self._forward(inputs, states)

# out is (output, state)
return out[0] if skip_states else out

def _forward_cpu(self, inputs, states):
def _forward(self, inputs, states):
"""forward using gluon cell"""
ns = len(states)
axis = self._layout.find('T')
states = sum(zip(*((j for j in i) for i in states)), ())
Expand All @@ -207,7 +210,8 @@ def _forward_cpu(self, inputs, states):

return outputs, new_states

def _forward_gpu(self, inputs, states):
def _forward_kernel(self, inputs, states):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1)
ctx = inputs.context
Expand Down
255 changes: 238 additions & 17 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
#include <vector>
#include <string>
#include <utility>
#include "./math.h"
#include "./math_functions-inl.h"
#include "./operator_common.h"
#include "./mshadow_op.h"
#include "./linalg.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -78,10 +82,12 @@ inline int rnn_param_size(int layerNum,
int size = rnn_single_param_size(inputSize, hiddenSize, mode);
// get size of remaining layers
if (bidirectional) {
size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize, hiddenSize, mode);
size += (layerNum - 1) * rnn_single_param_size(2 * hiddenSize,
hiddenSize, mode);
size *= 2;
} else {
size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize, mode);
size += (layerNum - 1) * rnn_single_param_size(hiddenSize, hiddenSize,
mode);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are just reformatting the code here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

}
return size;
}
Expand Down Expand Up @@ -114,7 +120,8 @@ struct RNNParam : public dmlc::Parameter<RNNParam> {

DMLC_DECLARE_FIELD(p).set_default(0.)
.set_range(0, 1)
.describe("Dropout probability, fraction of the input that gets dropped out at training time");
.describe("Dropout probability, fraction of the input that gets dropped"
"out at training time");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove this change. Length of this line is less than 100.
BTW, why there are still some parameters don't have their descriptions, like pkeep_, lstm_q_?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pkeep_, lstm_q_ are used in cudnn_rnn-inl.h


DMLC_DECLARE_FIELD(state_outputs).set_default(false)
.describe("Whether to have the states as symbol outputs.");
Expand All @@ -132,8 +139,6 @@ class RNNOp : public Operator {
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
// TODO(sbodenstein): add MShadow implementation
}

Expand All @@ -144,15 +149,224 @@ class RNNOp : public Operator {
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
// TODO(sbodenstein): add MShadow implementation
}

private:
RNNParam param_;
}; // class RNNOp

template<typename DType>
class RNNOp<cpu, DType> : public Operator {
public:
explicit RNNOp(RNNParam param) {
this->param_ = param;
// RNN Mode
switch (param_.mode) {
case rnn_enum::kLstm:
break;
default:
LOG(FATAL) << "only LSTM is implmented on CPU";
}
if (param_.mode == rnn_enum::kLstm)
param_.lstm_q_ = true;
else
param_.lstm_q_ = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems this check can be merged to the switch case statement above.

}

virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {
// Layout TNC

size_t in_expected = param_.lstm_q_ ? 4 : 3;
size_t out_expected = param_.lstm_q_ ? 3 : 2;

if (!param_.state_outputs)
LOG(FATAL) << "no state outputs is currently not supported for cpu.";

CHECK_EQ(req[rnn_enum::kOut], kWriteTo);
CHECK_EQ(in_data.size(), in_expected);
CHECK_EQ(out_data.size(), out_expected);

mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
// get input + output tensors
// w layout i2h_w, h2h_w, i2h_b, h2h_b
Tensor<cpu, 3, DType> x =
in_data[rnn_enum::kData].get<cpu, 3, DType>(s); // TNC
Tensor<cpu, 1, DType> w = in_data[rnn_enum::kParams].get<cpu, 1, DType>(s);
Tensor<cpu, 3, DType> hx =
in_data[rnn_enum::kState].get<cpu, 3, DType>(s); // LNC
Tensor<cpu, 3, DType> y =
out_data[rnn_enum::kOut].get<cpu, 3, DType>(s); // TNC
int64_t seq_len = x.shape_[0];
int64_t num_layers = hx.shape_[0];
int64_t batch_size = x.shape_[1];
int64_t h_channel = hx.shape_[2];
int64_t in_channel = x.shape_[2];
Tensor<cpu, 2, DType> x_flatten = in_data[rnn_enum::kData]
.get_with_shape<cpu, 2, DType>(
mshadow::Shape2(seq_len * batch_size, in_channel), s); // (T*N)C
Tensor<cpu, 2, DType> y_flatten = out_data[rnn_enum::kOut]
.get_with_shape<cpu, 2, DType>(
mshadow::Shape2(
y.shape_[0] * y.shape_[1], y.shape_[2]), s); // (T*N)C

CHECK_EQ(x.CheckContiguous(), true);
CHECK_EQ(w.CheckContiguous(), true);
CHECK_EQ(hx.CheckContiguous(), true);
CHECK_EQ(y.CheckContiguous(), true);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CHECK(x.CheckContiguous());


if (ctx.is_train)
LOG(FATAL) << "only inference mode is available for cpu at the moment.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can do CHECK(!ctx.is_train) << "..."

if (param_.lstm_q_) {
const size_t kNumMat = 4;
int64_t fused_h_ch = kNumMat * h_channel;
int64_t h_size = batch_size * fused_h_ch;
int64_t num_dir = 1 + param_.bidirectional;
int64_t h2h_w_size = h_channel * fused_h_ch;

Tensor<cpu, 3, DType> cx =
in_data[rnn_enum::kStateCell].get<cpu, 3, DType>(s);
CHECK_EQ(cx.CheckContiguous(), true);

Tensor<cpu, 3, DType> cy =
out_data[rnn_enum::kStateCellOut].get<cpu, 3, DType>(s);
Tensor<cpu, 3, DType> hy =
out_data[rnn_enum::kStateOut].get<cpu, 3, DType>(s);
CHECK_EQ(cy.CheckContiguous(), true);
CHECK_EQ(hy.CheckContiguous(), true);

DType* workspace_addr =
static_cast<DType *>(ctx.requested[rnn_enum::kTempSpace]
.get_host_space_internal(sizeof(DType) *
(seq_len * h_size + h_size
+ y.shape_[0] * y.shape_[1] * y.shape_[2])));
Tensor<cpu, 3, DType> i2h_y(
workspace_addr, mshadow::Shape3(seq_len, batch_size, fused_h_ch));
Tensor<cpu, 2, DType> i2h_y_flatten(
workspace_addr, mshadow::Shape2(seq_len * batch_size, fused_h_ch));
Tensor<cpu, 2, DType> h2h_y(workspace_addr
+ seq_len * h_size, mshadow::Shape2(batch_size, fused_h_ch));
Tensor<cpu, 3, DType> y_tmp(workspace_addr
+ (seq_len + 1) * h_size, y.shape_);
Tensor<cpu, 2, DType> y_flatten_tmp(workspace_addr
+ (seq_len + 1) * h_size, y_flatten.shape_);
CHECK_EQ(i2h_y.CheckContiguous(), true);
CHECK_EQ(h2h_y.CheckContiguous(), true);
CHECK_EQ(y_tmp.CheckContiguous(), true);

for (int64_t layer = 0; layer < num_layers; layer++) {
int reverse_dir = 0;
int out_tmp = 0;
if (param_.bidirectional && layer % 2)
reverse_dir = 1;
if (layer / num_dir % 2 == 0)
out_tmp = 1;
mshadow::Shape<2> i2h_w_shape = mshadow::Shape2(fused_h_ch,
(layer < num_dir) ? in_channel : num_dir * h_channel);
mshadow::Shape<2> h2h_w_shape = mshadow::Shape2(fused_h_ch, h_channel);
int64_t start = layer < num_dir ?
(layer * (in_channel * fused_h_ch + h2h_w_size)) : // input layer
(num_dir * (in_channel * fused_h_ch + h2h_w_size)
+ (layer - num_dir) * (h2h_w_size * num_dir + h2h_w_size));
Tensor<cpu, 2, DType> i2h_w(w.Slice(start, start + (layer < num_dir ?
(in_channel * fused_h_ch) : num_dir * h2h_w_size)).dptr_,
i2h_w_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why slice? i think w.dptr_ + start is same as w.Slice(start, start + (layer < num_dir ? in_channel * fused_h_ch) : num_dir * h2h_w_size)).dptr_

start += layer < num_dir ?
in_channel * fused_h_ch : h2h_w_size * num_dir;
Tensor<cpu, 2, DType> h2h_w(w.Slice(start, start + h2h_w_size).dptr_,
h2h_w_shape);
start = num_dir * (in_channel * fused_h_ch + h2h_w_size)
+ (num_layers - num_dir) * (h2h_w_size * (num_dir + 1))
+ layer * fused_h_ch * 2;
Tensor<cpu, 1, DType> i2h_b = w.Slice(start, start + fused_h_ch);
start += fused_h_ch;
Tensor<cpu, 1, DType> h2h_b = w.Slice(start, start + fused_h_ch);
if (out_tmp) {
linalg_gemm(layer < num_dir ? x_flatten:y_flatten, i2h_w,
i2h_y_flatten, false, true, s);
} else {
linalg_gemm(layer < num_dir ? x_flatten:y_flatten_tmp, i2h_w,
i2h_y_flatten, false, true, s);
}
i2h_y_flatten += repmat(i2h_b, seq_len * batch_size);
for (int64_t t = 0; t < seq_len; t++) {
int64_t timestep = t;
if (reverse_dir)
timestep = seq_len - 1 - t;
linalg_gemm(t == 0 ? hx[layer]:hy[layer], h2h_w, h2h_y,
false, true, s);
h2h_y += repmat(h2h_b, batch_size);
// fused element-wise ops
LSTMFusedElementWiseCPUOps(i2h_y[timestep], cx[layer], h2h_y,
y[timestep], out_tmp ? y_tmp[timestep]: y[timestep],
hy[layer], cy[layer], batch_size, h_channel, t,
reverse_dir, out_tmp && (layer == num_layers - 1));
}
}
} else {
LOG(FATAL) << "only LSTM is available for cpu at the moment.";
}
}

virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
LOG(FATAL) << "LSTM backward is not available for cpu at the moment.";
}

private:
RNNParam param_;

virtual void LSTMFusedElementWiseCPUOps(const Tensor<cpu, 2, DType> &i2h_y,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why virtual?

const Tensor<cpu, 2, DType> &cx,
const Tensor<cpu, 2, DType> &h2h_y,
const Tensor<cpu, 2, DType> &y,
// holding intermediate layer output
const Tensor<cpu, 2, DType> &tmp,
const Tensor<cpu, 2, DType> &hy,
const Tensor<cpu, 2, DType> &cy,
const int64_t batch_size,
const int64_t h_channel,
const int64_t t,
const int reverse_dir,
const int copy_tmp2y) {
int64_t ji;
#pragma omp parallel for private(ji)
for (ji = 0; ji < batch_size * h_channel; ji++) {
int64_t j = ji / h_channel; // batch dim
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it ok to write batch_size * h_channel in condition expression? It will calculate ji times.
And ++ji is better than ji++.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

u mean move it out of condition expression?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need to set ji private if define ji in for loop. like this:

#pragma omp parallel for
for(int64_t ji = 0; ... ; ....)

int64_t i = ji % h_channel;
int64_t f = i + h_channel;
int64_t c = i + h_channel * 2;
int64_t o = i + h_channel * 3;
h2h_y[j][i] += i2h_y[j][i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too many overloaded operator [] calls and temporary Tensor objects generated here. At least you can cache h2h_y[j], i2h_y[j], etc. for each loop.

Copy link
Contributor Author

@Jerryzcn Jerryzcn Mar 9, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried, but i did not notice any difference in runtime. i think the tensor object probably does not generate new tensor object here. I think multiple [] are probably implemented as a single dereference operation rather multiple one. I suspect that assigning it to a local variable will actually use one of the register for holding the pointer to the object, which may actually slow down the process

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When you use data[i], where data is a 2D tensor, it returns a 1D temporary Tensor object for you, and then call the 1D tensor's operator[]. You would not be able to notice much runtime improvement after you make the change if the program didn't run for a long time only for this loop and the improvement could be dwarfed by other factors that are major bottlenecks. At least, it's not a good practice to write C++ code like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay. but it seems inside mshadow, all the ops are implemented using multiple []
https://github.com/dmlc/mshadow/blob/master/mshadow/tensor_cpu-inl.h#L380
I will probably access the dptr_

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion on this. If you could use dptr_, that's the best for performance because it saves function calls and temp tensor object creation, but it could introduce the issue of code readability and defeat the purpose of OO.

I think the rule of thumb here is try to avoid temp tensor creation and destruction while keep the code readable. So it's okay to use operator[] for a 1D Tensor since it only return values and cache the temp tensor created by calling operator[] for a 2D tensor.

h2h_y[j][f] += i2h_y[j][f];
h2h_y[j][o] += i2h_y[j][o];
h2h_y[j][c] += i2h_y[j][c];
h2h_y[j][i] = 1.0f / (1.0f + math::exp(-h2h_y[j][i]));
h2h_y[j][f] = 1.0f / (1.0f + math::exp(-h2h_y[j][f]));
h2h_y[j][o] = 1.0f / (1.0f + math::exp(-h2h_y[j][o]));
h2h_y[j][c] = tanh(h2h_y[j][c]);
cy[j][i] = h2h_y[j][f] * (t == 0 ? cx[j][i]:cy[j][i])
+ h2h_y[j][i] * h2h_y[j][c];
hy[j][i] = h2h_y[j][o] * tanh(cy[j][i]);
tmp[j][i + h_channel * reverse_dir] = hy[j][i];
if (copy_tmp2y) {
y[j][i] = tmp[j][i];
if (reverse_dir)
y[j][i + h_channel] = tmp[j][i + h_channel];
}
}
}
}; // class RNNOp

template<typename xpu>
Operator* CreateOp(RNNParam param, int dtype);

Expand Down Expand Up @@ -184,7 +398,8 @@ class RNNProp : public OperatorProperty {
return num_outputs;
}

void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs)
override {
param_.Init(kwargs);
}

Expand All @@ -195,28 +410,33 @@ class RNNProp : public OperatorProperty {
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
using namespace mshadow;
if (param_.mode == rnn_enum::kLstm) {
CHECK_EQ(in_shape->size(), 4U) << "Input:[data, parameters, state, cell_state]";
CHECK_EQ(in_shape->size(), 4U) <<
"Input:[data, parameters, state, cell_state]";
} else {
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, parameters, state]";
}
const TShape &dshape = (*in_shape)[rnn_enum::kData];
if (dshape.ndim() == 0) return false;
CHECK_EQ(dshape.ndim(), 3U) \
<< "Input data should be rank-3 tensor of dim [sequence length, batch size, input size]";
CHECK_EQ(dshape.ndim(), 3U)
<< "Input data should be rank-3 tensor of dim [sequence length, "
<< "batch size, input size]";
// data: [sequence len, batch, input dimension]
int batch_size = dshape[1];
int input_size = dshape[2];
int numDirections = param_.bidirectional ? 2 : 1;
int total_layers = numDirections * param_.num_layers; // double for bidirectional
// double for bidirectional
int total_layers = numDirections * param_.num_layers;

SHAPE_ASSIGN_CHECK(*in_shape,
rnn_enum::kState,
Shape3(total_layers, batch_size, param_.state_size));
mshadow::Shape3(total_layers, batch_size,
param_.state_size));
if (param_.mode == rnn_enum::kLstm)
SHAPE_ASSIGN_CHECK(*in_shape,
rnn_enum::kStateCell,
Shape3(total_layers, batch_size, param_.state_size));
mshadow::Shape3(total_layers, batch_size,
param_.state_size));

// calculate parameter vector length
int param_size = rnn_param_size(param_.num_layers,
Expand Down Expand Up @@ -287,8 +507,9 @@ class RNNProp : public OperatorProperty {
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
std::vector<int> dep = {in_data[rnn_enum::kData], in_data[rnn_enum::kParams],
in_data[rnn_enum::kState], out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]};
std::vector<int> dep = {in_data[rnn_enum::kData],
in_data[rnn_enum::kParams], in_data[rnn_enum::kState],
out_data[rnn_enum::kOut], out_grad[rnn_enum::kOut]};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm not sure why you want to change the code in this function. it seems you just reorganize the code a little bit.

Copy link
Contributor Author

@Jerryzcn Jerryzcn Mar 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it exceeds 80 char per line limit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the coding style in mxnet allows up to 100 char per line.
so the original code is fine.


if (param_.state_outputs) {
dep.push_back(out_data[rnn_enum::kStateOut]);
Expand Down
1 change: 0 additions & 1 deletion src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ namespace mxnet {
namespace op {
template<>
Operator *CreateOp<cpu>(RNNParam param, int dtype) {
LOG(FATAL) << "RNN is only available for gpu at the moment.";
Operator *op = NULL;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new RNNOp<cpu, DType>(param);
Expand Down
Loading