From 7e60a6161175f5a3ca75bc6e07477648ffd203fa Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 28 Aug 2025 16:27:25 -0700 Subject: [PATCH] Support llama3 autoparallel + pipelining so far just tested locally `LOG_RANK=4 CONFIG_FILE=././torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name llama3_auto_parallel --parallelism.pipeline_parallel_degree 2 --training.steps 100` Runs and loss converges. Left one TODO about global-batch-size and gradient accumulation --- .../auto_parallel/parallelize_llama.py | 84 ++++++++++++++----- torchtitan/train.py | 6 +- 2 files changed, 65 insertions(+), 25 deletions(-) diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py index 6648f29ab..1228389cc 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_llama.py +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -33,28 +33,65 @@ def parallelize_llama( the model must fit on GPU or CPU memory. """ world_mesh = parallel_dims.world_mesh - def input_fn(): - global_batch_size = job_config.training.global_batch_size - if global_batch_size < 0: - # This global batch size results in 1 gradient accumulation - # step. - dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard - global_batch_size = job_config.training.local_batch_size * dp_degree - return ( - torch.randint( - 0, - # job_config.training.vocab_size, - model.vocab_size, - (global_batch_size, job_config.training.seq_len), - device=torch.device("cuda"), - ), - ) # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" assert parallel_dims.cp_enabled is False, "CP not supported yet" - assert parallel_dims.pp_enabled is False, "PP not supported yet" + pp_degree = job_config.parallelism.pipeline_parallel_degree + local_batch_size = job_config.training.local_batch_size + spmd_batch_size = local_batch_size + spmd_mesh = world_mesh + if parallel_dims.pp_enabled: + pp_rank = world_mesh["pp"].get_local_rank() + spmd_dims = [] + if parallel_dims.dp_replicate_enabled: + spmd_dims.append("dp_replicate") + if parallel_dims.dp_shard_enabled: + spmd_dims.append("dp_shard") + if parallel_dims.tp_enabled: + spmd_dims.append("tp") + spmd_mesh = world_mesh[spmd_dims] + + dp_degree = 1 + for dim in ["dp_replicate", "dp_shard"]: + if dim in spmd_mesh.mesh_dim_names: + dp_degree *= spmd_mesh[dim].size() + + microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size + n_microbatches = local_batch_size // microbatch_size + assert microbatch_size >= 1, f"invalid config {local_batch_size=}, {n_microbatches=}" + spmd_batch_size = microbatch_size * dp_degree + logger.info(f"Using {spmd_batch_size=}") + + def input_fn(): + # TODO(whc) - i am not clear what we put this code here for in the first place + # if global_batch_size < 0: + # # This global batch size results in 1 gradient accumulation + # # step. + # dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + # global_batch_size = job_config.training.local_batch_size * dp_degree + if parallel_dims.pp_enabled and pp_rank > 0: + # TODO: which dtype here? + return ( + torch.randn( + (spmd_batch_size, job_config.training.seq_len, model.model_args.dim), + device=torch.device("cuda"), + dtype=torch.bfloat16, + # important, otherwise autoparallel module will not produce grad_inputs, and pipelining will be sad + requires_grad=True, + ), + ) + else: + return ( + torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (spmd_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( lambda bucket_idx: 500 / parallel_dims.tp ) @@ -75,7 +112,7 @@ def input_fn(): with AutoParallel( model, input_fn, - world_mesh, + spmd_mesh, mp_policy=mp_policy, compile=job_config.compile, ) as autop: @@ -94,19 +131,19 @@ def input_fn(): "tp": Shard(2), } assert all( - name in possible_input_shardings for name in world_mesh.mesh_dim_names + name in possible_input_shardings for name in spmd_mesh.mesh_dim_names ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" x_sharding = tuple( - possible_input_shardings[name] for name in world_mesh.mesh_dim_names + possible_input_shardings[name] for name in spmd_mesh.mesh_dim_names ) out_sharding = x_sharding loss_parallel_enabled = ( - parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel + parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel and not job_config.parallelism.pp_enabled ) if loss_parallel_enabled: out_sharding = tuple( possible_output_shardings[name] - for name in world_mesh.mesh_dim_names + for name in spmd_mesh.mesh_dim_names if name != "dp_replicate" ) autop.add_input_constraints([x_sharding]) @@ -118,6 +155,7 @@ def input_fn(): parallel_mod = autop.apply_placement(sharding_placement) if loss_parallel_enabled: + # don't return DTensors for pipeline parallelism! they won't work # current PyTorch's implementation of loss parallel assumes # that the DTensor has a 1d device mesh. This is not true @@ -127,7 +165,7 @@ def input_fn(): # it would require putting the loss inside the model as well def _return_as_dtensor_for_loss_parallel(module, args, output): return torch.distributed.tensor.DTensor.from_local( - output, world_mesh["tp"], (Shard(2),) + output, spmd_mesh["tp"], (Shard(2),) ) # not keeping a reference to the hook, don't plan on diff --git a/torchtitan/train.py b/torchtitan/train.py index 2829aa3c5..b4f43263e 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -492,11 +492,13 @@ def forward_backward_step( ) if self.pp_has_first_stage: self.pp_schedule.step( - inputs, target=targets, losses=losses, input_batch=inputs + # TODO: input_batch kwarg only needed for CP, but + # autoparallel doesn't accept kwargs in its forward + inputs, target=targets, losses=losses #, input_batch=inputs ) else: self.pp_schedule.step( - target=targets, losses=losses, input_batch=inputs + target=targets, losses=losses #, input_batch=inputs ) # accumulate losses across pipeline microbatches