112112# Recall that `nn.MultiheadAttention` requires ``query```, ``key`` and
113113# ``value`` to be dense ``torch.Tensor``s. It also provides a
114114# ``key_padding_mask`` that is used to mask out padding tokens in the ``key``
115- # that arise due to different sequence lengths within a batch.
115+ # that arise due to different sequence lengths within a batch. Since there is
116+ # no ``query_padding_mask`` in ``nn.MHA``, users have to take care to mask/slice
117+ # the outputs appropriately to account for query sequence lengths. Nested tensor
118+ # cleanly removes the need for this sort of error-prone padding masks.
116119#
117120# * Memory
118121# Instead of materializing a dense ``[B, S, D]`` tensor with a ``[B, S]``
123126#
124127# * Performance
125128# Since unnecessary computation on padding is skipped, performance improves.
126- # We'll demonstrate this by building off the ``MultiheadAttention`` layer in the
129+ #
130+ # We'll demonstrate the above by building off the ``MultiheadAttention`` layer in the
127131# `Nested Tensor tutorial <https://pytorch.org/tutorials/prototype/nestedtensor.html>`_
132+ # and comparing it to the ``nn.MultiheadAttention`` layer.
128133
129134import torch
130135import torch .nn as nn
@@ -142,6 +147,7 @@ class MultiHeadAttention(nn.Module):
142147 has dim E_total // nheads
143148 nheads (int): Number of heads
144149 dropout (float, optional): Dropout probability. Default: 0.0
150+ bias (bool, optional): Whether to add bias to input projection. Default: True
145151 """
146152 def __init__ (
147153 self ,
@@ -151,7 +157,7 @@ def __init__(
151157 E_total : int ,
152158 nheads : int ,
153159 dropout : float = 0.0 ,
154- bias = False ,
160+ bias = True ,
155161 device = None ,
156162 dtype = None ,
157163 ):
@@ -163,15 +169,21 @@ def __init__(
163169 if self ._qkv_same_embed_dim :
164170 self .packed_proj = nn .Linear (E_q , E_total * 3 , bias = bias , ** factory_kwargs )
165171 else :
166- self .query_proj = nn .Linear (E_q , E_total , bias = bias , ** factory_kwargs )
167- self .key_proj = nn .Linear (E_k , E_total , bias = bias , ** factory_kwargs )
168- self .value_proj = nn .Linear (E_v , E_total , bias = bias , ** factory_kwargs )
172+ self .q_proj = nn .Linear (E_q , E_total , bias = bias , ** factory_kwargs )
173+ self .k_proj = nn .Linear (E_k , E_total , bias = bias , ** factory_kwargs )
174+ self .v_proj = nn .Linear (E_v , E_total , bias = bias , ** factory_kwargs )
169175 E_out = E_q
170176 self .out_proj = nn .Linear (E_total , E_out , bias = bias , ** factory_kwargs )
171177 assert E_total % nheads == 0 , "Embedding dim is not divisible by nheads"
172178 self .E_head = E_total // nheads
173-
174- def forward (self , query : torch .Tensor , key : torch .Tensor , value : torch .Tensor , attn_mask = None , is_causal = False ) -> torch .Tensor :
179+ self .bias = bias
180+
181+ def forward (self ,
182+ query : torch .Tensor ,
183+ key : torch .Tensor ,
184+ value : torch .Tensor ,
185+ attn_mask = None ,
186+ is_causal = False ) -> torch .Tensor :
175187 """
176188 Forward pass; runs the following process:
177189 1. Apply input projection
@@ -196,16 +208,16 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, a
196208 query , key , value = torch .chunk (result , 3 , dim = - 1 )
197209 else :
198210 q_weight , k_weight , v_weight = torch .chunk (self .packed_proj .weight , 3 , dim = 0 )
199- if bias :
211+ if self . bias :
200212 q_bias , k_bias , v_bias = torch .chunk (self .packed_proj .bias , 3 , dim = 0 )
201213 else :
202214 q_bias , k_bias , v_bias = None , None , None
203215 query , key , value = F .linear (query , q_weight , q_bias ), F .linear (key , k_weight , k_bias ), F .linear (value , v_weight , v_bias )
204216
205217 else :
206- query = self .query_proj (query )
207- key = self .key_proj (key )
208- value = self .value_proj (value )
218+ query = self .q_proj (query )
219+ key = self .k_proj (key )
220+ value = self .v_proj (value )
209221
210222 # Step 2. Split heads and prepare for SDPA
211223 # reshape query, key, value to separate by head
@@ -219,7 +231,7 @@ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, a
219231 # Step 3. Run SDPA
220232 # (N, nheads, L_t, E_head)
221233 attn_output = F .scaled_dot_product_attention (
222- query , key , value , attn_mask = attn_mask , dropout = self .dropout , is_causal = is_causal )
234+ query , key , value , dropout_p = self .dropout , is_causal = is_causal )
223235 # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
224236 attn_output = attn_output .transpose (1 , 2 ).flatten (- 2 )
225237
@@ -395,11 +407,10 @@ def benchmark(func, *args, **kwargs):
395407# followed by a feed-forward network (FFN) with skip connections. Implementing
396408# this is fairly straightforward using the ``MultiheadAttention`` layer above and
397409# is actually the same as an ``nn.TransformerEncoderLayer`` with ``is_causal=True``.
398- #
399410
400- # We will demonstrate examples of implementing the rest of the nn layers but will
401- # omit that from this tutorial for brevity. The full code is available
402- # `here <https://github.com/mikaylagawarecki/temp>`_.
411+ # We demonstrate examples of implementing the rest of the nn layers
412+ # `here <https://github.com/mikaylagawarecki/temp>`_ but omit that from this
413+ # tutorial for brevity.
403414
404415###############################################################################
405416# Going one step further
0 commit comments