Skip to content

Commit 6b8562d

Browse files
committed
Make MoE models non-strict tracing friendly
1 parent 0187d5f commit 6b8562d

File tree

2 files changed

+386
-3
lines changed

2 files changed

+386
-3
lines changed

torchtitan/experiments/graph_trainer/tests/test_trace_module.py

Lines changed: 345 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextlib
78
import unittest
89
from collections import Counter
910

1011
import torch
1112
import torch.nn as nn
13+
from torch.testing._internal.common_fsdp import FSDPTest
1214

1315
from torchtitan.experiments.graph_trainer.make_fx_tracer import (
1416
_copy_fwd_metadata_to_bw_nodes,
@@ -38,7 +40,7 @@ def forward(self, *args):
3840
loss = self.loss_fn(logits, labels)
3941
# Must look up params in forward (not __init__) so that
4042
# _reparametrize_module's swapped parameters are captured during tracing.
41-
params = [p for _, p in self.model.named_parameters(remove_duplicate=False)]
43+
params = list(self.model.parameters())
4244
grads = torch.autograd.grad(loss, params)
4345
return [loss] + list(grads)
4446

@@ -352,5 +354,347 @@ def test_patch_engine_restores_original(self):
352354
self.assertIs(torch.autograd._engine_run_backward, orig_fn)
353355

354356

357+
@contextlib.contextmanager
358+
def _use_raw_flex_attn():
359+
"""Swap the compiled flex_attention with the raw (uncompiled) version.
360+
361+
FlexAttentionWrapper uses torch.compile'd flex_attention by default.
362+
torch.compile inside make_fx tracing is not supported and raises:
363+
"Detected that you are using FX to symbolically trace a
364+
dynamo-optimized function."
365+
Using the raw version lets make_fx decompose flex_attention into
366+
plain aten ops (bmm, softmax, etc.) which trace correctly.
367+
368+
Note: make_fx(..., pre_dispatch=True) with raw flex_attention preserves
369+
it as a FlexAttentionHOP higher-order op in the graph instead of
370+
decomposing it, which is what torch.export also does.
371+
"""
372+
from torch.nn.attention.flex_attention import flex_attention as raw_flex_attention
373+
374+
from torchtitan.models.common.attention import FlexAttentionWrapper
375+
376+
original = FlexAttentionWrapper._compiled_flex_attn
377+
FlexAttentionWrapper._compiled_flex_attn = staticmethod(raw_flex_attention)
378+
try:
379+
yield
380+
finally:
381+
FlexAttentionWrapper._compiled_flex_attn = original
382+
383+
384+
@unittest.skipUnless(torch.cuda.is_available(), "CUDA required")
385+
class TestTraceModels(unittest.TestCase):
386+
DEVICE = "cuda"
387+
DTYPE = torch.float32
388+
BATCH_SIZE = 2
389+
SEQ_LEN = 128
390+
NUM_STEPS = 5
391+
LR = 1e-3
392+
393+
def setUp(self):
394+
torch.manual_seed(42)
395+
torch.use_deterministic_algorithms(True)
396+
397+
def tearDown(self):
398+
torch.use_deterministic_algorithms(False)
399+
400+
def _run_bitwise_test(
401+
self,
402+
model_ref,
403+
model_copy,
404+
fwd_args,
405+
labels,
406+
check_collective_ops=False,
407+
num_steps=5,
408+
lr=1e-3,
409+
):
410+
train_step_ref = TrainStepModule(model_ref, get_loss)
411+
412+
with _use_raw_flex_attn():
413+
traced_result = trace_module(train_step_ref, (*fwd_args, labels))
414+
415+
if check_collective_ops:
416+
ag = sum(
417+
1
418+
for n in traced_result.gm.graph.nodes
419+
if "all_gather_into_tensor" in str(n.target)
420+
)
421+
rs = sum(
422+
1
423+
for n in traced_result.gm.graph.nodes
424+
if "reduce_scatter_tensor" in str(n.target)
425+
)
426+
self.assertTrue(
427+
ag > 0 and rs > 0,
428+
f"Expected collective ops in FSDP graph (ag={ag}, rs={rs})",
429+
)
430+
431+
opt_ref = torch.optim.Adam(model_ref.parameters(), lr=lr)
432+
opt_copy = torch.optim.Adam(model_copy.parameters(), lr=lr)
433+
434+
for step in range(1, num_steps + 1):
435+
with _use_raw_flex_attn():
436+
logits_ref = model_ref(*fwd_args)
437+
loss_ref = get_loss(logits_ref, labels)
438+
loss_ref.backward()
439+
grads_ref = [p.grad.clone() for p in model_ref.parameters()]
440+
opt_ref.step()
441+
opt_ref.zero_grad()
442+
443+
train_step_copy = TrainStepModule(model_copy, get_loss)
444+
pab = _get_params_and_buffers(train_step_copy)
445+
wrapped = run_traced_module(traced_result, pab, (*fwd_args, labels))
446+
loss_tr = wrapped[0]
447+
grads_tr = wrapped[1:]
448+
for p, g in zip(model_copy.parameters(), grads_tr, strict=True):
449+
p.grad = g
450+
opt_copy.step()
451+
opt_copy.zero_grad()
452+
453+
self.assertTrue(
454+
torch.equal(loss_ref, loss_tr), f"Step {step}: loss mismatch"
455+
)
456+
for gr, gt in zip(grads_ref, grads_tr, strict=True):
457+
self.assertTrue(torch.equal(gr, gt), f"Step {step}: grad mismatch")
458+
459+
def _run_model_test(self, config_cls, model_config, use_attn_masks=False):
460+
vocab_size = model_config.vocab_size
461+
model_ref = create_model(config_cls, model_config, self.DEVICE, self.DTYPE)
462+
model_copy = create_model(config_cls, model_config, self.DEVICE, self.DTYPE)
463+
model_copy.load_state_dict(model_ref.state_dict())
464+
tokens = torch.randint(
465+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
466+
)
467+
labels = torch.randint(
468+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
469+
)
470+
471+
if use_attn_masks:
472+
from torchtitan.models.common.attention import (
473+
create_attention_mask,
474+
get_causal_mask_mod,
475+
)
476+
477+
attn_masks = create_attention_mask(
478+
get_causal_mask_mod(), 1, None, self.SEQ_LEN, self.SEQ_LEN
479+
)
480+
self._run_bitwise_test(
481+
model_ref,
482+
model_copy,
483+
(tokens, attn_masks),
484+
labels,
485+
num_steps=self.NUM_STEPS,
486+
lr=self.LR,
487+
)
488+
return
489+
490+
self._run_bitwise_test(
491+
model_ref,
492+
model_copy,
493+
(tokens,),
494+
labels,
495+
num_steps=self.NUM_STEPS,
496+
lr=self.LR,
497+
)
498+
499+
def test_llama3(self):
500+
from torchtitan.models.llama3 import llama3_configs, Llama3Model
501+
502+
self._run_model_test(Llama3Model, llama3_configs["debugmodel"])
503+
504+
def test_qwen3(self):
505+
from torchtitan.models.qwen3 import qwen3_configs
506+
from torchtitan.models.qwen3.model import Qwen3Model
507+
508+
self._run_model_test(Qwen3Model, qwen3_configs["debugmodel"])
509+
510+
def test_qwen3_moe(self):
511+
from torchtitan.models.qwen3 import qwen3_configs
512+
from torchtitan.models.qwen3.model import Qwen3Model
513+
514+
self._run_model_test(Qwen3Model, qwen3_configs["debugmodel_moe"])
515+
516+
def test_deepseek_v3(self):
517+
from torchtitan.models.deepseek_v3 import deepseekv3_configs
518+
from torchtitan.models.deepseek_v3.model import DeepSeekV3Model
519+
520+
self._run_model_test(DeepSeekV3Model, deepseekv3_configs["debugmodel"])
521+
522+
def test_llama4(self):
523+
from torchtitan.models.llama4 import llama4_configs
524+
from torchtitan.models.llama4.model import Llama4Model
525+
526+
self._run_model_test(
527+
Llama4Model, llama4_configs["debugmodel"], use_attn_masks=True
528+
)
529+
530+
def test_gpt_oss(self):
531+
from torch.nn.attention.flex_attention import and_masks
532+
533+
from torchtitan.models.common.attention import (
534+
create_attention_mask,
535+
get_causal_mask_mod,
536+
get_sliding_window_mask_mod,
537+
)
538+
from torchtitan.models.gpt_oss import gptoss_configs
539+
from torchtitan.models.gpt_oss.model import GptOssModel
540+
541+
config = gptoss_configs["debugmodel"]
542+
vocab_size = config.vocab_size
543+
model_ref = create_model(GptOssModel, config, self.DEVICE, self.DTYPE)
544+
model_copy = create_model(GptOssModel, config, self.DEVICE, self.DTYPE)
545+
model_copy.load_state_dict(model_ref.state_dict())
546+
tokens = torch.randint(
547+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
548+
)
549+
labels = torch.randint(
550+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
551+
)
552+
causal = get_causal_mask_mod()
553+
sw_size = config.layer.attention.sliding_window_size
554+
basic_mask = create_attention_mask(causal, 1, None, self.SEQ_LEN, self.SEQ_LEN)
555+
sliding_window_mask = create_attention_mask(
556+
and_masks(causal, get_sliding_window_mask_mod(sw_size)),
557+
1,
558+
None,
559+
self.SEQ_LEN,
560+
self.SEQ_LEN,
561+
)
562+
attn_masks = {
563+
"basic_mask": basic_mask,
564+
"sliding_window_mask": sliding_window_mask,
565+
}
566+
self._run_bitwise_test(
567+
model_ref,
568+
model_copy,
569+
(tokens, attn_masks),
570+
labels,
571+
num_steps=self.NUM_STEPS,
572+
lr=self.LR,
573+
)
574+
575+
576+
class TestTraceFSDP(FSDPTest):
577+
@property
578+
def world_size(self):
579+
return min(torch.cuda.device_count(), 4)
580+
581+
def _setup(self):
582+
from torchtitan.distributed import ParallelDims
583+
584+
self.parallel_dims = ParallelDims(
585+
dp_shard=-1,
586+
dp_replicate=1,
587+
cp=1,
588+
tp=1,
589+
pp=1,
590+
ep=1,
591+
etp=1,
592+
world_size=self.world_size,
593+
)
594+
595+
def _run_fsdp_model_test(self, config_cls, model_config, use_attn_masks=False):
596+
from torchtitan.experiments.graph_trainer.simple_fsdp import data_parallel
597+
598+
self._setup()
599+
fsdp_mesh = self.parallel_dims.get_mesh("fsdp")
600+
601+
model_ref = create_model(config_cls, model_config, "cuda", torch.float32)
602+
model_copy = create_model(config_cls, model_config, "cuda", torch.float32)
603+
model_copy.load_state_dict(model_ref.state_dict())
604+
data_parallel(model_ref, device_mesh=fsdp_mesh, mode="fully_shard")
605+
data_parallel(model_copy, device_mesh=fsdp_mesh, mode="fully_shard")
606+
607+
vocab_size = model_config.vocab_size
608+
seq_len = 128
609+
tokens = torch.randint(0, vocab_size, (2, seq_len), device="cuda")
610+
labels = torch.randint(0, vocab_size, (2, seq_len), device="cuda")
611+
612+
if use_attn_masks:
613+
from torchtitan.models.common.attention import (
614+
create_attention_mask,
615+
get_causal_mask_mod,
616+
)
617+
618+
attn_masks = create_attention_mask(
619+
get_causal_mask_mod(), 1, None, seq_len, seq_len
620+
)
621+
fwd_args = (tokens, attn_masks)
622+
else:
623+
fwd_args = (tokens,)
624+
625+
train_step_ref = TrainStepModule(model_ref, get_loss)
626+
627+
with _use_raw_flex_attn():
628+
traced_result = trace_module(train_step_ref, (*fwd_args, labels))
629+
630+
ag = sum(
631+
1
632+
for n in traced_result.gm.graph.nodes
633+
if "all_gather_into_tensor" in str(n.target)
634+
)
635+
rs = sum(
636+
1
637+
for n in traced_result.gm.graph.nodes
638+
if "reduce_scatter_tensor" in str(n.target)
639+
)
640+
self.assertTrue(
641+
ag > 0 and rs > 0,
642+
f"Expected collective ops in FSDP graph (ag={ag}, rs={rs})",
643+
)
644+
645+
opt_ref = torch.optim.Adam(model_ref.parameters(), lr=1e-3)
646+
opt_copy = torch.optim.Adam(model_copy.parameters(), lr=1e-3)
647+
648+
for step in range(1, 6):
649+
with _use_raw_flex_attn():
650+
logits_ref = model_ref(*fwd_args)
651+
loss_ref = get_loss(logits_ref, labels)
652+
loss_ref.backward()
653+
grads_ref = [p.grad.clone() for p in model_ref.parameters()]
654+
opt_ref.step()
655+
opt_ref.zero_grad()
656+
657+
train_step_copy = TrainStepModule(model_copy, get_loss)
658+
pab = _get_params_and_buffers(train_step_copy)
659+
wrapped = run_traced_module(traced_result, pab, (*fwd_args, labels))
660+
loss_tr = wrapped[0]
661+
grads_tr = wrapped[1:]
662+
for p, g in zip(model_copy.parameters(), grads_tr, strict=True):
663+
p.grad = g
664+
opt_copy.step()
665+
opt_copy.zero_grad()
666+
667+
self.assertTrue(
668+
torch.equal(loss_ref, loss_tr), f"Step {step}: loss mismatch"
669+
)
670+
for gr, gt in zip(grads_ref, grads_tr, strict=True):
671+
self.assertTrue(torch.equal(gr, gt), f"Step {step}: grad mismatch")
672+
673+
def test_llama3_fsdp(self):
674+
from torchtitan.models.llama3 import llama3_configs, Llama3Model
675+
676+
self._run_fsdp_model_test(Llama3Model, llama3_configs["debugmodel"])
677+
678+
def test_qwen3_fsdp(self):
679+
from torchtitan.models.qwen3 import qwen3_configs
680+
from torchtitan.models.qwen3.model import Qwen3Model
681+
682+
self._run_fsdp_model_test(Qwen3Model, qwen3_configs["debugmodel"])
683+
684+
def test_deepseek_v3_fsdp(self):
685+
from torchtitan.models.deepseek_v3 import deepseekv3_configs
686+
from torchtitan.models.deepseek_v3.model import DeepSeekV3Model
687+
688+
self._run_fsdp_model_test(DeepSeekV3Model, deepseekv3_configs["debugmodel"])
689+
690+
def test_llama4_fsdp(self):
691+
from torchtitan.models.llama4 import llama4_configs
692+
from torchtitan.models.llama4.model import Llama4Model
693+
694+
self._run_fsdp_model_test(
695+
Llama4Model, llama4_configs["debugmodel"], use_attn_masks=True
696+
)
697+
698+
355699
if __name__ == "__main__":
356700
unittest.main()

0 commit comments

Comments
 (0)