-
Notifications
You must be signed in to change notification settings - Fork 496
Support llama3 autoparallel + pipelining #1657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: autoparallel
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,28 +33,65 @@ def parallelize_llama( | |
the model must fit on GPU or CPU memory. | ||
""" | ||
world_mesh = parallel_dims.world_mesh | ||
def input_fn(): | ||
global_batch_size = job_config.training.global_batch_size | ||
if global_batch_size < 0: | ||
# This global batch size results in 1 gradient accumulation | ||
# step. | ||
dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard | ||
global_batch_size = job_config.training.local_batch_size * dp_degree | ||
return ( | ||
torch.randint( | ||
0, | ||
# job_config.training.vocab_size, | ||
model.vocab_size, | ||
(global_batch_size, job_config.training.seq_len), | ||
device=torch.device("cuda"), | ||
), | ||
) | ||
|
||
# TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP | ||
assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" | ||
assert parallel_dims.cp_enabled is False, "CP not supported yet" | ||
assert parallel_dims.pp_enabled is False, "PP not supported yet" | ||
|
||
pp_degree = job_config.parallelism.pipeline_parallel_degree | ||
local_batch_size = job_config.training.local_batch_size | ||
spmd_batch_size = local_batch_size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oops this is a bug for the non-pp case. should be |
||
spmd_mesh = world_mesh | ||
if parallel_dims.pp_enabled: | ||
pp_rank = world_mesh["pp"].get_local_rank() | ||
spmd_dims = [] | ||
if parallel_dims.dp_replicate_enabled: | ||
spmd_dims.append("dp_replicate") | ||
if parallel_dims.dp_shard_enabled: | ||
spmd_dims.append("dp_shard") | ||
if parallel_dims.tp_enabled: | ||
spmd_dims.append("tp") | ||
spmd_mesh = world_mesh[spmd_dims] | ||
|
||
dp_degree = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same, config could specify dp_degree |
||
for dim in ["dp_replicate", "dp_shard"]: | ||
if dim in spmd_mesh.mesh_dim_names: | ||
dp_degree *= spmd_mesh[dim].size() | ||
|
||
microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size | ||
n_microbatches = local_batch_size // microbatch_size | ||
assert microbatch_size >= 1, f"invalid config {local_batch_size=}, {n_microbatches=}" | ||
spmd_batch_size = microbatch_size * dp_degree | ||
logger.info(f"Using {spmd_batch_size=}") | ||
|
||
def input_fn(): | ||
# TODO(whc) - i am not clear what we put this code here for in the first place | ||
# if global_batch_size < 0: | ||
# # This global batch size results in 1 gradient accumulation | ||
# # step. | ||
# dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard | ||
# global_batch_size = job_config.training.local_batch_size * dp_degree | ||
if parallel_dims.pp_enabled and pp_rank > 0: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What a mess. No action here needed, but it's definitely worth thinking about what the terminal UX state here should be. |
||
# TODO: which dtype here? | ||
return ( | ||
torch.randn( | ||
(spmd_batch_size, job_config.training.seq_len, model.model_args.dim), | ||
device=torch.device("cuda"), | ||
dtype=torch.bfloat16, | ||
# important, otherwise autoparallel module will not produce grad_inputs, and pipelining will be sad | ||
requires_grad=True, | ||
), | ||
) | ||
else: | ||
return ( | ||
torch.randint( | ||
0, | ||
# job_config.training.vocab_size, | ||
model.vocab_size, | ||
(spmd_batch_size, job_config.training.seq_len), | ||
device=torch.device("cuda"), | ||
), | ||
) | ||
torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( | ||
lambda bucket_idx: 500 / parallel_dims.tp | ||
) | ||
|
@@ -75,7 +112,7 @@ def input_fn(): | |
with AutoParallel( | ||
model, | ||
input_fn, | ||
world_mesh, | ||
spmd_mesh, | ||
mp_policy=mp_policy, | ||
compile=job_config.compile, | ||
) as autop: | ||
|
@@ -94,19 +131,19 @@ def input_fn(): | |
"tp": Shard(2), | ||
} | ||
assert all( | ||
name in possible_input_shardings for name in world_mesh.mesh_dim_names | ||
name in possible_input_shardings for name in spmd_mesh.mesh_dim_names | ||
), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" | ||
x_sharding = tuple( | ||
possible_input_shardings[name] for name in world_mesh.mesh_dim_names | ||
possible_input_shardings[name] for name in spmd_mesh.mesh_dim_names | ||
) | ||
out_sharding = x_sharding | ||
loss_parallel_enabled = ( | ||
parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel | ||
parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel and not job_config.parallelism.pp_enabled | ||
) | ||
if loss_parallel_enabled: | ||
out_sharding = tuple( | ||
possible_output_shardings[name] | ||
for name in world_mesh.mesh_dim_names | ||
for name in spmd_mesh.mesh_dim_names | ||
if name != "dp_replicate" | ||
) | ||
autop.add_input_constraints([x_sharding]) | ||
|
@@ -118,6 +155,7 @@ def input_fn(): | |
parallel_mod = autop.apply_placement(sharding_placement) | ||
|
||
if loss_parallel_enabled: | ||
# don't return DTensors for pipeline parallelism! they won't work | ||
|
||
# current PyTorch's implementation of loss parallel assumes | ||
# that the DTensor has a 1d device mesh. This is not true | ||
|
@@ -127,7 +165,7 @@ def input_fn(): | |
# it would require putting the loss inside the model as well | ||
def _return_as_dtensor_for_loss_parallel(module, args, output): | ||
return torch.distributed.tensor.DTensor.from_local( | ||
output, world_mesh["tp"], (Shard(2),) | ||
output, spmd_mesh["tp"], (Shard(2),) | ||
) | ||
|
||
# not keeping a reference to the hook, don't plan on | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -492,11 +492,13 @@ def forward_backward_step( | |
) | ||
if self.pp_has_first_stage: | ||
self.pp_schedule.step( | ||
inputs, target=targets, losses=losses, input_batch=inputs | ||
# TODO: input_batch kwarg only needed for CP, but | ||
# autoparallel doesn't accept kwargs in its forward | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just fix this LOL |
||
inputs, target=targets, losses=losses #, input_batch=inputs | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, why does CP need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assumed you would know. Am I wrong? |
||
) | ||
else: | ||
self.pp_schedule.step( | ||
target=targets, losses=losses, input_batch=inputs | ||
target=targets, losses=losses #, input_batch=inputs | ||
) | ||
|
||
# accumulate losses across pipeline microbatches | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused pp degree config, should probably raise error when its not local world size