Skip to content

Commit b109f7d

Browse files
authored
[DSV3] Add output.contiguous() in model to match llama3 (#1504) (#1513)
Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #1504 **Summary** ## ~~Change tokenizer size~~ This is resolved by downloading the right tokenizer Before the change: ``` File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/xilunwu/pytorch/torch/nn/modules/normalization.py", line 414, in forward return F.rms_norm(x, self.normalized_shape, self.weight, self.eps) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/data/users/xilunwu/pytorch/torch/nn/functional.py", line 2924, in rms_norm return torch.rms_norm(input, normalized_shape, weight, eps) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ torch.AcceleratorError: CUDA error: device-side assert triggered Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1 Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions. ``` Adding CUDA_LAUNCH_BLOCKING=1 to launch command shows the real error is in embedding. After fixing the tokenizer size the training works fine. ## Add `.contiguous()` to output after calling transpose() Command: `NGPU=8 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.context-parallel-degree 2` Error: ``` [rank0]:[rank0]: File "/data/users/xilunwu/pytorch/torch/nn/modules/module.py", line 1784, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/xilunwu/oss/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 237, in forward [rank0]:[rank0]: output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead. ``` The model code didn't match with llama3. After adding `.contiguous()` it runs correctly. ``` NGPU=8 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" ./run_train.sh --parallelism.context-parallel-degree 2 + NGPU=8 + export LOG_RANK=0 + LOG_RANK=0 + CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml + overrides= + '[' 2 -ne 0 ']' + overrides='--parallelism.context-parallel-degree 2' + TORCHFT_LIGHTHOUSE=http://localhost:29510 + PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + TORCHFT_LIGHTHOUSE=http://localhost:29510 + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml --parallelism.context-parallel-degree 2 W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] ***************************************** W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0731 11:31:18.279000 3852560 torch/distributed/run.py:803] ***************************************** [rank0]:[titan] 2025-07-31 11:31:25,671 - root - INFO - Starting job: DeepSeek-V3 16B model training [rank0]:[titan] 2025-07-31 11:31:27,890 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:[titan] 2025-07-31 11:31:27,891 - root - INFO - Building 2-D device mesh with ['dp_shard', 'cp'], [4, 2] [rank0]:[titan] 2025-07-31 11:31:27,897 - root - INFO - [GC] Initial GC collection 0.00 seconds [rank0]:NCCL version 2.27.5+cuda12.6 [rank0]:[titan] 2025-07-31 11:31:32,956 - root - INFO - Loading tokenizer from tokenizer.json [rank0]:[titan] 2025-07-31 11:31:33,170 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:[titan] 2025-07-31 11:31:38,681 - root - INFO - Building deepseek_v3 16B with DeepSeekV3ModelArgs(_enforced='This field is used to enforce all fields have defaults.', max_batch_size=8, max_seq_len=4096, dtype='bf16', vocab_size=129280, dim=2048, inter_dim=10944, moe_inter_dim=1408, n_layers=27, n_dense_layers=1, n_heads=16, norm_eps=1e-05, n_routed_experts=64, n_shared_experts=2, n_activated_experts=6, n_expert_groups=1, n_limited_groups=1, score_func='softmax', route_scale=1.0, use_grouped_mm=True, load_balance_coeff=0.001, q_lora_rank=0, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, use_flex_attn=False, attn_mask_type='causal', original_seq_len=4096, rope_theta=10000.0, rope_factor=40, beta_fast=32, beta_slow=1, mscale=0.7) [rank0]:[titan] 2025-07-31 11:31:38,855 - root - INFO - CUDA capacity: NVIDIA H100 with 95.00GiB memory [rank0]:[titan] 2025-07-31 11:31:38,929 - root - INFO - Total parameter count: dense 968,486,400, sparse 14,848,098,304, active 2,771,250,688 [rank0]:[titan] 2025-07-31 11:31:38,929 - root - INFO - Model deepseek_v3 16B size: 15,816,584,704 total parameters [rank0]:[titan] 2025-07-31 11:31:38,930 - root - INFO - Applied full activation checkpointing to the model [rank0]:[titan] 2025-07-31 11:31:39,021 - root - INFO - Applied FSDP to the model [rank0]:[titan] 2025-07-31 11:31:39,021 - root - INFO - Applied Context Parallel to the model [rank0]:[titan] 2025-07-31 11:31:39,398 - root - INFO - Peak FLOPS used for computing MFU: 9.890e+14 [rank0]:[titan] 2025-07-31 11:31:39,399 - root - INFO - CUDA memory usage for model: 8.84GiB(9.30%) [rank0]:[titan] 2025-07-31 11:31:39,400 - root - INFO - Mixed precision training is handled by fully_shard [rank0]:[titan] 2025-07-31 11:31:39,400 - root - INFO - Trainer is initialized with local batch size 8, global batch size 32, gradient accumulation steps 1, sequence length 4096, total steps 1000 (warmup 200) [rank0]:[titan] 2025-07-31 11:31:39,400 - root - INFO - Training starts at step 1 [rank0]:[titan] 2025-07-31 11:31:49,242 - root - INFO - step: 1 loss: 12.2584 grad_norm: 1.2466 memory: 53.49GiB(56.30%) tps: 1,589 tflops: 28.21 mfu: 2.85% [rank0]:[titan] 2025-07-31 11:31:49,242 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank0]:[titan] 2025-07-31 11:32:13,707 - root - INFO - step: 10 loss: 11.5358 grad_norm: 1.4495 memory: 71.08GiB(74.82%) tps: 6,027 tflops: 107.02 mfu: 10.82% [rank0]:[titan] 2025-07-31 11:32:40,848 - root - INFO - step: 20 loss: 10.0093 grad_norm: 7.7745 memory: 71.08GiB(74.82%) tps: 6,037 tflops: 107.20 mfu: 10.84% ```
1 parent d655e16 commit b109f7d

File tree

1 file changed

+3
-1
lines changed
  • torchtitan/models/deepseek_v3/model

1 file changed

+3
-1
lines changed

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,9 @@ def forward(
233233
output = self.sdpa(q, k, v, scale=self.softmax_scale)
234234

235235
# Reshape and project output
236-
output = output.transpose(1, 2) # (bsz, seqlen, n_heads, v_head_dim)
236+
output = output.transpose(
237+
1, 2
238+
).contiguous() # (bsz, seqlen, n_heads, v_head_dim)
237239
output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim)
238240
return self.wo(output) # (bsz, seqlen, dim)
239241

0 commit comments

Comments
 (0)