A fully vectorized beam search for seq2seq transformers
implemented in PyTorch.
Features:
- Line-by-line commented with descriptive docstring.
- Batched decoding.
- Fully vectorized.
- KV caching.
- Sequence decoding terminates on EOS.
- With length penalty/normalization.
- Hardware accelerator enabled (e.g. cuda).
In my research I need a fast enough beam search implementation to compute BLEU score.
However, I didn't find one on the web.
While transformers
text generation module fits the requirement well enough, I want to code myself and gain deeper understanding of beam search.
The final product, indeed, perfectly reproduces the efficiency and correctness of transformers
beam search implementation, but better commented, and may be more suitable for beginners to check out how beam search technically works.
In the parlance of transformers
, the beam search implemented in this project corresponds to the following configuration:
from transformers import GenerationConfig
# Check beam_search.batched_beam_search.beam_search function.
generation_config = GenerationConfig(
max_length=..., # the max_length argument
early_stopping=True,
do_sample=False,
num_beams=..., # the beam_width argument
use_cache=True,
length_penalty=..., # the length_normalization argument
bos_token_id=..., # the bos_token_id argument
eos_token_id=..., # the eos_token_id argument
pad_token_id=..., # the pad_token_id argument
)
Batched and vectorized beam search decoding is tricky to implement. Therefore, I start from a naive implementation which is neither batched nor vectorized, but simple enough to ensure correctness. Then, I prompt GPT-4.1 to give me a batched implementation until it behaves identical to my naive implementation on real data and pretrained models. Finally, I manually refactor GPT's implementation and add on more efficiency like KV caching, and produce the final version, while maintaining invariance of the decoding results. I profile the code to ensure changes that bring more complexity but only marginal speedup are not merged into the codebase.
Since basically every application of beam search is slightly different from others, there might be no such one-for-all implementation that can be used directly. Hence, read the code, either the naive implementation or the final version, and adapt it to your need.
DeepLearning.AI courses:
Papers:
- Freitag & Al-Onaizan (2017): "Beam Search Strategies for Neural Machine Translation".
MIT.