Skip to content

add gpt2 model #215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 70 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
b984a02
first commit
MarcelWilnicki Oct 26, 2023
c78ce12
wip
MarcelWilnicki Oct 30, 2023
065d7de
wip
MarcelWilnicki Oct 30, 2023
d21745a
Merge branch 'marcel/gpt2_v2' into marcel/gpt-j
MarcelWilnicki Nov 6, 2023
66ec296
first commit
MarcelWilnicki Nov 6, 2023
5350070
wip
MarcelWilnicki Nov 6, 2023
840b7af
wip
MarcelWilnicki Nov 7, 2023
d5cbb4a
wip
MarcelWilnicki Nov 7, 2023
947af87
wip
MarcelWilnicki Nov 7, 2023
79920f8
wip
MarcelWilnicki Nov 7, 2023
0acdbba
wip
MarcelWilnicki Nov 7, 2023
5dfaac5
wip
MarcelWilnicki Nov 7, 2023
3cc3c7b
wip
MarcelWilnicki Nov 7, 2023
93b256e
wip
MarcelWilnicki Nov 7, 2023
ec39eb5
wip
MarcelWilnicki Nov 8, 2023
d4da8db
wip
MarcelWilnicki Nov 8, 2023
2bb78f1
wip
MarcelWilnicki Nov 8, 2023
3cee9a4
wip
MarcelWilnicki Nov 8, 2023
bd4eb26
wip
MarcelWilnicki Nov 8, 2023
29e607f
wip
MarcelWilnicki Nov 8, 2023
391b705
wip
MarcelWilnicki Nov 8, 2023
2c7f8c9
wip
MarcelWilnicki Nov 8, 2023
da03dd8
wip
MarcelWilnicki Nov 8, 2023
e8f4a94
wip
MarcelWilnicki Nov 8, 2023
59f653a
wip
MarcelWilnicki Nov 8, 2023
5ecb464
wip
MarcelWilnicki Nov 8, 2023
4fb0140
wip
MarcelWilnicki Nov 8, 2023
89ac8ea
wip
MarcelWilnicki Nov 8, 2023
7022601
wip
MarcelWilnicki Nov 8, 2023
e891622
wip
MarcelWilnicki Nov 8, 2023
cf09b18
wip
MarcelWilnicki Nov 8, 2023
5aecdc1
wip
MarcelWilnicki Nov 27, 2023
a9a9064
wip
MarcelWilnicki Nov 27, 2023
ddfcc4e
wip
MarcelWilnicki Nov 27, 2023
483b8ab
wip
MarcelWilnicki Nov 27, 2023
65b9c93
wip
MarcelWilnicki Nov 27, 2023
59d5a14
wip
MarcelWilnicki Nov 27, 2023
00f4f04
wip
MarcelWilnicki Nov 27, 2023
e844c18
wip
MarcelWilnicki Nov 27, 2023
f2f680f
wip
MarcelWilnicki Nov 27, 2023
0ac0aa4
wip
MarcelWilnicki Nov 27, 2023
2d20c60
wip
MarcelWilnicki Nov 27, 2023
0409761
wip
MarcelWilnicki Nov 27, 2023
ebcea3c
wip
MarcelWilnicki Nov 27, 2023
7d20adf
wip
MarcelWilnicki Nov 27, 2023
8c7f13c
wip
MarcelWilnicki Nov 29, 2023
867e062
wip
MarcelWilnicki Nov 30, 2023
901684d
wip
MarcelWilnicki Nov 30, 2023
e0fdb55
wip
MarcelWilnicki Nov 30, 2023
4ee143f
wip
MarcelWilnicki Dec 12, 2023
3f4cf94
wip
MarcelWilnicki Dec 12, 2023
e73173a
wip
MarcelWilnicki Dec 12, 2023
006bacb
wip
MarcelWilnicki Dec 13, 2023
9766d1a
wip
MarcelWilnicki Dec 13, 2023
38abcae
wip
MarcelWilnicki Dec 13, 2023
a2968c6
wip
MarcelWilnicki Dec 13, 2023
f2dee88
wip
MarcelWilnicki Dec 14, 2023
0b79157
wip
MarcelWilnicki Dec 14, 2023
f4f2e8f
cleanup
MarcelWilnicki Dec 14, 2023
d2b84b4
wip
MarcelWilnicki Dec 14, 2023
b0b358c
wip
MarcelWilnicki Dec 14, 2023
e7882dd
wip
MarcelWilnicki Dec 14, 2023
3a33cd9
wip
MarcelWilnicki Dec 14, 2023
61b9049
clean
MarcelWilnicki Dec 14, 2023
e300b42
wip
MarcelWilnicki Dec 14, 2023
fa6bf52
wip
MarcelWilnicki Dec 18, 2023
097f2b5
wip
MarcelWilnicki Dec 18, 2023
ac83528
wip
MarcelWilnicki Dec 18, 2023
9d74028
wip
MarcelWilnicki Dec 18, 2023
0944b67
wip
MarcelWilnicki Dec 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions natural_language_processing/text_generation/gpt/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import os

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

from utils.benchmark import run_model
from utils.nlp.lambada import Lambada


def run_pytorch_fp32(model_name, batch_size, num_runs, timeout, lambada_path, **kwargs):
from utils.pytorch import PyTorchRunnerV2, apply_compile_maybe

def run_single_pass(pytorch_runner, lambada):
start_ids = lambada.get_input_array()[0]
output = pytorch_runner.run(inputs=start_ids, max_new_tokens=10, pad_token_id=tokenizer.pad_token_id)
pytorch_runner.set_task_size(output.shape[1] - start_ids.shape[1])
output = detokenize(output[0])

for i in range(batch_size):
first_new_word = output.replace(detokenize(start_ids[0]), '').split()[0]
lambada.submit_prediction(i, first_new_word)

tokenizer = GPT2Tokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

def detokenize(answer):
return tokenizer.decode(answer, skip_special_tokens=True)

def tokenize(text):
return tokenizer.encode(text, return_tensors='pt')

model = GPT2LMHeadModel.from_pretrained(model_name, torchscript=True).eval()
dataset = Lambada(batch_size, tokenize, detokenize, lambada_path)
aio = '_aio_profiler_print' in dir(torch._C) and os.environ.get("AIO_PROCESS_MODE") != "0"
model.greedy_search = apply_compile_maybe(model.greedy_search, aio)
runner = PyTorchRunnerV2(model.generate)

return run_model(run_single_pass, runner, dataset, batch_size, num_runs, timeout)


if __name__ == "__main__":
from utils.helpers import DefaultArgParser

gpt_variants = ["gpt2"]
parser = DefaultArgParser(["pytorch"])
parser.require_model_name(gpt_variants)
parser.ask_for_batch_size()
parser.add_argument('--lambada_path', type=str, required=True, help="Path to Lambada dataset")
run_pytorch_fp32(**vars(parser.parse()))
1 change: 1 addition & 0 deletions utils/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def set_task_size(self, new_task_size):
"""
if new_task_size is None:
return

assert len(self._finish_times) - len(self._workload_size) in [1, 0]
self._workload_size.append(new_task_size)

Expand Down
5 changes: 4 additions & 1 deletion utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def runner_func():

self._start_times.append(start)
self._finish_times.append(finish)
self._workload_size.append(task_size)
self.set_task_size(task_size)
self._times_invoked += 1

return output
Expand Down Expand Up @@ -208,6 +208,9 @@ def apply_jit_script(model):
def apply_jit_trace(model, example_inputs):
return load_from_cache_or_apply(model, lambda: torch.jit.trace(model, example_inputs))

def apply_jit_trace_module(model, example_inputs):
return load_from_cache_or_apply(model, lambda: torch.jit.trace_module(model, example_inputs))


def apply_compile_maybe(model, aio):
if os.environ.get("TORCH_COMPILE") != "1":
Expand Down