|
17 | 17 | is_hopper_or_newer, |
18 | 18 | is_hopper, |
19 | 19 | ) |
| 20 | +from triton.tools.mxfp import MXFP4Tensor, MXScaleTensor |
20 | 21 | from triton.experimental import gluon |
21 | 22 | from triton.experimental.gluon import language as ttgl |
22 | 23 | from triton.experimental.gluon.language.nvidia.ampere import async_copy, mma_v2 |
@@ -629,145 +630,108 @@ def kernel(a_ptr, b_ptr, c_ptr, # |
629 | 630 |
|
630 | 631 |
|
631 | 632 | @pytest.mark.skipif(not is_hip_cdna4(), reason="Requires CDNA4") |
632 | | -@pytest.mark.parametrize("M, N, K, rhs_scale, mxfp_type, normal_type", [(32, 32, 128, rhs_scale, mxfp_type, normal_type) |
633 | | - for rhs_scale in [True, False] |
634 | | - for mxfp_type in ["e2m1"] |
635 | | - for normal_type in ["e4m3", "e5m2"]]) |
636 | | -def test_amd_mfma_scaled(M, N, K, rhs_scale, mxfp_type, normal_type): |
637 | | - device = 'cuda' |
638 | | - |
639 | | - @triton.jit |
640 | | - def triton_kernel(a_base, stride_am, stride_ak, a_scale, # |
641 | | - b_base, stride_bk, stride_bn, b_scale, # |
642 | | - out, # |
643 | | - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # |
644 | | - type_a: tl.constexpr, type_b: tl.constexpr): |
645 | | - DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1 |
646 | | - DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1 |
647 | | - PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A |
648 | | - PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B |
649 | | - a_ptr = a_base + tl.arange(0, BLOCK_M)[:, None] * stride_am + \ |
650 | | - tl.arange(0, PACKED_BLOCK_K_A)[None, :] * stride_ak |
651 | | - b_ptr = b_base + tl.arange(0, PACKED_BLOCK_K_B)[:, None] * stride_bk + \ |
652 | | - tl.arange(0, BLOCK_N)[None, :] * stride_bn |
653 | | - |
654 | | - a = tl.load(a_ptr) |
655 | | - b = tl.load(b_ptr) |
656 | | - SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 |
657 | | - if a_scale is not None: |
658 | | - scale_a_ptr = a_scale + tl.arange(0, BLOCK_M)[:, None] * SCALE_BLOCK_K + tl.arange(0, |
659 | | - SCALE_BLOCK_K)[None, :] |
660 | | - a_scale = tl.load(scale_a_ptr) |
661 | | - if b_scale is not None: |
662 | | - scale_b_ptr = b_scale + tl.arange(0, BLOCK_N)[:, None] * SCALE_BLOCK_K + tl.arange(0, |
663 | | - SCALE_BLOCK_K)[None, :] |
664 | | - b_scale = tl.load(scale_b_ptr) |
665 | | - c = tl.dot_scaled(a, a_scale, type_a, b, b_scale, type_b) |
666 | | - out_ptr = out + tl.arange(0, BLOCK_M)[:, None] * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] |
667 | | - tl.store(out_ptr, c.to(tl.bfloat16)) |
| 633 | +@pytest.mark.parametrize("M, N, K", [(32, 32, 128)]) |
| 634 | +@pytest.mark.parametrize("a_type, b_type", [(a_type, b_type) |
| 635 | + for a_type in ["e2m1", "e4m3", "e5m2"] |
| 636 | + for b_type in ["e2m1", "e4m3", "e5m2"]]) |
| 637 | +@pytest.mark.parametrize("has_scale", [True, False]) |
| 638 | +def test_amd_mfma_scaled(M, N, K, a_type, b_type, has_scale, device='cuda'): |
668 | 639 |
|
669 | 640 | @gluon.jit |
670 | | - def gluon_kernel(a_base, stride_am, stride_ak, a_scale, # |
671 | | - b_base, stride_bk, stride_bn, b_scale, # |
672 | | - out, # |
673 | | - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # |
674 | | - type_a: tl.constexpr, type_b: tl.constexpr): |
675 | | - DIV_FACTOR_A: tl.constexpr = 2 if type_a == "e2m1" else 1 |
676 | | - DIV_FACTOR_B: tl.constexpr = 2 if type_b == "e2m1" else 1 |
677 | | - PACKED_BLOCK_K_A: tl.constexpr = BLOCK_K // DIV_FACTOR_A |
678 | | - PACKED_BLOCK_K_B: tl.constexpr = BLOCK_K // DIV_FACTOR_B |
679 | | - SCALE_BLOCK_K: tl.constexpr = BLOCK_K // 32 |
| 641 | + def kernel(out_ptr, a_ptr, b_ptr, a_scale_ptr, b_scale_ptr, # |
| 642 | + M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr, # |
| 643 | + a_type: tl.constexpr, b_type: tl.constexpr): |
| 644 | + DIV_FACTOR_A: tl.constexpr = 2 if a_type == "e2m1" else 1 |
| 645 | + DIV_FACTOR_B: tl.constexpr = 2 if b_type == "e2m1" else 1 |
| 646 | + K_A: tl.constexpr = K // DIV_FACTOR_A |
| 647 | + K_B: tl.constexpr = K // DIV_FACTOR_B |
| 648 | + |
| 649 | + mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=4, instr_shape=[16, 16, 128], transposed=True, |
| 650 | + warps_per_cta=[2, 2]) |
680 | 651 |
|
681 | 652 | a_unpacked_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 16], [8, 8], [4, 1], [1, 0]) |
682 | 653 | a_packed_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [8, 8], [4, 1], [1, 0]) |
683 | | - a_layout: ttgl.constexpr = a_packed_layout if type_a == "e2m1" else a_unpacked_layout |
684 | | - |
685 | | - a_scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout( |
686 | | - reg_bases=[], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp_bases=[[0, 0], [16, 0]], |
687 | | - block_bases=[], shape=[32, 4]) |
| 654 | + a_load_layout: ttgl.constexpr = a_packed_layout if a_type == "e2m1" else a_unpacked_layout |
| 655 | + a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=16) |
| 656 | + a_scale_layout: ttgl.constexpr = ttgl.amd.cdna4.get_mfma_scale_layout(a_layout, [M, K // 32]) |
688 | 657 |
|
689 | 658 | b_unpacked_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 16], [32, 2], [4, 1], [1, 0]) |
690 | 659 | b_packed_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0]) |
691 | | - b_layout: ttgl.constexpr = b_packed_layout if type_b == "e2m1" else b_unpacked_layout |
692 | | - |
693 | | - b_scale_layout: ttgl.constexpr = ttgl.DistributedLinearLayout( |
694 | | - reg_bases=[], lane_bases=[[1, 0], [2, 0], [4, 0], [8, 0], [0, 1], [0, 2]], warp_bases=[[16, 0], [0, 0]], |
695 | | - block_bases=[], shape=[32, 4]) |
696 | | - |
697 | | - mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=4, instr_shape=[16, 16, 128], transposed=True, |
698 | | - warps_per_cta=[2, 2]) |
699 | | - |
700 | | - zero = ttgl.zeros([BLOCK_M, BLOCK_N], dtype=ttgl.float32, layout=mfma_layout) |
701 | | - |
702 | | - a_offsets = ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, a_layout))[:, None] * stride_am + \ |
703 | | - ttgl.arange(0, PACKED_BLOCK_K_A, layout=ttgl.SliceLayout(0, a_layout))[None, :] * stride_ak |
704 | | - a = ttgl.amd.cdna4.buffer_load(a_base, a_offsets) |
705 | | - a = ttgl.convert_layout(a, ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=16)) |
706 | | - |
707 | | - b_offsets = ttgl.arange(0, PACKED_BLOCK_K_B, layout=ttgl.SliceLayout(1, b_layout))[:, None] * stride_bk + \ |
708 | | - ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, b_layout))[None, :] * stride_bn |
709 | | - b = ttgl.amd.cdna4.buffer_load(b_base, b_offsets) |
710 | | - b = ttgl.convert_layout(b, ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=16)) |
711 | | - |
712 | | - if a_scale is not None: |
713 | | - a_scale_offsets = ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, a_scale_layout))[:, None] * SCALE_BLOCK_K + \ |
714 | | - ttgl.arange(0, SCALE_BLOCK_K, layout=ttgl.SliceLayout(0, a_scale_layout))[None, :] |
715 | | - a_scale = ttgl.amd.cdna4.buffer_load(a_scale, a_scale_offsets) |
| 660 | + b_load_layout: ttgl.constexpr = b_packed_layout if b_type == "e2m1" else b_unpacked_layout |
| 661 | + b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=16) |
| 662 | + b_scale_layout: ttgl.constexpr = ttgl.amd.cdna4.get_mfma_scale_layout(b_layout, [N, K // 32]) |
| 663 | + |
| 664 | + a_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, a_load_layout))[:, None] |
| 665 | + a_offs_k = ttgl.arange(0, K_A, layout=ttgl.SliceLayout(0, a_load_layout))[None, :] |
| 666 | + a = ttgl.amd.cdna4.buffer_load(a_ptr, a_offs_m * K_A + a_offs_k) |
| 667 | + a = ttgl.convert_layout(a, a_layout) |
| 668 | + |
| 669 | + b_offs_k = ttgl.arange(0, K_B, layout=ttgl.SliceLayout(1, b_load_layout))[:, None] |
| 670 | + b_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, b_load_layout))[None, :] |
| 671 | + b = ttgl.amd.cdna4.buffer_load(b_ptr, b_offs_k * N + b_offs_n) |
| 672 | + b = ttgl.convert_layout(b, b_layout) |
| 673 | + |
| 674 | + a_scale = None |
| 675 | + if a_scale_ptr is not None: |
| 676 | + a_scale_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, a_scale_layout))[:, None] |
| 677 | + a_scale_offs_k = ttgl.arange(0, K // 32, layout=ttgl.SliceLayout(0, a_scale_layout))[None, :] |
| 678 | + a_scale = ttgl.amd.cdna4.buffer_load(a_scale_ptr, a_scale_offs_m * (K // 32) + a_scale_offs_k) |
| 679 | + |
| 680 | + b_scale = None |
| 681 | + if b_scale_ptr is not None: |
| 682 | + b_scale_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(1, b_scale_layout))[:, None] |
| 683 | + b_scale_offs_k = ttgl.arange(0, K // 32, layout=ttgl.SliceLayout(0, b_scale_layout))[None, :] |
| 684 | + b_scale = ttgl.amd.cdna4.buffer_load(b_scale_ptr, b_scale_offs_n * (K // 32) + b_scale_offs_k) |
| 685 | + |
| 686 | + zero = ttgl.zeros([M, N], dtype=ttgl.float32, layout=mfma_layout) |
| 687 | + c = ttgl.amd.cdna4.mfma_scaled(a, a_scale, a_type, b, b_scale, b_type, zero) |
| 688 | + c = c.to(out_ptr.dtype.element_ty) |
| 689 | + |
| 690 | + out_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, mfma_layout))[:, None] |
| 691 | + out_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, mfma_layout))[None, :] |
| 692 | + ttgl.amd.cdna4.buffer_store(c, out_ptr, out_offs_m * N + out_offs_n) |
| 693 | + |
| 694 | + def _create_mxfp_operand(operand: int, m: int, n: int, dtype: str): |
| 695 | + size = (m, n) |
| 696 | + if dtype == 'e4m3': |
| 697 | + v = torch.randint(20, 40, size, dtype=torch.uint8) |
| 698 | + v_ref = v.view(torch.float8_e4m3fn).to(torch.float32) |
| 699 | + elif dtype == 'e5m2': |
| 700 | + v = torch.randint(20, 40, size, dtype=torch.uint8) |
| 701 | + v_ref = v.view(torch.float8_e5m2).to(torch.float32) |
716 | 702 | else: |
717 | | - a_scale = ttgl.full([BLOCK_M, SCALE_BLOCK_K], 127, dtype=ttgl.int8, layout=a_scale_layout) |
718 | | - |
719 | | - if b_scale is not None: |
720 | | - b_scale_offsets = ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(1, b_scale_layout))[:, None] * SCALE_BLOCK_K + \ |
721 | | - ttgl.arange(0, SCALE_BLOCK_K, layout=ttgl.SliceLayout(0, b_scale_layout))[None, :] |
722 | | - b_scale = ttgl.amd.cdna4.buffer_load(b_scale, b_scale_offsets) |
723 | | - else: |
724 | | - b_scale = ttgl.full([BLOCK_M, SCALE_BLOCK_K], 127, dtype=ttgl.int8, layout=b_scale_layout) |
725 | | - |
726 | | - c = ttgl.amd.cdna4.mfma_scaled(a, a_scale, type_a, b, b_scale, type_b, zero) |
727 | | - c = c.to(out.dtype.element_ty) |
728 | | - |
729 | | - out_offsets = ttgl.arange(0, BLOCK_M, layout=ttgl.SliceLayout(1, mfma_layout))[:, None] * BLOCK_N + \ |
730 | | - ttgl.arange(0, BLOCK_N, layout=ttgl.SliceLayout(0, mfma_layout))[None, :] |
731 | | - ttgl.amd.cdna4.buffer_store(c, out, out_offsets) |
| 703 | + assert dtype == 'e2m1' |
| 704 | + pack_dim = 1 if operand == 0 else 0 |
| 705 | + v_mxfp4 = MXFP4Tensor(size=size).random() |
| 706 | + v = v_mxfp4.to_packed_tensor(pack_dim) |
| 707 | + v_ref = v_mxfp4.to(torch.float32) |
| 708 | + return v.to(device), v_ref.to(device) |
| 709 | + |
| 710 | + def _create_mxfp_scale(operand: int, m: int, n: int): |
| 711 | + size = (m, n // 32) |
| 712 | + scale = MXScaleTensor(size=tuple(size)).random(1 / 32, 32) |
| 713 | + scale_ref = scale.to(torch.float32).repeat_interleave(32, dim=1) |
| 714 | + scale_ref = scale_ref.T.contiguous() if operand == 1 else scale_ref |
| 715 | + return scale.data.to(device), scale_ref.to(device) |
732 | 716 |
|
733 | 717 | torch.manual_seed(0) |
734 | | - |
735 | | - type_a = normal_type if rhs_scale else mxfp_type |
736 | | - type_b = mxfp_type if rhs_scale else normal_type |
737 | | - |
738 | | - DIV_FACTOR_A = 2 if type_a == "e2m1" else 1 |
739 | | - DIV_FACTOR_B = 2 if type_b == "e2m1" else 1 |
740 | | - x = torch.randint(20, 40, (M, K // DIV_FACTOR_A), dtype=torch.uint8, device=device) |
741 | | - y = torch.randint(20, 40, (K // DIV_FACTOR_B, N), dtype=torch.uint8, device=device) |
742 | | - |
743 | | - min_scale, max_scale = (0, 142) |
744 | | - scale_x = torch.randint(min_scale, max_scale + 1, (M, K // 32), dtype=torch.uint8, device=device) |
745 | | - scale_y = torch.randint(min_scale, max_scale + 1, (N, K // 32), dtype=torch.uint8, device=device) |
746 | | - if rhs_scale: |
747 | | - scale_x = None |
| 718 | + a, a_ref = _create_mxfp_operand(0, M, K, a_type) |
| 719 | + b, b_ref = _create_mxfp_operand(1, K, N, b_type) |
| 720 | + |
| 721 | + if has_scale: |
| 722 | + a_scale, a_scale_ref = _create_mxfp_scale(0, M, K) |
| 723 | + b_scale, b_scale_ref = _create_mxfp_scale(1, N, K) |
| 724 | + out = torch.empty((M, N), dtype=torch.float32, device=device) |
| 725 | + compiled = kernel[(1, )](out, a, b, a_scale, b_scale, M, N, K, a_type, b_type, num_warps=4) |
| 726 | + out_ref = torch.matmul(a_ref * a_scale_ref, b_ref * b_scale_ref) |
| 727 | + torch.testing.assert_close(out, out_ref) |
748 | 728 | else: |
749 | | - scale_y = None |
750 | | - |
751 | | - def make_finite(x, dtype): |
752 | | - if dtype not in ("e5m2", "e4m3"): |
753 | | - return x |
754 | | - mask = 0x7C if dtype == "e5m2" else 0x7F |
755 | | - finite = torch.arange(x.numel(), device=device, dtype=torch.uint8).reshape_as(x) % mask |
756 | | - x_finite = torch.where(x & mask == mask, finite | (0x80 & x), x) |
757 | | - x.copy_(x_finite) |
758 | | - return x |
759 | | - |
760 | | - x = make_finite(x, type_a) |
761 | | - y = make_finite(y, type_b) |
762 | | - |
763 | | - z = torch.zeros((M, N), dtype=torch.bfloat16, device=device) |
764 | | - pgm = gluon_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z, M, N, K, type_a, type_b) |
765 | | - assert "v_mfma_scale_f32_16x16x128_f8f6f4" in pgm.asm["amdgcn"] |
766 | | - |
767 | | - z_ref = torch.zeros((M, N), dtype=torch.bfloat16, device=device) |
768 | | - triton_kernel[(1, )](x, *x.stride(), scale_x, y, *y.stride(), scale_y, z_ref, M, N, K, type_a, type_b) |
| 729 | + out = torch.empty((M, N), dtype=torch.float32, device=device) |
| 730 | + compiled = kernel[(1, )](out, a, b, None, None, M, N, K, a_type, b_type, num_warps=4) |
| 731 | + out_ref = torch.matmul(a_ref, b_ref) |
| 732 | + torch.testing.assert_close(out, out_ref) |
769 | 733 |
|
770 | | - torch.testing.assert_close(z, z_ref, rtol=1e-5, atol=1e-5) |
| 734 | + assert 'v_mfma_scale_f32_16x16x128_f8f6f4' in compiled.asm['amdgcn'] |
771 | 735 |
|
772 | 736 |
|
773 | 737 | def test_math_fast_expf(): |
|
0 commit comments