-
-
Notifications
You must be signed in to change notification settings - Fork 13.3k
[Perf] Optimize cutlass moe problem size calculation, 5.3% E2E Throughput improvement, 2.2% TTFT improvement #31830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
027b0d8
3299058
54be3d7
905f611
45079c4
45dfa3b
1a00e99
2c089da
1902481
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -108,15 +108,7 @@ def run_cutlass_moe_fp8( | |
| assert global_num_experts != -1 | ||
| assert a1q_scale is not None | ||
|
|
||
| if expert_map is not None: | ||
| "Translate info from expert_map to topk_ids" | ||
| local_topk_ids = torch.where( | ||
| expert_map[topk_ids] != -1, expert_map[topk_ids], -1 | ||
| ) | ||
|
Comment on lines
-111
to
-115
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please verify correctness for removing the expert_map logic here. I assume this works because moe_permute already handles the mapping, but I'm not sure. I think you should test accuracy with EP and EPLB to properly exercise this case
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9462|± |0.0062|
| | |strict-match | 5|exact_match|↑ |0.9454|± |0.0063|Tested with EPLB, added in the PR description as well |
||
| else: | ||
| local_topk_ids = topk_ids | ||
|
|
||
| topk = local_topk_ids.size(1) | ||
| topk = topk_ids.size(1) | ||
| local_E = w1.size(0) | ||
|
|
||
| if use_batched_format: | ||
|
|
@@ -164,12 +156,8 @@ def run_cutlass_moe_fp8( | |
| # during offset calculations | ||
| expert_offsets = expert_offsets.to(torch.int64) | ||
| else: | ||
| problem_sizes1 = torch.empty( | ||
| (global_num_experts, 3), dtype=torch.int32, device=device | ||
| ) | ||
| problem_sizes2 = torch.empty( | ||
| (global_num_experts, 3), dtype=torch.int32, device=device | ||
| ) | ||
| problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device) | ||
| problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device) | ||
|
|
||
| num_expert = global_num_experts if expert_map is None else expert_map.size(0) | ||
| # permuted a1q reuses workspace2 | ||
|
|
@@ -182,11 +170,12 @@ def run_cutlass_moe_fp8( | |
| expert_map, | ||
| permuted_hidden_states=a1q_perm, | ||
| ) | ||
| expert_offsets = expert_first_token_offset[:-1] | ||
|
|
||
| ops.get_cutlass_moe_mm_problem_sizes( | ||
| local_topk_ids, problem_sizes1, problem_sizes2, global_num_experts, N, K | ||
| # swap_ab is a CUTLASS grouped-GEMM optimization (M <= 64 reduces padding). | ||
| swap_ab = a1q.size(0) <= 64 | ||
yewentao256 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets( | ||
| expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, swap_ab | ||
|
Comment on lines
+175
to
+176
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we use it or cutlass moe fp4 too?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems out of this PR's scope, I can test it and if could be used, I will have a following up PR for that. |
||
| ) | ||
| expert_offsets = expert_first_token_offset[:-1] | ||
|
|
||
| if not per_act_token and (expert_map is not None or use_batched_format): | ||
| # this is necessary to avoid imprecise scale calculation caused by | ||
|
|
@@ -240,9 +229,7 @@ def run_cutlass_moe_fp8( | |
| permuted_hidden_states=mm2_out, | ||
| topk_weights=topk_weights, | ||
| inv_permuted_idx=inv_perm, | ||
| expert_first_token_offset=( | ||
| expert_first_token_offset if expert_map is not None else None | ||
| ), | ||
| expert_first_token_offset=expert_first_token_offset, | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -772,15 +759,7 @@ def run_cutlass_moe_w4a8_fp8( | |
| f"w1 hidden size mismatch: got {w1.size(2) * 8}, expected {K=}" | ||
| ) | ||
|
|
||
| # Translate info from expert_map to topk_ids | ||
| if expert_map is not None: | ||
| local_topk_ids = torch.where( | ||
| expert_map[topk_ids] != -1, expert_map[topk_ids], -1 | ||
| ) | ||
| else: | ||
| local_topk_ids = topk_ids | ||
|
|
||
| topk = local_topk_ids.size(1) | ||
| topk = topk_ids.size(1) | ||
| a1q_perm = _resize_cache(workspace2.view(dtype=torch.float8_e4m3fn), (M * topk, K)) | ||
| mm1_out = _resize_cache(workspace13, (M * topk, N * 2)) | ||
| act_out = _resize_cache(workspace2, (M * topk, N)) | ||
|
|
@@ -790,12 +769,8 @@ def run_cutlass_moe_w4a8_fp8( | |
| ) | ||
| mm2_out = _resize_cache(workspace2, (M * topk, K)) | ||
|
|
||
| problem_sizes1 = torch.empty( | ||
| (global_num_experts, 3), dtype=torch.int32, device=device | ||
| ) | ||
| problem_sizes2 = torch.empty( | ||
| (global_num_experts, 3), dtype=torch.int32, device=device | ||
| ) | ||
| problem_sizes1 = torch.empty((local_E, 3), dtype=torch.int32, device=device) | ||
| problem_sizes2 = torch.empty((local_E, 3), dtype=torch.int32, device=device) | ||
|
|
||
| num_expert = global_num_experts if expert_map is None else expert_map.size(0) | ||
| # permuted a1q reuses workspace2 | ||
|
|
@@ -808,18 +783,11 @@ def run_cutlass_moe_w4a8_fp8( | |
| expert_map, | ||
| permuted_hidden_states=a1q_perm, | ||
| ) | ||
| expert_offsets = expert_first_token_offset[:-1] | ||
|
|
||
| # For RS gemm SwapAB is always enabled (swap logical M, N in the problem shape) | ||
| ops.get_cutlass_moe_mm_problem_sizes( | ||
| local_topk_ids, | ||
| problem_sizes1, | ||
| problem_sizes2, | ||
| global_num_experts, | ||
| N, | ||
| K, | ||
| force_swap_ab=True, | ||
| # for RS gemm SwapAB is always enabled (swap logical M, N in the problem shape). | ||
| ops.get_cutlass_moe_mm_problem_sizes_from_expert_offsets( | ||
| expert_first_token_offset, problem_sizes1, problem_sizes2, N, K, True | ||
| ) | ||
| expert_offsets = expert_first_token_offset[:-1] | ||
|
|
||
| ops.cutlass_w4a8_moe_mm( | ||
| mm1_out, | ||
|
|
@@ -866,9 +834,7 @@ def run_cutlass_moe_w4a8_fp8( | |
| permuted_hidden_states=mm2_out, | ||
| topk_weights=topk_weights, | ||
| inv_permuted_idx=inv_perm, | ||
| expert_first_token_offset=( | ||
| expert_first_token_offset if expert_map is not None else None | ||
| ), | ||
| expert_first_token_offset=expert_first_token_offset, | ||
| ) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.