diff --git a/setup.py b/setup.py index a3ada28..fa83872 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'speculative-decoding', packages = find_packages(exclude=[]), - version = '0.0.14', + version = '0.0.15', license='MIT', description = 'Speculative Decoding', author = 'Phil Wang', diff --git a/speculative_decoding/speculative_decoding_with_prophet.py b/speculative_decoding/speculative_decoding_with_prophet.py index 3e093f4..31f4248 100644 --- a/speculative_decoding/speculative_decoding_with_prophet.py +++ b/speculative_decoding/speculative_decoding_with_prophet.py @@ -388,8 +388,8 @@ def forward( has_start_tokens = exists(start_tokens) if return_loss: - label_start_index = (1 if not has_start_tokens else 0) - x, labels = x[:, :-1], x[:, label_start_index:] + start_index = (1 if has_start_tokens else 0) + x, labels = x[:, start_index:-1], x[:, 1:] x = self.token_emb(x) diff --git a/train_prophet.py b/train_prophet.py index 9313a97..ebd0854 100644 --- a/train_prophet.py +++ b/train_prophet.py @@ -135,7 +135,7 @@ def __len__(self): optim.step() optim.zero_grad() - if i % GENERATE_EVERY == 0: + if False and i % GENERATE_EVERY == 0: model.eval() inp = random.choice(val_dataset)[:PRIME_LENGTH]