10
10
from tqdm import tqdm
11
11
12
12
import vllm .envs as envs
13
+ from vllm .distributed .parallel_state import get_dp_group
13
14
from vllm .model_executor .layers .fused_moe .deep_gemm_moe import DeepGemmExperts
14
15
from vllm .model_executor .layers .fused_moe .deep_gemm_utils import (
15
16
compute_aligned_M , deep_gemm_block_shape )
@@ -131,11 +132,9 @@ def _deepgemm_fp8_gemm_nt_warmup(w: torch.Tensor, ws: torch.Tensor,
131
132
GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE : set [torch .Size ] = set ()
132
133
133
134
134
- def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup (w1 : torch .Tensor ,
135
- w2 : torch .Tensor ,
136
- w1_scale : torch .Tensor ,
137
- w2_scale : torch .Tensor ,
138
- num_topk : int ):
135
+ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup (
136
+ w1 : torch .Tensor , w2 : torch .Tensor , w1_scale : torch .Tensor ,
137
+ w2_scale : torch .Tensor , num_topk : int , max_tokens : int ):
139
138
if (w1 .size () in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE
140
139
and w2 .size () in GROUPED_FP8_GEMM_NT_CONTIGUOUS_WARMUP_CACHE ):
141
140
return
@@ -147,9 +146,13 @@ def _deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(w1: torch.Tensor,
147
146
num_experts = w1 .size (0 )
148
147
device = w1 .device
149
148
149
+ # Assumes all ranks have the same max_num_batched_tokens
150
+ max_tokens_across_dp = get_dp_group ().world_size * max_tokens
151
+ max_tokens = min (max_tokens_across_dp , envs .VLLM_FUSED_MOE_CHUNK_SIZE )
152
+
150
153
# This is the maximum GroupedGemm M size that we expect to run
151
154
# the grouped_gemm with.
152
- MAX_M = compute_aligned_M (envs . VLLM_FUSED_MOE_CHUNK_SIZE ,
155
+ MAX_M = compute_aligned_M (max_tokens ,
153
156
num_topk ,
154
157
num_experts ,
155
158
block_m ,
@@ -201,7 +204,8 @@ def deepgemm_fp8_gemm_nt_warmup(model: torch.nn.Module, max_tokens: int):
201
204
_deepgemm_fp8_gemm_nt_warmup (w = w , ws = ws , max_tokens = max_tokens )
202
205
203
206
204
- def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup (model : torch .nn .Module ):
207
+ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup (model : torch .nn .Module ,
208
+ max_tokens : int ):
205
209
dg_modules = [
206
210
m for m in model .modules ()
207
211
if _fused_moe_grouped_gemm_may_use_deep_gemm (m )
@@ -211,9 +215,9 @@ def deepgemm_grouped_fp8_gemm_nt_contiguous_warmup(model: torch.nn.Module):
211
215
w13 , w13_scale , w2 , w2_scale , num_topk = (
212
216
_extract_data_from_fused_moe_module (dgm ))
213
217
_deepgemm_grouped_fp8_gemm_nt_contiguous_warmup (
214
- w13 , w2 , w13_scale , w2_scale , num_topk )
218
+ w13 , w2 , w13_scale , w2_scale , num_topk , max_tokens )
215
219
216
220
217
221
def deep_gemm_warmup (model : torch .nn .Module , max_tokens : int ):
218
222
deepgemm_fp8_gemm_nt_warmup (model , max_tokens )
219
- deepgemm_grouped_fp8_gemm_nt_contiguous_warmup (model )
223
+ deepgemm_grouped_fp8_gemm_nt_contiguous_warmup (model , max_tokens )
0 commit comments