Skip to content

Conversation

@masahi
Copy link
Collaborator

@masahi masahi commented Oct 19, 2025

The main motivation is to support running gpt-oss on DGX Spark (sm121). Until we properly enable mixed-precision MXFP and TMA for sm120 / sm121, we need to fallback to the sm80 compilation path.

All MoE tests pass on RTX6000 with this change:


============================================================================ 2532 passed, 6621 skipped in 19.60s ============================================================================

Related issue
#8335

@masahi masahi requested a review from Jokeren October 19, 2025 22:59
@masahi masahi requested a review from ptillet as a code owner October 19, 2025 22:59


@triton.constexpr_function
def cuda_capability_geq(major, minor=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what other properties are uncorrect for sm_120?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what you mean by "other" or "incorrect properties". Without this workaround, the kernel tries to use native mxfp and TMA, assuming that sm120 has full features set of sm100. But those are the only things that are currently breaking gpt-oss on sm120 / 121.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant in addition of the checks you modified. Do you know which use of cuda_capability_geq is causing problems

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to say. I've seen two kinds of errors - one is use of TMA gather4 / scatter4, and other is some shape mismatch in dot. cuda_capability_geq is used in many places and the options supported by the kernel are very broad, I don't know which of them are actually problematic. Indeed, if we want to optimize for sm120 / 121, we need a more fine-grained approach to the capability check rather than falling back everything to sm80.

For example, the determination of the weight layout is highly architecture specific: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout.py#L22-L27. Even if we allowed has_native_mxfp to evaluate to True for sm120, I don't know if BlackwellMXValueLayout is compatible with the dot shape of MMAv2.

# hopper w/ mxfp4 doesn't support TMA
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4)
# hopper or sm120 w/ mxfp4 doesn't support TMA
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] == 10 or bitwidth(w.dtype) != 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we need a separate helper logic? I'm pretty sure we will enable TMA on hopper at some point so this will break.

Copy link
Collaborator Author

@masahi masahi Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but this one is a bit different since this is an ad-hoc check due to a kernel limitation rather than an architecture one. We could add something like target_info.supports_tma(), but that needs to return False for Hopper today, which is a bit odd. So when the kernel supports TMA for Hopper in the future, we need to update the helper anyway.

As a middle ground, how about something like this?

    # hopper or sm120 w/ mxfp4 doesn't support TMA
    supports_tma = [10] # Add 9 when the Hopper impl supports TMA
    can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] in supports_tma or bitwidth(w.dtype) != 4)

This way, when Hopper supports TMA, we can safely update it without breaking sm120. The condition torch.cuda.get_device_capability()[0] >= 9 might not be correct depending on how well sm120 TMA is supported by the kernel.



@triton.constexpr_function
def cuda_capability_geq(major, minor=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant in addition of the checks you modified. Do you know which use of cuda_capability_geq is causing problems

@triton.constexpr_function
def cuda_capability_geq(major, minor=0):
target = current_target()
if target.arch // 10 == 12 and major > 8:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this is a workaround but the function name doesn't reflect what's really doing. sm80 and sm120 still have subtle differences in the instructions.

Is it possible to separate the logic from this function?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark related changes are good to me. Thanks for catching these problems!

Copy link
Collaborator Author

@masahi masahi Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sm80 and sm120 still have subtle differences in the instructions

Yes in terms of the architecture, but what really matters is if those differences are recognized by the compiler or the kernel. Support for sm120 in the compiler is very limited, so from the compiler / kernel perspectives, sm80 and sm120 are pretty much the same.

We could introduce another helper to distinguish those kernel / compiler limitations. The Hopper limitation on TMA #8484 (comment) is another good example. But cuda_capability_geq is already used in so many places and adding another conditions makes things even more complicated.

The pervasive use of cuda_capability_geq indicates that the kernel treats "higher compute capability" as "more features". But as of sm120 this is no longer true. Checking compute capability is also meaningless when the relevant support is not available in the compiler or the kernel. So rather than adding more ad-hoc helpers / checks, we should revisit the use of compute capability as a criteria for feature selections.

I think we need some kind of "Backend" class from which all supported SM variants are derived. We can encode all target-specific available feature sets supported by the kernel there. We can cleanly express idiosyncrasies of the kernel, like

  • "SM90" backend does not support TMA with mxfp4 due to a kernel limitation, despite the support by HW
  • "SM120" backend does not support native MXFP or TMA due to compiler limitation, despite the support by HW

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants