Skip to content

Commit 46c781a

Browse files
committed
[Do Not Land] Debug for SDPA + CP nan issue in DeepSeekV3
ghstack-source-id: e0c733f Pull Request resolved: #1566
1 parent ed288bc commit 46c781a

File tree

4 files changed

+11
-7
lines changed

4 files changed

+11
-7
lines changed

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,12 @@
7373
dim=2048,
7474
inter_dim=10944,
7575
moe_inter_dim=1408,
76-
n_layers=27,
76+
# n_layers=27,
77+
n_layers=1,
7778
n_dense_layers=1,
78-
n_heads=16,
79+
# n_heads=16,
80+
# n_heads=1, # n_heads=2 reproduces the nan error
81+
n_heads=2,
7982
n_routed_experts=64,
8083
n_shared_experts=2,
8184
n_activated_experts=6,

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ enable_memory_snapshot = false
1313
save_memory_snapshot_folder = "memory_snapshot"
1414

1515
[metrics]
16-
log_freq = 10
16+
log_freq = 1
1717
disable_color_printing = false
1818
enable_tensorboard = false
1919
save_tb_folder = "tb"
@@ -37,10 +37,10 @@ decay_type = "cosine"
3737
min_lr_factor = 0.1
3838

3939
[training]
40-
local_batch_size = 8
40+
local_batch_size = 1
4141
seq_len = 4096
4242
max_norm = 1.0 # grad norm clipping
43-
steps = 1000
43+
steps = 2
4444
compile = false
4545
dataset = "c4" # supported datasets: c4_test (2K), c4 (177M)
4646

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def _apply_ac_to_transformer_block(
253253
)
254254

255255
if ac_config.mode == "full":
256-
return ptd_checkpoint_wrapper(module, preserve_rng_state=False)
256+
return ptd_checkpoint_wrapper(module, preserve_rng_state=False, debug=True)
257257

258258
assert ac_config.mode == "selective", f"{ac_config.mode}"
259259
use_op_sac = ac_config.selective_ac_option == "op"

torchtitan/train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
from typing import Any, Generator, Iterable, Optional
1212

1313
import torch
14-
from torch.distributed.elastic.multiprocessing.errors import record
1514

1615
import torchtitan.protocols.train_spec as train_spec_module
16+
from torch.distributed.elastic.multiprocessing.errors import record
1717
from torchtitan.components.checkpoint import CheckpointManager
1818
from torchtitan.components.dataloader import DataloaderStopIteration
1919
from torchtitan.components.ft import FTManager, maybe_semi_sync_training
@@ -448,6 +448,7 @@ def forward_backward_step(
448448
loss = self.loss_fn(pred, labels)
449449
# need to free to before bwd to avoid peaking memory
450450
del pred
451+
logger.info("backward")
451452
loss.backward()
452453

453454
return loss

0 commit comments

Comments
 (0)