Skip to content

Batched, fully vectorized, and educative beam search implementation for seq2seq transformers in PyTorch.

License

Notifications You must be signed in to change notification settings

kkew3/pytorch_beam_search

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

69 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch Beam Search

A fully vectorized beam search for seq2seq transformers implemented in PyTorch.

Features:

  • Line-by-line commented with descriptive docstring.
  • Batched decoding.
  • Fully vectorized.
  • KV caching.
  • Sequence decoding terminates on EOS.
  • With length penalty/normalization.
  • Hardware accelerator enabled (e.g. cuda).

Why this project

In my research I need a fast enough beam search implementation to compute BLEU score. However, I didn't find one on the web. While transformers text generation module fits the requirement well enough, I want to code myself and gain deeper understanding of beam search. The final product, indeed, perfectly reproduces the efficiency and correctness of transformers beam search implementation, but better commented, and may be more suitable for beginners to check out how beam search technically works.

In the parlance of transformers, the beam search implemented in this project corresponds to the following configuration:

from transformers import GenerationConfig

# Check beam_search.batched_beam_search.beam_search function.
generation_config = GenerationConfig(
    max_length=...,      # the max_length argument
    early_stopping=True,
    do_sample=False,
    num_beams=...,       # the beam_width argument
    use_cache=True,
    length_penalty=...,  # the length_normalization argument
    bos_token_id=...,    # the bos_token_id argument
    eos_token_id=...,    # the eos_token_id argument
    pad_token_id=...,    # the pad_token_id argument
)

How I develop this project

Batched and vectorized beam search decoding is tricky to implement. Therefore, I start from a naive implementation which is neither batched nor vectorized, but simple enough to ensure correctness. Then, I prompt GPT-4.1 to give me a batched implementation until it behaves identical to my naive implementation on real data and pretrained models. Finally, I manually refactor GPT's implementation and add on more efficiency like KV caching, and produce the final version, while maintaining invariance of the decoding results. I profile the code to ensure changes that bring more complexity but only marginal speedup are not merged into the codebase.

How to use this project

Since basically every application of beam search is slightly different from others, there might be no such one-for-all implementation that can be used directly. Hence, read the code, either the naive implementation or the final version, and adapt it to your need.

References

DeepLearning.AI courses:

Papers:

Similar projects

License

MIT.

About

Batched, fully vectorized, and educative beam search implementation for seq2seq transformers in PyTorch.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages