1313 compute_scale ,
1414)
1515from triton_kernels .numerics_details .mxfp_details ._downcast_to_mxfp import MXFP_BLOCK_SIZE
16+ from triton_kernels .tensor_details .layout_details .hopper_scale import unswizzle_mxfp4_scale_hopper
17+ from triton_kernels .tensor_details .layout_details .hopper_value import mxfp4_to_bf16_triton
1618from ._common import (
1719 compute_offsets ,
1820 get_scaled_dot_format_string ,
@@ -112,25 +114,48 @@ def _p_matmul(
112114 if Y_TMA_MODE is not None :
113115 Y = tl .make_tensor_descriptor (YPtr , Y .shape , Y .strides [:- 1 ] + (1 ,), Y .block_shape )
114116
117+ w_type : tl .constexpr = get_dtype (W )
115118 is_w_microscaled : tl .constexpr = WMxScale is not None
119+ is_x_microscaled : tl .constexpr = XMxScale is not None
120+ is_w_mxfp4 : tl .constexpr = w_type == tl .uint8 and is_w_microscaled
116121 tl .static_assert (not is_w_microscaled or W_TRANSPOSE , "NYI. Non-transposed mxfp4 weights" )
117122 MX_PACK_DIVISOR : tl .constexpr = MXFP_BLOCK_SIZE
118123 if is_w_microscaled :
119- w_type : tl .constexpr = get_dtype (W )
120124 tl .static_assert (w_type == tl .uint8 or (w_type == tl .float8e4nv or w_type == tl .float8e5 ),
121125 "mx_weight_ptr must be uint8 or fp8" )
122126 tl .static_assert (get_dtype (WMxScale ) == tl .uint8 , "mx_scale_ptr must be uint8" )
123127 tl .static_assert (BLOCK_K % MX_PACK_DIVISOR == 0 , "BLOCK_K must be a multiple of MX_PACK_DIVISOR" )
124- tl .static_assert (SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" or SWIZZLE_MX_SCALE is None , "Only Blackwell swizzling is supported for scales" )
125128
126129 # We have pack 2 fp4 values in a byte
127- W_PACK_DIVISOR : tl .constexpr = 2 if w_type == tl .uint8 else 1
128- PACKED_BLOCK_K_W : tl .constexpr = BLOCK_K // W_PACK_DIVISOR
129130 MX_SCALE_BLOCK_K : tl .constexpr = BLOCK_K // MX_PACK_DIVISOR
131+ if SWIZZLE_MX_VALUE == "HOPPER_VALUE" :
132+ tl .static_assert (is_w_mxfp4 , "Only mxfp4 is supported for HOPPER swizzling" )
133+ tl .static_assert (not is_x_microscaled )
134+ # We have pack 2 fp4 values in a byte but we divide the dimension by 2
135+ # when swizzling
136+ W_K_DIVISOR : tl .constexpr = 1
137+ W_K_MULTIPLIER : tl .constexpr = 2
138+ W_N_DIVISOR : tl .constexpr = 4
139+ else :
140+ # We have pack 2 fp4 values in a byte
141+ W_K_DIVISOR : tl .constexpr = 2 if is_w_mxfp4 else 1
142+ W_K_MULTIPLIER : tl .constexpr = 1
143+ W_N_DIVISOR : tl .constexpr = 1
144+
145+ if W_TRANSPOSE :
146+ # When weight is transposed, 2 fp4 values are packed per Byte along
147+ # the contiguous dimension, K.
148+ PACKED_BLOCK_K_W : tl .constexpr = (BLOCK_K // W_K_DIVISOR ) * W_K_MULTIPLIER
149+ PACKED_BLOCK_N_W : tl .constexpr = BLOCK_N // W_N_DIVISOR
150+ else :
151+ # When weight is not transposed, fp4 values are *not* packed along
152+ # the contiguous dimension, N.
153+ PACKED_BLOCK_K_W : tl .constexpr = BLOCK_K
154+ PACKED_BLOCK_N_W : tl .constexpr = BLOCK_N // W_K_DIVISOR
130155 else :
131156 PACKED_BLOCK_K_W : tl .constexpr = BLOCK_K
157+ PACKED_BLOCK_N_W : tl .constexpr = BLOCK_N
132158 tl .static_assert (SWIZZLE_MX_SCALE is None )
133- is_x_microscaled : tl .constexpr = XMxScale is not None
134159 if is_x_microscaled :
135160 x_type : tl .constexpr = get_dtype (X )
136161 tl .static_assert (x_type == tl .float8e4nv , "mx_act_ptr must be float8e4nv" )
@@ -202,6 +227,7 @@ def _p_matmul(
202227 else :
203228 shape_m = M
204229 off_n = BLOCK_N * pid_n
230+ off_w_n = PACKED_BLOCK_N_W * pid_n
205231
206232 # ---- offset x ------
207233 if USE_GATHER_TMA :
@@ -283,7 +309,7 @@ def _p_matmul(
283309 x_format : tl .constexpr = get_scaled_dot_format_string (x .dtype )
284310 if is_x_microscaled :
285311 if XMxScalePtrs is not None : # not using TMA for x scale load
286- off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_PACK_DIVISOR )
312+ off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_K_DIVISOR )
287313 if EVEN_K :
288314 mask_k_scale = tl .full ([MX_SCALE_BLOCK_K ], True , dtype = tl .int1 )
289315 else :
@@ -306,30 +332,47 @@ def _p_matmul(
306332
307333 # --- load w ---
308334 if W_TRANSPOSE :
309- w = tl .reshape (W .load ([off_w_z , off_n , off_k_w ]), W .block_shape [1 :]).T
335+ w = tl .reshape (W .load ([off_w_z , off_w_n , off_k_w ]), W .block_shape [1 :]).T
310336 else :
311- w = tl .reshape (W .load ([off_w_z , off_k_w , off_n ]), W .block_shape [1 :])
337+ w = tl .reshape (W .load ([off_w_z , off_k_w , off_w_n ]), W .block_shape [1 :])
312338
313339 # --- load w_scale ---
314340 w_format : tl .constexpr = get_scaled_dot_format_string (w .dtype )
315341 if is_w_microscaled :
316- off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_PACK_DIVISOR )
317- tl .static_assert (MX_PACK_DIVISOR % W_PACK_DIVISOR == 0 )
342+ off_k_mx = off_k_w // (MX_PACK_DIVISOR // W_K_DIVISOR )
343+ tl .static_assert (MX_PACK_DIVISOR % W_K_DIVISOR == 0 )
318344 if SWIZZLE_MX_SCALE == "BLACKWELL_SCALE" :
319345 flattened_expt_n_idx = off_w_z * ((N + 127 ) // 128 ) + (off_n // 128 )
320346 w_scales = WMxScale .load ([0 , flattened_expt_n_idx , off_k_mx // 4 , 0 , 0 ])
321347 w_scales = w_scales .reshape ((w_scales .shape [1 ], w_scales .shape [2 ] * w_scales .shape [- 2 ] * w_scales .shape [- 1 ]))
322348 w_scales = unswizzle_mx_scale_bw (w_scales )
349+ elif SWIZZLE_MX_SCALE == "HOPPER_SCALE" :
350+ # NYI: Hopper swizzling with non-transposed W
351+ tl .static_assert (W_TRANSPOSE )
352+ off_n_scale = pid_n * (BLOCK_N // 32 )
353+ off_k_scale = (off_k_w // PACKED_BLOCK_K_W ) * MX_SCALE_BLOCK_K * 32
354+ w_scales = WMxScale .load ([off_w_z , off_n_scale , off_k_scale ])
355+ w_scales = tl .reshape (w_scales , * w_scales .shape [1 :])
356+ num_warps : tl .constexpr = tl .extra .cuda .num_warps ()
357+ w_scales = unswizzle_mxfp4_scale_hopper (w_scales , mx_axis = 1 , num_warps = num_warps )
323358 else :
324359 w_scales = WMxScale .load ([off_w_z , off_k_mx , off_n ])
325360 w_scales = tl .reshape (w_scales , * w_scales .shape [1 :]).T
326361
327362 # --- update accumulator ---
328363 if is_w_microscaled :
329- if SWAP_XW :
330- acc = tl .dot_scaled (w .T , w_scales , w_format , x .T , x_scales , x_format , acc = acc , fast_math = True )
364+ if SWIZZLE_MX_VALUE == "HOPPER_VALUE" :
365+ tl .static_assert (x_format == "bf16" )
366+ tl .static_assert (w_format == "e2m1" )
367+ tl .static_assert (SWAP_XW )
368+ wT = mxfp4_to_bf16_triton (w .T , w_scales , mx_axis = 1 )
369+ tl .static_assert (wT .dtype == tl .bfloat16 )
370+ acc = tl .dot (wT , x .T , acc , max_num_imprecise_acc = MAX_NUM_IMPRECISE_ACC , allow_tf32 = ALLOW_TF32 )
331371 else :
332- acc = tl .dot_scaled (x , x_scales , x_format , w , w_scales , w_format , acc = acc , fast_math = True )
372+ if SWAP_XW :
373+ acc = tl .dot_scaled (w .T , w_scales , w_format , x .T , x_scales , x_format , acc = acc , fast_math = True )
374+ else :
375+ acc = tl .dot_scaled (x , x_scales , x_format , w , w_scales , w_format , acc = acc , fast_math = True )
333376 else :
334377 if SWAP_XW :
335378 acc = tl .dot (w .T , x .T , acc , max_num_imprecise_acc = MAX_NUM_IMPRECISE_ACC , allow_tf32 = ALLOW_TF32 )
0 commit comments