13
13
MoEPrepareAndFinalizeNoEP )
14
14
from vllm .model_executor .layers .fused_moe .topk_weight_and_reduce import (
15
15
TopKWeightAndReduceDelegate )
16
- from vllm .model_executor .layers .fused_moe .utils import (_fp8_perm ,
17
- _fp8_quantize ,
16
+ from vllm .model_executor .layers .fused_moe .utils import (_fp8_quantize ,
18
17
_resize_cache )
19
18
from vllm .scalar_type import scalar_types
20
19
@@ -34,6 +33,10 @@ def run_cutlass_moe_fp8(
34
33
w2_scale : Optional [torch .Tensor ],
35
34
a1q_scale : Optional [torch .Tensor ],
36
35
a2_scale : Optional [torch .Tensor ],
36
+ ab_strides1 : torch .Tensor ,
37
+ ab_strides2 : torch .Tensor ,
38
+ c_strides1 : torch .Tensor ,
39
+ c_strides2 : torch .Tensor ,
37
40
workspace13 : torch .Tensor ,
38
41
workspace2 : torch .Tensor ,
39
42
expert_num_tokens : Optional [torch .Tensor ],
@@ -152,27 +155,11 @@ def run_cutlass_moe_fp8(
152
155
problem_sizes1 , problem_sizes2 , a_map ,
153
156
c_map , global_num_experts , N , K )
154
157
155
- a1q = _fp8_perm (a1q , a_map )
156
- a1q_scale = a1q_scale [a_map ] if per_act_token else a1q_scale
158
+ a1q = ops .shuffle_rows (a1q , a_map )
159
+ a1q_scale = (ops .shuffle_rows (a1q_scale , a_map )
160
+ if per_act_token else a1q_scale )
157
161
expert_offsets = expert_offsets [:- 1 ]
158
162
159
- ab_strides1 = torch .full ((w1 .size (0 ), ),
160
- K ,
161
- device = device ,
162
- dtype = torch .int64 )
163
- c_strides1 = torch .full ((w1 .size (0 ), ),
164
- 2 * N ,
165
- device = device ,
166
- dtype = torch .int64 )
167
- ab_strides2 = torch .full ((w1 .size (0 ), ),
168
- N ,
169
- device = device ,
170
- dtype = torch .int64 )
171
- c_strides2 = torch .full ((w1 .size (0 ), ),
172
- K ,
173
- device = device ,
174
- dtype = torch .int64 )
175
-
176
163
if use_batched_format :
177
164
c1 = _resize_cache (workspace13 , (local_E * padded_M , N * 2 ))
178
165
c2 = _resize_cache (workspace2 , (local_E * padded_M , N ))
@@ -209,7 +196,8 @@ def run_cutlass_moe_fp8(
209
196
else :
210
197
# We can't do this inplace because output may point to the same tensor
211
198
# as c3.
212
- output .copy_ (c3 [c_map ].view (M * topk , K ), non_blocking = True )
199
+ output .copy_ (ops .shuffle_rows (c3 , c_map ).view (M * topk , K ),
200
+ non_blocking = True )
213
201
214
202
215
203
# TODO (bnell): split class batched vs. non-batched?
@@ -222,6 +210,10 @@ def __init__(
222
210
out_dtype : Optional [torch .dtype ],
223
211
per_act_token_quant : bool ,
224
212
per_out_ch_quant : bool ,
213
+ ab_strides1 : torch .Tensor ,
214
+ ab_strides2 : torch .Tensor ,
215
+ c_strides1 : torch .Tensor ,
216
+ c_strides2 : torch .Tensor ,
225
217
block_shape : Optional [list [int ]] = None ,
226
218
num_dispatchers : Optional [int ] = None ,
227
219
use_batched_format : bool = False ,
@@ -238,6 +230,10 @@ def __init__(
238
230
self .max_experts_per_worker = max_experts_per_worker
239
231
self .num_dispatchers = num_dispatchers
240
232
self .out_dtype = out_dtype
233
+ self .ab_strides1 = ab_strides1
234
+ self .ab_strides2 = ab_strides2
235
+ self .c_strides1 = c_strides1
236
+ self .c_strides2 = c_strides2
241
237
self .use_batched_format = use_batched_format
242
238
243
239
@property
@@ -316,7 +312,8 @@ def apply(self, output: torch.Tensor, hidden_states: torch.Tensor,
316
312
run_cutlass_moe_fp8 (
317
313
output , hidden_states , w1 , w2 , topk_ids , activation_callable ,
318
314
global_num_experts , expert_map , w1_scale , w2_scale , a1q_scale ,
319
- a2_scale , workspace13 , workspace2 , expert_num_tokens ,
315
+ a2_scale , self .ab_strides1 , self .ab_strides2 , self .c_strides1 ,
316
+ self .c_strides2 , workspace13 , workspace2 , expert_num_tokens ,
320
317
self .out_dtype if self .out_dtype is not None else in_dtype ,
321
318
self .per_act_token_quant , self .per_out_ch_quant ,
322
319
self .use_batched_format )
@@ -330,6 +327,10 @@ def cutlass_moe_fp8(
330
327
topk_ids : torch .Tensor ,
331
328
w1_scale : torch .Tensor ,
332
329
w2_scale : torch .Tensor ,
330
+ ab_strides1 : torch .Tensor ,
331
+ ab_strides2 : torch .Tensor ,
332
+ c_strides1 : torch .Tensor ,
333
+ c_strides2 : torch .Tensor ,
333
334
per_act_token : Optional [bool ] = None ,
334
335
activation : str = "silu" ,
335
336
a1_scale : Optional [torch .Tensor ] = None ,
@@ -357,6 +358,17 @@ def cutlass_moe_fp8(
357
358
Shape: [num_experts] or [num_experts, 2N]
358
359
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
359
360
Shape: [num_experts] or [num_experts, K]
361
+ - ab_strides1 (torch.Tensor): The input/weight strides for the first gemm.
362
+ Shape: [num_experts]
363
+ - ab_strides2 (torch.Tensor): The input/weight strides for the second gemm.
364
+ Shape: [num_experts]
365
+ - c_strides1 (torch.Tensor): The output strides for the first gemm.
366
+ Shape: [num_experts]
367
+ - c_strides2 (torch.Tensor): The output strides for the second gemm.
368
+ Shape: [num_experts]
369
+ - per_act_token (Optional[bool]): Whether the scale is per-token or
370
+ per-tensor.
371
+ - activation (str): The activation function to use.
360
372
- a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
361
373
Shape: scalar or [M]
362
374
- a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
@@ -389,6 +401,10 @@ def cutlass_moe_fp8(
389
401
out_dtype = a .dtype ,
390
402
per_act_token_quant = per_act_token ,
391
403
per_out_ch_quant = per_out_ch ,
404
+ ab_strides1 = ab_strides1 ,
405
+ ab_strides2 = ab_strides2 ,
406
+ c_strides1 = c_strides1 ,
407
+ c_strides2 = c_strides2 ,
392
408
use_batched_format = False ,
393
409
),
394
410
)
0 commit comments