-
Notifications
You must be signed in to change notification settings - Fork 1
Update core.py to have 1 extra token #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
mosheber
commented
Jul 14, 2024
- added 1 token from fix history and all correct from the target
* added 1 token from fix history and all correct from the target
* added +1 to the expected token count per iteration
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see the attached comments.
dsi/online/simul/core.py
Outdated
total_tokens += correct + 1 | ||
|
||
sim_shared_dict["total_tokens"] = total_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since total_tokens
is not used after the following line, please consider this minor suggested change:
total_tokens += correct + 1 | |
sim_shared_dict["total_tokens"] = total_tokens | |
sim_shared_dict["total_tokens"] = total_tokens + correct + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
dsi/online/simul/core.py
Outdated
@@ -99,7 +99,7 @@ def target_done_callback(args, res): | |||
else: | |||
# ALL CORRECT with {total_tokens + draft_tokens} | |||
|
|||
res_dict["total_tokens"] += res_dict["correct"] | |||
res_dict["total_tokens"] += res_dict["correct"] + 1 | |||
|
|||
if res_dict["total_tokens"] > args.max_tokens: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean we generate max_tokens+1
new tokens? Why not >=
?
if res_dict["total_tokens"] > args.max_tokens: | |
if res_dict["total_tokens"] >= args.max_tokens: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -99,7 +99,7 @@ def target_done_callback(args, res): | |||
else: | |||
# ALL CORRECT with {total_tokens + draft_tokens} | |||
|
|||
res_dict["total_tokens"] += res_dict["correct"] | |||
res_dict["total_tokens"] += res_dict["correct"] + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We call this line if the target accepts all the 'lookahead' draft tokens, right? Does setting res_dict["total_tokens"] += res_dict["correct"] + 1
mean we accept an additional token? I'm asking because we shouldn't accept an additional token unless it is the last token. We accept an additional token only if it is the last token (e.g., the 50th token where config.S == 50
) or the target rejects at least one draft token (in the current iteration). Instead, we should terminate the iteration that speculates this additional token with probability 1 - acceptance_rate
.
To conclude, there are two changes to the online simulation to boost the speedup of DSI:
- Accept an extra token if the target rejects a draft or if the extra token is the last.
- Simulate an immediate validation of the extra token by terminating the corresponding speculating iteration with probability
1 - acceptance_rate
.
We can separate them into two PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@mosheber, please see my previous comments and let me know when the tests pass so I can do another iteration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please merge or rebase main
and remove the skip marker in test_duration
and test_num_of_fix_history
(tests/integration/online/test_simul.py
). To run these two tests serially: python ./scripts/test.py online -- -vvv