Skip to content

Commit 18d020e

Browse files
[Feat] FP8 per tensor quant support (#4043)
* FP8 per tensor quant support * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5557283 commit 18d020e

File tree

1 file changed

+54
-52
lines changed

1 file changed

+54
-52
lines changed

unsloth/kernels/fp8.py

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,11 @@ def weight_dequant_block(
9595

9696

9797
def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype = torch.bfloat16):
98-
if s.shape[1] == 1:
99-
# this is row quantized weight, just simple multiplication suffices
98+
# Per-tensor scale: single value for entire weight matrix
99+
if s.numel() == 1:
100+
return x.to(dtype) * s.view(1, 1).to(dtype)
101+
# Row quantized weight: scale shape is (m, 1) or (n, 1)
102+
elif s.ndim == 2 and s.shape[1] == 1:
100103
if x.shape[0] == s.shape[0]:
101104
y = x.to(dtype) * s.to(dtype)
102105
elif x.shape[1] == s.shape[0]:
@@ -106,8 +109,8 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, dtype = torch.bfloat16):
106109
else:
107110
raise ValueError(f"Incompatible shapes {x.shape = }, {s.shape = }")
108111
return y
112+
# Block quantized weight: scale shape is (ceil(m/block_m), ceil(n/block_n))
109113
else:
110-
# this is block quantized weight
111114
return weight_dequant_block(x, s, dtype = dtype)
112115

113116

@@ -238,44 +241,29 @@ def w8a8_block_fp8_matmul_triton(
238241
block_size: list[int],
239242
output_dtype: torch.dtype = torch.float32,
240243
) -> torch.Tensor:
241-
"""This function performs matrix multiplication with block-wise
242-
quantization.
243-
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
244-
The output is returned in the specified `output_dtype`.
245-
Args:
246-
A: The input tensor, e.g., activation.
247-
B: The input tensor, e.g., weight.
248-
As: The per-token-group quantization scale for `A`.
249-
Bs: The per-block quantization scale for `B`.
250-
block_size: The block size for per-block quantization. It should
251-
be 2-dim, e.g., [128, 128].
252-
output_dytpe: The dtype of the returned tensor.
253-
Returns:
254-
torch.Tensor: The result of matmul.
255-
"""
256-
assert len(block_size) == 2
257-
block_n, block_k = block_size[0], block_size[1]
244+
"""Block-wise FP8 matmul."""
245+
if block_size is None:
246+
block_n, block_k = 128, 128
247+
else:
248+
assert len(block_size) == 2
249+
block_n, block_k = block_size[0], block_size[1]
258250

251+
N, K = B.shape
259252
assert A.shape[-1] == B.shape[-1]
260253
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
261254
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
262-
M = A.numel() // A.shape[-1]
263-
264255
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
265-
N, K = B.shape
266256
assert triton.cdiv(N, block_n) == Bs.shape[0]
267257
assert triton.cdiv(K, block_k) == Bs.shape[1]
268258

259+
M = A.numel() // A.shape[-1]
269260
C_shape = A.shape[:-1] + (N,)
270261
C = A.new_empty(C_shape, dtype = output_dtype)
271262

272263
BLOCK_SIZE_M = 128
273264
if M < BLOCK_SIZE_M:
274-
BLOCK_SIZE_M = triton.next_power_of_2(M)
275-
BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
276-
BLOCK_SIZE_K = block_k
277-
assert block_k % BLOCK_SIZE_K == 0
278-
BLOCK_SIZE_N = block_n
265+
BLOCK_SIZE_M = max(triton.next_power_of_2(M), 16)
266+
BLOCK_SIZE_K, BLOCK_SIZE_N = block_k, block_n
279267

280268
def grid(META):
281269
return (
@@ -342,29 +330,41 @@ def torchao_block_matmul(
342330
class FP8BlockQuantLinear(torch.autograd.Function):
343331
@staticmethod
344332
def forward(ctx, X, weight, weight_scale):
345-
# block_size = getattr(weight, 'block_size', [128,128])
346333
m, n = weight.shape
347-
p, q = weight_scale.shape
348-
block_size = getattr(weight, "block_size", None) or getattr(
349-
weight_scale, "block_size", [128, 128]
350-
)
351-
assert block_size is not None, "block_size is not set"
352-
if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q:
353-
if (
354-
triton.cdiv(m, block_size[0]) == q
355-
and triton.cdiv(n, block_size[1]) == p
356-
):
357-
# weights are transposed during backward pass for training :)
358-
# We transpose weight scale to counter that. Note that transposing weight would cause issues with matmul with input X
359-
weight_scale = weight_scale.T
360-
else:
361-
raise ValueError(
362-
f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}"
363-
)
334+
335+
# Save original scale for backward (before any transformation)
336+
original_weight_scale = weight_scale
337+
338+
# Handle per-tensor quantization: expand scalar to block scale shape
339+
if weight_scale.numel() == 1:
340+
block_size = [128, 128]
341+
# Expand scalar to (ceil(m/128), ceil(n/128)) - same value for all blocks
342+
num_blocks_m = triton.cdiv(m, block_size[0])
343+
num_blocks_n = triton.cdiv(n, block_size[1])
344+
weight_scale = weight_scale.expand(num_blocks_m, num_blocks_n).contiguous()
345+
else:
346+
# Block quantization path
347+
p, q = weight_scale.shape
348+
block_size = getattr(weight, "block_size", None) or getattr(
349+
weight_scale, "block_size", [128, 128]
350+
)
351+
assert block_size is not None, "block_size is not set"
352+
if triton.cdiv(m, block_size[0]) != p or triton.cdiv(n, block_size[1]) != q:
353+
if (
354+
triton.cdiv(m, block_size[0]) == q
355+
and triton.cdiv(n, block_size[1]) == p
356+
):
357+
weight_scale = weight_scale.T
358+
original_weight_scale = weight_scale # Update for transposed case
359+
else:
360+
raise ValueError(
361+
f"Weight shape {weight.shape} and scales shape {weight_scale.shape} is not compatible with block size {block_size}"
362+
)
364363

365364
if not weight.is_contiguous():
366365
weight = weight.contiguous()
367-
# this is replica of https://github.com/huggingface/transformers/blob/01c9e1ba683b3e50d7c76bf92f2d470759fd5e81/src/transformers/integrations/finegrained_fp8.py#L331-L353
366+
367+
# Quantize input and run FP8 matmul
368368
qinput, scale = act_quant(X, block_size[1])
369369
output = fp8_block_matmul(
370370
qinput,
@@ -375,8 +375,7 @@ def forward(ctx, X, weight, weight_scale):
375375
output_dtype = X.dtype,
376376
)
377377
ctx.weight = weight
378-
ctx.weight_scale = weight_scale
379-
ctx.block_size = block_size
378+
ctx.weight_scale = original_weight_scale # Save original for backward
380379
return output.to(X.dtype)
381380

382381
@staticmethod
@@ -592,11 +591,14 @@ def test_has_fbgemm():
592591

593592
@torch_compile
594593
def fp8_linear(X, weight, weight_scale, bias = None):
595-
if weight_scale.ndim == 2 and weight_scale.shape[1] > 1:
596-
# This is block quantized FP8 matmul
594+
# Per-tensor quantization: single scalar scale for entire weight
595+
# Block quantized FP8: 2D scale tensor with multiple columns
596+
if weight_scale.numel() == 1 or (
597+
weight_scale.ndim == 2 and weight_scale.shape[1] > 1
598+
):
597599
out = fp8_block_quant_linear(X, weight, weight_scale)
600+
# Row/channel quantized FP8: 2D scale with shape (n, 1)
598601
else:
599-
# Row quantized FP8
600602
out = fbgemm_fp8_linear(X, weight, weight_scale, bias)
601603
return out
602604

0 commit comments

Comments
 (0)