Skip to content

Update core.py to have 1 extra token #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

mosheber
Copy link
Collaborator

  • added 1 token from fix history and all correct from the target

* added 1 token from fix history and all correct from the target
@mosheber mosheber requested a review from keyboardAnt July 14, 2024 13:55
mosheber and others added 2 commits July 14, 2024 17:27
* added +1 to the expected token count per iteration
Copy link
Owner

@keyboardAnt keyboardAnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see the attached comments.

Comment on lines 31 to 33
total_tokens += correct + 1

sim_shared_dict["total_tokens"] = total_tokens
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since total_tokens is not used after the following line, please consider this minor suggested change:

Suggested change
total_tokens += correct + 1
sim_shared_dict["total_tokens"] = total_tokens
sim_shared_dict["total_tokens"] = total_tokens + correct + 1

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -99,7 +99,7 @@ def target_done_callback(args, res):
else:
# ALL CORRECT with {total_tokens + draft_tokens}

res_dict["total_tokens"] += res_dict["correct"]
res_dict["total_tokens"] += res_dict["correct"] + 1

if res_dict["total_tokens"] > args.max_tokens:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean we generate max_tokens+1 new tokens? Why not >=?

Suggested change
if res_dict["total_tokens"] > args.max_tokens:
if res_dict["total_tokens"] >= args.max_tokens:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -99,7 +99,7 @@ def target_done_callback(args, res):
else:
# ALL CORRECT with {total_tokens + draft_tokens}

res_dict["total_tokens"] += res_dict["correct"]
res_dict["total_tokens"] += res_dict["correct"] + 1
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call this line if the target accepts all the 'lookahead' draft tokens, right? Does setting res_dict["total_tokens"] += res_dict["correct"] + 1 mean we accept an additional token? I'm asking because we shouldn't accept an additional token unless it is the last token. We accept an additional token only if it is the last token (e.g., the 50th token where config.S == 50) or the target rejects at least one draft token (in the current iteration). Instead, we should terminate the iteration that speculates this additional token with probability 1 - acceptance_rate.

To conclude, there are two changes to the online simulation to boost the speedup of DSI:

  1. Accept an extra token if the target rejects a draft or if the extra token is the last.
  2. Simulate an immediate validation of the extra token by terminating the corresponding speculating iteration with probability 1 - acceptance_rate.

We can separate them into two PRs.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@keyboardAnt
Copy link
Owner

@mosheber, please see my previous comments and let me know when the tests pass so I can do another iteration.

@keyboardAnt keyboardAnt self-requested a review July 28, 2024 20:45
Copy link
Owner

@keyboardAnt keyboardAnt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please merge or rebase main and remove the skip marker in test_duration and test_num_of_fix_history (tests/integration/online/test_simul.py). To run these two tests serially: python ./scripts/test.py online -- -vvv

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants