diff --git a/natural_language_processing/text_generation/gpt/run.py b/natural_language_processing/text_generation/gpt/run.py new file mode 100644 index 00000000..b8e9978e --- /dev/null +++ b/natural_language_processing/text_generation/gpt/run.py @@ -0,0 +1,50 @@ +import os + +import torch +from transformers import GPT2Tokenizer, GPT2LMHeadModel + +from utils.benchmark import run_model +from utils.nlp.lambada import Lambada + + +def run_pytorch_fp32(model_name, batch_size, num_runs, timeout, lambada_path, **kwargs): + from utils.pytorch import PyTorchRunnerV2, apply_compile_maybe + + def run_single_pass(pytorch_runner, lambada): + start_ids = lambada.get_input_array()[0] + output = pytorch_runner.run(inputs=start_ids, max_new_tokens=10, pad_token_id=tokenizer.pad_token_id) + pytorch_runner.set_task_size(output.shape[1] - start_ids.shape[1]) + output = detokenize(output[0]) + + for i in range(batch_size): + first_new_word = output.replace(detokenize(start_ids[0]), '').split()[0] + lambada.submit_prediction(i, first_new_word) + + tokenizer = GPT2Tokenizer.from_pretrained(model_name) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + def detokenize(answer): + return tokenizer.decode(answer, skip_special_tokens=True) + + def tokenize(text): + return tokenizer.encode(text, return_tensors='pt') + + model = GPT2LMHeadModel.from_pretrained(model_name, torchscript=True).eval() + dataset = Lambada(batch_size, tokenize, detokenize, lambada_path) + aio = '_aio_profiler_print' in dir(torch._C) and os.environ.get("AIO_PROCESS_MODE") != "0" + model.greedy_search = apply_compile_maybe(model.greedy_search, aio) + runner = PyTorchRunnerV2(model.generate) + + return run_model(run_single_pass, runner, dataset, batch_size, num_runs, timeout) + + +if __name__ == "__main__": + from utils.helpers import DefaultArgParser + + gpt_variants = ["gpt2"] + parser = DefaultArgParser(["pytorch"]) + parser.require_model_name(gpt_variants) + parser.ask_for_batch_size() + parser.add_argument('--lambada_path', type=str, required=True, help="Path to Lambada dataset") + run_pytorch_fp32(**vars(parser.parse())) diff --git a/utils/benchmark.py b/utils/benchmark.py index b6f61f71..1895cc03 100644 --- a/utils/benchmark.py +++ b/utils/benchmark.py @@ -124,6 +124,7 @@ def set_task_size(self, new_task_size): """ if new_task_size is None: return + assert len(self._finish_times) - len(self._workload_size) in [1, 0] self._workload_size.append(new_task_size) diff --git a/utils/pytorch.py b/utils/pytorch.py index 8ed76cdc..9aff5f86 100644 --- a/utils/pytorch.py +++ b/utils/pytorch.py @@ -92,7 +92,7 @@ def runner_func(): self._start_times.append(start) self._finish_times.append(finish) - self._workload_size.append(task_size) + self.set_task_size(task_size) self._times_invoked += 1 return output @@ -208,6 +208,9 @@ def apply_jit_script(model): def apply_jit_trace(model, example_inputs): return load_from_cache_or_apply(model, lambda: torch.jit.trace(model, example_inputs)) +def apply_jit_trace_module(model, example_inputs): + return load_from_cache_or_apply(model, lambda: torch.jit.trace_module(model, example_inputs)) + def apply_compile_maybe(model, aio): if os.environ.get("TORCH_COMPILE") != "1":