Skip to content

Commit cae22fa

Browse files
more spelling + rendering
1 parent d83f14b commit cae22fa

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

en-wordlist.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11

22
ACL
33
ADI
4+
ALiBi
45
AOT
56
AOTInductor
67
APIs
@@ -80,6 +81,7 @@ FX
8081
FX's
8182
FairSeq
8283
Fastpath
84+
FFN
8385
FloydHub
8486
FloydHub's
8587
Frobenius
@@ -128,7 +130,7 @@ Kihyuk
128130
Kiuk
129131
Kubernetes
130132
Kuei
131-
KV-Caching
133+
KV
132134
LRSchedulers
133135
LSTM
134136
LSTMs
@@ -164,6 +166,7 @@ NLP
164166
NTK
165167
NUMA
166168
NaN
169+
NaNs
167170
NanoGPT
168171
Netron
169172
NeurIPS
@@ -232,6 +235,7 @@ Sigmoid
232235
SoTA
233236
Sohn
234237
Spacy
238+
SwiGLU
235239
TCP
236240
THP
237241
TIAToolbox
@@ -467,6 +471,7 @@ nheads
467471
nightlies
468472
NJT
469473
NJTs
474+
NJT's
470475
num
471476
numericalize
472477
numpy

intermediate_source/transformer_building_blocks.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def benchmark(func, *args, **kwargs):
344344
torch.manual_seed(6)
345345
vanilla_mha_layer = nn.MultiheadAttention(E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device='cuda')
346346

347-
# nn.MultiheadAttention uses a non conventional initialization for layers, so do this for exact parity :(
347+
# ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(
348348
mha_layer.out_proj.weight = nn.Parameter(vanilla_mha_layer.out_proj.weight.clone().detach())
349349
mha_layer.packed_proj.weight = nn.Parameter(vanilla_mha_layer.in_proj_weight.clone().detach())
350350
mha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())
@@ -421,7 +421,7 @@ def benchmark(func, *args, **kwargs):
421421
# gives equivalent results to an ``nn.TransformerEncoderLayer`` with
422422
# ``is_causal=True``.
423423
#
424-
# We demonstrate examples of implementing the rest of the nn layers
424+
# We demonstrate examples of implementing the rest of the ``nn`` layers
425425
# `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
426426
# tutorial for brevity.
427427

@@ -457,7 +457,7 @@ def benchmark(func, *args, **kwargs):
457457
# * SwiGLU activation in feed-forward network of Transformer Layer
458458
#
459459
# Input projection for MultiheadAttention
460-
# ----------------------------------------
460+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
461461
# Recall that when doing self-attention, the ``query``, ``key`` and ``value``
462462
# are the same tensor. Each of these tensors is projected with a
463463
# ``Linear(E_q, E_total)`` layer. Instead, we can pack this into one layer,
@@ -502,8 +502,9 @@ def forward(self, query):
502502
(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q)
503503
print(f"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x")
504504

505+
##################################################
505506
# SwiGLU feed forward network of Transformer Layer
506-
# ------------------------------------------------
507+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
507508
# SwiGLU is a non-linear activation function that is increasingly popular in the feed-forward
508509
# network of the transformer layer (e.g. Llama). A feed-forward network with SwiGLU activation is defined as
509510

@@ -524,6 +525,7 @@ def __init__(self, dim, hidden_dim, multiple_of, ffn_dim_multiplier=None, device
524525
def forward(self, x):
525526
return self.w2(F.silu(self.w1(x)) * self.w3(x))
526527

528+
########################################################################
527529
# An alternative way of implementing this that uses packed projection is
528530

529531
class PackedSwiGLUFFN(nn.Module):
@@ -543,6 +545,7 @@ def forward(self, x):
543545
x1, x3 = torch.chunk(self.w13(x), 2, dim=-1)
544546
return self.w2(F.silu(x1) * x3)
545547

548+
################################################################################
546549
# We can compare the performance of the two implementations as follows
547550
# Depending on your hardware, you might see different results. On an A100 I see
548551
# 1.12x speedup for D=128.
@@ -635,20 +638,14 @@ def alibi_mod(score, b, h, q_idx, kv_idx):
635638
)
636639
out_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)
637640

638-
################################################################################
639-
# And more
640-
# --------
641-
#
642-
# We intend to update this tutorial to demonstrate more examples of how to use
643-
# the various performant building blocks such as KV-Caching, Grouped Query Attention
644-
# etc.
645-
646641

647642
################################################################################
648643
# Extended examples
649644
# -----------------
650645
#
651-
# There are several good examples of using various performant building blocks to
646+
# We intend to update this tutorial to demonstrate more examples of how to use
647+
# the various performant building blocks such as KV-Caching, Grouped Query Attention
648+
# etc. Further, there are several good examples of using various performant building blocks to
652649
# implement various transformer architectures. Some examples include
653650
#
654651
# * `gpt-fast <https://github.com/pytorch-labs/gpt-fast>`_

0 commit comments

Comments
 (0)