|
11 | 11 | from packaging import version
|
12 | 12 |
|
13 | 13 | from vllm.platforms import current_platform
|
| 14 | +from vllm.utils.flashinfer import has_flashinfer |
14 | 15 |
|
15 | 16 | QUARK_MXFP4_AVAILABLE = importlib.util.find_spec(
|
16 | 17 | "quark") is not None and version.parse(
|
|
19 | 20 | TRTLLM_GEN_MXFP4_AVAILABLE = current_platform.is_cuda(
|
20 | 21 | ) and current_platform.is_device_capability(100)
|
21 | 22 |
|
| 23 | +HOPPER_MXFP4_BF16_AVAILABLE = (current_platform.is_cuda() |
| 24 | + and current_platform.is_device_capability(90) |
| 25 | + and has_flashinfer()) |
| 26 | + |
22 | 27 | if TRTLLM_GEN_MXFP4_AVAILABLE:
|
23 | 28 | from flashinfer import (fp4_quantize, mxfp8_quantize,
|
24 | 29 | next_positive_power_of_2,
|
@@ -542,3 +547,317 @@ def test_trtllm_gen_mxfp4_fused_moe(
|
542 | 547 | transpose_optimized=transpose_optimized)
|
543 | 548 | # relatively loose check since the mxfp4 quantization is less accurate
|
544 | 549 | check_accuracy(ref_result, tg_result, atol=0, rtol=0.3, percent=0.8)
|
| 550 | + |
| 551 | + |
| 552 | +def _interleave_scales_lastdim_by4(scales: torch.Tensor) -> torch.Tensor: |
| 553 | + """Interleave scales on the last dimension by groups of 4, matching |
| 554 | + the transformation in mxfp4.py's BF16 (Hopper) path.""" |
| 555 | + s = scales.to(torch.uint8) |
| 556 | + s_shape = s.shape |
| 557 | + assert s_shape[-1] % 4 == 0 |
| 558 | + s = s.reshape(*s_shape[:-1], s_shape[-1] // 4, 4) |
| 559 | + # Move the 4-group dimension before the row dimension |
| 560 | + permuted = s.permute(0, 2, 1, 3) |
| 561 | + # Merge the row dim with the 4-group dim |
| 562 | + return permuted.reshape(s_shape[0], s_shape[-1] // 4, s_shape[1] * 4) |
| 563 | + |
| 564 | + |
| 565 | +@pytest.mark.parametrize("topk", [1, 4]) |
| 566 | +@pytest.mark.parametrize("num_experts", [32]) |
| 567 | +@pytest.mark.parametrize("num_tokens", [1, 128]) |
| 568 | +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) |
| 569 | +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), |
| 570 | + (1.702, 1.0, 7.0)]) |
| 571 | +@pytest.mark.skipif( |
| 572 | + not HOPPER_MXFP4_BF16_AVAILABLE, |
| 573 | + reason="nvidia gpu sm90 and flashinfer are required for this test", |
| 574 | +) |
| 575 | +def test_flashinfer_cutlass_mxfp4_fused_moe( |
| 576 | + topk: int, |
| 577 | + num_experts: int, |
| 578 | + num_tokens: int, |
| 579 | + intermediate_size: int, |
| 580 | + hidden_size: int, |
| 581 | + alpha: float, |
| 582 | + beta: float, |
| 583 | + limit: Optional[float], |
| 584 | +): |
| 585 | + torch.manual_seed(42) |
| 586 | + device = "cuda:0" |
| 587 | + |
| 588 | + # Inputs |
| 589 | + hidden_states = torch.randn(num_tokens, |
| 590 | + hidden_size, |
| 591 | + device=device, |
| 592 | + dtype=torch.bfloat16) |
| 593 | + # Random MXFP4 weights and scales (uint8), contiguous [w1; w3] |
| 594 | + w13_q = torch.randint( |
| 595 | + 0, |
| 596 | + 256, (num_experts, 2 * intermediate_size, hidden_size // 2), |
| 597 | + device=device, |
| 598 | + dtype=torch.uint8) |
| 599 | + w13_scale = torch.randint( |
| 600 | + 118, |
| 601 | + 123, (num_experts, 2 * intermediate_size, hidden_size // 32), |
| 602 | + device=device, |
| 603 | + dtype=torch.uint8) |
| 604 | + |
| 605 | + w2_q = torch.randint(0, |
| 606 | + 256, |
| 607 | + (num_experts, hidden_size, intermediate_size // 2), |
| 608 | + device=device, |
| 609 | + dtype=torch.uint8) |
| 610 | + w2_scale = torch.randint( |
| 611 | + 118, |
| 612 | + 123, (num_experts, hidden_size, intermediate_size // 32), |
| 613 | + device=device, |
| 614 | + dtype=torch.uint8) |
| 615 | + # Bias contiguous [b1; b3] |
| 616 | + bias13 = (torch.randn(num_experts, |
| 617 | + 2 * intermediate_size, |
| 618 | + device=device, |
| 619 | + dtype=torch.bfloat16) * 10) |
| 620 | + bias2 = (torch.randn( |
| 621 | + num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) |
| 622 | + router_logits = torch.rand(num_tokens, |
| 623 | + num_experts, |
| 624 | + dtype=torch.float32, |
| 625 | + device=device) |
| 626 | + |
| 627 | + w13_ref = mxfp4_dequantize(w13_q.clone(), w13_scale.clone()).reshape( |
| 628 | + num_experts, 2 * intermediate_size, hidden_size) |
| 629 | + w2_ref = mxfp4_dequantize(w2_q.clone(), w2_scale.clone()).reshape( |
| 630 | + num_experts, hidden_size, intermediate_size) |
| 631 | + ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, |
| 632 | + hidden_states.to(torch.float32), w13_ref, |
| 633 | + bias13.to(torch.float32), w2_ref, |
| 634 | + bias2.to(torch.float32), alpha, beta, limit, 'bf16') |
| 635 | + |
| 636 | + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe |
| 637 | + |
| 638 | + # Swap halves to arrange as [w3; w1] (kernel expectation) |
| 639 | + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) |
| 640 | + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) |
| 641 | + |
| 642 | + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) |
| 643 | + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) |
| 644 | + |
| 645 | + w1_s, w3_s = torch.chunk(w13_scale, 2, dim=1) |
| 646 | + w13_s = torch.cat([w3_s, w1_s], dim=1) |
| 647 | + w13_s_inter = _interleave_scales_lastdim_by4(w13_s) |
| 648 | + w2_s_inter = _interleave_scales_lastdim_by4(w2_scale) |
| 649 | + |
| 650 | + routing_weights = torch.nn.functional.softmax(router_logits, |
| 651 | + dim=1, |
| 652 | + dtype=torch.float32) |
| 653 | + token_final_scales, token_selected_experts = torch.topk(routing_weights, |
| 654 | + topk, |
| 655 | + dim=-1) |
| 656 | + token_final_scales = (token_final_scales / |
| 657 | + token_final_scales.sum(dim=-1, keepdim=True)) |
| 658 | + token_selected_experts = token_selected_experts.to(torch.int).contiguous() |
| 659 | + |
| 660 | + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) |
| 661 | + if alpha is not None: |
| 662 | + alpha = torch.full((num_experts, ), alpha, device=hidden_states.device) |
| 663 | + if beta is not None: |
| 664 | + beta = torch.full((num_experts, ), beta, device=hidden_states.device) |
| 665 | + if limit is not None: |
| 666 | + limit = torch.full((num_experts, ), limit, device=hidden_states.device) |
| 667 | + |
| 668 | + _ = flashinfer_cutlass_fused_moe( |
| 669 | + input=hidden_states, |
| 670 | + token_selected_experts=token_selected_experts, |
| 671 | + token_final_scales=token_final_scales, |
| 672 | + fc1_expert_weights=w13_q_swapped, |
| 673 | + fc2_expert_weights=w2_q, |
| 674 | + output_dtype=torch.bfloat16, |
| 675 | + output=out, |
| 676 | + quant_scales=[w13_s_inter.to(torch.uint8), |
| 677 | + w2_s_inter.to(torch.uint8)], |
| 678 | + fc1_expert_biases=w13_b, |
| 679 | + fc2_expert_biases=bias2.to(torch.bfloat16), |
| 680 | + swiglu_alpha=alpha, |
| 681 | + swiglu_beta=beta, |
| 682 | + swiglu_limit=limit, |
| 683 | + tp_size=1, |
| 684 | + tp_rank=0, |
| 685 | + ep_size=1, |
| 686 | + ep_rank=0, |
| 687 | + use_w4_group_scaling=True, |
| 688 | + ) |
| 689 | + |
| 690 | + # Allow some mismatch due to MXFP4 quantization |
| 691 | + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) |
| 692 | + |
| 693 | + |
| 694 | +@pytest.mark.parametrize("topk", [1, 4]) |
| 695 | +@pytest.mark.parametrize("num_experts", [32]) |
| 696 | +@pytest.mark.parametrize("num_tokens", [1, 128]) |
| 697 | +@pytest.mark.parametrize("intermediate_size,hidden_size", [(3072, 3072)]) |
| 698 | +@pytest.mark.parametrize("alpha,beta,limit", [(1.0, 1.0, None), |
| 699 | + (1.702, 1.0, 7.0)]) |
| 700 | +@pytest.mark.skipif( |
| 701 | + not (current_platform.is_cuda() |
| 702 | + and current_platform.is_device_capability(100) and has_flashinfer()), |
| 703 | + reason="NVIDIA GPU sm100 and flashinfer are required for this test", |
| 704 | +) |
| 705 | +def test_flashinfer_cutlass_mxfp4_mxfp8_fused_moe( |
| 706 | + topk: int, |
| 707 | + num_experts: int, |
| 708 | + num_tokens: int, |
| 709 | + intermediate_size: int, |
| 710 | + hidden_size: int, |
| 711 | + alpha: Optional[float], |
| 712 | + beta: Optional[float], |
| 713 | + limit: Optional[float], |
| 714 | +): |
| 715 | + torch.manual_seed(42) |
| 716 | + device = "cuda:0" |
| 717 | + |
| 718 | + # Inputs |
| 719 | + hidden_states = torch.randn(num_tokens, |
| 720 | + hidden_size, |
| 721 | + device=device, |
| 722 | + dtype=torch.bfloat16) |
| 723 | + # Float weights in w13 format [w1; w3] |
| 724 | + w13 = (torch.randn(num_experts, |
| 725 | + 2 * intermediate_size, |
| 726 | + hidden_size, |
| 727 | + device=device, |
| 728 | + dtype=torch.bfloat16) / 10) |
| 729 | + w2 = (torch.randn(num_experts, |
| 730 | + hidden_size, |
| 731 | + intermediate_size, |
| 732 | + device=device, |
| 733 | + dtype=torch.bfloat16) / 10) |
| 734 | + # Bias contiguous [b1; b3] |
| 735 | + bias13 = (torch.randn(num_experts, |
| 736 | + 2 * intermediate_size, |
| 737 | + device=device, |
| 738 | + dtype=torch.bfloat16) * 10) |
| 739 | + bias2 = (torch.randn( |
| 740 | + num_experts, hidden_size, device=device, dtype=torch.bfloat16) * 10) |
| 741 | + router_logits = torch.rand(num_tokens, |
| 742 | + num_experts, |
| 743 | + dtype=torch.float32, |
| 744 | + device=device) |
| 745 | + |
| 746 | + # Quantize weights to MXFP4 per expert (SM100 path) |
| 747 | + from flashinfer import mxfp4_quantize |
| 748 | + |
| 749 | + def quant_mxfp4_batches(a: torch.Tensor, e: int): |
| 750 | + qs, sfs = [], [] |
| 751 | + for i in range(e): |
| 752 | + q, sf = mxfp4_quantize(a[i].cuda()) |
| 753 | + qs.append(q) |
| 754 | + sfs.append(sf) |
| 755 | + return torch.stack(qs), torch.stack(sfs) |
| 756 | + |
| 757 | + def dequant_mxfp4_batches(mat_fp4: torch.Tensor, |
| 758 | + scale_tensor: torch.Tensor): |
| 759 | + num_batches = mat_fp4.size(0) |
| 760 | + scale_tensor = scale_tensor.view(num_batches, -1) |
| 761 | + from flashinfer import mxfp4_dequantize |
| 762 | + return torch.stack([ |
| 763 | + mxfp4_dequantize(mat_fp4[b, :, :], scale_tensor[b, :]) |
| 764 | + for b in range(num_batches) |
| 765 | + ]) |
| 766 | + |
| 767 | + w13_q, w13_scale = quant_mxfp4_batches(w13, num_experts) |
| 768 | + w2_q, w2_scale = quant_mxfp4_batches(w2, num_experts) |
| 769 | + |
| 770 | + # Reference result using dequantized tensors and reference_moe |
| 771 | + w13_ref = dequant_mxfp4_batches( |
| 772 | + w13_q.view(torch.uint8), |
| 773 | + w13_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( |
| 774 | + num_experts, 2 * intermediate_size, hidden_size) |
| 775 | + w2_ref = dequant_mxfp4_batches( |
| 776 | + w2_q.view(torch.uint8), |
| 777 | + w2_scale.view(torch.uint8).reshape(-1)).to(torch.float32).reshape( |
| 778 | + num_experts, hidden_size, intermediate_size) |
| 779 | + |
| 780 | + # Quantize activations for SM100 path and dequantize for reference |
| 781 | + hidden_states_q, hidden_states_sf = mxfp8_quantize(hidden_states, True, 32) |
| 782 | + # Reference uses BF16 input but quantizes intermediate activation to MXFP8 |
| 783 | + ref = reference_moe(router_logits.to(torch.float32), topk, num_experts, |
| 784 | + hidden_states.to(torch.float32), w13_ref, |
| 785 | + bias13.to(torch.float32), w2_ref, |
| 786 | + bias2.to(torch.float32), alpha, beta, limit, 'mxfp8') |
| 787 | + |
| 788 | + # Prepare inputs for FlashInfer CUTLASS fused MoE |
| 789 | + from vllm.utils.flashinfer import flashinfer_cutlass_fused_moe |
| 790 | + |
| 791 | + # Swap halves to arrange as [w3; w1] (kernel expectation) |
| 792 | + w1_w, w3_w = torch.chunk(w13_q, 2, dim=1) |
| 793 | + w13_q_swapped = torch.cat([w3_w, w1_w], dim=1) |
| 794 | + |
| 795 | + # Swap scales halves to match swapped weights |
| 796 | + s1, s3 = torch.chunk(w13_scale, 2, dim=1) |
| 797 | + w13_scale_swapped = torch.cat([s3, s1], dim=1) |
| 798 | + |
| 799 | + b1, b3 = torch.chunk(bias13.to(torch.float32), 2, dim=-1) |
| 800 | + w13_b = torch.cat([b3, b1], dim=-1).to(torch.bfloat16) |
| 801 | + |
| 802 | + # Build routing for kernel |
| 803 | + routing_weights = torch.nn.functional.softmax(router_logits, |
| 804 | + dim=1, |
| 805 | + dtype=torch.float32) |
| 806 | + token_final_scales, token_selected_experts = torch.topk(routing_weights, |
| 807 | + topk, |
| 808 | + dim=-1) |
| 809 | + token_final_scales = (token_final_scales / |
| 810 | + token_final_scales.sum(dim=-1, keepdim=True)) |
| 811 | + token_selected_experts = token_selected_experts.to(torch.int).contiguous() |
| 812 | + |
| 813 | + out = torch.empty_like(hidden_states, dtype=torch.bfloat16) |
| 814 | + if alpha is not None: |
| 815 | + alpha_t = torch.full((num_experts, ), |
| 816 | + alpha, |
| 817 | + device=hidden_states.device) |
| 818 | + else: |
| 819 | + alpha_t = None |
| 820 | + if beta is not None: |
| 821 | + beta_t = torch.full((num_experts, ), beta, device=hidden_states.device) |
| 822 | + else: |
| 823 | + beta_t = None |
| 824 | + if limit is not None: |
| 825 | + limit_t = torch.full((num_experts, ), |
| 826 | + limit, |
| 827 | + device=hidden_states.device) |
| 828 | + else: |
| 829 | + limit_t = None |
| 830 | + |
| 831 | + # Quant scales for SM100 MXFP8+MXFP4 path |
| 832 | + fake_input_scale = torch.ones(num_experts, device=device) |
| 833 | + quant_scales = [ |
| 834 | + w13_scale_swapped.view(torch.int32), |
| 835 | + fake_input_scale, |
| 836 | + w2_scale.view(torch.int32), |
| 837 | + fake_input_scale, |
| 838 | + ] |
| 839 | + |
| 840 | + _ = flashinfer_cutlass_fused_moe( |
| 841 | + input=hidden_states_q, |
| 842 | + token_selected_experts=token_selected_experts, |
| 843 | + token_final_scales=token_final_scales, |
| 844 | + fc1_expert_weights=w13_q_swapped.contiguous().view(torch.long), |
| 845 | + fc2_expert_weights=w2_q.contiguous().view(torch.long), |
| 846 | + output_dtype=torch.bfloat16, |
| 847 | + output=out, |
| 848 | + quant_scales=quant_scales, |
| 849 | + fc1_expert_biases=w13_b, |
| 850 | + fc2_expert_biases=bias2.to(torch.bfloat16), |
| 851 | + swiglu_alpha=alpha_t, |
| 852 | + swiglu_beta=beta_t, |
| 853 | + swiglu_limit=limit_t, |
| 854 | + tp_size=1, |
| 855 | + tp_rank=0, |
| 856 | + ep_size=1, |
| 857 | + ep_rank=0, |
| 858 | + use_mxfp8_act_scaling=True, |
| 859 | + input_sf=hidden_states_sf, |
| 860 | + ) |
| 861 | + |
| 862 | + # Allow some mismatch due to MXFP4 quantization |
| 863 | + check_accuracy(ref, out, atol=0, rtol=0.3, percent=0.8) |
0 commit comments