Skip to content

Measure acceptance rate #34

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions dsi/configs/experiment/acceptance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Literal

import torch
from pydantic import Field

from dsi.configs.experiment.latency import ConfigLatency


class ConfigAcceptanteRate(ConfigLatency):
"""Includes all the parameters needed for measuring the acceptance rate
of a (target, draft, dataset) triplet.
"""

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
draft_gen_config: ConfigGen = Field(
default_factory=ConfigGen, title="Configuration of the generation from the drafter"
)

draft_model: str = Field(title="The draft model to use for the experiment")
draft_dtype: Literal["float32", "float16", "bfloat16"] = Field(
"bfloat16", title="The dtype of the draft model to use for the experiment"
)
draft_compile_model: bool = Field(
False, title="Whether to torch.compile() the draft model"
)
draft_revision: None | str = Field(
None, title="The revision of the draft model to use"
)

def get_torch_draft_dtype(self) -> torch.dtype:
return eval(f"torch.{self.draft_dtype}")
114 changes: 114 additions & 0 deletions dsi/offline/acceptance/experiment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import logging

import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

from dsi.configs.experiment.acceptance import ConfigAcceptanteRate
from dsi.configs.experiment.generation import ConfigGen
from dsi.online.latency.experiment import ExperimentLatency
from dsi.types.result import ResultAcceptance

log = logging.getLogger(__name__)


class ExperimentAcceptanceRate(ExperimentLatency):
"""
Measures the generation acceptance rate.
"""

def __init__(
self,
config: ConfigAcceptanteRate,
gen_config: ConfigGen,
draft_gen_config: ConfigGen,
):
self.config: ConfigAcceptanteRate
super().__init__(config, gen_config)
self.draft_gen_config: ConfigGen = draft_gen_config
Comment on lines +21 to +29
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rebase the main branch. To follow the interface of _Experiment, please use the same __init__ signature. That is

Suggested change
def __init__(
self,
config: ConfigAcceptanteRate,
gen_config: ConfigGen,
draft_gen_config: ConfigGen,
):
self.config: ConfigAcceptanteRate
super().__init__(config, gen_config)
self.draft_gen_config: ConfigGen = draft_gen_config
def __init__(self, config: ConfigAcceptanteRate)
self.config: ConfigAcceptanteRate
super().__init__(config)

The other arguments (gen_config: ConfigGen and draft_gen_config: ConfigGen) should be passed within config: ConfigAcceptanteRate. After rebasing main, you'll see that ExperimentLatency is now fixed similarly.


def _load_draft_model_tokenizer(self) -> tuple:
log.info(
f"Loading model: {self.config.draft_model}, \
compile={self.config.draft_compile_model}"
)
model = AutoModelForCausalLM.from_pretrained(
self.config.draft_model,
device_map="auto",
torch_dtype=self.config.get_torch_draft_dtype(),
revision=self.config.draft_revision,
)
tokenizer = AutoTokenizer.from_pretrained(self.config.draft_model)
model = torch.compile(model) if self.config.draft_compile_model else model
return model, tokenizer

def _are_tokenizers_same(self, tokenizer1, tokenizer2) -> bool:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use a static method? Also, please rename the function to communicate that it only checks the the tokens and their corresponding input id. (The current function returns True for tokenizers with the same vocab—even if they encode an input id to different vectors)

Suggested change
def _are_tokenizers_same(self, tokenizer1, tokenizer2) -> bool:
@staticmethod
def _are_vocabs_same(tokenizer1, tokenizer2) -> bool:

# Compare vocabularies
if tokenizer1.get_vocab() != tokenizer2.get_vocab():
return False

# Compare special tokens
special_tokens_to_compare = [
"eos_token_id",
"pad_token_id",
"bos_token_id",
"unk_token_id",
]
for token in special_tokens_to_compare:
if getattr(tokenizer1, token, None) != getattr(tokenizer2, token, None):
return False

return True

def _single_repeat(self) -> ResultAcceptance:
all_n_matches = []

examples = self._get_random_prompted_examples()

target_model, target_tokenizer = self._load_model_tokenizer()
draft_model, draft_tokenizer = self._load_draft_model_tokenizer()

# Check if tokenizers are the same
if not self._are_tokenizers_same(target_tokenizer, draft_tokenizer):
raise ValueError("The target and draft tokenizers are not the same.")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you define a custom exception in dsi/types/exception.py instead of using ValueError? Doing so aligns with best practices by enhancing the specificity and clarity of our error handling.


target_gen_kwargs = dict(
do_sample=self.gen_config.do_sample,
temperature=self.gen_config.temperature,
top_p=self.gen_config.top_p,
pad_token_id=target_tokenizer.eos_token_id,
max_new_tokens=self.config.max_new_tokens,
)

draft_gen_kwargs = dict(
do_sample=self.draft_gen_config.do_sample,
temperature=self.draft_gen_config.temperature,
top_p=self.draft_gen_config.top_p,
pad_token_id=target_tokenizer.eos_token_id,
max_new_tokens=1,
)

for ex in tqdm(examples):
inputs = target_tokenizer(ex, return_tensors="pt").to(target_model.device)
n_matches = [0]
output_target = target_model.generate(**inputs, **target_gen_kwargs)
prompt_len = len(inputs.input_ids[0])

# iterate over the tokens generated by the target and
# check whether the draft produces the same token
for i in range(prompt_len, len(output_target[0])):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is len(output_target[0])? Why use range(prompt_len, len(output_target[0]))? Please document this part.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still unclear what output_target is and why access with output_target[0]. Please consider adding descriptive variables and/or functions to enhance readability. Documentation would also be helpful.

inputs["input_ids"] = output_target[0, 0:i].view(1, -1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear what output_target refers to and why we access it with output_target[0, 0:i]. To enhance readability and make the code easier to understand without needing to run debug mode or insert print statements, consider encapsulating this logic in a descriptive function. Here's a proposed function:

def get_input_ids_prefix(output_seqs, tok_pos_last: int):
    """Returns a prefix of a sequence of input ids.
       Args: 
        output_seqs: The output from Hugging Face's transformers `model.generate` function, expected to be a sequence-like data structure.
        tok_pos_last: int - The last token position to include in the returned sequence.
    """
    return output_seqs[0, 0:tok_pos_last]

inputs["attention_mask"] = torch.tensor(
[[1] * i], device=draft_model.device
)
Comment on lines +102 to +104
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use torch.ones?

Suggested change
inputs["attention_mask"] = torch.tensor(
[[1] * i], device=draft_model.device
)
inputs["attention_mask"] = torch.ones(i, device=draft_model.device)

output_draft = draft_model.generate(**inputs, **draft_gen_kwargs)
if output_draft[-1, i] == output_target[-1, i]:
n_matches[-1] += 1
elif i < len(output_target[0]) - 1: # new window
n_matches.append(0)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests covering the experiment. It seems like this line makes n_matches = [x, 0], then we extend all_n_matches such that all_n_matches += [x, 0]. At the beginning of the examples loop, we initialize a new n_matches = [0]. Is this a bug? Anyway, it is critical to add such tests.

else: # at the end, remove last window
n_matches.pop()
all_n_matches += n_matches
ar = 1 - (1 / (1 + np.array(all_n_matches).mean()))
return ResultAcceptance(acceptance_rate=[ar])
10 changes: 10 additions & 0 deletions dsi/types/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,13 @@ class ResultLatency(_Result):

ttft: list[float] = field(default_factory=list)
tpot: list[float] = field(default_factory=list)


@dataclass
class ResultAcceptance(_Result):
"""
Args:
acceptance_rate: The acceptance rate
"""

acceptance_rate: list[float] = field(default_factory=list)
170 changes: 170 additions & 0 deletions tests/offline/test_acceptance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from unittest.mock import MagicMock, patch

import pytest
import torch
from transformers import AutoTokenizer

from dsi.configs.experiment.acceptance import ConfigAcceptanteRate
from dsi.configs.experiment.generation import ConfigGen
from dsi.offline.acceptance.experiment import ExperimentAcceptanceRate
from dsi.types.result import ResultAcceptance


@pytest.fixture
def experiment():
config = ConfigAcceptanteRate(
model="target_model", dataset="dataset", draft_model="draft_model"
)
gen_config = ConfigGen()
draft_gen_config = ConfigGen()
return ExperimentAcceptanceRate(config, gen_config, draft_gen_config)


@pytest.fixture
def mock_dependencies():
with patch(
"transformers.AutoModelForCausalLM.from_pretrained"
) as mock_model, patch(
"transformers.AutoTokenizer.from_pretrained"
) as mock_tokenizer, patch.object(
ExperimentAcceptanceRate, "_get_random_prompted_examples"
) as mock_examples:
mock_model.return_value = MagicMock()
mock_model.return_value.device = torch.device("cpu")
mock_tokenizer.return_value = MagicMock()
mock_examples.return_value = ["example1"]
yield mock_model, mock_tokenizer, mock_examples


def test_are_tokenizers_same_identical(experiment):
tokenizer1 = AutoTokenizer.from_pretrained("double7/vicuna-68m")
tokenizer2 = AutoTokenizer.from_pretrained("double7/vicuna-68m")
assert experiment._are_tokenizers_same(tokenizer1, tokenizer2)


def test_are_tokenizers_same_diff_config(experiment):
tokenizer1 = MagicMock()
tokenizer2 = MagicMock()
tokenizer1.config = {"model_type": "bigcode/starcoder"}
tokenizer2.config = {"model_type": "double7/vicuna-68m"}
assert not experiment._are_tokenizers_same(tokenizer1, tokenizer2)


def test_are_tokenizers_same_diff_vocab(experiment):
tokenizer1 = MagicMock()
tokenizer2 = MagicMock()
tokenizer1.get_vocab.return_value = {"hello": 1, "world": 2}
tokenizer2.get_vocab.return_value = {"hello": 1, "python": 3}
assert not experiment._are_tokenizers_same(tokenizer1, tokenizer2)


def test_are_tokenizers_same_diff_special_tokens(experiment):
tokenizer1 = MagicMock()
tokenizer2 = MagicMock()
tokenizer1.eos_token_id = 1
tokenizer2.eos_token_id = 2
tokenizer1.pad_token_id = 0
tokenizer2.pad_token_id = 0
tokenizer1.bos_token_id = -1
tokenizer2.bos_token_id = -1
tokenizer1.unk_token_id = 3
tokenizer2.unk_token_id = 3
assert not experiment._are_tokenizers_same(tokenizer1, tokenizer2)


def test_single_repeat_all_match(experiment, mock_dependencies):
mock_model, _, _ = mock_dependencies
# Mock the generate method to simulate target and draft models
# producing the same output
mock_model.return_value.generate.side_effect = [
torch.tensor([[0, 1, 2, 3]]),
torch.tensor([[0]]),
torch.tensor([[0, 1]]),
torch.tensor([[0, 1, 2]]),
torch.tensor([[0, 1, 2, 3]]),
]
result = experiment._single_repeat()
assert isinstance(result, ResultAcceptance)
# Since all tokens match, acceptance rate should be 1
assert result.acceptance_rate[0] == 0.8
Comment on lines +88 to +89
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like a discrepancy. The test passes, suggesting that there is a bug.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keyboardAnt
for the paper, we agreed to ignore the last matching window for each dataset input, since generation can stop because we reach EOS or max length, and not because draft model bad prediction.
Not sure we wrote it in the paper (please correct me).
According to the paper formula, AR is 0.8.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmamou
Even when ignoring the test's 4th (and last) match, why the returned acceptance rate is < 1? The first 3 tokens match in 100%.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@keyboardAnt
For the paper, we agreed to ignore the last matching window, not the last matching token.
In real scenarios, we will never get AR=1 since n cannot be infinite.

Copy link
Owner

@keyboardAnt keyboardAnt Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmamou

  1. What is the definition of acceptance rate for such windows? How is it different from the acceptance rate for single tokens? (For single tokens, the acceptance rate is the probability of accepting the drafter's next token prediction using an exact match or Miao's rejection sampling algorithm, assuming iid target tokens)
  2. Why is ignoring the last matching token counted as a mismatch in test_single_repeat_all_match? Please correct me if I'm wrong here. In the test, the target outputs the sequence of four input ids [0, 1, 2, 3]. Then, the drafter correctly predicts all four, one by one. However, the calculated acceptance rate equals 0.8. Why the acceptance rate isn't 1?

Copy link
Owner

@keyboardAnt keyboardAnt Jul 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jmamou, as we discussed today over the phone, the acceptance rate has only one definition:

The acceptance rate is the probability of accepting the drafter's next token prediction using an exact match or Miao's rejection sampling algorithm, assuming iid target tokens.

We estimate the acceptance rate as follows. Let A be the sum of the number of drafts accepted over all iterations, except the last iteration per example, over all the examples, in all datasets. Denote the total number of such iterations by N. Note that N <= the total number of examples. If N == the number of examples, the drafter predicts the example in a single iteration. In other words, we accepted all draft tokens without any rejections. In that case, we can raise a warning and return 1 (recommended) or raise an exception. Otherwise, we calculate the average number of accepted drafts n := A/N (dividing the sum by the total number of iterations). The estimated acceptance rate is then 1 - 1 / (1 + n) as mentioned in the paper. The issue with test_single_repeat_all_match is that 1 - 1 / (1 + n) != 1 for any n.

In practice, drafters may have a 100% acceptance rate. For example, an instance of the target running on faster hardware as a drafter, with both the target and drafter sampled greedily. But in such cases the perfect acceptance rate is guaranteed, and I do not see a reason to test it using our experiment. On the contrary, in most practical settings, the acceptance rate is < 1, and estimating it as == 1 means we haven't reached the false discovery rate. In other words, we need to consider more examples. Also, it might indicate a bug. So, this is why I think we should raise a warning or exception.



def test_single_repeat_no_match(experiment, mock_dependencies):
mock_model, _, _ = mock_dependencies
# Mock the generate method to simulate target and draft models
# producing different outputs
mock_model.return_value.generate.side_effect = [
torch.tensor([[0, 1, 2, 3]]),
torch.tensor([[4]]),
torch.tensor([[0, 5]]),
torch.tensor([[0, 1, 6]]),
torch.tensor([[0, 1, 2, 7]]),
]
result = experiment._single_repeat()
assert isinstance(result, ResultAcceptance)
# Since no tokens match, acceptance rate should be 0
assert result.acceptance_rate[0] == 0


def test_single_repeat_partial_match(experiment, mock_dependencies):
mock_model, _, _ = mock_dependencies
# Mock the generate method to simulate target and draft models
# producing partially matching outputs
mock_model.return_value.generate.side_effect = [
torch.tensor([[0, 1, 2, 3]]),
torch.tensor([[0]]),
torch.tensor([[0, 4]]),
torch.tensor([[0, 1, 2]]),
torch.tensor([[0, 1, 2, 5]]),
]
result = experiment._single_repeat()
assert isinstance(result, ResultAcceptance)
# Since half of the tokens match, acceptance rate should be 0.5
assert result.acceptance_rate[0] == 0.5


def test_config_acceptance_initialization_defaults():
config = ConfigAcceptanteRate(model="m", dataset="d", draft_model="dr")
assert config.draft_model == "dr"
assert config.draft_dtype == "bfloat16"
assert config.draft_revision is None
assert config.draft_compile_model is False


def test_config_acceptance_initialization_custom():
config = ConfigAcceptanteRate(
model="m",
dataset="d",
draft_model="test_model",
draft_dtype="float32",
draft_compile_model=True,
draft_revision="test_revision",
)
assert config.draft_model == "test_model"
assert config.draft_dtype == "float32"
assert config.draft_compile_model is True
assert config.draft_revision == "test_revision"


@pytest.mark.parametrize(
"dtype,expected",
[
("float32", torch.float32),
("float16", torch.float16),
("bfloat16", torch.bfloat16),
],
)
def test_get_torch_dtype(dtype, expected):
config = ConfigAcceptanteRate(
model="m", dataset="d", draft_model="dr", draft_dtype=dtype
)
assert config.get_torch_draft_dtype() == expected


def test_draft_revision_optional():
config = ConfigAcceptanteRate(model="m", dataset="d", draft_model="dr")
assert config.draft_revision is None
config = ConfigAcceptanteRate(
model="m", dataset="d", draft_model="dr", draft_revision="rev1"
)
assert config.draft_revision == "rev1"