From 839359688f33309e1f79b1334fb7b056bff07a34 Mon Sep 17 00:00:00 2001 From: Jonathan Chang Date: Sat, 3 Oct 2020 16:18:43 +0800 Subject: [PATCH 1/3] Add support for gpt2 batch inferencing --- src/transformers/modeling_gpt2.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 1efd378ede4831..7fd3ea36bfe04c 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -701,10 +701,20 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): if past: input_ids = input_ids[:, -1].unsqueeze(-1) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None return { "input_ids": input_ids, "past_key_values": past, "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, } @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) From 729733b9e230925266e12f9af50351c7d0a14588 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Tue, 13 Oct 2020 23:42:02 +0200 Subject: [PATCH 2/3] add test --- src/transformers/modeling_gpt2.py | 5 +++- tests/test_modeling_gpt2.py | 49 +++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_gpt2.py b/src/transformers/modeling_gpt2.py index 7fd3ea36bfe04c..b9ab17cff961dc 100644 --- a/src/transformers/modeling_gpt2.py +++ b/src/transformers/modeling_gpt2.py @@ -702,7 +702,10 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): input_ids = input_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) - if attention_mask is not None: + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create postion_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past: diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index 6d18d3638add1f..dfaad3aabc47a9 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -32,6 +32,7 @@ GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model, + GPT2Tokenizer, ) @@ -405,6 +406,54 @@ def test_gpt2_gradient_checkpointing(self): config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) + def test_gpt2_batch_generation(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) + self.model_tester.check_batch_generation(*config_and_inputs) + + @slow + def test_batch_generation(self): + model = GPT2LMHeadModel.from_pretrained("gpt2") + model.to(torch_device) + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + + tokenizer.padding_side = "left" + + # Define PAD Token = EOS Token = 50256 + tokenizer.pad_token = tokenizer.eos_token + model.config.pad_token_id = model.config.eos_token_id + + # use different length sentences to test batching + sentences = [ + "Hello, my dog is a little", + "Today, I", + ] + + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + + torch.manual_seed(0) + outputs = model.generate( + input_ids=inputs["input_ids"].to(torch_device), + attention_mask=inputs["attention_mask"].to(torch_device), + ) + + inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) + output_non_padded = model.generate(input_ids=inputs_non_padded) + + num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item() + inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) + output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings) + + batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True) + non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True) + padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True) + + expected_output_sentence = [ + "Hello, my dog is a little bit of a mess. I'm not sure if he's going", + "Today, I'm going to be doing a lot of research on this. I", + ] + self.assertListEqual(expected_output_sentence, batch_out_sentence) + self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence]) + @slow def test_model_from_pretrained(self): for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: From 380af05c189906ddc4ba79cbe427af0e1809c1e1 Mon Sep 17 00:00:00 2001 From: patrickvonplaten Date: Tue, 13 Oct 2020 23:48:37 +0200 Subject: [PATCH 3/3] remove typo --- tests/test_modeling_gpt2.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index dfaad3aabc47a9..75c0de4a229bec 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -406,10 +406,6 @@ def test_gpt2_gradient_checkpointing(self): config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs) - def test_gpt2_batch_generation(self): - config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True) - self.model_tester.check_batch_generation(*config_and_inputs) - @slow def test_batch_generation(self): model = GPT2LMHeadModel.from_pretrained("gpt2")