Skip to content

Commit

Permalink
the extra norm in projecting to prophet model dimensions hurt for som…
Browse files Browse the repository at this point in the history
…e reason
  • Loading branch information
lucidrains committed Oct 4, 2023
1 parent 2315a8a commit 5e8b036
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'speculative-decoding',
packages = find_packages(exclude=[]),
version = '0.1.0',
version = '0.1.1',
license='MIT',
description = 'Speculative Decoding',
author = 'Phil Wang',
Expand Down
6 changes: 4 additions & 2 deletions speculative_decoding/speculative_decoding_with_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,17 @@ def __init__(
model: Decoder,
prophet: Decoder,
prophet_train_length = 8, # should be greater than spec decoding gamma, as main model cache embedding is one step behind
detach_model_embed_for_prophet = False
detach_model_embed_for_prophet = False,
num_leading_start_tokens = 1
):
super().__init__()
self.model = model
self.prophet = prophet

model_prophet_same_dim = model.dim == prophet.dim
self.to_prophet_start_token = nn.Identity() if model_prophet_same_dim else nn.Sequential(RMSNorm(model.dim), nn.Linear(model.dim, prophet.dim, bias = False))
self.to_prophet_start_token = nn.Identity() if model_prophet_same_dim else nn.Linear(model.dim, prophet.dim, bias = False)

self.num_leading_start_tokens = num_leading_start_tokens
self.prophet_train_length = prophet_train_length
self.detach_model_embed_for_prophet = detach_model_embed_for_prophet

Expand Down
3 changes: 2 additions & 1 deletion train_prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,15 @@ def inner(*args, **kwargs):

prophet = Decoder(
num_tokens = 256,
dim = 512,
dim = 256,
depth = 2
)

model_and_prophet = ModelWithProphetWrapper(
model,
prophet,
prophet_train_length = GAMMA + 2,
num_leading_start_tokens = 1,
detach_model_embed_for_prophet = False # train end to end, shouldn't hurt (although benefits is dubious) given ProphetNet paper - of course, trying to get to the bottom of the benefits in spec decoding setting here
).to(device)

Expand Down

0 comments on commit 5e8b036

Please sign in to comment.