@@ -371,27 +371,28 @@ def __init__(self):
371
371
372
372
def _prepare_inputput_fn (self , mod , inputs , device_mesh ):
373
373
top_scores , selected_experts_indices = inputs
374
-
375
- top_scores = DTensor .from_local (top_scores , device_mesh , (Replicate (),))
376
- selected_experts_indices = DTensor .from_local (
377
- selected_experts_indices , device_mesh , (Replicate (),)
378
- )
379
374
self .num_tokens = top_scores .shape [0 ]
380
375
381
- # TODO : If needed, we can pad tokens in case bs*slen is not divisible by TP degree
376
+ # NOTE : If needed, we can pad tokens in case bs*slen is not divisible by TP degree
382
377
# if top_scores.shape[0] % device_mesh.size() != 0:
383
378
# num_tokens = top_scores.shape[0]
384
379
# tp_size = device_mesh.size()
385
380
# n_pad = (num_tokens // tp_size + 1) * tp_size - num_tokens
386
381
# selected_experts_indices = F.pad(selected_experts_indices, [0, 0, 0, n_pad])
387
382
# top_scores = F.pad(top_scores, [0, 0, 0, n_pad])
388
- assert self .num_tokens % device_mesh .size () == 0
389
383
390
- # split on the bs*slen dimension
391
- top_scores = top_scores .redistribute (device_mesh , (Shard (0 ),)).to_local ()
392
- selected_experts_indices = selected_experts_indices .redistribute (
393
- device_mesh , (Shard (0 ),)
394
- ).to_local ()
384
+ def _split_along_first_dim (x : torch .Tensor ) -> torch .Tensor :
385
+ assert x .is_contiguous ()
386
+ assert self .num_tokens % device_mesh .size () == 0
387
+ local_num_tokens = self .num_tokens // device_mesh .size ()
388
+ local_rank = device_mesh .get_local_rank ()
389
+ offset = local_rank * local_num_tokens
390
+ output = x [offset : offset + local_num_tokens ]
391
+
392
+ return output
393
+
394
+ top_scores = _split_along_first_dim (top_scores )
395
+ selected_experts_indices = _split_along_first_dim (selected_experts_indices )
395
396
396
397
return top_scores , selected_experts_indices
397
398
0 commit comments