Skip to content

Commit 3d33f74

Browse files
[mxfp] adjust num_stages for bf16/fp16 x mxfp (#8773)
For fp16/bf16 x mxfp, we upcast weight on the fly, so we should size smem_capacity accordingly. w/o thischange , gets the following error: "triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 263356, Hardware limit: 232448. Reducing block sizes or `num_stages` may help" for x.shape = [2048, 5120] bf16 x [32, 5120, 5120] float8_e4m3fn block_m=64, block_n=256, block_k=128, split_k=1, is_persistent=True -> leading to num_stages=4 # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because wasn't able to find a shape that runs reliably w/o OOMs. The example shape above 32 x 5120 x 5120 is too big. Will try to see if I can enable only on GB200. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent 3ac3994 commit 3d33f74

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

python/triton_kernels/triton_kernels/matmul_details/opt_flags_details/opt_flags_nvidia.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import warnings
2+
13
import torch
24
import triton
35
from triton_kernels import target_info
46
from triton_kernels.numerics_details.mxfp_details._downcast_to_mxfp import MXFP_BLOCK_SIZE
5-
from triton_kernels.tensor import FP4, bitwidth, get_layout
6-
from triton_kernels.tensor import Tensor
7+
from triton_kernels.tensor import FP4, Tensor, bitwidth, get_layout
78
from triton_kernels.tensor_details.layout import HopperMXScaleLayout
89
from triton_kernels.tensor_details.layout_details.blackwell_scale import BlackwellActMXScaleLayout
910

@@ -98,6 +99,14 @@ def compute_num_stages(
9899
if precision_config.max_num_imprecise_acc is not None:
99100
return 3
100101
weight_size = bitwidth(rhs_dtype) / 8
102+
if precision_config.b_mx_scale is not None and lhs_dtype in [torch.float16, torch.bfloat16]:
103+
# For fp16/bf16 x mxfp, we upcast weight on the fly, so size
104+
# smem_capacity accordingly.
105+
# w/o this, gets the following error:
106+
# "triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 263356, Hardware limit: 232448. Reducing block sizes or `num_stages` may help"
107+
# for x.shape = [2048, >=4096] bf16 x [32, >=4096, >=4096] float8_e4m3fn
108+
# block_m=64, block_n=256, block_k=128, split_k=1, is_persistent=True -> leading to num_stages=4
109+
weight_size = 2
101110
stage_size = block_m * block_k * lhs_dtype.itemsize + block_k * block_n * weight_size
102111
device_props = torch.cuda.get_device_properties(0)
103112
smem_capacity = device_props.shared_memory_per_block_optin
@@ -132,5 +141,10 @@ def compute_num_stages(
132141
elif has_native_mxfp:
133142
# mx scales
134143
stage_size += block_n * (block_k // int(MXFP_BLOCK_SIZE))
135-
num_stages = min(4, smem_capacity // int(stage_size))
144+
num_stages = min(smem_capacity // int(stage_size), 4)
145+
if num_stages == 0:
146+
warnings.warn(f"num_stages computed is 0 with {stage_size=} and {smem_capacity=}, "
147+
"bumping up to 1 but this may lead to out of shared memory errors, "
148+
"and in that case consider reducing block sizes.")
149+
num_stages = 1
136150
return num_stages

0 commit comments

Comments
 (0)