Making Pytorch Transformer Twice as Fast on Sequence Generation.

Paul Gresia
9 min readDec 18, 2020

by Alexandre Matton and Adrian Lam on December 17th, 2020

At Scale AI, we use Machine Learning models in a wide range of applications to empower our data labeling pipeline. We strive for speed and efficiency, and always try to get the best out of the models. Here, we will discuss some tricks we discovered that drastically improve over the PyTorch Transformer implementation in just a few lines of code.

Transformers are Here to Stay

Transformers have become ubiquitous. They were first introduced in Attention is All You Need (Vaswani et al., 2017) and were quickly added to Pytorch. Their popularity increased even more with the development of HuggingFace, which made large NLP pre-trained models such as BERT (Devlin et al., 2018) widely accessible, and created recipes to enable simple fine-tuning on a wide range of tasks. They’ve been successfully applied to a wide variety of sequence-to-sequence (Seq2Seq) tasks including machine translation, text summarization, or even image captioning (an image is just a sequence of pixels!). This popularity is completely warranted, because Transformers have some significant upsides:

  • The Transformer architecture is non-sequential making it distributable. Other traditional methods for sequence modeling such as RNNs are limited to processing sequences one token at a time. This sequential nature prevents parallelization and makes training slow. Transformers process entire sequences at once in a highly parallel fashion. This makes them incredibly fast on GPUs and helps handle long-range dependencies elegantly.
  • Transformers make few assumptions about the data. Traditionally, ML practitioners have tailored their networks to process specific types of data. Constraints such as forcing RNNs to process text sequentially from left-to-right allowed these networks to perform well even on scarce training data, spearheading breakthroughs. However, these constraints introduce biases in the model as sequential ordering is rarely the most optimal way to understand text. Transformers keep data representations generic, which make them capable of learning more subtle interactions between words. Recent papers (Alexey Dosovitskiy et al., 2020) show that the same story could also be true in computer vision, with transformers outperforming the long established CNNs once trained on the huge datasets that have recently been collected. With enough data, Transformers learn more complex and accurate representations than the constrained networks which used to be the only viable option.

Sequence-to-Sequence with Transformers

But Transformers also have their weaknesses. When generating sequences for Seq2Seq tasks at inference time, Transformers are constrained because each item in the output sequence can only be predicted one at a time. This, combined with the quadratic attention complexity can make them slower than their counterparts. (For training, this is not an issue thanks to teacher forcing).

Seq2Seq models typically create an internal high-level representation of the input sequence and then decode (i.e. generate) the output sentence. Given the high-level representation of the input sentence and the words that have already been decoded, Seq2Seq models estimate the most likely words to complete the sentence. This phenomenon is called auto-regression and the phase corresponding to generating a new word (or token) is a timestep.

When a Transformer is used as a Seq2Seq model, the input sequence is fed through an Encoder, and the output sequence is then generated by a Decoder, as illustrated in figures 1 and 2.

Decoding Inefficiency of the PyTorch Transformers

The Transformer class in Pytorch is generic which is great because it gives the ML researchers at Scale AI fine-tuned control but that also means it isn’t optimized for speed. Let’s take a deeper look.

First, it can be seen in Figure 1 that the encoder output can be computed separately from the decoder. This means that the encoder outputs can be computed once and re-used for each timestep thereafter. But Pytorch does NOT save this for you — and in fact wastes compute for each decoding timestep. To fix this, the Transformer Encoder and Decoder should always be separated.

# THIS IS THE NAIVE WAY TO USE TRANSFORMERS# INITIALIZATION
transformer = nn.Transformer(
d_model=hdim,
nhead=nhead,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
dim_feedforward=dim_feedforward,
).to(device=device)
transformer.eval()
# INFERENCE LOOP
decoded_tokens = first token
for i in range(len_output_to_decode) : # generate `len_output_to_decode` tokens
mask_dec = generate_square_subsequent_mask(
i + 1, device=first_token.device
) # create mask for autoregressive decoding
decoded_embeddings = embedding(decoded_tokens)
output = transformer(src, decoded_embeddings, tgt_mask=mask_dec)
logits = to_vocab(output) # projection to vocab size
# keep most likely tokens
top_indices = torch.argmax(logits, dim=-1)
# we only care about the last token that was decoded
top_indices_last_token = top_indices[-1:]
# add most likely token to the already decoded tokens
decoded_tokens = torch.cat(
[decoded_tokens, top_indices_last_token], dim=0
)

The code below is a much more efficient way to get the same results by decoupling the encoder and the decoder. Note that the code corresponding to the inference loop barely changes.

# INITIALIZATION
encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hdim, nhead=nhead, dim_feedforward=dim_feedforward
),
num_layers=num_layers,
).to(device=device)
encoder.eval()
decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=hdim, nhead=nhead, dim_feedforward=dim_feedforward
),
num_layers=num_layers,
).to(device=device)
decoder.eval()
# INFERENCE LOOP
decoded_tokens = first_token
src_embeddings = encoder(src)
for i in range(lenoutput_to_decode):
mask_dec = generate_square_subsequent_mask(
i + 1, device=first_token.device
) # create mask for autoregressive decoding
decoded_embeddings = embedding(decoded_tokens)
# the decoder uses the encoder output `src_embeddings`
output = decoder(decoded_embeddings, src_embeddings, tgt_mask-mask_dec)
logits = to_vocab(output) # projection to vocab size # keep most likely tokens
top_indices = torch.argmax(logits, dim=-1)
# we only care about the last token that was decoded
top_indices_last_token = top_indices[-1:]
# add most likely token to the already decoded tokens
decoded_tokens = torch.cat(
[decoded_tokens, top_indices_last_token], dim=0
)

The main inefficiency extends on the previous point. It can be seen in Figure 2 that the embedding of a decoded token only depends on the tokens that were decoded before it. This is a direct benefit of the Transformer model being autoregressive. Thus, it is unnecessary to recompute the embeddings of the already decoded tokens repeatedly and instead, we can again cache them. Each timestep then only consists of computing the attention for the newest token’s embedding.

Figure 3: Decoder self-attention links when decoding tokens. The boxes at the bottom represent the embeddings of the output tokens before self-attention, the top boxes represent the embeddings of the output tokens after self-attention. Using our trick (right side), most of the embeddings are not recomputed as they are cached. The number of links to into account becomes linear instead of quadratic.

From a complexity perspective, generating the n-th output token without our trick involves computing self-attention over the entire current output (O(n²)) and computing encoder-decoder attention between the whole input (of size that we will note M) with the current output (O(Mn)). Hence, the complexity of each timestep is O(Mn + n²). Given that we want to decode N tokens, N timesteps are needed, and the final complexity is O(MN² + N³).

Our trick accelerates each timestep. Only the parts of the self-attention and encoder-decoder attention responsible for updating the last token are computed. Figure 3 shows how this works for the self-attention. The new complexity of each timestep is O(M + N), so with N timesteps the final complexity is sped up to O(MN + N²).

The PyTorch Transformer decoder architecture is not assumed to be autoregressive. However, by inheriting the TransformerDecoder layer, we introduce a CausalTransformerDecoder which uses a cache to implement the improvement above. Our code differs from the Pytorch implementation by a few lines only. Our new decoder works similarly to the original TransformerDecoder, except that we now have to take into account the cache:

causal_decoder = CausalTransformerDecoder(
CausalTransformerDecoderLayer(
d_model=hdim,
nhead=nhead,
dim_feedforward=dim_feedforward,
),
num_layers=6,
).to(device=device)
causal_decoder.eval()
decoded_tokens = first_token
src_embeddings = encoder(src)
cache = None
for i in range(len_output_to_decode):
mask_dec = generate_square_subsequent_mask(
i + 1, device=first_token.device
) # create mask for autoregressive decoding
decoded_embeddings = embedding(decoded_tokens)
# only change here: we add the cache as an extra parameter
output, cache = causal_decoder(decoded_embeddings, src_embeddings, cache)
logits = to_vocab(output) # projection to vocab size # keep most likely tokens
top_indices = torch.argmax(logits, dim=-1)
# we only care about the last token that was decoded
top_indices_last_token = top_indices[-1:]
# add most likely token to the already decoded tokens
decoded_tokens = torch.cat(
[decoded_tokens, top_indices_last_token], dim=0
)

Experiments

We put our changes to the test to see how much faster we could get. We present two different scenarios: translation and generation of long texts.

We compare our three different implementations (see footnotes for details):

  • The most naive Pytorch implementation (defined in the first piece of code), which uses nn.Transformer
  • The Pytorch encoder-decoder implementation (second piece of code).
  • Our CausalTransformerDecoder (third piece of code).

As a reminder, these are three different implementations of the same model. When initialized with the same weights, they return the same outputs.

Text Translation

The first setting corresponds to translation. In this setting the input and output sequences are generally short and of similar lengths.

The non-linear curves show that the attention mechanisms progressively become the most compute-intensive parts of the model as the number of input and output tokens increase.
Our causal implementation is up to 40% faster than the Pytorch Encoder-Decoder implementation, and 150% faster than the Pytorch nn.Transformer implementation for 500 input/output tokens.

Long Text Generation

We now ask the model to generate long sequences from a fixed size input. Such a situation might arise when generating a story from an image or from an initial prompt.

The results below were obtained with a fixed input size of 500 tokens. Increasing the number of input tokens makes the models slower but doesn’t change the overall trends observed.

Our causal model is twice as fast as the PyTorch encoder-decoder implementation when the number of tokens to generate exceeds 1,000.
When decoding more than 500 tokens, the time ratio between the causal model and the other implementations becomes linear. This confirms the theory according to which the overall decoding complexity was reduced by a factor of N.

Finally, our CausalTransformerDecoder can also be used without any input sentence (i.e. without an encoder), as it is the case in some common generation settings. The model is typically asked to complete a story or an article. More information about this type of generation can be found in the GPT papers (Alec Radford et al., 2018). The results we find for this case are similar to the ones above.

Digging deeper…

One might notice that caching the output of each layer is sub-optimal. Indeed, the first stage of the attention layers consists of projecting the embeddings to the keys, values and queries spaces. In the PyTorch implementation and the proposed implementation, the same embeddings get projected repeatedly. Instead, the queries, keys and values could be directly cached. However, this requires substantially more changes, which could become unstable with new Pytorch upgrades. Moreover, the estimated gains are minor — less than 5% from our experiments.

For standard NLP use cases, the HuggingFace repository already embeds these optimizations. Notably, it caches keys and values. It also comes with different decoding flavors, such as beam search or nucleus sampling.

Conclusion

The simple tricks proposed take advantage of the fact that the overall Pytorch implementation of the Transformer is too generic. The changes provide modest speed improvements when generating a few hundreds of tokens which can become significant boosts over the original PyTorch implementation when the output length nears a thousand tokens. These gains are naturally directly proportional to the number of output tokens to decode. And best of all, they can be implemented in just a few lines using our repo.

References:

Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Łukasz Kaiser, and Illia Polosukhin. Attention is all you need. In Advances in Neural Information Processing Systems 30, pp. 5998–6008. Curran Associates, Inc., 2017. URL: https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf

Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. BERT: pre-training of deep bidirectional transformers for language understanding. CoRR, abs/1810.04805, 2018. URL: http://arxiv.org/abs/1810.04805.

Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani et al. An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. arXiv preprint 2020 URL: https://arxiv.org/pdf/2010.11929

Thomas Wolf, Lysandre Debut, Victor Sanh, Julien Chaumond, Clement Delangue, Anthony Moi, Pierric Cistac et al. Transformers: State-of-the-art natural language processing arXiv preprint 2019 URL: https://arxiv.org/pdf/1910.03771.pdf

Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever. Improving language understanding by generative pre-training. arXiv preprint 2018 URL: https://www.cs.ubc.ca/~amuham01/LING530/papers/radford2018improving.pdf

Footnotes:

1 — The vocabulary size used for this experiment was 30,000. It corresponds to the vocabulary size chosen in the BERT original paper, which uses BPE tokenization. We also ran the same experiments with a much smaller vocabulary (128 tokens), to imitate a character-level setting which did not show significant benefits.

2 — The hyperparameters used for the Transformer architecture are the ones of the original paper (6 layers, 8 heads, 512 hidden dimensions, 2048 feed-forward hidden dimensions for both encoder/decoder). The results should be similar with other configurations, provided that the encoder and decoder have the same size.

3 — The displayed results correspond to a batch-size of 8 sequences, but we made sure that a batch-size of 1 gives the same trend. A batch-size of 1 is usually the most common in inference as requests are sent asynchronously. However, in our specific setup the model takes several seconds to generate sentences, so it is more natural to batch requests.

3 — As a safety check, we benchmarked GPT-2 HuggingFace implementation against our Causal Decoder. To do that, we used the same set of hyperparameters. We generated up to 1000 tokens with the two models. The speed ratio between these two models was close to 1, oscillating between 0.85 and 1.10.

4 — All the experiments were run on a V100 GPU.

--

--