5757# Introducing the Building Blocks
5858# ===============================
5959# First, we will briefly introduce the 4 technologies mentioned in the introduction
60-
60+ #
6161# * `torch.nested <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
62-
62+ #
6363# Nested tensors generalize the shape of regular dense tensors, allowing for
6464# representation of ragged-sized data with the same tensor UX. In the context of
6565# transformers, we can think of nested tensors as a tool for representing variable
6666# sequence lengths. They eliminate the need for the bug-prone practices of explicit
6767# padding and masking (think ``key_padding_mask`` in ``nn.MultiHeadAttention``).
68-
68+ #
6969# * `scaled_dot_product_attention <https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html>`_
70-
70+ #
7171# ``scaled_dot_product_attention`` is a primitive for
7272# :math:`\text{softmax}(\frac{QK^T}{\sqrt{E}} + B)V` that dispatches into either fused
7373# implementations of the operator or a fallback implementation. It works out of
7474# the box in eager mode (i.e. the default mode of using PyTorch where operations
7575# are executed on the fly as they are encountered) and also integrates seamlessly
7676# with ``torch.compile()``. As of 2.6, it will also offer grouped query attention
7777# natively.
78-
78+ #
7979# * `torch.compile() <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
80-
80+ #
8181# ``torch.compile()`` is a compiler introduced in version 2.0 that is able to
8282# capture a graph of PyTorch code and perform various optimizations on it, such as
8383# fusing together sequences of ops. Nested tensors with the ``torch.jagged`` layout
8686# and SDPA is that compile can remove framework overhead ones sees in eager mode
8787# and fuse sequences of ops in transformers together (e.g. projection and
8888# activation).
89-
89+ #
9090# * `FlexAttention <https://pytorch.org/blog/flexattention/>`_
91-
91+ #
9292# ``FlexAttention`` is a primitive that allows users to modify attention scores
9393# prior to the softmax operation. It generalizes the additive ``B`` term above
9494# for ``scaled_dot_product_attention``, allowing for arbitrary calculation. It
9595# requires compile to achieve good performance.
96-
96+ #
9797# The above building blocks are "All You Need" (as of October 2024)
9898# ==================================================================
99-
99+ #
100100# The main premise in this section is that most transformer variations are
101101# GPT-style, consisting of layers like Embedding, Positional Encoding, Attention
102102# Blocks and Feed Forward networks. If we were to try to classify the differences
103103# in this space, we might land on something like:
104-
104+ #
105105# 1. Layer type (activation functions e.g. ``SwiGLU``, normalization functions
106106# e.g. ``RMSNorm`` etc., positional encodings e.g. Sinusoidal, Rotary etc.)
107107# 2. Layer ordering (where to apply norms, where to apply positional encoding etc.)
108108# 3. Modifications to attention score (``ALiBi``, Relative Positional Bias etc.)
109-
110-
109+ #
110+ #
111111# In a pre-compiler world, one might write their custom transformer and observe
112112# that it works but is slow. Then, one might write a custom fused kernel for
113113# the specific series of ops. In a compiler world, one can do the former, compile
@@ -400,12 +400,11 @@ def benchmark(func, *args, **kwargs):
400400######################################################################################
401401# For reference some sample outputs on A100:
402402#
403- # ..code::
404- # padded_time=0.03454, padded_peak_memory=4.14 GB
405- # nested_time=0.00612, nested_peak_memory=0.76 GB
406- # Difference between vanilla and nested result 0.0
407- # Nested speedup: 5.65
408- # Nested peak memory reduction 3.39 GB
403+ # padded_time=0.03454, padded_peak_memory=4.14 GB
404+ # nested_time=0.00612, nested_peak_memory=0.76 GB
405+ # Difference between vanilla and nested result 0.0
406+ # Nested speedup: 5.65
407+ # Nested peak memory reduction 3.39 GB
409408#
410409# We can also see the same for backward pass
411410
@@ -429,15 +428,14 @@ def benchmark(func, *args, **kwargs):
429428##################################################################################
430429# Sample outputs on A100:
431430#
432- # ..code::
433- # ``padded_bw_time``=2.09337, ``padded_bw_peak_mem``=5.10 GB
434- # ``nested_bw_time``=0.01452, ``nested_bw_peak_mem``=3.24 GB
435- # Nested backward speedup: 144.13
436- # Nested backward peak memory reduction 1.86 GB
437- # Difference in ``out_proj.weight.grad`` 0.000244140625
438- # Difference in ``packed_proj.weight.grad`` 0.001556396484375
439- # Difference in ``out_proj.bias.grad`` 0.0
440- # Difference in ``packed_proj.bias.grad`` 0.001953125
431+ # ``padded_bw_time``=2.09337, ``padded_bw_peak_mem``=5.10 GB
432+ # ``nested_bw_time``=0.01452, ``nested_bw_peak_mem``=3.24 GB
433+ # Nested backward speedup: 144.13
434+ # Nested backward peak memory reduction 1.86 GB
435+ # Difference in ``out_proj.weight.grad`` 0.000244140625
436+ # Difference in ``packed_proj.weight.grad`` 0.001556396484375
437+ # Difference in ``out_proj.bias.grad`` 0.0
438+ # Difference in ``packed_proj.bias.grad`` 0.001953125
441439#
442440
443441##################################################################################
0 commit comments