Skip to content

Commit 2726dac

Browse files
1009 round down largest power of 2 (#8413)
OCP standard recommends "Set 𝑋 to be the largest power-of-two1 less than or equal to max π‘‰π‘–βˆˆπ‘‰ (|𝑉𝑖|), divided by the largest power-of-two representable in the element data type." This PR changes the behavior of DequantScaleRoundingMode.ROUND_DOWN to follow this. Since ROUND_UP was the default and ROUND_DOWN would've caused lot of clipping assume not used. This PR should make ROUND_DOWN mode to be more useful. <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 5d2a7a9 commit 2726dac

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

β€Žpython/triton_kernels/tests/test_mxfp.pyβ€Ž

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,34 @@ def test_mxfp4_rounding_cases(dst_dtype, device):
4444
dequant_torch = upcast_from_mxfp_torch(quant_torch, scale_torch, dst_dtype, axis=1)
4545
assert_equal(dequant_torch, dequant)
4646

47+
# ROUND_DOWN should use the max power-of-two when computing scale.
48+
# Choose a block whose max is 33 so the chosen scale is
49+
# 2**floor(log2(33/(e2m1 max power of 2 = 4)) = 2**3 = 8 (exponent 127+3),
50+
# and the other values are multiples of representable FP4 values times 8
51+
# that allow exact reconstruction.
52+
x = torch.tensor([33.0, 24.0, 16.0, 8.0, 4.0, 0.0, -32.0, 0.0], device=device).bfloat16().view(1, -1, 1)
53+
quant, scale = downcast_to_mxfp(
54+
x,
55+
torch.uint8,
56+
axis=1,
57+
DEQUANT_SCALE_ROUNDING_MODE=DequantScaleRoundingMode.ROUND_DOWN,
58+
)
59+
dequant = upcast_from_mxfp(quant, scale, dst_dtype, axis=1)
60+
assert_equal(dequant[0, 1:, :], x[0, 1:, :])
61+
62+
# Golden: scale exponent is 127 + 3 for 2**3 = 8
63+
assert scale.item() == 127 + 3
64+
65+
# Torch reference path should match
66+
quant_torch, scale_torch = downcast_to_mxfp_torch(
67+
x,
68+
torch.uint8,
69+
axis=1,
70+
DEQUANT_SCALE_ROUNDING_MODE=DequantScaleRoundingMode.ROUND_DOWN,
71+
)
72+
assert_equal(quant_torch, quant)
73+
assert_equal(scale_torch, scale)
74+
4775

4876
@pytest.mark.parametrize("src_dtype", ["float4_e2m1", "float8_e5m2", "float8_e4m3fn"])
4977
@pytest.mark.parametrize("dst_dtype", ["float16", "bfloat16", "float32"])

β€Žpython/triton_kernels/triton_kernels/numerics_details/mxfp.pyβ€Ž

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# isort: off
22
# fmt: off
33
from enum import Enum
4+
import math
45
import triton
56
import torch
67
import torch.nn.functional as F
@@ -13,7 +14,10 @@
1314

1415

1516
class DequantScaleRoundingMode(Enum):
17+
# 2^round_up(log2(max/max_q)) avoids clipping the max value
1618
ROUND_UP = 0
19+
# 2^round_down(log2(max/max_power_of_2_q)) follows the OCP standard ~50% of
20+
# chance of clipping the max value.
1721
ROUND_DOWN = 1
1822

1923

@@ -176,7 +180,10 @@ def downcast_to_mxfp_torch(src_tensor: torch.Tensor, out_quant_type: torch.dtype
176180

177181
# Choose a max quantization value depending on type.
178182
max_quant_val = get_max_quant_val(out_quant_type)
179-
dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
183+
if DEQUANT_SCALE_ROUNDING_MODE == DequantScaleRoundingMode.ROUND_UP:
184+
dequant_scale = max_val / max_quant_val # shape: (..., padded_axis_shape//32, 1)
185+
else:
186+
dequant_scale = max_val / (2 ** math.floor(math.log2(max_quant_val)))
180187

181188
# Convert to int to round the FP32 scale, prior to quantization!
182189
ds_int = dequant_scale.view(torch.int32)

β€Žpython/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.pyβ€Ž

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,15 @@ def _get_max_quant_val(dtype: tl.constexpr):
1818
else:
1919
tl.static_assert(False, f"Invalid {dtype=}")
2020

21+
@triton.jit
22+
def _get_max_power_of_2_quant_val(dtype: tl.constexpr):
23+
if dtype == tl.uint8:
24+
return 4.0
25+
elif dtype == tl.float8e5:
26+
return 32768.0
27+
elif dtype == tl.float8e4nv:
28+
return 256.0
29+
2130
@triton.jit
2231
def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.constexpr,
2332
DEQUANT_SCALE_ROUNDING_MODE: tl.constexpr = 0):
@@ -32,18 +41,19 @@ def _compute_quant_and_scale(src_tensor, valid_src_mask, mx_tensor_dtype: tl.con
3241
abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) # Don't consider padding tensors in scale computation
3342
abs_tensor = tl.reshape(abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, MXFP_BLOCK_SIZE])
3443
max_val = tl.max(abs_tensor, axis=2, keep_dims=True)
35-
dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
3644
if DEQUANT_SCALE_ROUNDING_MODE == 0:
3745
# DequantScaleRoundingMode.ROUND_UP
3846
# compute 2 ** ceil(log2(dequant_scale))
3947
# Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros
4048
# A corner case: exponent is 0xFF that will overflow but that's already
4149
# NaN so assume we don't care.
50+
dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype)
4251
dequant_scale_exponent = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000
4352
else:
4453
# DequantScaleRoundingMode.ROUND_DOWN
4554
# compute 2 ** floor(log2(dequant_scale))
4655
assert DEQUANT_SCALE_ROUNDING_MODE == 1
56+
dequant_scale = max_val / _get_max_power_of_2_quant_val(mx_tensor_dtype)
4757
dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000
4858
dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True)
4959
quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded)

0 commit comments

Comments
Β (0)