@@ -114,24 +114,24 @@ def all_gather(x: torch.Tensor, dim=0) -> torch.Tensor:
114114def reduce_scatter (
115115 input_tensor : torch .Tensor ,
116116 n_expts_act : int ,
117- metadata : ReduceScatterMetadata ,
117+ metadata : Optional [ ReduceScatterMetadata ] = None ,
118118 expt_assignment : Optional [ExptAssignment ] = None ,
119119 dim : int = 0 ,
120120 op : dist .ReduceOp .RedOpType = dist .ReduceOp .SUM ,
121121) -> torch .Tensor :
122- if _is_distributed_launch ():
123- if metadata .mode and metadata . mode == "ep_sharding" :
122+ if metadata and _is_distributed_launch ():
123+ if metadata .mode == "ep_sharding" :
124124 if dim != 0 or op != dist .ReduceOp .SUM :
125125 raise NotImplementedError ("Only dim=0 and op=SUM are supported for MoE reduce_scatter." )
126126 output = convert_ep_to_dp (input_tensor , expt_assignment , metadata .active_indx , metadata .combine_indx )
127- # weighted average of the output token from experts
128- output = output .view (- 1 , n_expts_act , output .shape [- 1 ])
129- output , _ = reduce (output , dim = 1 )
130- return output
131127 else :
132128 raise NotImplementedError (f"Distributed reduce_scatter mode { metadata .mode } is not implemented yet." )
133129 else :
134- return input_tensor
130+ output = input_tensor
131+ # weighted average of the output token from experts
132+ output = output .view (- 1 , n_expts_act , output .shape [- 1 ])
133+ output , _ = reduce (output , dim = 1 )
134+ return output
135135
136136
137137# TODO: support TP > 1
@@ -283,7 +283,8 @@ def single(x):
283283 else :
284284 rdata = gi = si = None
285285 x = matmul (x , w1_full , b1_full , rdata , gather_indx = gi , precision_config = pc1_full , fused_activation = act )
286- return matmul (x , w2_full , b2_full , rdata , scatter_indx = si , precision_config = pc2_full )
286+ x = matmul (x , w2_full , b2_full , rdata , scatter_indx = si , precision_config = pc2_full )
287+ return reduce_scatter (x , n_expts_act , metadata = None , expt_assignment = None )
287288
288289 # distributed pass
289290 def distributed (x ):
@@ -328,17 +329,17 @@ def distributed(x):
328329 + [
329330 (128 , 1024 , 1024 , 128 , 2 , "bf16" , "bf16" , 1 , 1 ),
330331 (1024 , 1024 , 1024 , 128 , 2 , "bf16" , "bf16" , 1 , 1 ),
331- (1024 , 1024 , 1024 , 128 , 2 , "bf16" , "bf16" , 1 , 4 ),
332+ (1024 , 1024 , 1024 , 128 , 2 , "bf16" , "bf16" , 1 , 2 ),
332333 ] +
333334 # moe cases - test precision
334335 ([
335336 (128 , 1024 , 1024 , 128 , 2 , "fp8" , "mx4" , 1 , 1 ),
336337 (1024 , 1024 , 1024 , 128 , 2 , "fp8" , "mx4" , 1 , 1 ),
337- (1024 , 1024 , 1024 , 128 , 2 , "fp8" , "mx4" , 1 , 4 ),
338+ (1024 , 1024 , 1024 , 128 , 2 , "fp8" , "mx4" , 1 , 2 ),
338339 ] if has_native_mx4 else [
339340 (128 , 1024 , 1024 , 128 , 2 , "bf16" , "mx4" , 1 , 1 ),
340341 (1024 , 1024 , 1024 , 128 , 2 , "bf16" , "mx4" , 1 , 1 ),
341- (1024 , 1024 , 1024 , 128 , 2 , "bf16" , "mx4" , 1 , 4 ),
342+ (1024 , 1024 , 1024 , 128 , 2 , "bf16" , "mx4" , 1 , 2 ),
342343 ]),
343344)
344345def test_mlp_mp (batch , dim1 , dim2 , n_expts_tot , n_expts_act , x_dtype , w_dtype , TP , EP , monkeypatch ):
0 commit comments