Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 61 additions & 23 deletions torchtitan/experiments/auto_parallel/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,65 @@ def parallelize_llama(
the model must fit on GPU or CPU memory.
"""
world_mesh = parallel_dims.world_mesh
def input_fn():
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
# This global batch size results in 1 gradient accumulation
# step.
dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard
global_batch_size = job_config.training.local_batch_size * dp_degree
return (
torch.randint(
0,
# job_config.training.vocab_size,
model.vocab_size,
(global_batch_size, job_config.training.seq_len),
device=torch.device("cuda"),
),
)

# TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP
assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet"
assert parallel_dims.cp_enabled is False, "CP not supported yet"
assert parallel_dims.pp_enabled is False, "PP not supported yet"

pp_degree = job_config.parallelism.pipeline_parallel_degree
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused pp degree config, should probably raise error when its not local world size

local_batch_size = job_config.training.local_batch_size
spmd_batch_size = local_batch_size
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops this is a bug for the non-pp case. should be local *dp degree and put in an 'else' branch

spmd_mesh = world_mesh
if parallel_dims.pp_enabled:
pp_rank = world_mesh["pp"].get_local_rank()
spmd_dims = []
if parallel_dims.dp_replicate_enabled:
spmd_dims.append("dp_replicate")
if parallel_dims.dp_shard_enabled:
spmd_dims.append("dp_shard")
if parallel_dims.tp_enabled:
spmd_dims.append("tp")
spmd_mesh = world_mesh[spmd_dims]

dp_degree = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same, config could specify dp_degree

for dim in ["dp_replicate", "dp_shard"]:
if dim in spmd_mesh.mesh_dim_names:
dp_degree *= spmd_mesh[dim].size()

microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size
n_microbatches = local_batch_size // microbatch_size
assert microbatch_size >= 1, f"invalid config {local_batch_size=}, {n_microbatches=}"
spmd_batch_size = microbatch_size * dp_degree
logger.info(f"Using {spmd_batch_size=}")

def input_fn():
# TODO(whc) - i am not clear what we put this code here for in the first place
# if global_batch_size < 0:
# # This global batch size results in 1 gradient accumulation
# # step.
# dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard
# global_batch_size = job_config.training.local_batch_size * dp_degree
if parallel_dims.pp_enabled and pp_rank > 0:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What a mess. No action here needed, but it's definitely worth thinking about what the terminal UX state here should be.

# TODO: which dtype here?
return (
torch.randn(
(spmd_batch_size, job_config.training.seq_len, model.model_args.dim),
device=torch.device("cuda"),
dtype=torch.bfloat16,
# important, otherwise autoparallel module will not produce grad_inputs, and pipelining will be sad
requires_grad=True,
),
)
else:
return (
torch.randint(
0,
# job_config.training.vocab_size,
model.vocab_size,
(spmd_batch_size, job_config.training.seq_len),
device=torch.device("cuda"),
),
)
torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = (
lambda bucket_idx: 500 / parallel_dims.tp
)
Expand All @@ -75,7 +112,7 @@ def input_fn():
with AutoParallel(
model,
input_fn,
world_mesh,
spmd_mesh,
mp_policy=mp_policy,
compile=job_config.compile,
) as autop:
Expand All @@ -94,19 +131,19 @@ def input_fn():
"tp": Shard(2),
}
assert all(
name in possible_input_shardings for name in world_mesh.mesh_dim_names
name in possible_input_shardings for name in spmd_mesh.mesh_dim_names
), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel"
x_sharding = tuple(
possible_input_shardings[name] for name in world_mesh.mesh_dim_names
possible_input_shardings[name] for name in spmd_mesh.mesh_dim_names
)
out_sharding = x_sharding
loss_parallel_enabled = (
parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel
parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel and not job_config.parallelism.pp_enabled
)
if loss_parallel_enabled:
out_sharding = tuple(
possible_output_shardings[name]
for name in world_mesh.mesh_dim_names
for name in spmd_mesh.mesh_dim_names
if name != "dp_replicate"
)
autop.add_input_constraints([x_sharding])
Expand All @@ -118,6 +155,7 @@ def input_fn():
parallel_mod = autop.apply_placement(sharding_placement)

if loss_parallel_enabled:
# don't return DTensors for pipeline parallelism! they won't work

# current PyTorch's implementation of loss parallel assumes
# that the DTensor has a 1d device mesh. This is not true
Expand All @@ -127,7 +165,7 @@ def input_fn():
# it would require putting the loss inside the model as well
def _return_as_dtensor_for_loss_parallel(module, args, output):
return torch.distributed.tensor.DTensor.from_local(
output, world_mesh["tp"], (Shard(2),)
output, spmd_mesh["tp"], (Shard(2),)
)

# not keeping a reference to the hook, don't plan on
Expand Down
6 changes: 4 additions & 2 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,11 +492,13 @@ def forward_backward_step(
)
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs, target=targets, losses=losses, input_batch=inputs
# TODO: input_batch kwarg only needed for CP, but
# autoparallel doesn't accept kwargs in its forward
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just fix this LOL

inputs, target=targets, losses=losses #, input_batch=inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, why does CP need input_batch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed you would know. Am I wrong?

)
else:
self.pp_schedule.step(
target=targets, losses=losses, input_batch=inputs
target=targets, losses=losses #, input_batch=inputs
)

# accumulate losses across pipeline microbatches
Expand Down
Loading