diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 7cc6e89038..b48e8172a4 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -22,7 +22,6 @@ ScheduleDualPipeV, ScheduleZBVZeroBubble, ) - from torchtitan.components.loss import LossFunction from torchtitan.config import ( ActivationCheckpointConfig, @@ -40,9 +39,12 @@ __all__ = [ "pipeline_llm", + "get_pipeline_metadata", "build_pipeline_schedule", "generate_llm_fqn_per_model_part", "pipeline_module_split", + "split_module", + "get_pp_rank_to_stage_indices_mapping", ] @@ -63,6 +65,82 @@ def pipeline_llm( ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: pp_mesh = parallel_dims.get_mesh("pp") + num_virtual_stages, num_layers, input_weight, output_weight = get_pipeline_metadata( + parallel_dims, parallelism, model_config + ) + + module_names_per_stage = parallelism.module_fqns_per_model_part + if module_names_per_stage is None: + module_names_per_stage = generate_llm_fqn_per_model_part( + num_virtual_stages, num_layers, input_weight, output_weight + ) + for i, stage_ms in enumerate(module_names_per_stage): + logger.debug(f"Stage {i}: {stage_ms}") + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn( + m, + parallel_dims=parallel_dims, + training=training, + model_converters=model_converters, + parallelism=parallelism, + compile_config=compile_config, + ac_config=ac_config, + dump_folder=dump_folder, + ) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule( + parallelism=parallelism, + local_batch_size=training.local_batch_size, + stages=stages, + loss_fn=loss_fn, + ) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage + + +def get_pipeline_metadata( + parallel_dims: ParallelDims, + parallelism: ParallelismConfig, + model_config: BaseModel.Config, +) -> tuple[int, int, int, int]: + """ + Determine the number of virtual stages and the number of layers in the model. + + Args: + parallel_dims (ParallelDims): Parallel dimensions. + parallelism (ParallelismConfig): Parallelism configuration. + model_config (BaseModel.Config): Model configuration. + + Returns: + tuple: A tuple containing the number of virtual stages, the number of layers in the model, + the input weight, and the output weight. + """ # Determine the number of virtual stages based on schedule type schedule_class = get_schedule_class(parallelism.pipeline_parallel_schedule) is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) @@ -123,60 +201,7 @@ def pipeline_llm( # For single-stage schedules, default is 1 virtual stage per rank stages_per_rank = 1 if is_single_stage_schedule else 2 num_virtual_stages = parallel_dims.pp * stages_per_rank - - module_names_per_stage = parallelism.module_fqns_per_model_part - if module_names_per_stage is None: - module_names_per_stage = generate_llm_fqn_per_model_part( - num_virtual_stages, num_layers, input_weight, output_weight - ) - for i, stage_ms in enumerate(module_names_per_stage): - logger.debug(f"Stage {i}: {stage_ms}") - - stages, model_parts = pipeline_module_split( - model, - pp_mesh, - parallelism.pipeline_parallel_schedule, - device, - module_names_per_stage, - ) - - # For PP with looped schedules, each item in model_parts is one stage-model-chunk. - # We need to iterate through model_parts to apply SPMD parallelisms, compilation, - # optimizer, and checkpointing - for i, m in enumerate(model_parts): - # apply SPMD-style PT-D techniques - m = parallelize_fn( - m, - parallel_dims=parallel_dims, - training=training, - model_converters=model_converters, - parallelism=parallelism, - compile_config=compile_config, - ac_config=ac_config, - dump_folder=dump_folder, - ) - model_parts[i] = m - # NOTE: this is to update the model in the stage - # in case the model is modified e.g. by torch.compile - stages[i].submod = m - - pp_schedule = build_pipeline_schedule( - parallelism=parallelism, - local_batch_size=training.local_batch_size, - stages=stages, - loss_fn=loss_fn, - ) - - # This is used in the train loop to determine whether to pass in the input_ids and labels - has_first_stage = False - has_last_stage = False - for stage in stages: - if stage.is_first: - has_first_stage = True - if stage.is_last: - has_last_stage = True - - return pp_schedule, model_parts, has_first_stage, has_last_stage + return num_virtual_stages, num_layers, input_weight, output_weight def build_pipeline_schedule( @@ -185,6 +210,8 @@ def build_pipeline_schedule( local_batch_size: int, stages: list[PipelineStage], loss_fn: Callable, + backward_requires_autograd: bool = True, + scale_grads: bool = False, ) -> _PipelineSchedule: """Builds a pipeline schedule for the given job configuration and stages. @@ -233,7 +260,8 @@ def build_pipeline_schedule( stages if looped_schedule else stages[0], n_microbatches=n_microbatches, loss_fn=loss_fn, - scale_grads=False, + backward_requires_autograd=backward_requires_autograd, + scale_grads=scale_grads, ) logger.info( f"Using pipeline schedule {parallelism.pipeline_parallel_schedule} " @@ -371,6 +399,108 @@ def generate_llm_fqn_per_model_part( return module_names_per_stage +def split_module( + whole_model: nn.Module, + module_names: list[str], +) -> nn.Module: + """ + Splits a whole model into a module based on the specified module names. + + Args: + whole_model: The complete model to be split + module_names: List of module names to include in the split + + Returns: + The split module + + Example usage: + module_names = ["tok_embeddings", "layers.0", "layers.1", "norm", "output"] + split_module(whole_model, module_names) + """ + model = copy.deepcopy(whole_model) + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance( + module_value, (nn.ModuleDict, nn.ModuleList, ModuleDict, ModuleList) + ): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, (nn.ModuleDict, ModuleDict)): + setattr(model, module_name, ModuleDict()) + elif isinstance(module_value, (nn.ModuleList, ModuleList)): + setattr(model, module_name, ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with None + setattr(model, module_name, None) + return model + + +def get_pp_rank_to_stage_indices_mapping( + pp_rank: int, + pp_degree, + pp_schedule: str, + num_stages: int, +) -> tuple[int, ...]: + """ + Returns a mapping from PP rank to stage indices for the given pipeline schedule. + + Args: + pp_rank: Pipeline parallel rank + pp_degree: Number of pipeline parallel ranks + pp_schedule: Name of pipeline parallelism schedule + num_stages: Number of pipeline stages + + Returns: + Mapping from PP rank to stage indices + """ + schedule_class = get_schedule_class(pp_schedule) + style = ( + "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + ) + assert ( + num_stages % pp_degree == 0 + ), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}" + stages_per_rank = num_stages // pp_degree + if style == "loop": + return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank)) + elif style == "v": + assert ( + stages_per_rank == 2 + ), f"v schedules assume 2 stages per rank, got {stages_per_rank}" + stage_v_pairs = list( + zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) + ) + return tuple(stage_v_pairs[pp_rank]) + else: + raise ValueError(f"Unknown style {style}") + + def pipeline_module_split( whole_model: nn.Module, pp_mesh: DeviceMesh, @@ -412,97 +542,21 @@ def pipeline_module_split( """ pp_rank = pp_mesh.get_local_rank() pp_degree = pp_mesh.size() - - def _build_stage_from_modules( - stage_idx: int, module_names: list[str], num_stages: int - ) -> tuple[PipelineStage, nn.Module]: - model = copy.deepcopy(whole_model) - - # Create a set of modules to keep for faster lookup - modules_to_keep = set(module_names) - for module_name, module_value in model.named_children(): - # Handle layer-like structures (e.g., "layers.0", "layers.1") - if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): - layers_to_keep = { - name.split(".", 1)[1] - for name in modules_to_keep - if name.startswith(f"{module_name}.") - } - if layers_to_keep: - # Keep only specified layers - if isinstance(module_value, nn.ModuleDict): - for layer_name in list(module_value.keys()): - if layer_name not in layers_to_keep: - del module_value[layer_name] - elif isinstance(module_value, nn.ModuleList): - indices_to_keep = { - int(idx) for idx in layers_to_keep if idx.isdigit() - } - new_layers = ModuleList( - [ - layer - for i, layer in enumerate(module_value) - if i in indices_to_keep - ] - ) - setattr(model, module_name, new_layers) - else: - # No layers from this structure needed, set to empty structure - if isinstance(module_value, nn.ModuleDict): - setattr(model, module_name, ModuleDict()) - elif isinstance(module_value, nn.ModuleList): - setattr(model, module_name, ModuleList()) - # Handle simple module attributes (e.g., "linear", "norm") - elif module_name not in modules_to_keep: - # Replace with None - setattr(model, module_name, None) - - stage = PipelineStage( - model, - stage_idx, - num_stages, - device, - group=pp_mesh.get_group("pp"), - ) - return stage, model - num_stages = len(module_names_per_stage) stages = [] models = [] - - schedule_class = get_schedule_class(pp_schedule) - style = ( - "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + pp_rank_to_stage_indices = get_pp_rank_to_stage_indices_mapping( + pp_rank, pp_degree, pp_schedule, num_stages ) - - def _get_stage_indices() -> tuple[int, ...]: - """ - Compute the stage ids for the stages that will run on this pp rank - for either a looped or V style schedule - """ - assert ( - num_stages % pp_degree == 0 - ), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}" - stages_per_rank = num_stages // pp_degree - if style == "loop": - return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank)) - elif style == "v": - assert ( - stages_per_rank == 2 - ), f"v schedules assume 2 stages per rank, got {stages_per_rank}" - stage_v_pairs = list( - zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) - ) - return stage_v_pairs[pp_rank] - else: - raise ValueError(f"Unknown style {style}") - - for stage_idx in _get_stage_indices(): + for stage_idx in pp_rank_to_stage_indices: module_names = module_names_per_stage[stage_idx] - stage, model_chunk = _build_stage_from_modules( + model_chunk = split_module(whole_model, module_names) + stage = PipelineStage( + model_chunk, stage_idx, - module_names, num_stages, + device, + group=pp_mesh.get_group("pp"), ) logger.info( f"PP rank {pp_rank} is building stage_idx {stage_idx} "