|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# This file applies the PT-D parallelisms (except pipeline parallelism) and various |
| 8 | +# training techniques (e.g. activation checkpointing and compile) to the Llama model. |
| 9 | + |
| 10 | +import torch |
| 11 | +import torch.nn as nn |
| 12 | + |
| 13 | +from torch.distributed.device_mesh import DeviceMesh |
| 14 | +from torch.distributed.tensor import Replicate, Shard |
| 15 | +from torch.distributed.tensor.parallel import ( |
| 16 | + ColwiseParallel, |
| 17 | + parallelize_module, |
| 18 | + PrepareModuleInput, |
| 19 | + RowwiseParallel, |
| 20 | + SequenceParallel, |
| 21 | +) |
| 22 | + |
| 23 | +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP |
| 24 | +from torchtitan.distributed import ParallelDims |
| 25 | +from torchtitan.distributed.expert_parallel import NoParallel |
| 26 | +from torchtitan.models.llama3.infra.parallelize import ( |
| 27 | + apply_ac, |
| 28 | + apply_compile, |
| 29 | + apply_ddp, |
| 30 | + apply_fsdp, |
| 31 | +) |
| 32 | +from torchtitan.tools.logging import logger |
| 33 | + |
| 34 | + |
| 35 | +def parallelize_qwen3( |
| 36 | + model: nn.Module, |
| 37 | + parallel_dims: ParallelDims, |
| 38 | + job_config: JobConfig, |
| 39 | +): |
| 40 | + |
| 41 | + world_mesh = parallel_dims.world_mesh |
| 42 | + assert ( |
| 43 | + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 |
| 44 | + ), f""" |
| 45 | + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree |
| 46 | + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). |
| 47 | + """ |
| 48 | + if parallel_dims.tp_enabled: |
| 49 | + if ( |
| 50 | + job_config.parallelism.enable_async_tensor_parallel |
| 51 | + and not job_config.training.compile |
| 52 | + ): |
| 53 | + raise RuntimeError("Async TP requires --training.compile") |
| 54 | + |
| 55 | + enable_float8_linear = "float8" in job_config.model.converters |
| 56 | + float8_is_rowwise = job_config.float8.recipe_name in ( |
| 57 | + "rowwise", |
| 58 | + "rowwise_with_gw_hp", |
| 59 | + ) |
| 60 | + |
| 61 | + # For now, float8 all-gather with TP is only supported for tensorwise |
| 62 | + # float8 scaling recipes. For rowwise recipes, we use regular TP and |
| 63 | + # all-gather happens in high precision. |
| 64 | + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise |
| 65 | + |
| 66 | + apply_tp( |
| 67 | + model, |
| 68 | + world_mesh["tp"], |
| 69 | + loss_parallel=not job_config.parallelism.disable_loss_parallel, |
| 70 | + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, |
| 71 | + enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, |
| 72 | + ) |
| 73 | + |
| 74 | + if job_config.activation_checkpoint.mode != "none": |
| 75 | + apply_ac(model, job_config.activation_checkpoint) |
| 76 | + |
| 77 | + # turn on per-TransformerBlock compile after AC wrapping and before FSDP |
| 78 | + if job_config.training.compile: |
| 79 | + apply_compile(model) |
| 80 | + |
| 81 | + if parallel_dims.fsdp_enabled: |
| 82 | + # apply FSDP or HSDP, potentially with Context Parallel |
| 83 | + if parallel_dims.dp_replicate_enabled: |
| 84 | + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") |
| 85 | + else: |
| 86 | + dp_mesh_dim_names = ("dp_shard_cp",) |
| 87 | + |
| 88 | + apply_fsdp( |
| 89 | + model, |
| 90 | + world_mesh[tuple(dp_mesh_dim_names)], |
| 91 | + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], |
| 92 | + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], |
| 93 | + pp_enabled=parallel_dims.pp_enabled, |
| 94 | + cpu_offload=job_config.training.enable_cpu_offload, |
| 95 | + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, |
| 96 | + ) |
| 97 | + |
| 98 | + if parallel_dims.dp_replicate_enabled: |
| 99 | + logger.info("Applied HSDP to the model") |
| 100 | + else: |
| 101 | + logger.info("Applied FSDP to the model") |
| 102 | + |
| 103 | + if parallel_dims.dp_replicate_enabled: |
| 104 | + logger.info("Applied HSDP to the model") |
| 105 | + else: |
| 106 | + logger.info("Applied FSDP to the model") |
| 107 | + |
| 108 | + if parallel_dims.cp_enabled: |
| 109 | + logger.info("Applied Context Parallel to the model") |
| 110 | + |
| 111 | + if job_config.training.enable_cpu_offload: |
| 112 | + logger.info("Applied CPU Offloading to the model") |
| 113 | + elif parallel_dims.dp_replicate_enabled: |
| 114 | + if world_mesh.ndim > 1: |
| 115 | + raise RuntimeError("DDP has not supported > 1D parallelism") |
| 116 | + apply_ddp( |
| 117 | + model, |
| 118 | + world_mesh, |
| 119 | + enable_compile=job_config.training.compile, |
| 120 | + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, |
| 121 | + ) |
| 122 | + |
| 123 | + return model |
| 124 | + |
| 125 | + |
| 126 | +def apply_tp( |
| 127 | + model: nn.Module, |
| 128 | + tp_mesh: DeviceMesh, |
| 129 | + loss_parallel: bool, |
| 130 | + enable_float8_tensorwise_tp: bool, |
| 131 | + enable_async_tp: bool, |
| 132 | +): |
| 133 | + """Apply tensor parallelism.""" |
| 134 | + # 1. Parallelize the embedding and shard its outputs (which are the first |
| 135 | + # transformer block's inputs) |
| 136 | + # 2. Parallelize the root norm layer over the sequence dim |
| 137 | + # 3. Parallelize the final linear output layer |
| 138 | + parallelize_module( |
| 139 | + model, |
| 140 | + tp_mesh, |
| 141 | + { |
| 142 | + "tok_embeddings": RowwiseParallel( |
| 143 | + input_layouts=Replicate(), |
| 144 | + output_layouts=Shard(1), |
| 145 | + ), |
| 146 | + "norm": SequenceParallel(), |
| 147 | + "output": ColwiseParallel( |
| 148 | + input_layouts=Shard(1), |
| 149 | + output_layouts=Shard(-1) if loss_parallel else Replicate(), |
| 150 | + use_local_output=not loss_parallel, |
| 151 | + ), |
| 152 | + }, |
| 153 | + ) |
| 154 | + |
| 155 | + # Parallel styles used for transformer block linear weights and their |
| 156 | + # inputs may be different for float8 linears with tensorwise scaling. |
| 157 | + if enable_float8_tensorwise_tp: |
| 158 | + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there |
| 159 | + from torchao.float8.float8_tensor_parallel import ( |
| 160 | + Float8ColwiseParallel, |
| 161 | + Float8RowwiseParallel, |
| 162 | + PrepareFloat8ModuleInput, |
| 163 | + ) |
| 164 | + |
| 165 | + rowwise_parallel, colwise_parallel, prepare_module_input = ( |
| 166 | + Float8RowwiseParallel, |
| 167 | + Float8ColwiseParallel, |
| 168 | + PrepareFloat8ModuleInput, |
| 169 | + ) |
| 170 | + else: |
| 171 | + rowwise_parallel, colwise_parallel, prepare_module_input = ( |
| 172 | + RowwiseParallel, |
| 173 | + ColwiseParallel, |
| 174 | + PrepareModuleInput, |
| 175 | + ) |
| 176 | + |
| 177 | + # Apply tensor + sequence parallelism to every transformer block |
| 178 | + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel |
| 179 | + # by folding (and unfolding) the batch dimension and the sequence dimension. |
| 180 | + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 |
| 181 | + for transformer_block in model.layers.values(): |
| 182 | + layer_plan = { |
| 183 | + "attention_norm": SequenceParallel(), |
| 184 | + "attention": prepare_module_input( |
| 185 | + input_layouts=(Shard(1), Replicate()), |
| 186 | + desired_input_layouts=(Replicate(), Replicate()), |
| 187 | + ), |
| 188 | + "attention.wq": colwise_parallel(use_local_output=False), |
| 189 | + "attention.wk": colwise_parallel(use_local_output=False), |
| 190 | + "attention.wv": colwise_parallel(use_local_output=False), |
| 191 | + "attention.q_norm": NoParallel(use_local_output=False), |
| 192 | + "attention.k_norm": NoParallel(use_local_output=False), |
| 193 | + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), |
| 194 | + "ffn_norm": SequenceParallel(), |
| 195 | + "feed_forward": prepare_module_input( |
| 196 | + input_layouts=(Shard(1),), |
| 197 | + desired_input_layouts=(Replicate(),), |
| 198 | + ), |
| 199 | + "feed_forward.w1": colwise_parallel(), |
| 200 | + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), |
| 201 | + "feed_forward.w3": colwise_parallel(), |
| 202 | + } |
| 203 | + |
| 204 | + parallelize_module( |
| 205 | + module=transformer_block, |
| 206 | + device_mesh=tp_mesh, |
| 207 | + parallelize_plan=layer_plan, |
| 208 | + ) |
| 209 | + |
| 210 | + if enable_async_tp: |
| 211 | + from torch.distributed._symmetric_memory import enable_symm_mem_for_group |
| 212 | + |
| 213 | + torch._inductor.config._micro_pipeline_tp = True |
| 214 | + enable_symm_mem_for_group(tp_mesh.get_group().group_name) |
| 215 | + |
| 216 | + logger.info( |
| 217 | + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" |
| 218 | + "Tensor Parallelism to the model" |
| 219 | + ) |
0 commit comments