From 0583f76c67abf4bdee6fa0929afde6a07b33811f Mon Sep 17 00:00:00 2001 From: "Andrew M. Dai" Date: Tue, 17 Apr 2018 12:06:58 -0700 Subject: [PATCH 1/2] Specify that pretraining maskgan is not optional. --- research/maskgan/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/research/maskgan/README.md b/research/maskgan/README.md index 315ab6516c7..6ef1ab156e9 100644 --- a/research/maskgan/README.md +++ b/research/maskgan/README.md @@ -14,10 +14,10 @@ tested. Pretraining may not work correctly. For training on PTB: -1. (Optional) Pretrain a LM on PTB and store the checkpoint in `/tmp/pretrain-lm/`. +1. Pretrain a LM on PTB and store the checkpoint in `/tmp/pretrain-lm/`. Instructions WIP. -2. (Optional) Run MaskGAN in MLE pretraining mode. If step 1 was not run, set +2. Run MaskGAN in MLE pretraining mode. If step 1 was not run, set `language_model_ckpt_dir` to empty. ```bash From 20662f3d93ff5fa8f51a99cf09bf6ab300d29d18 Mon Sep 17 00:00:00 2001 From: Andrew Dai Date: Sun, 29 Apr 2018 04:57:49 -0400 Subject: [PATCH 2/2] Increase minimum TF version for DEFINE_enum and rename variable mappings for change to RNN variable names. --- research/maskgan/README.md | 2 +- .../maskgan/model_utils/variable_mapping.py | 268 ++++++++---------- 2 files changed, 121 insertions(+), 149 deletions(-) diff --git a/research/maskgan/README.md b/research/maskgan/README.md index 6ef1ab156e9..17a5763fbf2 100644 --- a/research/maskgan/README.md +++ b/research/maskgan/README.md @@ -5,7 +5,7 @@ ______*](https://arxiv.org/abs/1801.07736) published at ICLR 2018. ## Requirements -* TensorFlow >= v1.3 +* TensorFlow >= v1.5 ## Instructions diff --git a/research/maskgan/model_utils/variable_mapping.py b/research/maskgan/model_utils/variable_mapping.py index abfb0b9eec6..0301b969716 100644 --- a/research/maskgan/model_utils/variable_mapping.py +++ b/research/maskgan/model_utils/variable_mapping.py @@ -163,52 +163,48 @@ def rnn_zaremba(hparams, model): if v.op.name == str(model) + '/rnn/embedding' ][0] lstm_w_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - str(model) + '/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == str(model) + + '/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' ][0] lstm_b_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - str(model) + '/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == str(model) + + '/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' ][0] lstm_w_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - str(model) + '/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == str(model) + + '/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' ][0] lstm_b_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - str(model) + '/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == str(model) + + '/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' ][0] # Dictionary mapping. if model == 'gen': variable_mapping = { 'Model/embedding': embedding, - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': lstm_w_0, - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': lstm_b_0, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': lstm_w_1, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': lstm_b_1, + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0, + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1, 'Model/softmax_w': softmax_w, 'Model/softmax_b': softmax_b } else: if FLAGS.dis_share_embedding: variable_mapping = { - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': lstm_w_0, - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': lstm_b_0, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': lstm_w_1, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': lstm_b_1 + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0, + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1 } else: variable_mapping = { 'Model/embedding': embedding, - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': lstm_w_0, - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': lstm_b_0, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': lstm_w_1, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': lstm_b_1 + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': lstm_w_0, + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': lstm_b_0, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': lstm_w_1, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': lstm_b_1 } return variable_mapping @@ -356,24 +352,20 @@ def gen_encoder_seq2seq(hparams): if v.op.name == 'gen/encoder/rnn/embedding' ][0] encoder_lstm_w_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' ][0] encoder_lstm_b_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' ][0] encoder_lstm_w_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' ][0] encoder_lstm_b_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' ][0] if FLAGS.data_set == 'ptb': @@ -385,24 +377,24 @@ def gen_encoder_seq2seq(hparams): variable_mapping = { str(model_str) + '/embedding': encoder_embedding, - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': encoder_lstm_w_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': encoder_lstm_b_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': encoder_lstm_w_1, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': encoder_lstm_b_1 } else: variable_mapping = { - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': encoder_lstm_w_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': encoder_lstm_b_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': encoder_lstm_w_1, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': encoder_lstm_b_1 } return variable_mapping @@ -418,24 +410,20 @@ def gen_decoder_seq2seq(hparams): if v.op.name == 'gen/decoder/rnn/embedding' ][0] decoder_lstm_w_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' ][0] decoder_lstm_b_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' ][0] decoder_lstm_w_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' ][0] decoder_lstm_b_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' ][0] decoder_softmax_b = [ v for v in tf.trainable_variables() @@ -450,13 +438,13 @@ def gen_decoder_seq2seq(hparams): variable_mapping = { str(model_str) + '/embedding': decoder_embedding, - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': decoder_lstm_w_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': decoder_lstm_b_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': decoder_lstm_w_1, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': decoder_lstm_b_1, str(model_str) + '/softmax_b': decoder_softmax_b @@ -487,34 +475,34 @@ def dis_fwd_bidirectional(hparams): ][0] fw_lstm_w_0 = [ v for v in tf.trainable_variables() - if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/weights' + if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' ][0] fw_lstm_b_0 = [ v for v in tf.trainable_variables() - if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/biases' + if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_0/basic_lstm_cell/bias' ][0] fw_lstm_w_1 = [ v for v in tf.trainable_variables() - if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/weights' + if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' ][0] fw_lstm_b_1 = [ v for v in tf.trainable_variables() - if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/biases' + if v.op.name == 'dis/rnn/fw/multi_rnn_cell/cell_1/basic_lstm_cell/bias' ][0] if FLAGS.dis_share_embedding: variable_mapping = { - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': fw_lstm_w_0, - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': fw_lstm_b_0, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': fw_lstm_w_1, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': fw_lstm_b_1 + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': fw_lstm_w_0, + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': fw_lstm_b_0, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': fw_lstm_w_1, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': fw_lstm_b_1 } else: variable_mapping = { 'Model/embedding': embedding, - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': fw_lstm_w_0, - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': fw_lstm_b_0, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': fw_lstm_w_1, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': fw_lstm_b_1 + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': fw_lstm_w_0, + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': fw_lstm_b_0, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': fw_lstm_w_1, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': fw_lstm_b_1 } return variable_mapping @@ -537,26 +525,26 @@ def dis_bwd_bidirectional(hparams): # Backward Discriminator Elements. bw_lstm_w_0 = [ v for v in tf.trainable_variables() - if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/weights' + if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' ][0] bw_lstm_b_0 = [ v for v in tf.trainable_variables() - if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/biases' + if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_0/basic_lstm_cell/bias' ][0] bw_lstm_w_1 = [ v for v in tf.trainable_variables() - if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/weights' + if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' ][0] bw_lstm_b_1 = [ v for v in tf.trainable_variables() - if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/biases' + if v.op.name == 'dis/rnn/bw/multi_rnn_cell/cell_1/basic_lstm_cell/bias' ][0] variable_mapping = { - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': bw_lstm_w_0, - 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': bw_lstm_b_0, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': bw_lstm_w_1, - 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': bw_lstm_b_1 + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': bw_lstm_w_0, + 'Model/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': bw_lstm_b_0, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': bw_lstm_w_1, + 'Model/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': bw_lstm_b_1 } return variable_mapping @@ -576,24 +564,20 @@ def dis_encoder_seq2seq(hparams): ## Encoder forward variables. encoder_lstm_w_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' ][0] encoder_lstm_b_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' ][0] encoder_lstm_w_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' ][0] encoder_lstm_b_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' ][0] if FLAGS.data_set == 'ptb': @@ -602,13 +586,13 @@ def dis_encoder_seq2seq(hparams): model_str = 'model' variable_mapping = { - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': encoder_lstm_w_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': encoder_lstm_b_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': encoder_lstm_w_1, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': encoder_lstm_b_1 } return variable_mapping @@ -624,24 +608,20 @@ def dis_decoder_seq2seq(hparams): if v.op.name == 'dis/decoder/rnn/embedding' ][0] decoder_lstm_w_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' ][0] decoder_lstm_b_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' ][0] decoder_lstm_w_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' ][0] decoder_lstm_b_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' ][0] if FLAGS.data_set == 'ptb': @@ -653,24 +633,24 @@ def dis_decoder_seq2seq(hparams): variable_mapping = { str(model_str) + '/embedding': decoder_embedding, - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': decoder_lstm_w_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': decoder_lstm_b_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': decoder_lstm_w_1, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': decoder_lstm_b_1 } else: variable_mapping = { - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': decoder_lstm_w_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_0/basic_lstm_cell/bias': decoder_lstm_b_0, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/weights': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': decoder_lstm_w_1, - str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/biases': + str(model_str) + '/RNN/multi_rnn_cell/cell_1/basic_lstm_cell/bias': decoder_lstm_b_1, } return variable_mapping @@ -688,24 +668,20 @@ def dis_seq2seq_vd(hparams): ## Encoder variables. encoder_lstm_w_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' ][0] encoder_lstm_b_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'dis/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' ][0] encoder_lstm_w_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' ][0] encoder_lstm_b_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'dis/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' ][0] ## Attention. @@ -721,43 +697,39 @@ def dis_seq2seq_vd(hparams): ## Decoder. decoder_lstm_w_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel' ][0] decoder_lstm_b_0 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'dis/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias' ][0] decoder_lstm_w_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights' + v for v in tf.trainable_variables() if v.op.name == + 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel' ][0] decoder_lstm_b_1 = [ - v for v in tf.trainable_variables() - if v.op.name == - 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases' + v for v in tf.trainable_variables() if v.op.name == + 'dis/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias' ][0] # Standard variable mappings. variable_mapping = { - 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights': + 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': encoder_lstm_w_0, - 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases': + 'gen/encoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias': encoder_lstm_b_0, - 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights': + 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': encoder_lstm_w_1, - 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases': + 'gen/encoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias': encoder_lstm_b_1, - 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/weights': + 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel': decoder_lstm_w_0, - 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/biases': + 'gen/decoder/rnn/multi_rnn_cell/cell_0/basic_lstm_cell/bias': decoder_lstm_b_0, - 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/weights': + 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/kernel': decoder_lstm_w_1, - 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/biases': + 'gen/decoder/rnn/multi_rnn_cell/cell_1/basic_lstm_cell/bias': decoder_lstm_b_1 }