diff --git a/dsi/configs/experiment/acceptance.py b/dsi/configs/experiment/acceptance.py new file mode 100644 index 0000000..639c55d --- /dev/null +++ b/dsi/configs/experiment/acceptance.py @@ -0,0 +1,26 @@ +from typing import Literal + +import torch +from pydantic import Field + +from dsi.configs.experiment.latency import ConfigLatency + + +class ConfigAcceptanteRate(ConfigLatency): + """Includes all the parameters needed for measuring the acceptance rate + of a (target, draft, dataset) triplet. + """ + + draft_model: str = Field(title="The draft model to use for the experiment") + draft_dtype: Literal["float32", "float16", "bfloat16"] = Field( + "bfloat16", title="The dtype of the draft model to use for the experiment" + ) + draft_compile_model: bool = Field( + False, title="Whether to torch.compile() the draft model" + ) + draft_revision: None | str = Field( + None, title="The revision of the draft model to use" + ) + + def get_torch_draft_dtype(self) -> torch.dtype: + return eval(f"torch.{self.draft_dtype}") diff --git a/dsi/offline/acceptance/experiment.py b/dsi/offline/acceptance/experiment.py new file mode 100644 index 0000000..8f49e30 --- /dev/null +++ b/dsi/offline/acceptance/experiment.py @@ -0,0 +1,114 @@ +import logging + +import numpy as np +import torch +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from dsi.configs.experiment.acceptance import ConfigAcceptanteRate +from dsi.configs.experiment.generation import ConfigGen +from dsi.online.latency.experiment import ExperimentLatency +from dsi.types.result import ResultAcceptance + +log = logging.getLogger(__name__) + + +class ExperimentAcceptanceRate(ExperimentLatency): + """ + Measures the generation acceptance rate. + """ + + def __init__( + self, + config: ConfigAcceptanteRate, + gen_config: ConfigGen, + draft_gen_config: ConfigGen, + ): + self.config: ConfigAcceptanteRate + super().__init__(config, gen_config) + self.draft_gen_config: ConfigGen = draft_gen_config + + def _load_draft_model_tokenizer(self) -> tuple: + log.info( + f"Loading model: {self.config.draft_model}, \ + compile={self.config.draft_compile_model}" + ) + model = AutoModelForCausalLM.from_pretrained( + self.config.draft_model, + device_map="auto", + torch_dtype=self.config.get_torch_draft_dtype(), + revision=self.config.draft_revision, + ) + tokenizer = AutoTokenizer.from_pretrained(self.config.draft_model) + model = torch.compile(model) if self.config.draft_compile_model else model + return model, tokenizer + + def _are_tokenizers_same(self, tokenizer1, tokenizer2) -> bool: + # Compare vocabularies + if tokenizer1.get_vocab() != tokenizer2.get_vocab(): + return False + + # Compare special tokens + special_tokens_to_compare = [ + "eos_token_id", + "pad_token_id", + "bos_token_id", + "unk_token_id", + ] + for token in special_tokens_to_compare: + if getattr(tokenizer1, token, None) != getattr(tokenizer2, token, None): + return False + + return True + + def _single_repeat(self) -> ResultAcceptance: + all_n_matches = [] + + examples = self._get_random_prompted_examples() + + target_model, target_tokenizer = self._load_model_tokenizer() + draft_model, draft_tokenizer = self._load_draft_model_tokenizer() + + # Check if tokenizers are the same + if not self._are_tokenizers_same(target_tokenizer, draft_tokenizer): + raise ValueError("The target and draft tokenizers are not the same.") + + target_gen_kwargs = dict( + do_sample=self.gen_config.do_sample, + temperature=self.gen_config.temperature, + top_p=self.gen_config.top_p, + pad_token_id=target_tokenizer.eos_token_id, + max_new_tokens=self.config.max_new_tokens, + ) + + draft_gen_kwargs = dict( + do_sample=self.draft_gen_config.do_sample, + temperature=self.draft_gen_config.temperature, + top_p=self.draft_gen_config.top_p, + pad_token_id=target_tokenizer.eos_token_id, + max_new_tokens=1, + ) + + for ex in tqdm(examples): + inputs = target_tokenizer(ex, return_tensors="pt").to(target_model.device) + n_matches = [0] + output_target = target_model.generate(**inputs, **target_gen_kwargs) + prompt_len = len(inputs.input_ids[0]) + + # iterate over the tokens generated by the target and + # check whether the draft produces the same token + for i in range(prompt_len, len(output_target[0])): + inputs["input_ids"] = output_target[0, 0:i].view(1, -1) + inputs["attention_mask"] = torch.tensor( + [[1] * i], device=draft_model.device + ) + output_draft = draft_model.generate(**inputs, **draft_gen_kwargs) + if output_draft[-1, i] == output_target[-1, i]: + n_matches[-1] += 1 + elif i < len(output_target[0]) - 1: # new window + n_matches.append(0) + else: # at the end, remove last window + n_matches.pop() + all_n_matches += n_matches + ar = 1 - (1 / (1 + np.array(all_n_matches).mean())) + return ResultAcceptance(acceptance_rate=[ar]) diff --git a/dsi/types/result.py b/dsi/types/result.py index b6fb496..7835c9a 100644 --- a/dsi/types/result.py +++ b/dsi/types/result.py @@ -53,3 +53,13 @@ class ResultLatency(_Result): ttft: list[float] = field(default_factory=list) tpot: list[float] = field(default_factory=list) + + +@dataclass +class ResultAcceptance(_Result): + """ + Args: + acceptance_rate: The acceptance rate + """ + + acceptance_rate: list[float] = field(default_factory=list) diff --git a/tests/offline/test_acceptance.py b/tests/offline/test_acceptance.py new file mode 100644 index 0000000..704d864 --- /dev/null +++ b/tests/offline/test_acceptance.py @@ -0,0 +1,170 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch +from transformers import AutoTokenizer + +from dsi.configs.experiment.acceptance import ConfigAcceptanteRate +from dsi.configs.experiment.generation import ConfigGen +from dsi.offline.acceptance.experiment import ExperimentAcceptanceRate +from dsi.types.result import ResultAcceptance + + +@pytest.fixture +def experiment(): + config = ConfigAcceptanteRate( + model="target_model", dataset="dataset", draft_model="draft_model" + ) + gen_config = ConfigGen() + draft_gen_config = ConfigGen() + return ExperimentAcceptanceRate(config, gen_config, draft_gen_config) + + +@pytest.fixture +def mock_dependencies(): + with patch( + "transformers.AutoModelForCausalLM.from_pretrained" + ) as mock_model, patch( + "transformers.AutoTokenizer.from_pretrained" + ) as mock_tokenizer, patch.object( + ExperimentAcceptanceRate, "_get_random_prompted_examples" + ) as mock_examples: + mock_model.return_value = MagicMock() + mock_model.return_value.device = torch.device("cpu") + mock_tokenizer.return_value = MagicMock() + mock_examples.return_value = ["example1"] + yield mock_model, mock_tokenizer, mock_examples + + +def test_are_tokenizers_same_identical(experiment): + tokenizer1 = AutoTokenizer.from_pretrained("double7/vicuna-68m") + tokenizer2 = AutoTokenizer.from_pretrained("double7/vicuna-68m") + assert experiment._are_tokenizers_same(tokenizer1, tokenizer2) + + +def test_are_tokenizers_same_diff_config(experiment): + tokenizer1 = MagicMock() + tokenizer2 = MagicMock() + tokenizer1.config = {"model_type": "bigcode/starcoder"} + tokenizer2.config = {"model_type": "double7/vicuna-68m"} + assert not experiment._are_tokenizers_same(tokenizer1, tokenizer2) + + +def test_are_tokenizers_same_diff_vocab(experiment): + tokenizer1 = MagicMock() + tokenizer2 = MagicMock() + tokenizer1.get_vocab.return_value = {"hello": 1, "world": 2} + tokenizer2.get_vocab.return_value = {"hello": 1, "python": 3} + assert not experiment._are_tokenizers_same(tokenizer1, tokenizer2) + + +def test_are_tokenizers_same_diff_special_tokens(experiment): + tokenizer1 = MagicMock() + tokenizer2 = MagicMock() + tokenizer1.eos_token_id = 1 + tokenizer2.eos_token_id = 2 + tokenizer1.pad_token_id = 0 + tokenizer2.pad_token_id = 0 + tokenizer1.bos_token_id = -1 + tokenizer2.bos_token_id = -1 + tokenizer1.unk_token_id = 3 + tokenizer2.unk_token_id = 3 + assert not experiment._are_tokenizers_same(tokenizer1, tokenizer2) + + +def test_single_repeat_all_match(experiment, mock_dependencies): + mock_model, _, _ = mock_dependencies + # Mock the generate method to simulate target and draft models + # producing the same output + mock_model.return_value.generate.side_effect = [ + torch.tensor([[0, 1, 2, 3]]), + torch.tensor([[0]]), + torch.tensor([[0, 1]]), + torch.tensor([[0, 1, 2]]), + torch.tensor([[0, 1, 2, 3]]), + ] + result = experiment._single_repeat() + assert isinstance(result, ResultAcceptance) + # Since all tokens match, acceptance rate should be 1 + assert result.acceptance_rate[0] == 0.8 + + +def test_single_repeat_no_match(experiment, mock_dependencies): + mock_model, _, _ = mock_dependencies + # Mock the generate method to simulate target and draft models + # producing different outputs + mock_model.return_value.generate.side_effect = [ + torch.tensor([[0, 1, 2, 3]]), + torch.tensor([[4]]), + torch.tensor([[0, 5]]), + torch.tensor([[0, 1, 6]]), + torch.tensor([[0, 1, 2, 7]]), + ] + result = experiment._single_repeat() + assert isinstance(result, ResultAcceptance) + # Since no tokens match, acceptance rate should be 0 + assert result.acceptance_rate[0] == 0 + + +def test_single_repeat_partial_match(experiment, mock_dependencies): + mock_model, _, _ = mock_dependencies + # Mock the generate method to simulate target and draft models + # producing partially matching outputs + mock_model.return_value.generate.side_effect = [ + torch.tensor([[0, 1, 2, 3]]), + torch.tensor([[0]]), + torch.tensor([[0, 4]]), + torch.tensor([[0, 1, 2]]), + torch.tensor([[0, 1, 2, 5]]), + ] + result = experiment._single_repeat() + assert isinstance(result, ResultAcceptance) + # Since half of the tokens match, acceptance rate should be 0.5 + assert result.acceptance_rate[0] == 0.5 + + +def test_config_acceptance_initialization_defaults(): + config = ConfigAcceptanteRate(model="m", dataset="d", draft_model="dr") + assert config.draft_model == "dr" + assert config.draft_dtype == "bfloat16" + assert config.draft_revision is None + assert config.draft_compile_model is False + + +def test_config_acceptance_initialization_custom(): + config = ConfigAcceptanteRate( + model="m", + dataset="d", + draft_model="test_model", + draft_dtype="float32", + draft_compile_model=True, + draft_revision="test_revision", + ) + assert config.draft_model == "test_model" + assert config.draft_dtype == "float32" + assert config.draft_compile_model is True + assert config.draft_revision == "test_revision" + + +@pytest.mark.parametrize( + "dtype,expected", + [ + ("float32", torch.float32), + ("float16", torch.float16), + ("bfloat16", torch.bfloat16), + ], +) +def test_get_torch_dtype(dtype, expected): + config = ConfigAcceptanteRate( + model="m", dataset="d", draft_model="dr", draft_dtype=dtype + ) + assert config.get_torch_draft_dtype() == expected + + +def test_draft_revision_optional(): + config = ConfigAcceptanteRate(model="m", dataset="d", draft_model="dr") + assert config.draft_revision is None + config = ConfigAcceptanteRate( + model="m", dataset="d", draft_model="dr", draft_revision="rev1" + ) + assert config.draft_revision == "rev1"