diff --git a/python/triton_kernels/bench/bench_mlp.py b/python/triton_kernels/bench/bench_mlp.py index a640e07d7125..8540bd2abfdc 100644 --- a/python/triton_kernels/bench/bench_mlp.py +++ b/python/triton_kernels/bench/bench_mlp.py @@ -8,7 +8,7 @@ import triton_kernels.roofline as roofline import triton_kernels.swiglu from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation -from triton_kernels.target_info import get_cdna_version +from triton_kernels.target_info import get_cdna_version, has_native_mxfp import distributed as triton_dist from triton_kernels.tensor_details import layout from bench_utils import quantize_weight @@ -111,12 +111,11 @@ def roofline_mlp(batch_sizes, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_d if __name__ == "__main__": - has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4 batch_sizes_dense = [*range(128, 8192, 128)] batch_ranges_moe = [(2**(2 + k), 2**(3 + k), min(2**k, 32)) for k in range(8)] batch_sizes_moe = list(chain(*[range(*r) for r in batch_ranges_moe])) dense_dtypes = ["fp8", "fp8"] - quantized_dtypes = ["fp8", "mx4"] if has_native_mx4 else ["bf16", "mx4"] + quantized_dtypes = ["fp8", "mx4"] if has_native_mxfp() else ["bf16", "mx4"] rank, world_size = triton_dist.setup() if world_size > 1: # Running all workloads at once may cause OOM on some GPUs such as H100 80GB. diff --git a/python/triton_kernels/bench/distributed.py b/python/triton_kernels/bench/distributed.py index 75099b2949a9..659af967b9c1 100644 --- a/python/triton_kernels/bench/distributed.py +++ b/python/triton_kernels/bench/distributed.py @@ -15,7 +15,7 @@ from triton_kernels.topk import topk_torch from triton_kernels.topk import topk from triton_kernels.matmul_ogs import matmul_ogs, PrecisionConfig, FlexCtx, FnSpecs, FusedActivation -from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq +from triton_kernels.target_info import get_cdna_version, is_hip, is_cuda, cuda_capability_geq, has_native_mxfp from triton_kernels.tensor_details import layout from triton_kernels.tensor import BIT, SparseMatrix, Bitmatrix, make_ragged_tensor_metadata @@ -716,9 +716,6 @@ def distributed(x): dist.destroy_process_group() -has_native_mx4 = torch.cuda.get_device_capability(0)[0] >= 10 or get_cdna_version() == 4 - - @pytest.mark.parametrize( "batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype, TP, EP", [ @@ -741,7 +738,7 @@ def distributed(x): (1024, 1024, 1024, 128, 2, "fp8", "mx4", 4, 1), (1024, 1024, 1024, 128, 2, "fp8", "mx4", 1, 4), (1024, 1024, 1024, 128, 2, "fp8", "mx4", 2, 2), - ] if has_native_mx4 else [ + ] if has_native_mxfp() else [ (1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 1), (1024, 1024, 1024, 128, 2, "bf16", "mx4", 4, 1), (1024, 1024, 1024, 128, 2, "bf16", "mx4", 1, 4), diff --git a/python/triton_kernels/tests/test_matmul.py b/python/triton_kernels/tests/test_matmul.py index ecced6626874..3827c6af2eed 100644 --- a/python/triton_kernels/tests/test_matmul.py +++ b/python/triton_kernels/tests/test_matmul.py @@ -75,7 +75,7 @@ def init_compute_data(m, n, k, rdata, gindx, sindx, n_expts_tot, n_expts_act, mo if mode == 'batched' or (not has_y_gammas) or (has_y_gammas and (gindx is not None) and act_dtype.itemsize >= 2): gs0 = None gs1 = None - if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] < 10: + if "float8" in str(weight_dtype) and torch.cuda.get_device_capability()[0] != 10: w = w.transpose(-1, -2).contiguous().transpose(-1, -2) def _apply_padding_and_fill_unused_part_with_nan(t, is_padded): @@ -325,8 +325,8 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o if act_dtype_str == "float16" and "mx" in weight_dtype_str and torch.cuda.get_device_capability()[0] >= 10: pytest.skip("float16 x mx not supported with cuda capability >= 10") if weight_dtype_str.startswith("mx"): - if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] < 10: - pytest.skip("float8 x mx not supported with cuda capability < 10") + if "float8" in act_dtype_str and torch.cuda.get_device_capability()[0] != 10: + pytest.skip("float8 x mx not supported with cuda capability != 10") if n == 2880 and k == 2880 and torch.cuda.get_device_capability()[0] < 9: pytest.skip("Not enough memory on A100") @@ -373,7 +373,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o torch.manual_seed(0) block_k = None - if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] < 10: + if is_persistent and weight_dtype_str.startswith("mx") and torch.cuda.get_device_capability()[0] != 10: # Override block_k for testing correctness. The default is temporarily 128 for # performance reasons which doesn't work with persistent matmul. # TODO: revisit when Triton is better for H100 + MXFP4 @@ -473,7 +473,7 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, inner_expt_o w_tri = convert_layout(w_tri, w_layout, **w_layout_opts) w_scale_tri = convert_layout(w_scale_tri, w_scale_layout, **w_scale_layout_opts) else: - if torch.cuda.get_device_capability()[0] < 10: + if torch.cuda.get_device_capability()[0] != 10: pytest.skip("transposed mxfp weight not supported with cuda capability < 10") if block_m == 16: pytest.skip("PassManager::run failed from Triton compiler") @@ -646,7 +646,7 @@ def _make_tensor(shape, dtype, trans): (torch.float16, torch.bfloat16, torch.float8_e5m2), ): if ( - torch.cuda.get_device_capability()[0] < 10 + torch.cuda.get_device_capability()[0] != 10 and dtype is torch.float8_e5m2 and (not w_transpose) ): diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index 685b44323343..6185212d264b 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -489,8 +489,8 @@ def matmul_ogs(x, w, bias, # TODO: remove this code path; using uint8 for mxfp4 weight will bite us when we want to support uint8 for real dtype = FP4 if w.dtype == torch.uint8 else w.dtype w = wrap_torch_tensor(w, dtype=dtype) - if w_has_mx and (torch.cuda.get_device_capability()[0] < 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)): - assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not on >=Blackwell)" + if w_has_mx and (torch.cuda.get_device_capability()[0] != 10 or w.storage.layout is not None and not isinstance(w.storage.layout, StridedLayout)): + assert w.stride(-2) == 1, "`w` must be column-major when it has data-type mxfp and (swizzled or not sm100)" if w_scale is not None and not isinstance(w_scale, Tensor): w_scale = Tensor(w_scale) if w_scale is not None: @@ -536,8 +536,8 @@ def matmul_ogs(x, w, bias, (inner_routing_data is None or w.stride(-1) == 1 or inner_routing_data.w_is_padded) ) has_gather_tma = has_gather and target_info.has_tma_gather() - # hopper w/ mxfp4 doesn't support TMA - can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4) + # hopper or sm120 w/ mxfp4 doesn't support TMA + can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] == 10 or bitwidth(w.dtype) != 4) can_use_fused_scatter = has_scatter and (fused_activation.specs.fn is None) and (epilogue.specs.fn is None) and (routing_data.n_expts_act == 1) opt_flags = make_opt_flags(out_dtype, x.dtype, w.dtype, precision_config, batch_size, M, N, w.shape[-2], routing_data, diff --git a/python/triton_kernels/triton_kernels/target_info.py b/python/triton_kernels/triton_kernels/target_info.py index 4350efa8b78a..8c856b8cde98 100644 --- a/python/triton_kernels/triton_kernels/target_info.py +++ b/python/triton_kernels/triton_kernels/target_info.py @@ -3,13 +3,15 @@ import triton.language as tl from triton.language.target_info import ( - cuda_capability_geq, is_cuda, is_hip, is_hip_cdna3, is_hip_cdna4, + current_target, ) +from triton.language.target_info import cuda_capability_geq as _cuda_capability_geq + __all__ = [ "cuda_capability_geq", "get_cdna_version", @@ -23,6 +25,15 @@ ] +@triton.constexpr_function +def cuda_capability_geq(major, minor=0): + target = current_target() + if target.arch // 10 == 12 and major > 8: + # Pretend sm120 as sm80 for now + return False + return _cuda_capability_geq(major, minor) + + @triton.constexpr_function def get_cdna_version(): """