|
6 | 6 | import triton.tools.experimental_descriptor |
7 | 7 | from test_mxfp import MXFP4Tensor, MXScaleTensor |
8 | 8 | import re |
9 | | -from triton._internal_testing import is_cuda, is_hip, is_hip_mi200 |
| 9 | +from triton._internal_testing import is_cuda, is_hip, is_hip_mi200, is_hip_mi350, is_hip_cdna |
10 | 10 |
|
11 | 11 |
|
12 | 12 | def f8_to_f16(x, dtype): |
@@ -711,8 +711,18 @@ def block_scale_fp4_matmul( # |
711 | 711 | (128, 256, 256), (128, 128, 64), (128, 64, 128)]) |
712 | 712 | @pytest.mark.parametrize(("scale_type", "VEC_SIZE"), [("float8_e8m0fnu", 32), ("float8_e4m3fn", 16)], |
713 | 713 | ids=["mxfp4", "nvfp4"]) |
714 | | -@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10") |
715 | | -def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, scale_type, device): |
| 714 | +@pytest.mark.parametrize("nonKDim", ([0, 16, 32] if is_hip_cdna() else [])) |
| 715 | +def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, scale_type, nonKDim, device): |
| 716 | + if is_cuda() and torch.cuda.get_device_capability()[0] < 10: |
| 717 | + pytest.skip("Requires compute capability >= 10") |
| 718 | + elif is_hip(): |
| 719 | + if not is_hip_mi350(): |
| 720 | + pytest.skip("Scaled fp4 matmul is only natively supported on MI350") |
| 721 | + if scale_type != 'float8_e8m0fnu': |
| 722 | + pytest.skip("MI350 only supports E8M0 scale") |
| 723 | + if (nonKDim == 16 and BLOCK_K < 128) or (nonKDim == 32 and BLOCK_K < 64): |
| 724 | + pytest.skip(f"MI350 does not support {BLOCK_K=} for scaled mfma {nonKDim=} variants") |
| 725 | + |
716 | 726 | NUM_STAGES = 1 |
717 | 727 | torch.manual_seed(42) |
718 | 728 | a_mxfp4 = MXFP4Tensor(size=(M, K), device=device).random() |
@@ -744,9 +754,12 @@ def test_block_scale_fp4(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, VEC_SIZE, scale_typ |
744 | 754 |
|
745 | 755 | output = a.new_empty((M, N), dtype=torch.float32) |
746 | 756 | grid = (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1) |
| 757 | + kernel_kwargs = {} |
| 758 | + if is_hip(): |
| 759 | + kernel_kwargs["matrix_instr_nonkdim"] = nonKDim |
747 | 760 | block_scale_fp4_matmul[grid](a, b, output, a_scale, b_scale, M, N, K, a_scale.stride(0), a.stride(0), a.stride(1), |
748 | 761 | b.stride(0), b.stride(1), output.stride(0), output.stride(1), VEC_SIZE, BLOCK_M, |
749 | | - BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES) |
| 762 | + BLOCK_N, BLOCK_K, NUM_STAGES=NUM_STAGES, **kernel_kwargs) |
750 | 763 |
|
751 | 764 | torch.testing.assert_close(ref_out, output, atol=1e-2, rtol=1e-2) |
752 | 765 |
|
|
0 commit comments