Skip to content

Commit 6adec7b

Browse files
committed
Make MoE models non-strict tracing friendly
1 parent 6ad427e commit 6adec7b

File tree

2 files changed

+261
-2
lines changed

2 files changed

+261
-2
lines changed

torchtitan/experiments/graph_trainer/tests/test_trace_module.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import contextlib
78
import unittest
89

910
import torch
@@ -255,5 +256,224 @@ def test_dtensor_train_step(self):
255256
self.assertTrue(torch.equal(gr.full_tensor(), gt.full_tensor()))
256257

257258

259+
@contextlib.contextmanager
260+
def _use_raw_flex_attn():
261+
"""Swap the compiled flex_attention with the raw (uncompiled) version.
262+
263+
FlexAttentionWrapper uses torch.compile'd flex_attention by default.
264+
torch.compile inside make_fx tracing is not supported and raises:
265+
"Detected that you are using FX to symbolically trace a
266+
dynamo-optimized function."
267+
Using the raw version lets make_fx decompose flex_attention into
268+
plain aten ops (bmm, softmax, etc.) which trace correctly.
269+
270+
Note: make_fx(..., pre_dispatch=True) with raw flex_attention preserves
271+
it as a FlexAttentionHOP higher-order op in the graph instead of
272+
decomposing it, which is what torch.export also does.
273+
"""
274+
from torch.nn.attention.flex_attention import flex_attention as raw_flex_attention
275+
276+
from torchtitan.models.common.attention import FlexAttentionWrapper
277+
278+
original = FlexAttentionWrapper._compiled_flex_attn
279+
FlexAttentionWrapper._compiled_flex_attn = staticmethod(raw_flex_attention)
280+
try:
281+
yield
282+
finally:
283+
FlexAttentionWrapper._compiled_flex_attn = original
284+
285+
286+
@unittest.skipUnless(torch.cuda.is_available(), "CUDA required")
287+
class TestTraceModels(unittest.TestCase):
288+
DEVICE = "cuda"
289+
DTYPE = torch.float32
290+
BATCH_SIZE = 2
291+
SEQ_LEN = 128
292+
NUM_STEPS = 5
293+
LR = 1e-3
294+
295+
def setUp(self):
296+
torch.manual_seed(42)
297+
torch.use_deterministic_algorithms(True)
298+
299+
def tearDown(self):
300+
torch.use_deterministic_algorithms(False)
301+
302+
def _run_bitwise_test(
303+
self,
304+
model_ref,
305+
model_copy,
306+
fwd_args,
307+
labels,
308+
check_collective_ops=False,
309+
num_steps=5,
310+
lr=1e-3,
311+
):
312+
train_step_ref = TrainStepModule(model_ref, get_loss)
313+
314+
with _use_raw_flex_attn():
315+
traced_result = trace_module(train_step_ref, (*fwd_args, labels))
316+
317+
if check_collective_ops:
318+
ag = sum(
319+
1
320+
for n in traced_result.gm.graph.nodes
321+
if "all_gather_into_tensor" in str(n.target)
322+
)
323+
rs = sum(
324+
1
325+
for n in traced_result.gm.graph.nodes
326+
if "reduce_scatter_tensor" in str(n.target)
327+
)
328+
self.assertTrue(
329+
ag > 0 and rs > 0,
330+
f"Expected collective ops in FSDP graph (ag={ag}, rs={rs})",
331+
)
332+
333+
opt_ref = torch.optim.Adam(model_ref.parameters(), lr=lr)
334+
opt_copy = torch.optim.Adam(model_copy.parameters(), lr=lr)
335+
336+
for step in range(1, num_steps + 1):
337+
with _use_raw_flex_attn():
338+
logits_ref = model_ref(*fwd_args)
339+
loss_ref = get_loss(logits_ref, labels)
340+
loss_ref.backward()
341+
grads_ref = [p.grad.clone() for p in model_ref.parameters()]
342+
opt_ref.step()
343+
opt_ref.zero_grad()
344+
345+
train_step_copy = TrainStepModule(model_copy, get_loss)
346+
pab = _get_params_and_buffers(train_step_copy)
347+
wrapped = run_traced_module(traced_result, pab, (*fwd_args, labels))
348+
loss_tr = wrapped[0]
349+
grads_tr = wrapped[1:]
350+
for p, g in zip(model_copy.parameters(), grads_tr, strict=True):
351+
p.grad = g
352+
opt_copy.step()
353+
opt_copy.zero_grad()
354+
355+
self.assertTrue(
356+
torch.equal(loss_ref, loss_tr), f"Step {step}: loss mismatch"
357+
)
358+
for gr, gt in zip(grads_ref, grads_tr, strict=True):
359+
self.assertTrue(torch.equal(gr, gt), f"Step {step}: grad mismatch")
360+
361+
def _run_model_test(self, config_cls, model_config, use_attn_masks=False):
362+
vocab_size = model_config.vocab_size
363+
model_ref = create_model(config_cls, model_config, self.DEVICE, self.DTYPE)
364+
model_copy = create_model(config_cls, model_config, self.DEVICE, self.DTYPE)
365+
model_copy.load_state_dict(model_ref.state_dict())
366+
tokens = torch.randint(
367+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
368+
)
369+
labels = torch.randint(
370+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
371+
)
372+
373+
if use_attn_masks:
374+
from torchtitan.models.common.attention import (
375+
create_attention_mask,
376+
get_causal_mask_mod,
377+
)
378+
379+
attn_masks = create_attention_mask(
380+
get_causal_mask_mod(), 1, None, self.SEQ_LEN, self.SEQ_LEN
381+
)
382+
self._run_bitwise_test(
383+
model_ref,
384+
model_copy,
385+
(tokens, attn_masks),
386+
labels,
387+
num_steps=self.NUM_STEPS,
388+
lr=self.LR,
389+
)
390+
return
391+
392+
self._run_bitwise_test(
393+
model_ref,
394+
model_copy,
395+
(tokens,),
396+
labels,
397+
num_steps=self.NUM_STEPS,
398+
lr=self.LR,
399+
)
400+
401+
def test_llama3(self):
402+
from torchtitan.models.llama3 import llama3_configs, Llama3Model
403+
404+
self._run_model_test(Llama3Model, llama3_configs["debugmodel"])
405+
406+
def test_qwen3(self):
407+
from torchtitan.models.qwen3 import qwen3_configs
408+
from torchtitan.models.qwen3.model import Qwen3Model
409+
410+
self._run_model_test(Qwen3Model, qwen3_configs["debugmodel"])
411+
412+
def test_qwen3_moe(self):
413+
from torchtitan.models.qwen3 import qwen3_configs
414+
from torchtitan.models.qwen3.model import Qwen3Model
415+
416+
self._run_model_test(Qwen3Model, qwen3_configs["debugmodel_moe"])
417+
418+
def test_deepseek_v3(self):
419+
from torchtitan.models.deepseek_v3 import deepseekv3_configs
420+
from torchtitan.models.deepseek_v3.model import DeepSeekV3Model
421+
422+
self._run_model_test(DeepSeekV3Model, deepseekv3_configs["debugmodel"])
423+
424+
def test_llama4(self):
425+
from torchtitan.models.llama4 import llama4_configs
426+
from torchtitan.models.llama4.model import Llama4Model
427+
428+
self._run_model_test(
429+
Llama4Model, llama4_configs["debugmodel"], use_attn_masks=True
430+
)
431+
432+
def test_gpt_oss(self):
433+
from torch.nn.attention.flex_attention import and_masks
434+
435+
from torchtitan.models.common.attention import (
436+
create_attention_mask,
437+
get_causal_mask_mod,
438+
get_sliding_window_mask_mod,
439+
)
440+
from torchtitan.models.gpt_oss import gptoss_configs
441+
from torchtitan.models.gpt_oss.model import GptOssModel
442+
443+
config = gptoss_configs["debugmodel"]
444+
vocab_size = config.vocab_size
445+
model_ref = create_model(GptOssModel, config, self.DEVICE, self.DTYPE)
446+
model_copy = create_model(GptOssModel, config, self.DEVICE, self.DTYPE)
447+
model_copy.load_state_dict(model_ref.state_dict())
448+
tokens = torch.randint(
449+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
450+
)
451+
labels = torch.randint(
452+
0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE
453+
)
454+
causal = get_causal_mask_mod()
455+
sw_size = config.layer.attention.sliding_window_size
456+
basic_mask = create_attention_mask(causal, 1, None, self.SEQ_LEN, self.SEQ_LEN)
457+
sliding_window_mask = create_attention_mask(
458+
and_masks(causal, get_sliding_window_mask_mod(sw_size)),
459+
1,
460+
None,
461+
self.SEQ_LEN,
462+
self.SEQ_LEN,
463+
)
464+
attn_masks = {
465+
"basic_mask": basic_mask,
466+
"sliding_window_mask": sliding_window_mask,
467+
}
468+
self._run_bitwise_test(
469+
model_ref,
470+
model_copy,
471+
(tokens, attn_masks),
472+
labels,
473+
num_steps=self.NUM_STEPS,
474+
lr=self.LR,
475+
)
476+
477+
258478
if __name__ == "__main__":
259479
unittest.main()

torchtitan/models/common/moe/kernels.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _fill_indices_kernel(
6868
# ==============
6969

7070

71-
def fill_indices_wrapper(
71+
def _fill_indices_impl(
7272
tokens_per_expert_group: torch.Tensor,
7373
start_index_values: torch.Tensor,
7474
write_offsets: torch.Tensor,
@@ -77,7 +77,7 @@ def fill_indices_wrapper(
7777
max_len: int,
7878
block_size: int = 128,
7979
max_blocks: int = 1024, # cap on total number of blocks to launch
80-
):
80+
) -> torch.Tensor:
8181
# preallocate output
8282
permuted_indices = torch.full(
8383
(max_len,), -1, dtype=torch.int64, device=tokens_per_expert_group.device
@@ -104,6 +104,45 @@ def fill_indices_wrapper(
104104
return permuted_indices
105105

106106

107+
@torch.library.custom_op("torchtitan::fill_indices", mutates_args=())
108+
def fill_indices_wrapper(
109+
tokens_per_expert_group: torch.Tensor,
110+
start_index_values: torch.Tensor,
111+
write_offsets: torch.Tensor,
112+
experts_per_rank: int,
113+
num_ranks: int,
114+
max_len: int,
115+
block_size: int = 128,
116+
max_blocks: int = 1024,
117+
) -> torch.Tensor:
118+
return _fill_indices_impl(
119+
tokens_per_expert_group,
120+
start_index_values,
121+
write_offsets,
122+
experts_per_rank,
123+
num_ranks,
124+
max_len,
125+
block_size,
126+
max_blocks,
127+
)
128+
129+
130+
@fill_indices_wrapper.register_fake
131+
def _fill_indices_fake(
132+
tokens_per_expert_group: torch.Tensor,
133+
start_index_values: torch.Tensor,
134+
write_offsets: torch.Tensor,
135+
experts_per_rank: int,
136+
num_ranks: int,
137+
max_len: int,
138+
block_size: int = 128,
139+
max_blocks: int = 1024,
140+
) -> torch.Tensor:
141+
return torch.empty(
142+
max_len, dtype=torch.int64, device=tokens_per_expert_group.device
143+
)
144+
145+
107146
# reference
108147
def fill_indices_cpu(
109148
tokens_per_expert_group: torch.Tensor,

0 commit comments

Comments
 (0)