-
Notifications
You must be signed in to change notification settings - Fork 1
Measure acceptance rate #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d92eb28
c5177e2
d2da6c0
48117c5
7f7df49
ae64a2a
6e76cc0
9999f6f
bc3f79c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
from typing import Literal | ||
|
||
import torch | ||
from pydantic import Field | ||
|
||
from dsi.configs.experiment.latency import ConfigLatency | ||
|
||
|
||
class ConfigAcceptanteRate(ConfigLatency): | ||
"""Includes all the parameters needed for measuring the acceptance rate | ||
of a (target, draft, dataset) triplet. | ||
""" | ||
|
||
draft_model: str = Field(title="The draft model to use for the experiment") | ||
jmamou marked this conversation as resolved.
Show resolved
Hide resolved
|
||
draft_dtype: Literal["float32", "float16", "bfloat16"] = Field( | ||
"bfloat16", title="The dtype of the draft model to use for the experiment" | ||
) | ||
draft_compile_model: bool = Field( | ||
False, title="Whether to torch.compile() the draft model" | ||
) | ||
draft_revision: None | str = Field( | ||
None, title="The revision of the draft model to use" | ||
) | ||
|
||
def get_torch_draft_dtype(self) -> torch.dtype: | ||
return eval(f"torch.{self.draft_dtype}") |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,114 @@ | ||||||||||||||||||||||||||
import logging | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||
from tqdm import tqdm | ||||||||||||||||||||||||||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
from dsi.configs.experiment.acceptance import ConfigAcceptanteRate | ||||||||||||||||||||||||||
from dsi.configs.experiment.generation import ConfigGen | ||||||||||||||||||||||||||
from dsi.online.latency.experiment import ExperimentLatency | ||||||||||||||||||||||||||
from dsi.types.result import ResultAcceptance | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
log = logging.getLogger(__name__) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
class ExperimentAcceptanceRate(ExperimentLatency): | ||||||||||||||||||||||||||
jmamou marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
Measures the generation acceptance rate. | ||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||
config: ConfigAcceptanteRate, | ||||||||||||||||||||||||||
gen_config: ConfigGen, | ||||||||||||||||||||||||||
draft_gen_config: ConfigGen, | ||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||
self.config: ConfigAcceptanteRate | ||||||||||||||||||||||||||
super().__init__(config, gen_config) | ||||||||||||||||||||||||||
self.draft_gen_config: ConfigGen = draft_gen_config | ||||||||||||||||||||||||||
Comment on lines
+21
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please rebase the
Suggested change
The other arguments ( |
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _load_draft_model_tokenizer(self) -> tuple: | ||||||||||||||||||||||||||
log.info( | ||||||||||||||||||||||||||
f"Loading model: {self.config.draft_model}, \ | ||||||||||||||||||||||||||
compile={self.config.draft_compile_model}" | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
model = AutoModelForCausalLM.from_pretrained( | ||||||||||||||||||||||||||
self.config.draft_model, | ||||||||||||||||||||||||||
device_map="auto", | ||||||||||||||||||||||||||
torch_dtype=self.config.get_torch_draft_dtype(), | ||||||||||||||||||||||||||
revision=self.config.draft_revision, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
tokenizer = AutoTokenizer.from_pretrained(self.config.draft_model) | ||||||||||||||||||||||||||
model = torch.compile(model) if self.config.draft_compile_model else model | ||||||||||||||||||||||||||
return model, tokenizer | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _are_tokenizers_same(self, tokenizer1, tokenizer2) -> bool: | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use a static method? Also, please rename the function to communicate that it only checks the the tokens and their corresponding input id. (The current function returns
Suggested change
|
||||||||||||||||||||||||||
# Compare vocabularies | ||||||||||||||||||||||||||
if tokenizer1.get_vocab() != tokenizer2.get_vocab(): | ||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Compare special tokens | ||||||||||||||||||||||||||
special_tokens_to_compare = [ | ||||||||||||||||||||||||||
"eos_token_id", | ||||||||||||||||||||||||||
"pad_token_id", | ||||||||||||||||||||||||||
"bos_token_id", | ||||||||||||||||||||||||||
"unk_token_id", | ||||||||||||||||||||||||||
] | ||||||||||||||||||||||||||
for token in special_tokens_to_compare: | ||||||||||||||||||||||||||
if getattr(tokenizer1, token, None) != getattr(tokenizer2, token, None): | ||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
return True | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
def _single_repeat(self) -> ResultAcceptance: | ||||||||||||||||||||||||||
all_n_matches = [] | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
examples = self._get_random_prompted_examples() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
target_model, target_tokenizer = self._load_model_tokenizer() | ||||||||||||||||||||||||||
draft_model, draft_tokenizer = self._load_draft_model_tokenizer() | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# Check if tokenizers are the same | ||||||||||||||||||||||||||
if not self._are_tokenizers_same(target_tokenizer, draft_tokenizer): | ||||||||||||||||||||||||||
raise ValueError("The target and draft tokenizers are not the same.") | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you define a custom exception in |
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
target_gen_kwargs = dict( | ||||||||||||||||||||||||||
do_sample=self.gen_config.do_sample, | ||||||||||||||||||||||||||
temperature=self.gen_config.temperature, | ||||||||||||||||||||||||||
top_p=self.gen_config.top_p, | ||||||||||||||||||||||||||
pad_token_id=target_tokenizer.eos_token_id, | ||||||||||||||||||||||||||
max_new_tokens=self.config.max_new_tokens, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
draft_gen_kwargs = dict( | ||||||||||||||||||||||||||
do_sample=self.draft_gen_config.do_sample, | ||||||||||||||||||||||||||
temperature=self.draft_gen_config.temperature, | ||||||||||||||||||||||||||
top_p=self.draft_gen_config.top_p, | ||||||||||||||||||||||||||
pad_token_id=target_tokenizer.eos_token_id, | ||||||||||||||||||||||||||
max_new_tokens=1, | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
for ex in tqdm(examples): | ||||||||||||||||||||||||||
inputs = target_tokenizer(ex, return_tensors="pt").to(target_model.device) | ||||||||||||||||||||||||||
n_matches = [0] | ||||||||||||||||||||||||||
jmamou marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
output_target = target_model.generate(**inputs, **target_gen_kwargs) | ||||||||||||||||||||||||||
prompt_len = len(inputs.input_ids[0]) | ||||||||||||||||||||||||||
jmamou marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
# iterate over the tokens generated by the target and | ||||||||||||||||||||||||||
# check whether the draft produces the same token | ||||||||||||||||||||||||||
for i in range(prompt_len, len(output_target[0])): | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it still unclear what |
||||||||||||||||||||||||||
inputs["input_ids"] = output_target[0, 0:i].view(1, -1) | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear what def get_input_ids_prefix(output_seqs, tok_pos_last: int):
"""Returns a prefix of a sequence of input ids.
Args:
output_seqs: The output from Hugging Face's transformers `model.generate` function, expected to be a sequence-like data structure.
tok_pos_last: int - The last token position to include in the returned sequence.
"""
return output_seqs[0, 0:tok_pos_last] |
||||||||||||||||||||||||||
inputs["attention_mask"] = torch.tensor( | ||||||||||||||||||||||||||
[[1] * i], device=draft_model.device | ||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||
Comment on lines
+102
to
+104
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use
Suggested change
|
||||||||||||||||||||||||||
output_draft = draft_model.generate(**inputs, **draft_gen_kwargs) | ||||||||||||||||||||||||||
if output_draft[-1, i] == output_target[-1, i]: | ||||||||||||||||||||||||||
n_matches[-1] += 1 | ||||||||||||||||||||||||||
elif i < len(output_target[0]) - 1: # new window | ||||||||||||||||||||||||||
n_matches.append(0) | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add tests covering the experiment. It seems like this line makes |
||||||||||||||||||||||||||
else: # at the end, remove last window | ||||||||||||||||||||||||||
n_matches.pop() | ||||||||||||||||||||||||||
all_n_matches += n_matches | ||||||||||||||||||||||||||
ar = 1 - (1 / (1 + np.array(all_n_matches).mean())) | ||||||||||||||||||||||||||
return ResultAcceptance(acceptance_rate=[ar]) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
import torch | ||
from transformers import AutoTokenizer | ||
|
||
from dsi.configs.experiment.acceptance import ConfigAcceptanteRate | ||
from dsi.configs.experiment.generation import ConfigGen | ||
from dsi.offline.acceptance.experiment import ExperimentAcceptanceRate | ||
from dsi.types.result import ResultAcceptance | ||
|
||
|
||
@pytest.fixture | ||
def experiment(): | ||
config = ConfigAcceptanteRate( | ||
model="target_model", dataset="dataset", draft_model="draft_model" | ||
) | ||
gen_config = ConfigGen() | ||
draft_gen_config = ConfigGen() | ||
return ExperimentAcceptanceRate(config, gen_config, draft_gen_config) | ||
|
||
|
||
@pytest.fixture | ||
def mock_dependencies(): | ||
with patch( | ||
"transformers.AutoModelForCausalLM.from_pretrained" | ||
) as mock_model, patch( | ||
"transformers.AutoTokenizer.from_pretrained" | ||
) as mock_tokenizer, patch.object( | ||
ExperimentAcceptanceRate, "_get_random_prompted_examples" | ||
) as mock_examples: | ||
mock_model.return_value = MagicMock() | ||
mock_model.return_value.device = torch.device("cpu") | ||
mock_tokenizer.return_value = MagicMock() | ||
mock_examples.return_value = ["example1"] | ||
yield mock_model, mock_tokenizer, mock_examples | ||
|
||
|
||
def test_are_tokenizers_same_identical(experiment): | ||
tokenizer1 = AutoTokenizer.from_pretrained("double7/vicuna-68m") | ||
tokenizer2 = AutoTokenizer.from_pretrained("double7/vicuna-68m") | ||
assert experiment._are_tokenizers_same(tokenizer1, tokenizer2) | ||
|
||
|
||
def test_are_tokenizers_same_diff_config(experiment): | ||
tokenizer1 = MagicMock() | ||
tokenizer2 = MagicMock() | ||
tokenizer1.config = {"model_type": "bigcode/starcoder"} | ||
tokenizer2.config = {"model_type": "double7/vicuna-68m"} | ||
assert not experiment._are_tokenizers_same(tokenizer1, tokenizer2) | ||
|
||
|
||
def test_are_tokenizers_same_diff_vocab(experiment): | ||
tokenizer1 = MagicMock() | ||
tokenizer2 = MagicMock() | ||
tokenizer1.get_vocab.return_value = {"hello": 1, "world": 2} | ||
tokenizer2.get_vocab.return_value = {"hello": 1, "python": 3} | ||
assert not experiment._are_tokenizers_same(tokenizer1, tokenizer2) | ||
|
||
|
||
def test_are_tokenizers_same_diff_special_tokens(experiment): | ||
tokenizer1 = MagicMock() | ||
tokenizer2 = MagicMock() | ||
tokenizer1.eos_token_id = 1 | ||
tokenizer2.eos_token_id = 2 | ||
tokenizer1.pad_token_id = 0 | ||
tokenizer2.pad_token_id = 0 | ||
tokenizer1.bos_token_id = -1 | ||
tokenizer2.bos_token_id = -1 | ||
tokenizer1.unk_token_id = 3 | ||
tokenizer2.unk_token_id = 3 | ||
assert not experiment._are_tokenizers_same(tokenizer1, tokenizer2) | ||
|
||
|
||
def test_single_repeat_all_match(experiment, mock_dependencies): | ||
mock_model, _, _ = mock_dependencies | ||
# Mock the generate method to simulate target and draft models | ||
# producing the same output | ||
mock_model.return_value.generate.side_effect = [ | ||
torch.tensor([[0, 1, 2, 3]]), | ||
torch.tensor([[0]]), | ||
torch.tensor([[0, 1]]), | ||
torch.tensor([[0, 1, 2]]), | ||
torch.tensor([[0, 1, 2, 3]]), | ||
] | ||
result = experiment._single_repeat() | ||
assert isinstance(result, ResultAcceptance) | ||
# Since all tokens match, acceptance rate should be 1 | ||
assert result.acceptance_rate[0] == 0.8 | ||
Comment on lines
+88
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like a discrepancy. The test passes, suggesting that there is a bug. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @keyboardAnt There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jmamou There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @keyboardAnt There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jmamou, as we discussed today over the phone, the acceptance rate has only one definition:
We estimate the acceptance rate as follows. Let In practice, drafters may have a 100% acceptance rate. For example, an instance of the target running on faster hardware as a drafter, with both the target and drafter sampled greedily. But in such cases the perfect acceptance rate is guaranteed, and I do not see a reason to test it using our experiment. On the contrary, in most practical settings, the acceptance rate is |
||
|
||
|
||
def test_single_repeat_no_match(experiment, mock_dependencies): | ||
mock_model, _, _ = mock_dependencies | ||
# Mock the generate method to simulate target and draft models | ||
# producing different outputs | ||
mock_model.return_value.generate.side_effect = [ | ||
torch.tensor([[0, 1, 2, 3]]), | ||
torch.tensor([[4]]), | ||
torch.tensor([[0, 5]]), | ||
torch.tensor([[0, 1, 6]]), | ||
torch.tensor([[0, 1, 2, 7]]), | ||
] | ||
result = experiment._single_repeat() | ||
assert isinstance(result, ResultAcceptance) | ||
# Since no tokens match, acceptance rate should be 0 | ||
assert result.acceptance_rate[0] == 0 | ||
|
||
|
||
def test_single_repeat_partial_match(experiment, mock_dependencies): | ||
mock_model, _, _ = mock_dependencies | ||
# Mock the generate method to simulate target and draft models | ||
# producing partially matching outputs | ||
mock_model.return_value.generate.side_effect = [ | ||
torch.tensor([[0, 1, 2, 3]]), | ||
torch.tensor([[0]]), | ||
torch.tensor([[0, 4]]), | ||
torch.tensor([[0, 1, 2]]), | ||
torch.tensor([[0, 1, 2, 5]]), | ||
] | ||
result = experiment._single_repeat() | ||
assert isinstance(result, ResultAcceptance) | ||
# Since half of the tokens match, acceptance rate should be 0.5 | ||
assert result.acceptance_rate[0] == 0.5 | ||
|
||
|
||
def test_config_acceptance_initialization_defaults(): | ||
config = ConfigAcceptanteRate(model="m", dataset="d", draft_model="dr") | ||
assert config.draft_model == "dr" | ||
assert config.draft_dtype == "bfloat16" | ||
assert config.draft_revision is None | ||
assert config.draft_compile_model is False | ||
|
||
|
||
def test_config_acceptance_initialization_custom(): | ||
config = ConfigAcceptanteRate( | ||
model="m", | ||
dataset="d", | ||
draft_model="test_model", | ||
draft_dtype="float32", | ||
draft_compile_model=True, | ||
draft_revision="test_revision", | ||
) | ||
assert config.draft_model == "test_model" | ||
assert config.draft_dtype == "float32" | ||
assert config.draft_compile_model is True | ||
assert config.draft_revision == "test_revision" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"dtype,expected", | ||
[ | ||
("float32", torch.float32), | ||
("float16", torch.float16), | ||
("bfloat16", torch.bfloat16), | ||
], | ||
) | ||
def test_get_torch_dtype(dtype, expected): | ||
config = ConfigAcceptanteRate( | ||
model="m", dataset="d", draft_model="dr", draft_dtype=dtype | ||
) | ||
assert config.get_torch_draft_dtype() == expected | ||
|
||
|
||
def test_draft_revision_optional(): | ||
config = ConfigAcceptanteRate(model="m", dataset="d", draft_model="dr") | ||
assert config.draft_revision is None | ||
config = ConfigAcceptanteRate( | ||
model="m", dataset="d", draft_model="dr", draft_revision="rev1" | ||
) | ||
assert config.draft_revision == "rev1" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.