Skip to content

Commit

Permalink
fix start token and determining input sequence
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 2, 2023
1 parent 94002e5 commit fd2cc0f
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.0.14',
version = '0.0.15',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
4 changes: 2 additions & 2 deletions speculative_decoding/speculative_decoding_with_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def forward(
has_start_tokens = exists(start_tokens)

if return_loss:
label_start_index = (1 if not has_start_tokens else 0)
x, labels = x[:, :-1], x[:, label_start_index:]
start_index = (1 if has_start_tokens else 0)
x, labels = x[:, start_index:-1], x[:, 1:]

x = self.token_emb(x)

Expand Down
2 changes: 1 addition & 1 deletion train_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __len__(self):
optim.step()
optim.zero_grad()

if i % GENERATE_EVERY == 0:
if False and i % GENERATE_EVERY == 0:
model.eval()

inp = random.choice(val_dataset)[:PRIME_LENGTH]
Expand Down

0 comments on commit fd2cc0f

Please sign in to comment.