Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions python/triton_kernels/bench/bench_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import triton_kernels.roofline as roofline
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
from triton_kernels.target_info import get_cdna_version
from triton_kernels.target_info import get_cdna_version, has_native_mxfp
import distributed as triton_dist
from triton_kernels.tensor_details import layout
from bench_utils import quantize_weight
Expand Down Expand Up @@ -111,12 +111,11 @@ def roofline_mlp(batch_sizes, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d


if __name__ == "__main__":
has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4
batch_sizes_dense = [*range(128, 8192, 128)]
batch_ranges_moe = [(2**(2 + k), 2**(3 + k), min(2**k, 32)) for k in range(8)]
batch_sizes_moe = list(chain(*[range(*r) for r in batch_ranges_moe]))
dense_dtypes = ["fp8", "fp8"]
quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"]
quantized_dtypes = ["fp8", "mx4"] if has_native_mxfp() else ["bf16", "mx4"]
rank, world_size = triton_dist.setup()
if world_size > 1:
# Running all workloads at once may cause OOM on some GPUs such as H100 80GB.
Expand Down
7 changes: 2 additions & 5 deletions python/triton_kernels/bench/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from triton_kernels.topk import topk_torch
from triton_kernels.topk import topk
from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation
from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq
from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq, has_native_mxfp
from triton_kernels.tensor_details import layout
from triton_kernels.tensor import BIT, SparseMatrix, Bitmatrix, make_ragged_tensor_metadata

Expand Down Expand Up @@ -716,9 +716,6 @@ def distributed(x):
dist.destroy_process_group()


has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4


@pytest.mark.parametrize(
"batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP",
[
Expand All @@ -741,7 +738,7 @@ def distributed(x):
(1024, 1024, 1024, 128, 2, "fp8", "mx4", 4, 1),
(1024, 1024, 1024, 128, 2, "fp8", "mx4", 1, 4),
(1024, 1024, 1024, 128, 2, "fp8", "mx4", 2, 2),
] if has_native_mx4 else [
] if has_native_mxfp() else [
(1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 1),
(1024, 1024, 1024, 128, 2, "bf16", "mx4", 4, 1),
(1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 4),
Expand Down
12 changes: 6 additions & 6 deletions python/triton_kernels/tests/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mo
if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2):
gs0 = None
gs1 = None
if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10:
if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] != 10:
w = w.transpose(-1, -2).contiguous().transpose(-1, -2)

def _apply_padding_and_fill_unused_part_with_nan(t, is_padded):
Expand Down Expand Up @@ -324,8 +324,8 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
if "float16" in act_dtype_str and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10:
pytest.skip("float16 x mx not supported with cuda capability >= 10")
if weight_dtype_str.startswith("mx"):
if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10:
pytest.skip("float8 x mx not supported with cuda capability < 10")
if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] != 10:
pytest.skip("float8 x mx not supported with cuda capability != 10")
if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9:
pytest.skip("Not enough memory on A100")

Expand Down Expand Up @@ -372,7 +372,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
torch.manual_seed(0)

block_k = None
if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10:
if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] != 10:
# Override block_k for testing correctness. The default is temporarily 128 for
# performance reasons which doesn't work with persistent matmul.
# TODO: revisit when Triton is better for H100 + MXFP4
Expand Down Expand Up @@ -472,7 +472,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o
w_tri = convert_layout(w_tri, w_layout, **w_layout_opts)
w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts)
else:
if torch.cuda.get_device_capability()[0] < 10:
if torch.cuda.get_device_capability()[0] != 10:
pytest.skip("transposed mxfp weight not supported with cuda capability < 10")
if block_m == 16:
pytest.skip("PassManager::run failed from Triton compiler")
Expand Down Expand Up @@ -645,7 +645,7 @@ def _make_tensor(shape, dtype, trans):
(torch.float16, torch.bfloat16, torch.float8_e5m2),
):
if (
torch.cuda.get_device_capability()[0] < 10
torch.cuda.get_device_capability()[0] != 10
and dtype is torch.float8_e5m2
and (not w_transpose)
):
Expand Down
8 changes: 4 additions & 4 deletions python/triton_kernels/triton_kernels/matmul_ogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,8 @@ def matmul_ogs(x, w, bias,
# TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real
dtype = FP4 if w.dtype == torch.uint8 else w.dtype
w = wrap_torch_tensor(w, dtype=dtype)
if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)"
if w_has_mx and (torch.cuda.get_device_capability()[0] != 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)):
assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not sm100)"
if w_scale is not None and not isinstance(w_scale, Tensor):
w_scale = Tensor(w_scale)
if w_scale is not None:
Expand Down Expand Up @@ -536,8 +536,8 @@ def matmul_ogs(x, w, bias,
(inner_routing_data is None or w.stride(-1) == 1 or inner_routing_data.w_is_padded)
)
has_gather_tma = has_gather and target_info.has_tma_gather()
# hopper w/ mxfp4 doesn't support TMA
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
# hopper or sm120 w/ mxfp4 doesn't support TMA
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] == 10 or bitwidth(w.dtype) != 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we need a separate helper logic? I'm pretty sure we will enable TMA on hopper at some point so this will break.

Copy link
Collaborator Author

@masahi masahi Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but this one is a bit different since this is an ad-hoc check due to a kernel limitation rather than an architecture one. We could add something like target_info.supports_tma(), but that needs to return False for Hopper today, which is a bit odd. So when the kernel supports TMA for Hopper in the future, we need to update the helper anyway.

As a middle ground, how about something like this?

    # hopper or sm120 w/ mxfp4 doesn't support TMA
    supports_tma = [10] # Add 9 when the Hopper impl supports TMA
    can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] in supports_tma or bitwidth(w.dtype) != 4)

This way, when Hopper supports TMA, we can safely update it without breaking sm120. The condition torch.cuda.get_device_capability()[0] >= 9 might not be correct depending on how well sm120 TMA is supported by the kernel.

can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1)
opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config,
batch_size, M, N, w.shape[-2], routing_data,
Expand Down
13 changes: 12 additions & 1 deletion python/triton_kernels/triton_kernels/target_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import triton.language as tl

from triton.language.target_info import (
cuda_capability_geq,
is_cuda,
is_hip,
is_hip_cdna3,
is_hip_cdna4,
current_target,
)

from triton.language.target_info import cuda_capability_geq as _cuda_capability_geq

__all__ = [
"cuda_capability_geq",
"get_cdna_version",
Expand All @@ -23,6 +25,15 @@
]


@triton.constexpr_function
def cuda_capability_geq(major, minor=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what other properties are uncorrect for sm_120?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean by "other" or "incorrect properties". Without this workaround, the kernel tries to use native mxfp and TMA, assuming that sm120 has full features set of sm100. But those are the only things that are currently breaking gpt-oss on sm120 / 121.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant in addition of the checks you modified. Do you know which use of cuda_capability_geq is causing problems

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to say. I've seen two kinds of errors - one is use of TMA gather4 / scatter4, and other is some shape mismatch in dot. cuda_capability_geq is used in many places and the options supported by the kernel are very broad, I don't know which of them are actually problematic. Indeed, if we want to optimize for sm120 / 121, we need a more fine-grained approach to the capability check rather than falling back everything to sm80.

For example, the determination of the weight layout is highly architecture specific: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout.py#L22-L27. Even if we allowed has_native_mxfp to evaluate to True for sm120, I don't know if BlackwellMXValueLayout is compatible with the dot shape of MMAv2.

target = current_target()
if target.arch // 10 == 12 and major > 8:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this is a workaround but the function name doesn't reflect what's really doing. sm80 and sm120 still have subtle differences in the instructions.

Is it possible to separate the logic from this function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark related changes are good to me. Thanks for catching these problems!

Copy link
Collaborator Author

@masahi masahi Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm80 and sm120 still have subtle differences in the instructions

Yes in terms of the architecture, but what really matters is if those differences are recognized by the compiler or the kernel. Support for sm120 in the compiler is very limited, so from the compiler / kernel perspectives, sm80 and sm120 are pretty much the same.

We could introduce another helper to distinguish those kernel / compiler limitations. The Hopper limitation on TMA #8484 (comment) is another good example. But cuda_capability_geq is already used in so many places and adding another conditions makes things even more complicated.

The pervasive use of cuda_capability_geq indicates that the kernel treats "higher compute capability" as "more features". But as of sm120 this is no longer true. Checking compute capability is also meaningless when the relevant support is not available in the compiler or the kernel. So rather than adding more ad-hoc helpers / checks, we should revisit the use of compute capability as a criteria for feature selections.

I think we need some kind of "Backend" class from which all supported SM variants are derived. We can encode all target-specific available feature sets supported by the kernel there. We can cleanly express idiosyncrasies of the kernel, like

  • "SM90" backend does not support TMA with mxfp4 due to a kernel limitation, despite the support by HW
  • "SM120" backend does not support native MXFP or TMA due to compiler limitation, despite the support by HW

# Pretend sm120 as sm80 for now
return False
return _cuda_capability_geq(major, minor)


@triton.constexpr_function
def get_cdna_version():
"""
Expand Down
Loading