3535 destroy_distributed ,
3636 get_dp_group ,
3737 get_draft_dp_group ,
38+ get_draft_sp_group ,
3839 get_tp_group ,
3940 init_distributed ,
4041)
@@ -335,10 +336,6 @@ def sanity_check(args: Namespace) -> None:
335336 args .draft_accumulation_steps = (
336337 args .draft_accumulation_steps * args .sp_ulysses_size * args .sp_ring_size
337338 )
338- if args .attention_backend == "usp" :
339- assert (
340- args .train_hidden_states_path is not None
341- ), "train_hidden_states_path should not be None for usp"
342339
343340
344341def build_draft_model (args : Namespace ) -> Tuple [AutoDraftModelConfig , nn .Module ]:
@@ -410,6 +407,9 @@ def build_dataloaders(
410407 )
411408 cache_key = hashlib .md5 (cache_params_string .encode ()).hexdigest ()
412409 train_dataset = load_dataset ("json" , data_files = args .train_data_path )["train" ]
410+ is_online = (
411+ args .train_data_path is not None and args .train_hidden_states_path is None
412+ )
413413 with rank_0_priority ():
414414 train_eagle3_dataset = build_eagle3_dataset (
415415 dataset = train_dataset ,
@@ -431,7 +431,7 @@ def build_dataloaders(
431431 cache_key = cache_key ,
432432 )
433433
434- if args . train_hidden_states_path is not None :
434+ if not is_online :
435435 train_eagle3_dataset = build_offline_eagle3_dataset (
436436 args .train_hidden_states_path ,
437437 args .max_length ,
@@ -443,7 +443,9 @@ def build_dataloaders(
443443 num_workers = args .dataloader_num_workers ,
444444 shuffle = True ,
445445 process_group = (
446- get_draft_dp_group () if args .attention_backend == "usp" else get_dp_group ()
446+ get_draft_dp_group ()
447+ if args .attention_backend == "usp" and not is_online
448+ else get_dp_group ()
447449 ),
448450 is_vlm = args .is_vlm ,
449451 )
@@ -473,7 +475,7 @@ def build_dataloaders(
473475 shuffle = False ,
474476 process_group = (
475477 get_draft_dp_group ()
476- if args .attention_backend == "usp"
478+ if args .attention_backend == "usp" and not is_online
477479 else get_dp_group ()
478480 ),
479481 is_vlm = args .is_vlm ,
@@ -630,13 +632,56 @@ def record_metrcs(
630632 tracker .log (logdict , step = global_step )
631633
632634
633- def get_dp_data_shard_from_tp (tensor : torch .Tensor ) -> torch .Tensor :
635+ def get_dp_data_shard_from_tp (tensor : torch .Tensor , sp_dim : int = 1 ) -> torch .Tensor :
634636 """
635- Get the data shard from the tensor .
637+ Process: TP split -> Pad to Max Len -> SP gather .
636638 """
637- tp_size = dist .get_world_size (get_tp_group ())
638- tp_rank = dist .get_rank (get_tp_group ())
639- return tensor .chunk (tp_size , dim = 0 )[tp_rank ]
639+ # 1. TP: Slice the tensor along the batch dimension
640+ tp_group = get_tp_group ()
641+ tp_size = dist .get_world_size (tp_group )
642+ tp_rank = dist .get_rank (tp_group )
643+
644+ local_tp_shard = tensor .chunk (tp_size , dim = 0 )[tp_rank ]
645+
646+ # 2. SP: Handle dynamic sequence lengths and Gather
647+ sp_group = get_draft_sp_group ()
648+
649+ if sp_group is not None and dist .get_world_size (sp_group ) > 1 :
650+ sp_world_size = dist .get_world_size (sp_group )
651+ local_seq_len = local_tp_shard .size (sp_dim )
652+
653+ # Find global max sequence length in SP group
654+ len_tensor = torch .tensor (
655+ [local_seq_len ], device = local_tp_shard .device , dtype = torch .long
656+ )
657+ dist .all_reduce (len_tensor , op = dist .ReduceOp .MAX , group = sp_group )
658+ max_seq_len = len_tensor .item ()
659+
660+ # Pad local tensor if necessary
661+ # Shape is [Batch, Seq, Hidden] or [Batch, Seq], and sp_dim=1
662+ if local_seq_len < max_seq_len :
663+ pad_size = max_seq_len - local_seq_len
664+
665+ pad_config = [0 ] * (local_tp_shard .ndim * 2 )
666+
667+ pad_idx = (local_tp_shard .ndim - 1 - sp_dim ) * 2 + 1
668+ pad_config [pad_idx ] = pad_size
669+
670+ # Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed
671+ local_tp_shard_padded = nn .F .pad (local_tp_shard , pad_config , value = 0 )
672+ else :
673+ local_tp_shard_padded = local_tp_shard
674+
675+ gathered_shards = [
676+ torch .empty_like (local_tp_shard_padded ) for _ in range (sp_world_size )
677+ ]
678+ dist .all_gather (
679+ gathered_shards , local_tp_shard_padded .contiguous (), group = sp_group
680+ )
681+
682+ return torch .cat (gathered_shards , dim = sp_dim )
683+
684+ return local_tp_shard
640685
641686
642687def main ():
0 commit comments