Skip to content

Commit b516c04

Browse files
committed
Make MoE models non-strict tracing friendly
1 parent 0187d5f commit b516c04

File tree

2 files changed

+261
-2
lines changed

2 files changed

+261
-2
lines changed

torchtitan/experiments/graph_trainer/tests/test_trace_module.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextlib
78
import unittest
89
from collections import Counter
910

@@ -352,5 +353,224 @@ def test_patch_engine_restores_original(self):
352353
self.assertIs(torch.autograd._engine_run_backward, orig_fn)
353354

354355

356+
@contextlib.contextmanager
357+
def _use_raw_flex_attn():
358+
"""Swap the compiled flex_attention with the raw (uncompiled) version.
359+
360+
FlexAttentionWrapper uses torch.compile'd flex_attention by default.
361+
torch.compile inside make_fx tracing is not supported and raises:
362+
"Detected that you are using FX to symbolically trace a
363+
dynamo-optimized function."
364+
Using the raw version lets make_fx decompose flex_attention into
365+
plain aten ops (bmm, softmax, etc.) which trace correctly.
366+
367+
Note: make_fx(..., pre_dispatch=True) with raw flex_attention preserves
368+
it as a FlexAttentionHOP higher-order op in the graph instead of
369+
decomposing it, which is what torch.export also does.
370+
"""
371+
from torch.nn.attention.flex_attention import flex_attention as raw_flex_attention
372+
373+
from torchtitan.models.common.attention import FlexAttentionWrapper
374+
375+
original = FlexAttentionWrapper._compiled_flex_attn
376+
FlexAttentionWrapper._compiled_flex_attn = staticmethod(raw_flex_attention)
377+
try:
378+
yield
379+
finally:
380+
FlexAttentionWrapper._compiled_flex_attn = original
381+
382+
383+
@unittest.skipUnless(torch.cuda.is_available(), "CUDA required")
384+
class TestTraceModels(unittest.TestCase):
385+
DEVICE = "cuda"
386+
DTYPE = torch.float32
387+
BATCH_SIZE = 2
388+
SEQ_LEN = 128
389+
NUM_STEPS = 5
390+
LR = 1e-3
391+
392+
def setUp(self):
393+
torch.manual_seed(42)
394+
torch.use_deterministic_algorithms(True)
395+
396+
def tearDown(self):
397+
torch.use_deterministic_algorithms(False)
398+
399+
def _run_bitwise_test(
400+
self,
401+
model_ref,
402+
model_copy,
403+
fwd_args,
404+
labels,
405+
check_collective_ops=False,
406+
num_steps=5,
407+
lr=1e-3,
408+
):
409+
train_step_ref = TrainStepModule(model_ref, get_loss)
410+
411+
with _use_raw_flex_attn():
412+
traced_result = trace_module(train_step_ref, (*fwd_args, labels))
413+
414+
if check_collective_ops:
415+
ag = sum(
416+
1
417+
for n in traced_result.gm.graph.nodes
418+
if "all_gather_into_tensor" in str(n.target)
419+
)
420+
rs = sum(
421+
1
422+
for n in traced_result.gm.graph.nodes
423+
if "reduce_scatter_tensor" in str(n.target)
424+
)
425+
self.assertTrue(
426+
ag > 0 and rs > 0,
427+
f"Expected collective ops in FSDP graph (ag={ag}, rs={rs})",
428+
)
429+
430+
opt_ref = torch.optim.Adam(model_ref.parameters(), lr=lr)
431+
opt_copy = torch.optim.Adam(model_copy.parameters(), lr=lr)
432+
433+
for step in range(1, num_steps + 1):
434+
with _use_raw_flex_attn():
435+
logits_ref = model_ref(*fwd_args)
436+
loss_ref = get_loss(logits_ref, labels)
437+
loss_ref.backward()
438+
grads_ref = [p.grad.clone() for p in model_ref.parameters()]
439+
opt_ref.step()
440+
opt_ref.zero_grad()
441+
442+
train_step_copy = TrainStepModule(model_copy, get_loss)
443+
pab = _get_params_and_buffers(train_step_copy)
444+
wrapped = run_traced_module(traced_result, pab, (*fwd_args, labels))
445+
loss_tr = wrapped[0]
446+
grads_tr = wrapped[1:]
447+
for p, g in zip(model_copy.parameters(), grads_tr, strict=True):
448+
p.grad = g
449+
opt_copy.step()
450+
opt_copy.zero_grad()
451+
452+
self.assertTrue(
453+
torch.equal(loss_ref, loss_tr), f"Step {step}: loss mismatch"
454+
)
455+
for gr, gt in zip(grads_ref, grads_tr, strict=True):
456+
self.assertTrue(torch.equal(gr, gt), f"Step {step}: grad mismatch")
457+
458+
def _run_model_test(self, config_cls, model_config, use_attn_masks=False):
459+
vocab_size = model_config.vocab_size
460+
model_ref = create_model(config_cls, model_config, self.DEVICE, self.DTYPE)
461+
model_copy = create_model(config_cls, model_config, self.DEVICE, self.DTYPE)
462+
model_copy.load_state_dict(model_ref.state_dict())
463+
tokens = torch.randint(
464+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
465+
)
466+
labels = torch.randint(
467+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
468+
)
469+
470+
if use_attn_masks:
471+
from torchtitan.models.common.attention import (
472+
create_attention_mask,
473+
get_causal_mask_mod,
474+
)
475+
476+
attn_masks = create_attention_mask(
477+
get_causal_mask_mod(), 1, None, self.SEQ_LEN, self.SEQ_LEN
478+
)
479+
self._run_bitwise_test(
480+
model_ref,
481+
model_copy,
482+
(tokens, attn_masks),
483+
labels,
484+
num_steps=self.NUM_STEPS,
485+
lr=self.LR,
486+
)
487+
return
488+
489+
self._run_bitwise_test(
490+
model_ref,
491+
model_copy,
492+
(tokens,),
493+
labels,
494+
num_steps=self.NUM_STEPS,
495+
lr=self.LR,
496+
)
497+
498+
def test_llama3(self):
499+
from torchtitan.models.llama3 import llama3_configs, Llama3Model
500+
501+
self._run_model_test(Llama3Model, llama3_configs["debugmodel"])
502+
503+
def test_qwen3(self):
504+
from torchtitan.models.qwen3 import qwen3_configs
505+
from torchtitan.models.qwen3.model import Qwen3Model
506+
507+
self._run_model_test(Qwen3Model, qwen3_configs["debugmodel"])
508+
509+
def test_qwen3_moe(self):
510+
from torchtitan.models.qwen3 import qwen3_configs
511+
from torchtitan.models.qwen3.model import Qwen3Model
512+
513+
self._run_model_test(Qwen3Model, qwen3_configs["debugmodel_moe"])
514+
515+
def test_deepseek_v3(self):
516+
from torchtitan.models.deepseek_v3 import deepseekv3_configs
517+
from torchtitan.models.deepseek_v3.model import DeepSeekV3Model
518+
519+
self._run_model_test(DeepSeekV3Model, deepseekv3_configs["debugmodel"])
520+
521+
def test_llama4(self):
522+
from torchtitan.models.llama4 import llama4_configs
523+
from torchtitan.models.llama4.model import Llama4Model
524+
525+
self._run_model_test(
526+
Llama4Model, llama4_configs["debugmodel"], use_attn_masks=True
527+
)
528+
529+
def test_gpt_oss(self):
530+
from torch.nn.attention.flex_attention import and_masks
531+
532+
from torchtitan.models.common.attention import (
533+
create_attention_mask,
534+
get_causal_mask_mod,
535+
get_sliding_window_mask_mod,
536+
)
537+
from torchtitan.models.gpt_oss import gptoss_configs
538+
from torchtitan.models.gpt_oss.model import GptOssModel
539+
540+
config = gptoss_configs["debugmodel"]
541+
vocab_size = config.vocab_size
542+
model_ref = create_model(GptOssModel, config, self.DEVICE, self.DTYPE)
543+
model_copy = create_model(GptOssModel, config, self.DEVICE, self.DTYPE)
544+
model_copy.load_state_dict(model_ref.state_dict())
545+
tokens = torch.randint(
546+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
547+
)
548+
labels = torch.randint(
549+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
550+
)
551+
causal = get_causal_mask_mod()
552+
sw_size = config.layer.attention.sliding_window_size
553+
basic_mask = create_attention_mask(causal, 1, None, self.SEQ_LEN, self.SEQ_LEN)
554+
sliding_window_mask = create_attention_mask(
555+
and_masks(causal, get_sliding_window_mask_mod(sw_size)),
556+
1,
557+
None,
558+
self.SEQ_LEN,
559+
self.SEQ_LEN,
560+
)
561+
attn_masks = {
562+
"basic_mask": basic_mask,
563+
"sliding_window_mask": sliding_window_mask,
564+
}
565+
self._run_bitwise_test(
566+
model_ref,
567+
model_copy,
568+
(tokens, attn_masks),
569+
labels,
570+
num_steps=self.NUM_STEPS,
571+
lr=self.LR,
572+
)
573+
574+
355575
if __name__ == "__main__":
356576
unittest.main()

torchtitan/models/common/moe/kernels.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _fill_indices_kernel(
6868
# ==============
6969

7070

71-
def fill_indices_wrapper(
71+
def _fill_indices_impl(
7272
tokens_per_expert_group: torch.Tensor,
7373
start_index_values: torch.Tensor,
7474
write_offsets: torch.Tensor,
@@ -77,7 +77,7 @@ def fill_indices_wrapper(
7777
max_len: int,
7878
block_size: int = 128,
7979
max_blocks: int = 1024, # cap on total number of blocks to launch
80-
):
80+
) -> torch.Tensor:
8181
# preallocate output
8282
permuted_indices = torch.full(
8383
(max_len,), -1, dtype=torch.int64, device=tokens_per_expert_group.device
@@ -104,6 +104,45 @@ def fill_indices_wrapper(
104104
return permuted_indices
105105

106106

107+
@torch.library.custom_op("torchtitan::fill_indices", mutates_args=())
108+
def fill_indices_wrapper(
109+
tokens_per_expert_group: torch.Tensor,
110+
start_index_values: torch.Tensor,
111+
write_offsets: torch.Tensor,
112+
experts_per_rank: int,
113+
num_ranks: int,
114+
max_len: int,
115+
block_size: int = 128,
116+
max_blocks: int = 1024,
117+
) -> torch.Tensor:
118+
return _fill_indices_impl(
119+
tokens_per_expert_group,
120+
start_index_values,
121+
write_offsets,
122+
experts_per_rank,
123+
num_ranks,
124+
max_len,
125+
block_size,
126+
max_blocks,
127+
)
128+
129+
130+
@fill_indices_wrapper.register_fake
131+
def _fill_indices_fake(
132+
tokens_per_expert_group: torch.Tensor,
133+
start_index_values: torch.Tensor,
134+
write_offsets: torch.Tensor,
135+
experts_per_rank: int,
136+
num_ranks: int,
137+
max_len: int,
138+
block_size: int = 128,
139+
max_blocks: int = 1024,
140+
) -> torch.Tensor:
141+
return torch.empty(
142+
max_len, dtype=torch.int64, device=tokens_per_expert_group.device
143+
)
144+
145+
107146
# reference
108147
def fill_indices_cpu(
109148
tokens_per_expert_group: torch.Tensor,

0 commit comments

Comments
 (0)