You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am developing intuition for the heuristic used to compute the number of warps.
The _layer_norm_fwd_fused kernel uses BLOCK_SIZE for accumulation in for loops over an input vector. But the forward method of the LayerNorm class sets BLOCK_SIZE to the next power of two of the input vector dimensionality. There is one iteration in the _layer_norm_fwd_fused kernel and no need for accumulation.
There is also no accumulation in the _layer_norm_bwd_dx_fused kernel that takes BLOCK_SIZE value from the context set in the forward pass, suggesting that the for loops are not used for compiler prompting. Because BLOCK_SIZE_N in the _layer_norm_bwd_dx_fused kernel cannot be less than the input vector dimensionality, the use of the for loops in the _layer_norm_fwd_fused kernel can lead to potential bugs.
An instance of the _layer_norm_{fwd_fused, bwd_dx_fused} kernel processes one entire input vector and uses the number of warps according to the heuristic. Upto 8 warps are used resulting in upto 256 threads. To get a number of warps other than 8, BLOCK_SIZE is divided by 256. This suggests that i) 256 contiguous 16-bit elements are accessed by 32 threads in a warp, and ii) a thread accesses 8 such elements in a single 128-bit vectorized load/store transaction.
In the _layer_norm_bwd_dwdb kernel, the accesses to partial weight and bias gradients are in contiguous row segments of 128 16-bit elements. One half of a warp would access a contiguous row segment with one 128-bit vectorized transaction per thread, and another half of the warp would similarly access another non-adjacent contiguous row segment. Note that the default number of warps appears to be used here, in contrast to the layer_norm_{fwd_fused, bwd_dx_fused} kernels.
Based on this analysis, the heuristic for computing the number of warps uses the following criteria:
one 128-bit vectorized load/store transaction per thread per data block,
threads in a warp access one contiguous segment, or two non-adjacent contiguous segments, and
a thread block preferably has upto 256 threads; 256 threads may be suitable for high occupancy across NVIDIA architectures.
I also assume that the .sum and += accumulations are automatically optimized by the compiler into parallel scans with O(log N) step complexity.
The purpose of the for loops in the layer_norm_fwd_fused kernel remains unclear. Any comments regarding the heuristic are also appreciated. Thank you.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
I am developing intuition for the heuristic used to compute the number of warps.
The
_layer_norm_fwd_fused
kernel usesBLOCK_SIZE
for accumulation in for loops over an input vector. But theforward
method of theLayerNorm
class setsBLOCK_SIZE
to the next power of two of the input vector dimensionality. There is one iteration in the_layer_norm_fwd_fused
kernel and no need for accumulation.There is also no accumulation in the
_layer_norm_bwd_dx_fused
kernel that takesBLOCK_SIZE
value from the context set in the forward pass, suggesting that the for loops are not used for compiler prompting. BecauseBLOCK_SIZE_N
in the_layer_norm_bwd_dx_fused
kernel cannot be less than the input vector dimensionality, the use of the for loops in the_layer_norm_fwd_fused
kernel can lead to potential bugs.An instance of the
_layer_norm_{fwd_fused, bwd_dx_fused}
kernel processes one entire input vector and uses the number of warps according to the heuristic. Upto 8 warps are used resulting in upto 256 threads. To get a number of warps other than 8,BLOCK_SIZE
is divided by 256. This suggests that i) 256 contiguous 16-bit elements are accessed by 32 threads in a warp, and ii) a thread accesses 8 such elements in a single 128-bit vectorized load/store transaction.In the
_layer_norm_bwd_dwdb
kernel, the accesses to partial weight and bias gradients are in contiguous row segments of 128 16-bit elements. One half of a warp would access a contiguous row segment with one 128-bit vectorized transaction per thread, and another half of the warp would similarly access another non-adjacent contiguous row segment. Note that the default number of warps appears to be used here, in contrast to thelayer_norm_{fwd_fused, bwd_dx_fused}
kernels.Based on this analysis, the heuristic for computing the number of warps uses the following criteria:
I also assume that the
.sum
and+=
accumulations are automatically optimized by the compiler into parallel scans with O(log N) step complexity.The purpose of the for loops in the
layer_norm_fwd_fused
kernel remains unclear. Any comments regarding the heuristic are also appreciated. Thank you.Beta Was this translation helpful? Give feedback.
All reactions