Skip to content

Commit 8a749c6

Browse files
authored
Move the call to init_attention_mask to trainer (#1616)
One perspective on the attention mask is that it should be coupled with the dataloader rather than the modeling component. Therefore, this PR moves the creation of the attention mask to the trainer, removing it from the model itself. This PR also fixes #1612 ``` -> % LOG_RANK=6 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --training.steps=100 --parallelism.pipeline_parallel_degree=4 + NGPU=8 + export LOG_RANK=6 + LOG_RANK=6 + CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 6 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml --training.steps=100 --parallelism.pipeline_parallel_degree=4 W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] ***************************************** W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0821 10:14:43.689000 1062175 torch/distributed/run.py:803] ***************************************** [rank6]:[titan] 2025-08-21 10:14:50,681 - root - INFO - Starting job: DeepSeek-V3 16B model training [rank6]:[titan] 2025-08-21 10:14:53,248 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank6]:[titan] 2025-08-21 10:14:53,250 - root - INFO - Building 2-D device mesh with ['pp', 'dp_shard'], [4, 2] [rank6]:[titan] 2025-08-21 10:14:53,265 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank6]:[titan] 2025-08-21 10:14:56,937 - root - INFO - Loading tokenizer from tokenizer.json [rank6]:[titan] 2025-08-21 10:14:57,076 - root - INFO - Preparing c4 dataset from allenai/c4 [rank6]:[titan] 2025-08-21 10:15:00,743 - root - INFO - Building deepseek_v3 16B with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=4096, vocab_size=102400, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, norm_eps=1e-05, moe_args=MoEArgs(num_experts=64, num_shared_experts=2, score_func='softmax', route_norm=True, route_scale=1.0, score_before_experts=False, top_k=6, use_grouped_mm=True, load_balance_coeff=0.001), n_expert_groups=1, n_limited_groups=1, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=True, attn_mask_type='block_causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7) [rank6]:[titan] 2025-08-21 10:15:00,966 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank6]:[titan] 2025-08-21 10:15:01,008 - root - INFO - Total parameter count: dense 858,385,920, sparse 14,848,098,304, active 2,661,150,208 [rank6]:[titan] 2025-08-21 10:15:01,008 - root - INFO - Model deepseek_v3 16B size: 15,706,484,224 total parameters [rank6]:Stage 3: Modules to keep: {'layers.14', 'layers.13', 'layers.11', 'layers.12'} [rank6]:Stage 7: Modules to keep: {'output', 'norm', 'layers.26', 'layers.25'} [rank6]:[titan] 2025-08-21 10:15:01,029 - root - INFO - PP rank 3 is building stage_idx 3 with modules ['layers.11', 'layers.12', 'layers.13', 'layers.14'] [rank6]:[titan] 2025-08-21 10:15:01,048 - root - INFO - PP rank 3 is building stage_idx 7 with modules ['layers.25', 'layers.26', 'norm', 'output'] [rank6]:[titan] 2025-08-21 10:15:01,048 - root - INFO - Applied full activation checkpointing to the model [rank6]:[titan] 2025-08-21 10:15:01,072 - root - INFO - Applied FSDP to the model [rank6]:[titan] 2025-08-21 10:15:01,072 - root - INFO - Applied full activation checkpointing to the model [rank6]:[titan] 2025-08-21 10:15:01,080 - root - INFO - Applied FSDP to the model [rank6]:[titan] 2025-08-21 10:15:01,080 - root - INFO - Using pipeline schedule Interleaved1F1B with 8 microbatches and 8 stages. [rank6]:[titan] 2025-08-21 10:15:01,488 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank6]:[titan] 2025-08-21 10:15:01,488 - root - INFO - CUDA memory usage for model: 6.94GiB(7.31%) [rank6]:[titan] 2025-08-21 10:15:01,489 - root - WARNING - Warmup steps (200) exceed total training steps (100). Adjusting warmup steps to 100. [rank6]:[titan] 2025-08-21 10:15:01,489 - root - WARNING - Warmup (100) + decay (80) steps exceed total training steps (100). Adjusting decay steps to 0. [rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Mixed precision training is handled by fully_shard [rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Trainer is initialized with local batch size 8, global batch size 16, gradient accumulation steps 1, sequence length 4096, total steps 100 (warmup 200) [rank6]:[titan] 2025-08-21 10:15:01,489 - root - INFO - Training starts at step 1 [rank6]:[rank6]:[W821 10:15:10.781655306 ProcessGroupNCCL.cpp:3993] Warning: An unbatched P2P op (send/recv) was called on this ProcessGroup with size 4. In lazy initialization mode, this will result in a new 2-rank NCCL communicator to be created. (function operator()) [rank6]:NCCL version 2.27.5+cuda12.6 [rank6]:[rank6]:[W821 10:15:16.977607954 ProcessGroupNCCL.cpp:3993] Warning: An unbatched P2P op (send/recv) was called on this ProcessGroup with size 4. In lazy initialization mode, this will result in a new 2-rank NCCL communicator to be created. (function operator()) [rank6]:/data/users/chienchin/mywork/pytorch/torch/__init__.py:1539: UserWarning: Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /data/users/chienchin/mywork/pytorch/aten/src/ATen/Context.cpp:80.) [rank6]: return _C._get_float32_matmul_precision() [rank6]:[titan] 2025-08-21 10:15:28,674 - root - INFO - step: 1 loss: 12.0194 grad_norm: 1.8958 memory: 53.94GiB(56.78%) tps: 296 tflops: 5.16 mfu: 0.52% [rank6]:[titan] 2025-08-21 10:15:28,674 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 no^[^[[rank6]:[titan] 2025-08-21 10:15:43,154 - root - INFO - step: 10 loss: 10.3629 grad_norm: 3.0762 memory: 67.11GiB(70.64%) tps: 5,092 tflops: 88.73 mfu: 8.97% [rank6]:[titan] 2025-08-21 10:15:59,017 - root - INFO - step: 20 loss: 8.9238 grad_norm: 2.5020 memory: 67.11GiB(70.64%) tps: 5,165 tflops: 90.00 mfu: 9.10% [rank6]:[titan] 2025-08-21 10:16:15,051 - root - INFO - step: 30 loss: 7.8167 grad_norm: 1.7460 memory: 67.11GiB(70.64%) tps: 5,109 tflops: 89.04 mfu: 9.00% [rank6]:[titan] 2025-08-21 10:16:31,989 - root - INFO - step: 40 loss: 7.1761 grad_norm: 1.1432 memory: 67.11GiB(70.64%) tps: 4,837 tflops: 84.29 mfu: 8.52% [rank6]:[titan] 2025-08-21 10:16:48,455 - root - INFO - step: 50 loss: 6.7850 grad_norm: 1.4950 memory: 67.11GiB(70.64%) tps: 4,975 tflops: 86.70 mfu: 8.77% [rank6]:[titan] 2025-08-21 10:17:04,602 - root - INFO - step: 60 loss: 6.8310 grad_norm: 1.2972 memory: 67.11GiB(70.64%) tps: 5,074 tflops: 88.42 mfu: 8.94% [rank6]:[titan] 2025-08-21 10:17:22,231 - root - INFO - step: 70 loss: 6.6627 grad_norm: 1.1630 memory: 67.11GiB(70.64%) tps: 4,647 tflops: 80.98 mfu: 8.19% [rank6]:[titan] 2025-08-21 10:17:41,358 - root - INFO - step: 80 loss: 6.3542 grad_norm: 0.8215 memory: 67.11GiB(70.64%) tps: 4,283 tflops: 74.64 mfu: 7.55% [rank6]:[titan] 2025-08-21 10:17:58,336 - root - INFO - step: 90 loss: 6.4442 grad_norm: 1.2542 memory: 67.11GiB(70.64%) tps: 4,825 tflops: 84.09 mfu: 8.50% [rank6]:[titan] 2025-08-21 10:18:12,542 - root - INFO - [GC] Peforming periodical GC collection 0.07 seconds [rank6]:[titan] 2025-08-21 10:18:14,566 - root - INFO - step: 100 loss: 6.7519 grad_norm: 1.3966 memory: 67.11GiB(70.64%) tps: 5,048 tflops: 87.97 mfu: 8.89% [rank6]:[titan] 2025-08-21 10:18:14,566 - root - INFO - Training completed [rank6]:[titan] 2025-08-21 10:18:17,159 - root - INFO - Process group destroyed ```
1 parent 7d744b2 commit 8a749c6

File tree

11 files changed

+38
-61
lines changed

11 files changed

+38
-61
lines changed

torchtitan/distributed/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,7 @@ def context(cp_context: Generator[None, None, None] | None = None):
206206

207207
if SDPBackend.MATH in ScaledDotProductAttention.backends:
208208
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)
209-
assert (
210-
ScaledDotProductAttention.backends
211-
), "No valid SDPA backends with CP."
209+
212210
stack.enter_context(cp_context)
213211

214212
yield

torchtitan/experiments/forge/example_train.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ def forward_backward_step(
160160
# apply context parallelism if cp is enabled
161161
# ensure CP handles the separate freqs_cis buffer for each pp stage
162162
inputs = input_dict["input"]
163+
# Create the FlexAttention mask according to the input
164+
if getattr(self.model_args, "use_flex_attn", False):
165+
cp_mesh = (
166+
parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
167+
)
168+
init_attention_mask(inputs, self.tokenizer.eos_id, cp_mesh)
169+
163170
optional_context_parallel_ctx = (
164171
dist_utils.create_context_parallel_ctx(
165172
cp_mesh=parallel_dims.world_mesh["cp"],

torchtitan/experiments/llama4/model/args.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
7171
"CP support for FlexAttention is still in progress."
7272
)
7373

74-
if (
75-
job_config.parallelism.pipeline_parallel_degree > 1
76-
and self.use_flex_attn
77-
and self.attn_mask_type == "block_causal"
78-
):
79-
raise RuntimeError(
80-
"PP + block causal FlexAttention support will be fixed soon."
81-
)
82-
8374
def get_nparams_and_flops(
8475
self, model: nn.Module, seq_len: int
8576
) -> tuple[int, float]:

torchtitan/experiments/llama4/model/model.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.nn.functional as F
1010
from torch import nn
1111

12-
from torchtitan.models.attention import build_attention, init_attention_mask
12+
from torchtitan.models.attention import build_attention
1313
from torchtitan.models.moe import MoE
1414
from torchtitan.protocols import ModelProtocol
1515

@@ -451,7 +451,6 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
451451
def forward(
452452
self,
453453
tokens: torch.Tensor,
454-
eos_id: int | None = None,
455454
input_batch: torch.Tensor | None = None,
456455
):
457456
"""
@@ -471,11 +470,6 @@ def forward(
471470
torch.Tensor: Output logits after applying the Transformer model.
472471
473472
"""
474-
if self.model_args.use_flex_attn:
475-
init_attention_mask(
476-
input_batch if input_batch is not None else tokens, eos_id=eos_id
477-
)
478-
479473
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
480474
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
481475

torchtitan/experiments/qwen3/model/model.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn.functional as F
1212
from torch import nn
1313

14-
from torchtitan.models.attention import build_attention, init_attention_mask
14+
from torchtitan.models.attention import build_attention
1515
from torchtitan.protocols.train_spec import ModelProtocol
1616

1717
from .args import Qwen3ModelArgs
@@ -411,7 +411,6 @@ def forward(
411411
self,
412412
tokens: torch.Tensor,
413413
input_batch: torch.Tensor | None = None,
414-
eos_id: int | None = None,
415414
):
416415
"""
417416
Perform a forward pass through the Transformer model.
@@ -425,18 +424,11 @@ def forward(
425424
This will always be the input batch regardless of the pipeline stage.
426425
This field is required for non-first PP stages to perform document
427426
masking attention (to analyze the boundary of the document).
428-
eos_id (int | None): End-of-sequence token ID. If not provided, uses self.eos_id.
429427
430428
Returns:
431429
torch.Tensor: Output logits after applying the Transformer model.
432430
433431
"""
434-
if self.model_args.use_flex_attn:
435-
init_attention_mask(
436-
input_batch if input_batch is not None else tokens,
437-
eos_id=eos_id,
438-
)
439-
440432
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
441433
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
442434

torchtitan/models/attention.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9+
import functools
910
from typing import Callable, ClassVar
1011

1112
import torch
1213
import torch.nn.functional as F
14+
from torch.distributed.tensor.experimental._attention import create_cp_block_mask
1315
from torch.nn.attention import sdpa_kernel, SDPBackend
1416
from torch.nn.attention.flex_attention import (
1517
_mask_mod_signature,
@@ -239,5 +241,18 @@ def build_attention(
239241
return ScaledDotProductAttention(attn_mask_type)
240242

241243

242-
def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None:
244+
def init_attention_mask(
245+
batch: torch.Tensor,
246+
eos_id: int | None,
247+
cp_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
248+
) -> None:
249+
250+
# This is not functional yet because we currently gate the use of Flex + CP
251+
# while we continue debugging accuracy issues. However, we want to evaluate
252+
# the user experience with CP enabled.
253+
if cp_mesh is not None:
254+
FlexAttention.compiled_create_block_mask = functools.partial(
255+
create_cp_block_mask, device_mesh=cp_mesh
256+
)
257+
243258
FlexAttention.init_attention_mask(batch, eos_id)

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
105105
"CP support for FlexAttention is still in progress."
106106
)
107107

108-
if (
109-
job_config.parallelism.pipeline_parallel_degree > 1
110-
and self.use_flex_attn
111-
and self.attn_mask_type == "block_causal"
112-
):
113-
raise RuntimeError(
114-
"PP + block causal FlexAttention support will be fixed soon."
115-
)
116-
117108
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:
118109
"""
119110
Adopted from llama4 implementation.

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111
from torch import nn
1212

13-
from torchtitan.models.attention import build_attention, init_attention_mask
13+
from torchtitan.models.attention import build_attention
1414
from torchtitan.models.moe import FeedForward, MoE
1515
from torchtitan.protocols.train_spec import ModelProtocol
1616

@@ -364,7 +364,6 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None:
364364
def forward(
365365
self,
366366
tokens: torch.Tensor,
367-
eos_id: int | None = None,
368367
input_batch: torch.Tensor | None = None,
369368
):
370369
"""
@@ -383,10 +382,6 @@ def forward(
383382
Returns:
384383
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
385384
"""
386-
if self.model_args.use_flex_attn:
387-
init_attention_mask(
388-
input_batch if input_batch is not None else tokens, eos_id=eos_id
389-
)
390385

391386
h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens
392387

torchtitan/models/llama3/model/args.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
5050
"CP support for FlexAttention is still in progress."
5151
)
5252

53-
if (
54-
job_config.parallelism.pipeline_parallel_degree > 1
55-
and self.use_flex_attn
56-
and self.attn_mask_type == "block_causal"
57-
):
58-
raise RuntimeError(
59-
"PP + block causal FlexAttention support will be fixed soon."
60-
)
6153
self.max_seq_len = seq_len
6254

6355
def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]:

torchtitan/models/llama3/model/model.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn.functional as F
1212
from torch import nn
1313

14-
from torchtitan.models.attention import build_attention, init_attention_mask
14+
from torchtitan.models.attention import build_attention
1515
from torchtitan.protocols.train_spec import ModelProtocol
1616

1717
from .args import TransformerModelArgs
@@ -395,7 +395,6 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
395395
def forward(
396396
self,
397397
tokens: torch.Tensor,
398-
eos_id: int | None = None,
399398
input_batch: torch.Tensor | None = None,
400399
):
401400
"""
@@ -415,11 +414,6 @@ def forward(
415414
torch.Tensor: Output logits after applying the Transformer model.
416415
417416
"""
418-
if self.model_args.use_flex_attn:
419-
init_attention_mask(
420-
input_batch if input_batch is not None else tokens, eos_id=eos_id
421-
)
422-
423417
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
424418
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
425419

0 commit comments

Comments
 (0)