Skip to content

Commit 8967fac

Browse files
authored
Merge pull request #5 from urialon/fix_finalise
Removing argument 'is_train' which is not used in SeqDecoder.finalise_minibatch
2 parents 0753df1 + 282a38a commit 8967fac

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

Models/exprsynth/seq2seqmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _extend_minibatch_by_sample(self, batch_data: Dict[str, Any], sample: Dict[s
8383

8484
def _finalise_minibatch(self, batch_data: Dict[str, Any], is_train: bool) -> Dict[tf.Tensor, Any]:
8585
minibatch = super()._finalise_minibatch(batch_data, is_train)
86-
self._decoder_model.finalise_minibatch(batch_data, minibatch, is_train)
86+
self._decoder_model.finalise_minibatch(batch_data, minibatch)
8787
return minibatch
8888

8989
# ------- These are the bits that we only need for test-time:

Models/exprsynth/seqdecoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def extend_minibatch_by_sample(self, batch_data: Dict[str, Any], sample: Dict[st
290290
batch_data['target_token_ids'].append(sample['target_token_ids'])
291291
batch_data['target_token_ids_mask'].append(sample['target_token_ids_mask'])
292292

293-
def finalise_minibatch(self, batch_data: Dict[str, Any], minibatch: Dict[tf.Tensor, Any], is_train: bool) -> None:
293+
def finalise_minibatch(self, batch_data: Dict[str, Any], minibatch: Dict[tf.Tensor, Any]) -> None:
294294
write_to_minibatch(minibatch, self.placeholders['target_token_ids'], batch_data['target_token_ids'])
295295
write_to_minibatch(minibatch, self.placeholders['target_token_ids_mask'], batch_data['target_token_ids_mask'])
296296

0 commit comments

Comments
 (0)