-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[TRITON_KERNELS] Support sm120 / 121 via sm80 fallback #8484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,13 +3,15 @@ | |
| import triton.language as tl | ||
|
|
||
| from triton.language.target_info import ( | ||
| cuda_capability_geq, | ||
| is_cuda, | ||
| is_hip, | ||
| is_hip_cdna3, | ||
| is_hip_cdna4, | ||
| current_target, | ||
| ) | ||
|
|
||
| from triton.language.target_info import cuda_capability_geq as _cuda_capability_geq | ||
|
|
||
| __all__ = [ | ||
| "cuda_capability_geq", | ||
| "get_cdna_version", | ||
|
|
@@ -23,6 +25,15 @@ | |
| ] | ||
|
|
||
|
|
||
| @triton.constexpr_function | ||
| def cuda_capability_geq(major, minor=0): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what other properties are uncorrect for sm_120? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what you mean by "other" or "incorrect properties". Without this workaround, the kernel tries to use native mxfp and TMA, assuming that sm120 has full features set of sm100. But those are the only things that are currently breaking gpt-oss on sm120 / 121. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant in addition of the checks you modified. Do you know which use of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's hard to say. I've seen two kinds of errors - one is use of TMA gather4 / scatter4, and other is some shape mismatch in dot. For example, the determination of the weight layout is highly architecture specific: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout.py#L22-L27. Even if we allowed |
||
| target = current_target() | ||
| if target.arch // 10 == 12 and major > 8: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand this is a workaround but the function name doesn't reflect what's really doing. sm80 and sm120 still have subtle differences in the instructions. Is it possible to separate the logic from this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Benchmark related changes are good to me. Thanks for catching these problems! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes in terms of the architecture, but what really matters is if those differences are recognized by the compiler or the kernel. Support for sm120 in the compiler is very limited, so from the compiler / kernel perspectives, sm80 and sm120 are pretty much the same. We could introduce another helper to distinguish those kernel / compiler limitations. The Hopper limitation on TMA #8484 (comment) is another good example. But The pervasive use of I think we need some kind of "Backend" class from which all supported SM variants are derived. We can encode all target-specific available feature sets supported by the kernel there. We can cleanly express idiosyncrasies of the kernel, like
|
||
| # Pretend sm120 as sm80 for now | ||
| return False | ||
| return _cuda_capability_geq(major, minor) | ||
|
|
||
|
|
||
| @triton.constexpr_function | ||
| def get_cdna_version(): | ||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we need a separate helper logic? I'm pretty sure we will enable TMA on hopper at some point so this will break.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but this one is a bit different since this is an ad-hoc check due to a kernel limitation rather than an architecture one. We could add something like
target_info.supports_tma(), but that needs to return False for Hopper today, which is a bit odd. So when the kernel supports TMA for Hopper in the future, we need to update the helper anyway.As a middle ground, how about something like this?
This way, when Hopper supports TMA, we can safely update it without breaking sm120. The condition
torch.cuda.get_device_capability()[0] >= 9might not be correct depending on how well sm120 TMA is supported by the kernel.