Skip to content

Commit 2025abb

Browse files
authored
async tp minor fix (#1629)
follow up of #1619 to fix remaining errors. also fixing a TODO
1 parent cd337db commit 2025abb

File tree

5 files changed

+22
-30
lines changed

5 files changed

+22
-30
lines changed

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ def parallelize_llama(
6363
):
6464
raise NotImplementedError("CP support for FlexAttention is still in progress.")
6565

66-
model_compile_enabled = (
67-
job_config.compile.enable and "model" in job_config.compile.components
68-
)
6966
if parallel_dims.tp_enabled:
7067
enable_float8_linear = "float8" in job_config.model.converters
7168
float8_is_rowwise = job_config.float8.recipe_name in (
@@ -104,6 +101,9 @@ def parallelize_llama(
104101
if job_config.activation_checkpoint.mode != "none":
105102
apply_ac(model, job_config.activation_checkpoint)
106103

104+
model_compile_enabled = (
105+
job_config.compile.enable and "model" in job_config.compile.components
106+
)
107107
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
108108
if model_compile_enabled:
109109
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
@@ -160,7 +160,7 @@ def parallelize_llama(
160160
apply_ddp(
161161
model,
162162
dp_mesh,
163-
enable_compile=job_config.training.compile,
163+
enable_compile=model_compile_enabled,
164164
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
165165
)
166166

torchtitan/experiments/simple_fsdp/parallelize.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
1111
from torchtitan.distributed import ParallelDims
12+
from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp
1213
from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_tp
1314
from torchtitan.tools.logging import logger
1415

@@ -37,16 +38,7 @@ def parallelize_llama(
3738
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
3839
"""
3940

40-
model_compile_enabled = (
41-
job_config.compile.enable and "model" in job_config.compile.components
42-
)
4341
if parallel_dims.tp_enabled:
44-
if (
45-
job_config.parallelism.enable_async_tensor_parallel
46-
and not model_compile_enabled
47-
):
48-
raise RuntimeError("Async TP requires torch.compile")
49-
5042
enable_float8_linear = "float8" in job_config.model.converters
5143
float8_is_rowwise = job_config.float8.recipe_name in (
5244
"rowwise",
@@ -64,8 +56,8 @@ def parallelize_llama(
6456
tp_mesh,
6557
loss_parallel=not job_config.parallelism.disable_loss_parallel,
6658
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
67-
enable_async_tp=job_config.parallelism.enable_async_tensor_parallel,
6859
)
60+
maybe_enable_async_tp(job_config, tp_mesh)
6961

7062
if job_config.activation_checkpoint.mode != "none":
7163
apply_ac(model, job_config.activation_checkpoint)
@@ -98,11 +90,10 @@ def parallelize_llama(
9890
mode=dp_mode,
9991
ac_mode=job_config.activation_checkpoint.mode,
10092
mp_policy=mp_policy,
101-
tp_mesh=tp_mesh if parallel_dims.tp_enabled else None,
10293
)
10394
logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode)
10495

105-
if model_compile_enabled:
96+
if job_config.compile.enable and "model" in job_config.compile.components:
10697
torch._inductor.config.reorder_for_peak_memory = False
10798
model = torch.compile(model, fullgraph=True)
10899

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ def _custom_policy(ctx, func, *args, **kwargs):
185185

186186
class ReplicateComputation(torch.nn.Module):
187187
def __init__(
188-
self, device_mesh, param_sharding, mode, regional_ac, mp_policy, tp_mesh
188+
self,
189+
device_mesh,
190+
param_sharding,
191+
mode,
192+
regional_ac,
193+
mp_policy,
189194
):
190195
super().__init__()
191196
self.device_mesh = device_mesh
@@ -197,7 +202,6 @@ def __init__(
197202
mp_policy = mp_policy or MixedPrecisionPolicy()
198203
self.param_dtype = mp_policy.param_dtype
199204
self.reduce_dtype = mp_policy.reduce_dtype
200-
self.tp_mesh = tp_mesh
201205

202206
def replicate_compute(self, x):
203207
# data parallel runtime replicate parameters and do local compute
@@ -207,10 +211,7 @@ def replicate_compute(self, x):
207211
# support for FSDP/DDP/HSDP + TP (assuming TP shards the inner-most dim)
208212
if x._spec.mesh.mesh_dim_names[-1] == "tp":
209213
tp_placement = x._spec.placements[-1]
210-
# TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"]
211-
# after DeviceMesh supports slicing a non-root mesh
212-
# dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"]
213-
dp_mesh, tp_mesh = self.device_mesh, self.tp_mesh
214+
dp_mesh, tp_mesh = self.device_mesh, x._spec.mesh["tp"]
214215

215216
# re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather
216217
sharded_local_tensor = x.to_local()
@@ -270,7 +271,6 @@ def data_parallel(
270271
mode="replicate",
271272
ac_mode: str = "none",
272273
mp_policy: Optional[MixedPrecisionPolicy] = None,
273-
tp_mesh: Optional[DeviceMesh] = None,
274274
):
275275
if mode == "replicate":
276276
param_sharding = (Replicate(),)
@@ -314,7 +314,6 @@ def data_parallel(
314314
# mode,
315315
# regional_ac,
316316
# mp_policy=mp_policy,
317-
# tp_mesh=tp_mesh,
318317
# ),
319318
# unsafe=True,
320319
# )
@@ -328,7 +327,6 @@ def data_parallel(
328327
mode,
329328
regional_ac,
330329
mp_policy=mp_policy,
331-
tp_mesh=tp_mesh,
332330
),
333331
)
334332
return model

torchtitan/models/deepseek_v3/infra/parallelize.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ def parallelize_deepseekv3(
9292
if job_config.activation_checkpoint.mode != "none":
9393
apply_ac(model, job_config.activation_checkpoint)
9494

95-
if job_config.compile.enable and "model" in job_config.compile.components:
95+
model_compile_enabled = (
96+
job_config.compile.enable and "model" in job_config.compile.components
97+
)
98+
if model_compile_enabled:
9699
# NOTE: needed for torch.compile to work with dynamic shapes in token-choice MoE
97100
torch._dynamo.config.capture_scalar_outputs = True
98101
apply_compile(model)
@@ -147,7 +150,7 @@ def parallelize_deepseekv3(
147150
apply_ddp(
148151
model,
149152
dp_mesh,
150-
enable_compile=job_config.training.compile,
153+
enable_compile=model_compile_enabled,
151154
enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd,
152155
)
153156

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,6 @@ def parallelize_llama(
6464
):
6565
raise NotImplementedError("CP support for FlexAttention is still in progress.")
6666

67-
model_compile_enabled = (
68-
job_config.compile.enable and "model" in job_config.compile.components
69-
)
7067
if parallel_dims.tp_enabled:
7168
enable_float8_linear = "float8" in job_config.model.converters
7269
float8_is_rowwise = job_config.float8.recipe_name in (
@@ -90,6 +87,9 @@ def parallelize_llama(
9087
if job_config.activation_checkpoint.mode != "none":
9188
apply_ac(model, job_config.activation_checkpoint)
9289

90+
model_compile_enabled = (
91+
job_config.compile.enable and "model" in job_config.compile.components
92+
)
9393
# turn on per-TransformerBlock compile after AC wrapping and before FSDP
9494
if model_compile_enabled:
9595
apply_compile(model)

0 commit comments

Comments
 (0)