diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 00000000..a1a9f671 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,9 @@ +[submodule "mshadow"] + path = mshadow + url = https://github.com/dmlc/mshadow.git +[submodule "rabit"] + path = rabit + url = https://github.com/dmlc/rabit.git +[submodule "dmlc-core"] + path = dmlc-core + url = https://github.com/dmlc/dmlc-core.git diff --git a/dmlc-core b/dmlc-core new file mode 160000 index 00000000..75f1950d --- /dev/null +++ b/dmlc-core @@ -0,0 +1 @@ +Subproject commit 75f1950d386d033b0b64919017515d27e698962a diff --git a/mshadow b/mshadow new file mode 160000 index 00000000..208a1982 --- /dev/null +++ b/mshadow @@ -0,0 +1 @@ +Subproject commit 208a198213ea011e42f91b128b14a7206cce62a5 diff --git a/rabit b/rabit new file mode 160000 index 00000000..c71ed6fc --- /dev/null +++ b/rabit @@ -0,0 +1 @@ +Subproject commit c71ed6fccbd3e62c9ed9b54631dedc21a33addcf diff --git a/src/io/iter_attach_txt-inl.hpp b/src/io/iter_attach_txt-inl.hpp index 0f948575..75f0744f 100644 --- a/src/io/iter_attach_txt-inl.hpp +++ b/src/io/iter_attach_txt-inl.hpp @@ -22,7 +22,7 @@ class AttachTxtIterator : public IIterator { } virtual void SetParam(const char *name, const char *val) { base_->SetParam(name, val); - if (!strcmp(name, "filename")) filename_ = val; + if (!strcmp(name, "txtfilename")) filename_ = val; if (!strcmp(name, "batch_size")) batch_size_ = (index_t)atoi(val); if (!strcmp(name, "round_batch")) round_batch_ = atoi(val); } diff --git a/src/layer/layer.h b/src/layer/layer.h index 816876ae..c326bb54 100644 --- a/src/layer/layer.h +++ b/src/layer/layer.h @@ -313,6 +313,7 @@ const int kPRelu = 29; const int kBatchNorm = 30; const int kFixConnect = 31; const int kBatchNorm_no_ma = 32; +const int kLSTM = 1023; /*! \brief gap used to encode pairtest layer */ const int kPairTestGap = 1024; /*! \brief use integer to encode layer types */ @@ -352,6 +353,7 @@ inline LayerType GetLayerType(const char *type) { if (!strcmp(type, "prelu")) return kPRelu; if (!strcmp(type, "batch_norm")) return kBatchNorm; if (!strcmp(type, "batch_norm_no_ma")) return kBatchNorm_no_ma; + if (!strcmp(type, "lstm")) return kLSTM; #if CXXNET_USE_CAFFE_ADAPTOR if (!strcmp(type, "caffe")) return kCaffe; #endif diff --git a/src/layer/layer_impl-inl.hpp b/src/layer/layer_impl-inl.hpp index 9dd5a2aa..72653175 100644 --- a/src/layer/layer_impl-inl.hpp +++ b/src/layer/layer_impl-inl.hpp @@ -25,6 +25,7 @@ #include "./insanity_pooling_layer-inl.hpp" #include "./prelu_layer-inl.hpp" #include "./batch_norm_layer-inl.hpp" +#include "./lstm_layer-inl.hpp" #include "./loss/softmax_layer-inl.hpp" #include "./loss/lp_loss_layer-inl.hpp" #include "./loss/multi_logistic_layer-inl.hpp" @@ -69,6 +70,7 @@ ILayer* CreateLayer_(LayerType type, case kBatchNorm_no_ma: return new BatchNormLayer(p_rnd); case kLpLoss: return new LpLossLayer(label_info); case kMultiLogistic: return new MultiLogisticLayer(label_info); + case kLSTM: return new LSTMLayer(p_rnd); #if CXXNET_USE_CAFFE_ADAPTOR case kCaffe: return new CaffeLayer(); #endif diff --git a/src/layer/lstm_layer-inl.hpp b/src/layer/lstm_layer-inl.hpp new file mode 100644 index 00000000..a9d7299c --- /dev/null +++ b/src/layer/lstm_layer-inl.hpp @@ -0,0 +1,340 @@ +#ifndef CXXNET_LAYER_LSTM_LAYER_INL_HPP_ +#define CXXNET_LAYER_LSTM_LAYER_INL_HPP_ + +#include +#include "./layer.h" +#include "./param.h" +#include "./op.h" +#include "../utils/utils.h" + +namespace cxxnet { +namespace layer { + +template +class LSTMLayer : public ILayer { + public: + LSTMLayer(mshadow::Random *p_rnd) : prnd_(p_rnd) { + this->parallel_size = 1; + } + virtual ~LSTMLayer(void) {} + virtual void SetParam(const char *name, const char* val) { + param_.SetParam(name, val); + if (!strcmp(name, "parallel_size")) this->parallel_size = atoi(val); + } + virtual void ApplyVisitor(typename ILayer::IVisitor *pvisitor) { + pvisitor->Visit("wmat", wmat_, gwmat_); + pvisitor->Visit("bias", bias_, gbias_); + } + virtual void InitModel(void) { + //ifog weights: input, forget, output and cell gate * input vector and hidden state + //ifog bias: input, forget, output and cell gate + wmat_.Resize(mshadow::Shape2(param_.num_hidden * 4, param_.num_hidden + param_.num_input_node)); + bias_.Resize(mshadow::Shape1(param_.num_hidden * 4)); + param_.RandInitWeight(this->prnd_, wmat_, wmat_.size(1), wmat_.size(0)); + bias_ = param_.init_bias; + // TODO: fancy_forget_bias_init + // https://gist.github.com/karpathy/587454dc0146a6ae21fc + gwmat_.Resize(wmat_.shape_); + gbias_.Resize(bias_.shape_); + gwmat_ = 0.0f; gbias_ = 0.0f; + } + virtual void SaveModel(utils::IStream &fo) const { + fo.Write(¶m_, sizeof(LayerParam)); + wmat_.SaveBinary(fo); + bias_.SaveBinary(fo); + } + virtual void LoadModel(utils::IStream &fi) { + utils::Check(fi.Read(¶m_, sizeof(LayerParam)) != 0, + "LSTMLayer:LoadModel invalid model file"); + wmat_.LoadBinary(fi); + bias_.LoadBinary(fi); + gwmat_.Resize(wmat_.shape_); + gbias_.Resize(bias_.shape_); + gwmat_ = 0.0f; gbias_ = 0.0f; + } + virtual void SetStream(mshadow::Stream *stream) { + wmat_.set_stream(stream); + bias_.set_stream(stream); + gwmat_.set_stream(stream); + gbias_.set_stream(stream); + + it.set_stream(stream); + ft.set_stream(stream); + ot.set_stream(stream); + gt.set_stream(stream); + ct.set_stream(stream); + c_tanht.set_stream(stream); + ht.set_stream(stream); + + flush.set_stream(stream); + t.set_stream(stream); + xhprev.set_stream(stream); + lifog.set_stream(stream); + d_xhprev.set_stream(stream); + d_lifog.set_stream(stream); + d_c.set_stream(stream); + d_cprev.set_stream(stream); + } + virtual void InitConnection(const std::vector*> &nodes_in, + const std::vector*> &nodes_out, + ConnectState *p_cstate) { + utils::Check((nodes_in.size() == 1 || nodes_in.size() == 2) && nodes_out.size() == 1, + "LSTMLayer: Layer only support 2(w/sequence label)-1 connection"); + utils::Check(param_.num_hidden > 0, "LSTMLayer: must set nhidden correctly"); + nodes_out[0]->data.shape_ = + mshadow::Shape4(nodes_in[0]->data.size(0), 1, 1, param_.num_hidden); + if (param_.num_input_node == 0) { + param_.num_input_node = static_cast(nodes_in[0]->data.size(3)); + } else { + utils::Check(param_.num_input_node == static_cast(nodes_in[0]->data.size(3)), + "LSTMLayer: input hidden nodes is not consistent"); + } + this->seq_length = nodes_in[0]->data.size(0); + nodes_in[0]->must_contiguous = true; + nodes_in[1]->must_contiguous = true; + nodes_out[0]->must_contiguous = true; + this->initTemp(); + } + + virtual void OnBatchSizeChanged(const std::vector*> &nodes_in, + const std::vector*> &nodes_out, + ConnectState *p_cstate) { + this->seq_length = nodes_in[0]->data.size(0); + this->initTemp(); + } + + /* + nodes_in[0] size: [batch_size][1][1][input_width] + nodes_in[1] size: [batch_size][1][1][1] + nodes_out[0] size: [batch_size][1][1][hidden_size] + + The input sequence nodes_in[0] should be: + Seq[0][i], Seq[1][j], ... , Seq[parallel_size][k], Seq[0][i + 1], Seq[1][j + 1], ... , Seq[parallel_size][k + 1], ... + The correspond sequence label (in nodes_in[1]) should be '1' when it is the beginning of a sequence. + */ + virtual void Forward(bool is_train, + const std::vector*> &nodes_in, + const std::vector*> &nodes_out, + ConnectState *p_cstate) { + mshadow::Tensor &node_in = nodes_in[0]->data; + mshadow::Tensor &node_out = nodes_out[0]->data; + mshadow::Tensor xt = node_in; + mshadow::Tensor seq_label = nodes_in[1]->data; + + CHECK(nodes_out[0]->data.CheckContiguous()); + CHECK(xt.CheckContiguous()); + CHECK(seq_label.CheckContiguous()); + CHECK(ht.CheckContiguous()); + + index_t n_seq = seq_length / parallel_size; + xt.shape_ = mshadow::Shape4(n_seq,1,parallel_size,node_in.size(3)); + seq_label.shape_ = mshadow::Shape4(n_seq, 1, 1, parallel_size); + seq_label.stride_ = parallel_size; + + for (index_t i = 0; i < n_seq; i++){ + flush = mshadow::expr::broadcast<0>(seq_label[i][0][0], flush.shape_); + if (i != 0) + t = flush * ht[i-1][0]; + else + t = flush * ht[n_seq-1][0]; + concat2D(xhprev, xt[i][0], t); + if (i != 0) + t = flush * ct[i-1][0]; + else + t = flush * ct[n_seq-1][0]; + LSTM_Forward(xhprev, t, ht[i][0], ct[i][0], it[i][0], ft[i][0], ot[i][0], gt[i][0], c_tanht[i][0]); + } + ht.shape_ = node_out.shape_; + mshadow::Copy(node_out, ht, ht.stream_); + ht.shape_ = ct.shape_; + } + + virtual void Backprop(bool prop_grad, + const std::vector*> &nodes_in, + const std::vector*> &nodes_out, + ConnectState *p_cstate) { + mshadow::Tensor &node_in = nodes_in[0]->data; + mshadow::Tensor &node_out = nodes_out[0]->data; + mshadow::Tensor d_xt = node_in; + mshadow::Tensor d_ht = node_out; + mshadow::Tensor seq_label = nodes_in[1]->data; + + CHECK(d_xt.CheckContiguous()); + CHECK(d_ht.CheckContiguous()); + CHECK(seq_label.CheckContiguous()); + + index_t n_seq = seq_length / parallel_size; + d_xt.shape_ = mshadow::Shape4(n_seq,1,parallel_size,node_in.size(3)); + d_ht.shape_ = mshadow::Shape4(n_seq,1,parallel_size,node_out.size(3)); + seq_label.shape_ = mshadow::Shape4(n_seq, 1, 1, parallel_size); + seq_label.stride_ = parallel_size; + d_cprev = 0.0f; + + for (index_t i = n_seq - 1; i < n_seq; i--){ //unsigned int >=0 + mshadow::Copy(d_c, d_cprev, d_cprev.stream_); + if (i == 0){ + flush = 0.0f; + concat2D(xhprev, d_xt[i][0], flush); + LSTM_Backprop(d_ht[i][0], xhprev, flush, c_tanht[i][0], it[i][0], ft[i][0], ot[i][0], gt[i][0], d_xhprev, d_c, d_cprev); + }else{ + flush = mshadow::expr::broadcast<0>(seq_label[i][0][0], flush.shape_); + t = flush * ht[i-1][0]; + concat2D(xhprev, d_xt[i][0], t); + t = flush * ct[i-1][0]; + LSTM_Backprop(d_ht[i][0], xhprev, t, c_tanht[i][0], it[i][0], ft[i][0], ot[i][0], gt[i][0], d_xhprev, d_c, d_cprev); + t = d_xhprev.Slice(param_.num_input_node, param_.num_input_node + param_.num_hidden).T(); + d_ht[i-1][0] += flush * t; + d_cprev *= flush; + } + if (prop_grad) { + d_xt[i][0] = d_xhprev.Slice(0, param_.num_input_node).T(); + } + } + } + + protected: + void LSTM_Forward(mshadow::Tensor xhprev, + mshadow::Tensor cprev, + mshadow::Tensor h, + mshadow::Tensor c, + mshadow::Tensor i, + mshadow::Tensor f, + mshadow::Tensor o, + mshadow::Tensor g, + mshadow::Tensor c_tanh){ + using namespace cxxnet::op; + using namespace mshadow::expr; + /* + li_t = w_ix * x_t + w_ih * h_t-1 + b_i + lf_t = w_fx * x_t + w_fh * h_t-1 + b_f + lo_t = w_ox * x_t + w_oh * h_t-1 + b_o + lg_t = w_gx * x_t + w_gh * h_t-1 + b_g + lifog = [li_t, lf_t, lo_t, lg_t] + */ + lifog = broadcast<0>(bias_, lifog.shape_); + lifog += dot(wmat_, xhprev.T()); + mshadow::Tensor li, lf, lo, lg; + li = lifog.Slice(0 * param_.num_hidden, 1 * param_.num_hidden); + lf = lifog.Slice(1 * param_.num_hidden, 2 * param_.num_hidden); + lo = lifog.Slice(2 * param_.num_hidden, 3 * param_.num_hidden); + lg = lifog.Slice(3 * param_.num_hidden, 4 * param_.num_hidden); + /* + i_t = sigmoid(li_t) + f_t = sigmoid(lf_t) + o_t = sigmoid(lo_t) + g_t = tanh(lg_t) + */ + i = F(li.T()); + f = F(lf.T()); + o = F(lo.T()); + g = F(lg.T()); + /* + c_t = f_t * c_t-1 + i_t * g_t + h_t = o_t * tanh(c_t) + */ + c = f * cprev + i * g; + c_tanh = F(c); + h = o * c_tanh; + } + + void LSTM_Backprop(mshadow::Tensor d_h, + mshadow::Tensor xhprev, + mshadow::Tensor cprev, + mshadow::Tensor c_tanh, + mshadow::Tensor i, + mshadow::Tensor f, + mshadow::Tensor o, + mshadow::Tensor g, + mshadow::Tensor d_xhprev, + mshadow::Tensor d_c, + mshadow::Tensor d_cprev){ + using namespace cxxnet::op; + using namespace mshadow::expr; + + d_c += F(c_tanh) * o * d_h; + d_cprev = f * d_c; + + mshadow::Tensor d_li, d_lf, d_lo, d_lg; + d_li = d_lifog.Slice(0 * param_.num_hidden, 1 * param_.num_hidden); + d_lf = d_lifog.Slice(1 * param_.num_hidden, 2 * param_.num_hidden); + d_lo = d_lifog.Slice(2 * param_.num_hidden, 3 * param_.num_hidden); + d_lg = d_lifog.Slice(3 * param_.num_hidden, 4 * param_.num_hidden); + + d_li = F(i.T()) * g.T() * d_c.T(); + d_lf = F(f.T()) * cprev.T() * d_c.T(); + d_lo = F(o.T()) * c_tanh.T() * d_h.T(); + d_lg = F(g.T()) * i.T() * d_c.T(); + + gwmat_ += dot(d_lifog, xhprev); + gbias_ += sum_rows(d_lifog.T()); + d_xhprev = dot(wmat_.T(), d_lifog); + } + + inline void tensor2To4(mshadow::Tensor a, mshadow::Tensor *a4){ + CHECK(a.CheckContiguous()); + a4->set_stream(a.stream_); + a4->dptr_ = a.dptr_; + a4->stride_ = a.stride_; + a4->shape_ = mshadow::Shape4(1,1,a.size(0),a.size(1)); + CHECK(a4->CheckContiguous()); + } + + inline void concat2D(mshadow::Tensor dst, mshadow::Tensor a, mshadow::Tensor b){ + utils::Check(a.size(0) == b.size(0) && b.size(0) == dst.size(0), "LSTMLayer: concat size[0] mismatch"); + utils::Check(a.size(1) + b.size(1) == dst.size(1), "LSTMLayer: concat size[1] mismatch"); + mshadow::Tensor dst4, a4, b4; + tensor2To4(dst, &dst4); + tensor2To4(a, &a4); + tensor2To4(b, &b4); + dst4 = mshadow::expr::concat<3>(a4, b4); + CHECK(dst.CheckContiguous()); + } + + inline void initTemp(){ + it.Resize(mshadow::Shape4(seq_length / parallel_size, 1, parallel_size, param_.num_hidden)); + ft.Resize(mshadow::Shape4(seq_length / parallel_size, 1, parallel_size, param_.num_hidden)); + ot.Resize(mshadow::Shape4(seq_length / parallel_size, 1, parallel_size, param_.num_hidden)); + gt.Resize(mshadow::Shape4(seq_length / parallel_size, 1, parallel_size, param_.num_hidden)); + ct.Resize(mshadow::Shape4(seq_length / parallel_size, 1, parallel_size, param_.num_hidden)); + c_tanht.Resize(mshadow::Shape4(seq_length / parallel_size, 1, parallel_size, param_.num_hidden)); + ht.Resize(mshadow::Shape4(seq_length / parallel_size, 1, parallel_size, param_.num_hidden)); + + flush.Resize(mshadow::Shape2(parallel_size, param_.num_hidden)); + t.Resize(mshadow::Shape2(parallel_size, param_.num_hidden)); + xhprev.Resize(mshadow::Shape2(parallel_size, param_.num_input_node + param_.num_hidden)); + d_xhprev.Resize(mshadow::Shape2(param_.num_input_node + param_.num_hidden, parallel_size)); + d_c.Resize(mshadow::Shape2(parallel_size, param_.num_hidden)); + d_cprev.Resize(mshadow::Shape2(parallel_size, param_.num_hidden)); + lifog.Resize(mshadow::Shape2(4 * param_.num_hidden, parallel_size)); + d_lifog.Resize(mshadow::Shape2(4 * param_.num_hidden, parallel_size)); + } + + /*! \brief random number generator */ + mshadow::Random *prnd_; + /*! \brief parameters that potentially be useful */ + LayerParam param_; + /*! \brief weight matrix */ + mshadow::TensorContainer wmat_; + /*! \brief bias */ + mshadow::TensorContainer bias_; + /*! \brief accumulates the gradient of weight matrix */ + mshadow::TensorContainer gwmat_; + /*! \brief accumulates the gradient of bias */ + mshadow::TensorContainer gbias_; + + /*! \brief batched BPTT */ + size_t parallel_size, seq_length; + + /*! \brief var in LSTM layer */ + mshadow::TensorContainer it, ft, ot, gt, ct, c_tanht, ht; + mshadow::TensorContainer flush, t; + mshadow::TensorContainer xhprev; + mshadow::TensorContainer lifog; + mshadow::TensorContainer d_xhprev; + mshadow::TensorContainer d_lifog; + mshadow::TensorContainer d_c; + mshadow::TensorContainer d_cprev; +}; +} // namespace layer +} // namespace cxxnet +#endif // LAYER_LSTM_LAYER_INL_HPP_ diff --git a/src/nnet/neural_net-inl.hpp b/src/nnet/neural_net-inl.hpp index 681dc024..fe836543 100644 --- a/src/nnet/neural_net-inl.hpp +++ b/src/nnet/neural_net-inl.hpp @@ -242,6 +242,7 @@ struct NeuralNet { // setup extra data for (int i = 0; i < cfg.param.extra_data_num; ++i) { const std::vector& extra_shape = cfg.extra_shape; + nodes[i + 1].must_contiguous = true; nodes[i + 1].data.shape_ = mshadow::Shape4( max_batch, extra_shape[i * 3], extra_shape[i * 3 + 1], extra_shape[i * 3 + 2]); }