11"""
2- Dismantling the `` nn.Transformer`` modules for gains and profits
3- =================================================================
2+ Accelerating PyTorch Transformers by replacing nn.Transformer with Nested Tensors and torch.compile()
3+ =====================================================================================================
44**Author:** `Mikayla Gawarecki <https://github.com/mikaylagawarecki>`_
55
66.. note::
77 This tutorial should be run with the latest nightly, or, when available, 2.6.
88
9- The ``torch.nn`` module currently provides various ``Transformer``-related layers.
10- In particular ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``,
11- ``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family
12- of layers was initially implemented following the `Attention is All
13- You Need <https://arxiv.org/abs/1706.03762>`_ paper. Since then, various improvements
14- were made to try to make these layers more flexible.
15-
16- While historically these layers intended to provide out-of-the-box, performant
17- solutions, we make the observations that
18-
19- 1. People want to add slight customizations to their transformer layers
20- 2. Writing these layers and customizations is not hard
21-
22-
23- Supporting all transformer variants via a small number of out of the box layers would
24- yield too many keyword arguments. This tutorial will describe how to build your
25- own performant transformer layers following our recommended best practices.
26- The technologies used will be the following
9+ Over the past few years, the PyTorch team has developed various lower level
10+ features that, when composed, can create a variety of transformer variants. These
11+ include:
2712
28131. Nested Tensors with the ``torch.jagged`` layout (AKA NJTs)
29142. ``scaled_dot_product_attention``
30153. ``torch.compile()``
31164. ``FlexAttention``
3217
18+ This tutorial will give a brief overview of the above technologies and
19+ demonstrate how they can be composed to yield flexible and performant transformer \
20+ layers with improved user experience.
21+
22+ One may observe that the ``torch.nn`` module currently provides various ``Transformer``-related layers.
23+ In particular ``TransformerEncoderLayer``, ``TransformerEncoder``, ``TransformerDecoderLayer``,
24+ ``TransformerDecoder``, ``Transformer`` and ``MultiheadAttention``. This family
25+ of layers was initially implemented following the `Attention is All
26+ You Need <https://arxiv.org/abs/1706.03762>`_ paper. The components discussed in
27+ this tutorial provide improved user experience, flexibility and performance over
28+ the existing ``nn`` layers.
29+
3330Is this tutorial for me?
3431========================
3532
33+ If you are wondering about what building blocks the ``torch`` library provides
34+ for writing your own transformer layers and best practices, you are in the
35+ right place, please keep reading!
36+
3637If you are looking for an out-of-the-box implementation of a popular transformer
3738architecture, note that there are many open-source libraries that provide them,
3839with some examples being:
4142* `xformers <https://github.com/facebookresearch/xformers>`_
4243* `torchtune <https://github.com/pytorch/torchtune>`_
4344
44- Please head there instead!
45-
4645If you are only interested in performant attention score modifications, please
4746head to the `FlexAttention blog <https://pytorch.org/blog/flexattention/>`_ that
4847contains a `gym of masks <https://github.com/pytorch-labs/attention-gym>`_.
49- If you are wondering about what building blocks the ``torch`` library provides
50- for writing your own transformer layers and best practices, you are in the
51- right place, please keep reading!
52-
5348
5449"""
5550
@@ -393,7 +388,7 @@ def benchmark(func, *args, **kwargs):
393388
394389print (f"{ padded_time = :.5f} , padded_peak_memory={ padded_peak_memory / 1e9 :.2f} GB" )
395390print (f"{ nested_time = :.5f} , nested_peak_memory={ nested_peak_memory / 1e9 :.2f} GB" )
396- print ("Difference between vanilla and nested result" , (padded_result - padded_nested_result ).abs ().max ().item ())
391+ print ("Max difference between vanilla and nested result" , (padded_result - padded_nested_result ).abs ().max ().item ())
397392print (f"Nested speedup: { (padded_time / nested_time ):.2f} " )
398393print (f"Nested peak memory reduction { ((padded_peak_memory - nested_peak_memory )/ 1e9 ):.2f} GB" )
399394
@@ -404,7 +399,7 @@ def benchmark(func, *args, **kwargs):
404399#
405400# padded_time=0.03454, padded_peak_memory=4.14 GB
406401# nested_time=0.00612, nested_peak_memory=0.76 GB
407- # Difference between vanilla and nested result 0.0
402+ # Max difference between vanilla and nested result 0.0
408403# Nested speedup: 5.65
409404# Nested peak memory reduction 3.39 GB
410405#
@@ -432,14 +427,14 @@ def benchmark(func, *args, **kwargs):
432427#
433428# .. code::
434429#
435- # `` padded_bw_time`` =2.09337, `` padded_bw_peak_mem`` =5.10 GB
436- # `` nested_bw_time`` =0.01452, `` nested_bw_peak_mem`` =3.24 GB
430+ # padded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB
431+ # nested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB
437432# Nested backward speedup: 144.13
438433# Nested backward peak memory reduction 1.86 GB
439- # Difference in `` out_proj.weight.grad`` 0.000244140625
440- # Difference in `` packed_proj.weight.grad`` 0.001556396484375
441- # Difference in `` out_proj.bias.grad`` 0.0
442- # Difference in `` packed_proj.bias.grad`` 0.001953125
434+ # Difference in out_proj.weight.grad 0.000244140625
435+ # Difference in packed_proj.weight.grad 0.001556396484375
436+ # Difference in out_proj.bias.grad 0.0
437+ # Difference in packed_proj.bias.grad 0.001953125
443438#
444439
445440##################################################################################
@@ -493,6 +488,53 @@ def benchmark(func, *args, **kwargs):
493488print (f"Total sequence length in nested key/value { kv_len .sum ().item ()} , max sequence length { kv_len .max ().item ()} " )
494489out = new_mha_layer (query , key , value , is_causal = False )
495490
491+ ########################################################################################
492+ # As above, we can compare this against the vanilla compiled ``nn.MultiheadAttention``.
493+
494+ torch .manual_seed (6 )
495+ query , _ , _ , q_len = gen_batch (N , E_q , E_k , E_v , device )
496+ _ , key , value , kv_len = gen_batch (N , E_q , E_k , E_v , device )
497+ padded_query , padded_key , padded_value = (
498+ t .to_padded_tensor (0.0 ) for t in (query , key , value )
499+ )
500+
501+ key_padding_mask = torch .where (padded_key == 0.0 , - math .inf , 0 )[:, :, 0 ]
502+
503+ # warmup compile
504+ warmup_nested_result = new_mha_layer (query , key , value , is_causal = False )
505+ warmup_vanilla_result = vanilla_mha_layer (padded_query ,
506+ padded_key ,
507+ padded_value ,
508+ key_padding_mask = key_padding_mask ,
509+ need_weights = False ,
510+ is_causal = False )
511+
512+ nested_result , nested_time , nested_peak_memory = benchmark (new_mha_layer , query , key , value , is_causal = False )
513+ (padded_result , _ ), padded_time , padded_peak_memory = benchmark (vanilla_mha_layer ,
514+ padded_query ,
515+ padded_key ,
516+ padded_value ,
517+ key_padding_mask = key_padding_mask ,
518+ need_weights = False ,
519+ is_causal = False )
520+ padded_nested_result = nested_result .to_padded_tensor (0.0 )
521+ for i , entry_length in enumerate (q_len ):
522+ # padding-specific step: remove output projection bias from padded entries for fair comparison
523+ padded_result [i , entry_length :, :] = 0.0
524+
525+ print ("Max difference between vanilla and nested result" , (padded_result - padded_nested_result ).abs ().max ().item ())
526+ print (f"Nested speedup: { (padded_time / nested_time ):.2f} " )
527+ print (f"Nested peak memory reduction { ((padded_peak_memory - nested_peak_memory )/ 1e9 ):.2f} GB" )
528+
529+ ##################################################################################
530+ # Sample outputs on A100:
531+ #
532+ # .. code::
533+ #
534+ # Max difference between vanilla and nested result 0.0
535+ # Nested speedup: 4.01
536+ # Nested peak memory reduction 1.40 GB
537+ #
496538
497539################################################################################
498540# Fully masked rows no longer cause NaNs
@@ -549,6 +591,29 @@ def alibi_mod(score, b, h, q_idx, kv_idx):
549591)
550592out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
551593
594+ ###############################################################################
595+ # In addition, one can also use the ``block_mask`` utility of ``FlexAttention``
596+ # with NJTs via the ``create_nested_block_mask`` function. This is useful for
597+ # taking advantage of the sparsity of the mask to speed up the attention computation.
598+ # In the following example, we show how to create a causal block mask using this
599+ # utility.
600+
601+ from torch .nn .attention .flex_attention import create_nested_block_mask
602+
603+ def causal_mask (b , h , q_idx , kv_idx ):
604+ return q_idx >= kv_idx
605+
606+ query , key , value , _ = gen_batch (N , E_q , E_k , E_v , device )
607+ block_mask = create_nested_block_mask (causal_mask , 1 , 1 , query , _compile = True )
608+ query = (
609+ query .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
610+ )
611+ key = key .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
612+ value = (
613+ value .unflatten (- 1 , [n_heads , D ]).transpose (1 , 2 ).detach ().requires_grad_ ()
614+ )
615+ out_flex = flex_attention (query , key , value , block_mask = block_mask )
616+
552617###############################################################################
553618# Packed Projection
554619# -----------------
@@ -579,8 +644,8 @@ def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
579644 self .k_proj = nn .Linear (E_q , E_total , bias = bias , ** factory_kwargs )
580645 self .v_proj = nn .Linear (E_q , E_total , bias = bias , ** factory_kwargs )
581646
582- def forward (self , query ):
583- return self .q_proj (query ), self .k_proj (query ), self .v_proj (query )
647+ def forward (self , x ):
648+ return self .q_proj (x ), self .k_proj (x ), self .v_proj (x )
584649
585650class PackedInputProjection (nn .Module ):
586651 def __init__ (self , E_q , E_total , bias = False , device = None , dtype = None ):
@@ -591,7 +656,7 @@ def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):
591656 def forward (self , query ):
592657 return torch .chunk (self .packed_proj (query ), 3 , dim = - 1 )
593658
594- B , D , dtype = 256 , 4096 , torch .bfloat16
659+ B , D , dtype = 256 , 8192 , torch .bfloat16
595660
596661torch .set_float32_matmul_precision ('high' )
597662in_proj = torch .compile (InputProjection (D , D , device = 'cuda' , dtype = torch .bfloat16 ))
@@ -606,6 +671,7 @@ def forward(self, query):
606671# benchmark
607672(q_out , k_out , v_out ), time , _ = benchmark (in_proj , q )
608673(q_out , k_out , v_out ), time_packed , _ = benchmark (packed_in_proj , q )
674+ # On my A100 prints 1.05x speedup
609675print (f"InputProjection: { time :5f} s, PackedInputProjection: { time_packed :5f} s, speedup: { time / time_packed :.2f} x" )
610676
611677##################################################
@@ -669,6 +735,7 @@ def forward(self, x):
669735# benchmark
670736_ , time , _ = benchmark (swigluffn , q )
671737_ , time_packed , _ = benchmark (packed_swigluffn , q )
738+ # On my A100 prints 1.08x speedup
672739print (f"SwiGLUFFN: { time } s, PackedSwiGLUFFN: { time_packed } s, speedup: { time / time_packed :.2f} x" )
673740
674741################################################################################
0 commit comments