@@ -344,7 +344,7 @@ def benchmark(func, *args, **kwargs):
344
344
torch .manual_seed (6 )
345
345
vanilla_mha_layer = nn .MultiheadAttention (E_q , nheads , dropout = dropout , batch_first = True , bias = bias , device = 'cuda' )
346
346
347
- # nn.MultiheadAttention uses a non conventional initialization for layers, so do this for exact parity :(
347
+ # `` nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(
348
348
mha_layer .out_proj .weight = nn .Parameter (vanilla_mha_layer .out_proj .weight .clone ().detach ())
349
349
mha_layer .packed_proj .weight = nn .Parameter (vanilla_mha_layer .in_proj_weight .clone ().detach ())
350
350
mha_layer .out_proj .bias = nn .Parameter (vanilla_mha_layer .out_proj .bias .clone ().detach ())
@@ -421,7 +421,7 @@ def benchmark(func, *args, **kwargs):
421
421
# gives equivalent results to an ``nn.TransformerEncoderLayer`` with
422
422
# ``is_causal=True``.
423
423
#
424
- # We demonstrate examples of implementing the rest of the nn layers
424
+ # We demonstrate examples of implementing the rest of the ``nn`` layers
425
425
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
426
426
# tutorial for brevity.
427
427
@@ -457,7 +457,7 @@ def benchmark(func, *args, **kwargs):
457
457
# * SwiGLU activation in feed-forward network of Transformer Layer
458
458
#
459
459
# Input projection for MultiheadAttention
460
- # ----------------------------------------
460
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
461
461
# Recall that when doing self-attention, the ``query``, ``key`` and ``value``
462
462
# are the same tensor. Each of these tensors is projected with a
463
463
# ``Linear(E_q, E_total)`` layer. Instead, we can pack this into one layer,
@@ -502,8 +502,9 @@ def forward(self, query):
502
502
(q_out , k_out , v_out ), time_packed , _ = benchmark (packed_in_proj , q )
503
503
print (f"InputProjection: { time :5f} s, PackedInputProjection: { time_packed :5f} s, speedup: { time / time_packed :.2f} x" )
504
504
505
+ ##################################################
505
506
# SwiGLU feed forward network of Transformer Layer
506
- # ------------------------------------------------
507
+ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
507
508
# SwiGLU is a non-linear activation function that is increasingly popular in the feed-forward
508
509
# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as
509
510
@@ -524,6 +525,7 @@ def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device
524
525
def forward (self , x ):
525
526
return self .w2 (F .silu (self .w1 (x )) * self .w3 (x ))
526
527
528
+ ########################################################################
527
529
# An alternative way of implementing this that uses packed projection is
528
530
529
531
class PackedSwiGLUFFN (nn .Module ):
@@ -543,6 +545,7 @@ def forward(self, x):
543
545
x1 , x3 = torch .chunk (self .w13 (x ), 2 , dim = - 1 )
544
546
return self .w2 (F .silu (x1 ) * x3 )
545
547
548
+ ################################################################################
546
549
# We can compare the performance of the two implementations as follows
547
550
# Depending on your hardware, you might see different results. On an A100 I see
548
551
# 1.12x speedup for D=128.
@@ -635,20 +638,14 @@ def alibi_mod(score, b, h, q_idx, kv_idx):
635
638
)
636
639
out_flex2 = flex_attention (query , key , value , score_mod = alibi_score_mod )
637
640
638
- ################################################################################
639
- # And more
640
- # --------
641
- #
642
- # We intend to update this tutorial to demonstrate more examples of how to use
643
- # the various performant building blocks such as KV-Caching, Grouped Query Attention
644
- # etc.
645
-
646
641
647
642
################################################################################
648
643
# Extended examples
649
644
# -----------------
650
645
#
651
- # There are several good examples of using various performant building blocks to
646
+ # We intend to update this tutorial to demonstrate more examples of how to use
647
+ # the various performant building blocks such as KV-Caching, Grouped Query Attention
648
+ # etc. Further, there are several good examples of using various performant building blocks to
652
649
# implement various transformer architectures. Some examples include
653
650
#
654
651
# * `gpt-fast <https://github.com/pytorch-labs/gpt-fast>`_
0 commit comments