Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
334 changes: 194 additions & 140 deletions torchtitan/distributed/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
ScheduleDualPipeV,
ScheduleZBVZeroBubble,
)

from torchtitan.components.loss import LossFunction
from torchtitan.config import (
ActivationCheckpointConfig,
Expand All @@ -40,9 +39,12 @@

__all__ = [
"pipeline_llm",
"get_pipeline_metadata",
"build_pipeline_schedule",
"generate_llm_fqn_per_model_part",
"pipeline_module_split",
"split_module",
"get_pp_rank_to_stage_indices_mapping",
]
Comment on lines 40 to 48
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do all these fields need to be public and exposed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this should have been marked as draft PR, it is not ready for review yet.



Expand All @@ -63,6 +65,82 @@ def pipeline_llm(
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
pp_mesh = parallel_dims.get_mesh("pp")

num_virtual_stages, num_layers, input_weight, output_weight = get_pipeline_metadata(
parallel_dims, parallelism, model_config
)

module_names_per_stage = parallelism.module_fqns_per_model_part
if module_names_per_stage is None:
module_names_per_stage = generate_llm_fqn_per_model_part(
num_virtual_stages, num_layers, input_weight, output_weight
)
for i, stage_ms in enumerate(module_names_per_stage):
logger.debug(f"Stage {i}: {stage_ms}")

stages, model_parts = pipeline_module_split(
model,
pp_mesh,
parallelism.pipeline_parallel_schedule,
device,
module_names_per_stage,
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for i, m in enumerate(model_parts):
# apply SPMD-style PT-D techniques
m = parallelize_fn(
m,
parallel_dims=parallel_dims,
training=training,
model_converters=model_converters,
parallelism=parallelism,
compile_config=compile_config,
ac_config=ac_config,
dump_folder=dump_folder,
)
model_parts[i] = m
# NOTE: this is to update the model in the stage
# in case the model is modified e.g. by torch.compile
stages[i].submod = m

pp_schedule = build_pipeline_schedule(
parallelism=parallelism,
local_batch_size=training.local_batch_size,
stages=stages,
loss_fn=loss_fn,
)

# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, model_parts, has_first_stage, has_last_stage


def get_pipeline_metadata(
parallel_dims: ParallelDims,
parallelism: ParallelismConfig,
model_config: BaseModel.Config,
) -> tuple[int, int, int, int]:
"""
Determine the number of virtual stages and the number of layers in the model.

Args:
parallel_dims (ParallelDims): Parallel dimensions.
parallelism (ParallelismConfig): Parallelism configuration.
model_config (BaseModel.Config): Model configuration.

Returns:
tuple: A tuple containing the number of virtual stages, the number of layers in the model,
the input weight, and the output weight.
"""
# Determine the number of virtual stages based on schedule type
schedule_class = get_schedule_class(parallelism.pipeline_parallel_schedule)
is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle)
Expand Down Expand Up @@ -123,60 +201,7 @@ def pipeline_llm(
# For single-stage schedules, default is 1 virtual stage per rank
stages_per_rank = 1 if is_single_stage_schedule else 2
num_virtual_stages = parallel_dims.pp * stages_per_rank

module_names_per_stage = parallelism.module_fqns_per_model_part
if module_names_per_stage is None:
module_names_per_stage = generate_llm_fqn_per_model_part(
num_virtual_stages, num_layers, input_weight, output_weight
)
for i, stage_ms in enumerate(module_names_per_stage):
logger.debug(f"Stage {i}: {stage_ms}")

stages, model_parts = pipeline_module_split(
model,
pp_mesh,
parallelism.pipeline_parallel_schedule,
device,
module_names_per_stage,
)

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for i, m in enumerate(model_parts):
# apply SPMD-style PT-D techniques
m = parallelize_fn(
m,
parallel_dims=parallel_dims,
training=training,
model_converters=model_converters,
parallelism=parallelism,
compile_config=compile_config,
ac_config=ac_config,
dump_folder=dump_folder,
)
model_parts[i] = m
# NOTE: this is to update the model in the stage
# in case the model is modified e.g. by torch.compile
stages[i].submod = m

pp_schedule = build_pipeline_schedule(
parallelism=parallelism,
local_batch_size=training.local_batch_size,
stages=stages,
loss_fn=loss_fn,
)

# This is used in the train loop to determine whether to pass in the input_ids and labels
has_first_stage = False
has_last_stage = False
for stage in stages:
if stage.is_first:
has_first_stage = True
if stage.is_last:
has_last_stage = True

return pp_schedule, model_parts, has_first_stage, has_last_stage
return num_virtual_stages, num_layers, input_weight, output_weight


def build_pipeline_schedule(
Expand All @@ -185,6 +210,8 @@ def build_pipeline_schedule(
local_batch_size: int,
stages: list[PipelineStage],
loss_fn: Callable,
backward_requires_autograd: bool = True,
scale_grads: bool = False,
) -> _PipelineSchedule:
"""Builds a pipeline schedule for the given job configuration and stages.

Expand Down Expand Up @@ -233,7 +260,8 @@ def build_pipeline_schedule(
stages if looped_schedule else stages[0],
n_microbatches=n_microbatches,
loss_fn=loss_fn,
scale_grads=False,
backward_requires_autograd=backward_requires_autograd,
scale_grads=scale_grads,
)
logger.info(
f"Using pipeline schedule {parallelism.pipeline_parallel_schedule} "
Expand Down Expand Up @@ -371,6 +399,108 @@ def generate_llm_fqn_per_model_part(
return module_names_per_stage


def split_module(
whole_model: nn.Module,
module_names: list[str],
) -> nn.Module:
"""
Splits a whole model into a module based on the specified module names.

Args:
whole_model: The complete model to be split
module_names: List of module names to include in the split

Returns:
The split module

Example usage:
module_names = ["tok_embeddings", "layers.0", "layers.1", "norm", "output"]
split_module(whole_model, module_names)
"""
model = copy.deepcopy(whole_model)
# Create a set of modules to keep for faster lookup
modules_to_keep = set(module_names)
for module_name, module_value in model.named_children():
# Handle layer-like structures (e.g., "layers.0", "layers.1")
if isinstance(
module_value, (nn.ModuleDict, nn.ModuleList, ModuleDict, ModuleList)
):
layers_to_keep = {
name.split(".", 1)[1]
for name in modules_to_keep
if name.startswith(f"{module_name}.")
}
if layers_to_keep:
# Keep only specified layers
if isinstance(module_value, nn.ModuleDict):
for layer_name in list(module_value.keys()):
if layer_name not in layers_to_keep:
del module_value[layer_name]
elif isinstance(module_value, nn.ModuleList):
indices_to_keep = {
int(idx) for idx in layers_to_keep if idx.isdigit()
}
new_layers = ModuleList(
[
layer
for i, layer in enumerate(module_value)
if i in indices_to_keep
]
)
setattr(model, module_name, new_layers)
else:
# No layers from this structure needed, set to empty structure
if isinstance(module_value, (nn.ModuleDict, ModuleDict)):
setattr(model, module_name, ModuleDict())
elif isinstance(module_value, (nn.ModuleList, ModuleList)):
setattr(model, module_name, ModuleList())
# Handle simple module attributes (e.g., "linear", "norm")
elif module_name not in modules_to_keep:
# Replace with None
setattr(model, module_name, None)
return model


def get_pp_rank_to_stage_indices_mapping(
pp_rank: int,
pp_degree,
pp_schedule: str,
num_stages: int,
) -> tuple[int, ...]:
"""
Returns a mapping from PP rank to stage indices for the given pipeline schedule.

Args:
pp_rank: Pipeline parallel rank
pp_degree: Number of pipeline parallel ranks
pp_schedule: Name of pipeline parallelism schedule
num_stages: Number of pipeline stages

Returns:
Mapping from PP rank to stage indices
"""
schedule_class = get_schedule_class(pp_schedule)
style = (
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
)
assert (
num_stages % pp_degree == 0
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
stages_per_rank = num_stages // pp_degree
if style == "loop":
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
elif style == "v":
assert (
stages_per_rank == 2
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
stage_v_pairs = list(
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
)
return tuple(stage_v_pairs[pp_rank])
else:
raise ValueError(f"Unknown style {style}")


def pipeline_module_split(
whole_model: nn.Module,
pp_mesh: DeviceMesh,
Expand Down Expand Up @@ -412,97 +542,21 @@ def pipeline_module_split(
"""
pp_rank = pp_mesh.get_local_rank()
pp_degree = pp_mesh.size()

def _build_stage_from_modules(
stage_idx: int, module_names: list[str], num_stages: int
) -> tuple[PipelineStage, nn.Module]:
model = copy.deepcopy(whole_model)

# Create a set of modules to keep for faster lookup
modules_to_keep = set(module_names)
for module_name, module_value in model.named_children():
# Handle layer-like structures (e.g., "layers.0", "layers.1")
if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)):
layers_to_keep = {
name.split(".", 1)[1]
for name in modules_to_keep
if name.startswith(f"{module_name}.")
}
if layers_to_keep:
# Keep only specified layers
if isinstance(module_value, nn.ModuleDict):
for layer_name in list(module_value.keys()):
if layer_name not in layers_to_keep:
del module_value[layer_name]
elif isinstance(module_value, nn.ModuleList):
indices_to_keep = {
int(idx) for idx in layers_to_keep if idx.isdigit()
}
new_layers = ModuleList(
[
layer
for i, layer in enumerate(module_value)
if i in indices_to_keep
]
)
setattr(model, module_name, new_layers)
else:
# No layers from this structure needed, set to empty structure
if isinstance(module_value, nn.ModuleDict):
setattr(model, module_name, ModuleDict())
elif isinstance(module_value, nn.ModuleList):
setattr(model, module_name, ModuleList())
# Handle simple module attributes (e.g., "linear", "norm")
elif module_name not in modules_to_keep:
# Replace with None
setattr(model, module_name, None)

stage = PipelineStage(
model,
stage_idx,
num_stages,
device,
group=pp_mesh.get_group("pp"),
)
return stage, model

num_stages = len(module_names_per_stage)
stages = []
models = []

schedule_class = get_schedule_class(pp_schedule)
style = (
"v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop"
pp_rank_to_stage_indices = get_pp_rank_to_stage_indices_mapping(
pp_rank, pp_degree, pp_schedule, num_stages
)

def _get_stage_indices() -> tuple[int, ...]:
"""
Compute the stage ids for the stages that will run on this pp rank
for either a looped or V style schedule
"""
assert (
num_stages % pp_degree == 0
), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}"
stages_per_rank = num_stages // pp_degree
if style == "loop":
return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank))
elif style == "v":
assert (
stages_per_rank == 2
), f"v schedules assume 2 stages per rank, got {stages_per_rank}"
stage_v_pairs = list(
zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1))
)
return stage_v_pairs[pp_rank]
else:
raise ValueError(f"Unknown style {style}")

for stage_idx in _get_stage_indices():
for stage_idx in pp_rank_to_stage_indices:
module_names = module_names_per_stage[stage_idx]
stage, model_chunk = _build_stage_from_modules(
model_chunk = split_module(whole_model, module_names)
stage = PipelineStage(
model_chunk,
stage_idx,
module_names,
num_stages,
device,
group=pp_mesh.get_group("pp"),
)
logger.info(
f"PP rank {pp_rank} is building stage_idx {stage_idx} "
Expand Down
Loading