@@ -95,8 +95,11 @@ def weight_dequant_block(
9595
9696
9797def weight_dequant (x : torch .Tensor , s : torch .Tensor , dtype = torch .bfloat16 ):
98- if s .shape [1 ] == 1 :
99- # this is row quantized weight, just simple multiplication suffices
98+ # Per-tensor scale: single value for entire weight matrix
99+ if s .numel () == 1 :
100+ return x .to (dtype ) * s .view (1 , 1 ).to (dtype )
101+ # Row quantized weight: scale shape is (m, 1) or (n, 1)
102+ elif s .ndim == 2 and s .shape [1 ] == 1 :
100103 if x .shape [0 ] == s .shape [0 ]:
101104 y = x .to (dtype ) * s .to (dtype )
102105 elif x .shape [1 ] == s .shape [0 ]:
@@ -106,8 +109,8 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype = torch.bfloat16):
106109 else :
107110 raise ValueError (f"Incompatible shapes { x .shape = } , { s .shape = } " )
108111 return y
112+ # Block quantized weight: scale shape is (ceil(m/block_m), ceil(n/block_n))
109113 else :
110- # this is block quantized weight
111114 return weight_dequant_block (x , s , dtype = dtype )
112115
113116
@@ -238,44 +241,29 @@ def w8a8_block_fp8_matmul_triton(
238241 block_size : list [int ],
239242 output_dtype : torch .dtype = torch .float32 ,
240243) -> torch .Tensor :
241- """This function performs matrix multiplication with block-wise
242- quantization.
243- It takes two input tensors `A` and `B` with scales `As` and `Bs`.
244- The output is returned in the specified `output_dtype`.
245- Args:
246- A: The input tensor, e.g., activation.
247- B: The input tensor, e.g., weight.
248- As: The per-token-group quantization scale for `A`.
249- Bs: The per-block quantization scale for `B`.
250- block_size: The block size for per-block quantization. It should
251- be 2-dim, e.g., [128, 128].
252- output_dytpe: The dtype of the returned tensor.
253- Returns:
254- torch.Tensor: The result of matmul.
255- """
256- assert len (block_size ) == 2
257- block_n , block_k = block_size [0 ], block_size [1 ]
244+ """Block-wise FP8 matmul."""
245+ if block_size is None :
246+ block_n , block_k = 128 , 128
247+ else :
248+ assert len (block_size ) == 2
249+ block_n , block_k = block_size [0 ], block_size [1 ]
258250
251+ N , K = B .shape
259252 assert A .shape [- 1 ] == B .shape [- 1 ]
260253 assert A .shape [:- 1 ] == As .shape [:- 1 ] and A .is_contiguous ()
261254 assert triton .cdiv (A .shape [- 1 ], block_k ) == As .shape [- 1 ]
262- M = A .numel () // A .shape [- 1 ]
263-
264255 assert B .ndim == 2 and B .is_contiguous () and Bs .ndim == 2
265- N , K = B .shape
266256 assert triton .cdiv (N , block_n ) == Bs .shape [0 ]
267257 assert triton .cdiv (K , block_k ) == Bs .shape [1 ]
268258
259+ M = A .numel () // A .shape [- 1 ]
269260 C_shape = A .shape [:- 1 ] + (N ,)
270261 C = A .new_empty (C_shape , dtype = output_dtype )
271262
272263 BLOCK_SIZE_M = 128
273264 if M < BLOCK_SIZE_M :
274- BLOCK_SIZE_M = triton .next_power_of_2 (M )
275- BLOCK_SIZE_M = max (BLOCK_SIZE_M , 16 )
276- BLOCK_SIZE_K = block_k
277- assert block_k % BLOCK_SIZE_K == 0
278- BLOCK_SIZE_N = block_n
265+ BLOCK_SIZE_M = max (triton .next_power_of_2 (M ), 16 )
266+ BLOCK_SIZE_K , BLOCK_SIZE_N = block_k , block_n
279267
280268 def grid (META ):
281269 return (
@@ -342,29 +330,41 @@ def torchao_block_matmul(
342330class FP8BlockQuantLinear (torch .autograd .Function ):
343331 @staticmethod
344332 def forward (ctx , X , weight , weight_scale ):
345- # block_size = getattr(weight, 'block_size', [128,128])
346333 m , n = weight .shape
347- p , q = weight_scale .shape
348- block_size = getattr (weight , "block_size" , None ) or getattr (
349- weight_scale , "block_size" , [128 , 128 ]
350- )
351- assert block_size is not None , "block_size is not set"
352- if triton .cdiv (m , block_size [0 ]) != p or triton .cdiv (n , block_size [1 ]) != q :
353- if (
354- triton .cdiv (m , block_size [0 ]) == q
355- and triton .cdiv (n , block_size [1 ]) == p
356- ):
357- # weights are transposed during backward pass for training :)
358- # We transpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X
359- weight_scale = weight_scale .T
360- else :
361- raise ValueError (
362- f"Weight shape { weight .shape } and scales shape { weight_scale .shape } is not compatible with block size { block_size } "
363- )
334+
335+ # Save original scale for backward (before any transformation)
336+ original_weight_scale = weight_scale
337+
338+ # Handle per-tensor quantization: expand scalar to block scale shape
339+ if weight_scale .numel () == 1 :
340+ block_size = [128 , 128 ]
341+ # Expand scalar to (ceil(m/128), ceil(n/128)) - same value for all blocks
342+ num_blocks_m = triton .cdiv (m , block_size [0 ])
343+ num_blocks_n = triton .cdiv (n , block_size [1 ])
344+ weight_scale = weight_scale .expand (num_blocks_m , num_blocks_n ).contiguous ()
345+ else :
346+ # Block quantization path
347+ p , q = weight_scale .shape
348+ block_size = getattr (weight , "block_size" , None ) or getattr (
349+ weight_scale , "block_size" , [128 , 128 ]
350+ )
351+ assert block_size is not None , "block_size is not set"
352+ if triton .cdiv (m , block_size [0 ]) != p or triton .cdiv (n , block_size [1 ]) != q :
353+ if (
354+ triton .cdiv (m , block_size [0 ]) == q
355+ and triton .cdiv (n , block_size [1 ]) == p
356+ ):
357+ weight_scale = weight_scale .T
358+ original_weight_scale = weight_scale # Update for transposed case
359+ else :
360+ raise ValueError (
361+ f"Weight shape { weight .shape } and scales shape { weight_scale .shape } is not compatible with block size { block_size } "
362+ )
364363
365364 if not weight .is_contiguous ():
366365 weight = weight .contiguous ()
367- # this is replica of https://github.com/huggingface/transformers/blob/01c9e1ba683b3e50d7c76bf92f2d470759fd5e81/src/transformers/integrations/finegrained_fp8.py#L331-L353
366+
367+ # Quantize input and run FP8 matmul
368368 qinput , scale = act_quant (X , block_size [1 ])
369369 output = fp8_block_matmul (
370370 qinput ,
@@ -375,8 +375,7 @@ def forward(ctx, X, weight, weight_scale):
375375 output_dtype = X .dtype ,
376376 )
377377 ctx .weight = weight
378- ctx .weight_scale = weight_scale
379- ctx .block_size = block_size
378+ ctx .weight_scale = original_weight_scale # Save original for backward
380379 return output .to (X .dtype )
381380
382381 @staticmethod
@@ -592,11 +591,14 @@ def test_has_fbgemm():
592591
593592@torch_compile
594593def fp8_linear (X , weight , weight_scale , bias = None ):
595- if weight_scale .ndim == 2 and weight_scale .shape [1 ] > 1 :
596- # This is block quantized FP8 matmul
594+ # Per-tensor quantization: single scalar scale for entire weight
595+ # Block quantized FP8: 2D scale tensor with multiple columns
596+ if weight_scale .numel () == 1 or (
597+ weight_scale .ndim == 2 and weight_scale .shape [1 ] > 1
598+ ):
597599 out = fp8_block_quant_linear (X , weight , weight_scale )
600+ # Row/channel quantized FP8: 2D scale with shape (n, 1)
598601 else :
599- # Row quantized FP8
600602 out = fbgemm_fp8_linear (X , weight , weight_scale , bias )
601603 return out
602604
0 commit comments