3636 get_dp_group ,
3737 get_draft_dp_group ,
3838 get_tp_group ,
39- init_distributed ,
39+ init_distributed , get_draft_sp_group ,
4040)
4141from specforge .modeling .target import (
4242 Eagle3TargetModel ,
@@ -335,10 +335,6 @@ def sanity_check(args: Namespace) -> None:
335335 args .draft_accumulation_steps = (
336336 args .draft_accumulation_steps * args .sp_ulysses_size * args .sp_ring_size
337337 )
338- if args .attention_backend in ("usp" , "usp_fa" ):
339- assert (
340- args .train_hidden_states_path is not None
341- ), "train_hidden_states_path should not be None for usp"
342338
343339
344340def build_draft_model (args : Namespace ) -> Tuple [AutoDraftModelConfig , nn .Module ]:
@@ -410,6 +406,9 @@ def build_dataloaders(
410406 )
411407 cache_key = hashlib .md5 (cache_params_string .encode ()).hexdigest ()
412408 train_dataset = load_dataset ("json" , data_files = args .train_data_path )["train" ]
409+ is_online = (
410+ args .train_data_path is not None and args .train_hidden_states_path is None
411+ )
413412 with rank_0_priority ():
414413 train_eagle3_dataset = build_eagle3_dataset (
415414 dataset = train_dataset ,
@@ -431,7 +430,7 @@ def build_dataloaders(
431430 cache_key = cache_key ,
432431 )
433432
434- if args . train_hidden_states_path is not None :
433+ if not is_online :
435434 train_eagle3_dataset = build_offline_eagle3_dataset (
436435 args .train_hidden_states_path ,
437436 args .max_length ,
@@ -444,7 +443,7 @@ def build_dataloaders(
444443 shuffle = True ,
445444 process_group = (
446445 get_draft_dp_group ()
447- if args .attention_backend == "usp"
446+ if args .attention_backend == "usp" and not is_online
448447 else get_dp_group ()
449448 ),
450449 is_vlm = args .is_vlm ,
@@ -475,7 +474,7 @@ def build_dataloaders(
475474 shuffle = False ,
476475 process_group = (
477476 get_draft_dp_group ()
478- if args .attention_backend == "usp"
477+ if args .attention_backend == "usp" and not is_online
479478 else get_dp_group ()
480479 ),
481480 is_vlm = args .is_vlm ,
@@ -632,13 +631,59 @@ def record_metrcs(
632631 tracker .log (logdict , step = global_step )
633632
634633
635- def get_dp_data_shard_from_tp (tensor : torch .Tensor ) -> torch .Tensor :
634+ import torch
635+ import torch .distributed as dist
636+ import torch .nn .functional as F
637+
638+
639+ def get_dp_data_shard_from_tp (tensor : torch .Tensor , sp_dim : int = 1 ) -> torch .Tensor :
636640 """
637- Get the data shard from the tensor .
641+ Process: TP split -> Pad to Max Len -> SP gather .
638642 """
639- tp_size = dist .get_world_size (get_tp_group ())
640- tp_rank = dist .get_rank (get_tp_group ())
641- return tensor .chunk (tp_size , dim = 0 )[tp_rank ]
643+ # 1. TP: Slice the tensor along the batch dimension
644+ tp_group = get_tp_group ()
645+ tp_size = dist .get_world_size (tp_group )
646+ tp_rank = dist .get_rank (tp_group )
647+
648+ local_tp_shard = tensor .chunk (tp_size , dim = 0 )[tp_rank ]
649+
650+ # 2. SP: Handle dynamic sequence lengths and Gather
651+ sp_group = get_draft_sp_group ()
652+
653+ if sp_group is not None and dist .get_world_size (sp_group ) > 1 :
654+ sp_world_size = dist .get_world_size (sp_group )
655+
656+ # --- Fix for Variable Sequence Lengths ---
657+ local_seq_len = local_tp_shard .size (sp_dim )
658+
659+ # Find global max sequence length in SP group
660+ len_tensor = torch .tensor ([local_seq_len ], device = local_tp_shard .device , dtype = torch .long )
661+ dist .all_reduce (len_tensor , op = dist .ReduceOp .MAX , group = sp_group )
662+ max_seq_len = len_tensor .item ()
663+
664+ # Pad local tensor if necessary
665+ # Assuming shape is [Batch, Seq, Hidden] or [Batch, Seq], and sp_dim=1
666+ if local_seq_len < max_seq_len :
667+ pad_size = max_seq_len - local_seq_len
668+
669+ # Construct pad tuple for F.pad (applies from last dim backwards)
670+ # Initialize with all zeros (no padding for other dims)
671+ pad_config = [0 ] * (local_tp_shard .ndim * 2 )
672+
673+ pad_idx = (local_tp_shard .ndim - 1 - sp_dim ) * 2 + 1
674+ pad_config [pad_idx ] = pad_size
675+
676+ # Pad value: 0 is standard, ensure it matches your pad_token_id logic if needed
677+ local_tp_shard_padded = F .pad (local_tp_shard , pad_config , value = 0 )
678+ else :
679+ local_tp_shard_padded = local_tp_shard
680+
681+ gathered_shards = [torch .empty_like (local_tp_shard_padded ) for _ in range (sp_world_size )]
682+ dist .all_gather (gathered_shards , local_tp_shard_padded .contiguous (), group = sp_group )
683+
684+ return torch .cat (gathered_shards , dim = sp_dim )
685+
686+ return local_tp_shard
642687
643688
644689def main ():
0 commit comments