|
4 | 4 | # This source code is licensed under the BSD-style license found in the |
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
| 7 | +import contextlib |
7 | 8 | import unittest |
8 | 9 |
|
9 | 10 | import torch |
@@ -255,5 +256,224 @@ def test_dtensor_train_step(self): |
255 | 256 | self.assertTrue(torch.equal(gr.full_tensor(), gt.full_tensor())) |
256 | 257 |
|
257 | 258 |
|
| 259 | +@contextlib.contextmanager |
| 260 | +def _use_raw_flex_attn(): |
| 261 | + """Swap the compiled flex_attention with the raw (uncompiled) version. |
| 262 | +
|
| 263 | + FlexAttentionWrapper uses torch.compile'd flex_attention by default. |
| 264 | + torch.compile inside make_fx tracing is not supported and raises: |
| 265 | + "Detected that you are using FX to symbolically trace a |
| 266 | + dynamo-optimized function." |
| 267 | + Using the raw version lets make_fx decompose flex_attention into |
| 268 | + plain aten ops (bmm, softmax, etc.) which trace correctly. |
| 269 | +
|
| 270 | + Note: make_fx(..., pre_dispatch=True) with raw flex_attention preserves |
| 271 | + it as a FlexAttentionHOP higher-order op in the graph instead of |
| 272 | + decomposing it, which is what torch.export also does. |
| 273 | + """ |
| 274 | + from torch.nn.attention.flex_attention import flex_attention as raw_flex_attention |
| 275 | + |
| 276 | + from torchtitan.models.common.attention import FlexAttentionWrapper |
| 277 | + |
| 278 | + original = FlexAttentionWrapper._compiled_flex_attn |
| 279 | + FlexAttentionWrapper._compiled_flex_attn = staticmethod(raw_flex_attention) |
| 280 | + try: |
| 281 | + yield |
| 282 | + finally: |
| 283 | + FlexAttentionWrapper._compiled_flex_attn = original |
| 284 | + |
| 285 | + |
| 286 | +@unittest.skipUnless(torch.cuda.is_available(), "CUDA required") |
| 287 | +class TestTraceModels(unittest.TestCase): |
| 288 | + DEVICE = "cuda" |
| 289 | + DTYPE = torch.float32 |
| 290 | + BATCH_SIZE = 2 |
| 291 | + SEQ_LEN = 128 |
| 292 | + NUM_STEPS = 5 |
| 293 | + LR = 1e-3 |
| 294 | + |
| 295 | + def setUp(self): |
| 296 | + torch.manual_seed(42) |
| 297 | + torch.use_deterministic_algorithms(True) |
| 298 | + |
| 299 | + def tearDown(self): |
| 300 | + torch.use_deterministic_algorithms(False) |
| 301 | + |
| 302 | + def _run_bitwise_test( |
| 303 | + self, |
| 304 | + model_ref, |
| 305 | + model_copy, |
| 306 | + fwd_args, |
| 307 | + labels, |
| 308 | + check_collective_ops=False, |
| 309 | + num_steps=5, |
| 310 | + lr=1e-3, |
| 311 | + ): |
| 312 | + train_step_ref = TrainStepModule(model_ref, get_loss) |
| 313 | + |
| 314 | + with _use_raw_flex_attn(): |
| 315 | + traced_result = trace_module(train_step_ref, (*fwd_args, labels)) |
| 316 | + |
| 317 | + if check_collective_ops: |
| 318 | + ag = sum( |
| 319 | + 1 |
| 320 | + for n in traced_result.gm.graph.nodes |
| 321 | + if "all_gather_into_tensor" in str(n.target) |
| 322 | + ) |
| 323 | + rs = sum( |
| 324 | + 1 |
| 325 | + for n in traced_result.gm.graph.nodes |
| 326 | + if "reduce_scatter_tensor" in str(n.target) |
| 327 | + ) |
| 328 | + self.assertTrue( |
| 329 | + ag > 0 and rs > 0, |
| 330 | + f"Expected collective ops in FSDP graph (ag={ag}, rs={rs})", |
| 331 | + ) |
| 332 | + |
| 333 | + opt_ref = torch.optim.Adam(model_ref.parameters(), lr=lr) |
| 334 | + opt_copy = torch.optim.Adam(model_copy.parameters(), lr=lr) |
| 335 | + |
| 336 | + for step in range(1, num_steps + 1): |
| 337 | + with _use_raw_flex_attn(): |
| 338 | + logits_ref = model_ref(*fwd_args) |
| 339 | + loss_ref = get_loss(logits_ref, labels) |
| 340 | + loss_ref.backward() |
| 341 | + grads_ref = [p.grad.clone() for p in model_ref.parameters()] |
| 342 | + opt_ref.step() |
| 343 | + opt_ref.zero_grad() |
| 344 | + |
| 345 | + train_step_copy = TrainStepModule(model_copy, get_loss) |
| 346 | + pab = _get_params_and_buffers(train_step_copy) |
| 347 | + wrapped = run_traced_module(traced_result, pab, (*fwd_args, labels)) |
| 348 | + loss_tr = wrapped[0] |
| 349 | + grads_tr = wrapped[1:] |
| 350 | + for p, g in zip(model_copy.parameters(), grads_tr, strict=True): |
| 351 | + p.grad = g |
| 352 | + opt_copy.step() |
| 353 | + opt_copy.zero_grad() |
| 354 | + |
| 355 | + self.assertTrue( |
| 356 | + torch.equal(loss_ref, loss_tr), f"Step {step}: loss mismatch" |
| 357 | + ) |
| 358 | + for gr, gt in zip(grads_ref, grads_tr, strict=True): |
| 359 | + self.assertTrue(torch.equal(gr, gt), f"Step {step}: grad mismatch") |
| 360 | + |
| 361 | + def _run_model_test(self, config_cls, model_config, use_attn_masks=False): |
| 362 | + vocab_size = model_config.vocab_size |
| 363 | + model_ref = create_model(config_cls, model_config, self.DEVICE, self.DTYPE) |
| 364 | + model_copy = create_model(config_cls, model_config, self.DEVICE, self.DTYPE) |
| 365 | + model_copy.load_state_dict(model_ref.state_dict()) |
| 366 | + tokens = torch.randint( |
| 367 | + 0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE |
| 368 | + ) |
| 369 | + labels = torch.randint( |
| 370 | + 0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE |
| 371 | + ) |
| 372 | + |
| 373 | + if use_attn_masks: |
| 374 | + from torchtitan.models.common.attention import ( |
| 375 | + create_attention_mask, |
| 376 | + get_causal_mask_mod, |
| 377 | + ) |
| 378 | + |
| 379 | + attn_masks = create_attention_mask( |
| 380 | + get_causal_mask_mod(), 1, None, self.SEQ_LEN, self.SEQ_LEN |
| 381 | + ) |
| 382 | + self._run_bitwise_test( |
| 383 | + model_ref, |
| 384 | + model_copy, |
| 385 | + (tokens, attn_masks), |
| 386 | + labels, |
| 387 | + num_steps=self.NUM_STEPS, |
| 388 | + lr=self.LR, |
| 389 | + ) |
| 390 | + return |
| 391 | + |
| 392 | + self._run_bitwise_test( |
| 393 | + model_ref, |
| 394 | + model_copy, |
| 395 | + (tokens,), |
| 396 | + labels, |
| 397 | + num_steps=self.NUM_STEPS, |
| 398 | + lr=self.LR, |
| 399 | + ) |
| 400 | + |
| 401 | + def test_llama3(self): |
| 402 | + from torchtitan.models.llama3 import llama3_configs, Llama3Model |
| 403 | + |
| 404 | + self._run_model_test(Llama3Model, llama3_configs["debugmodel"]) |
| 405 | + |
| 406 | + def test_qwen3(self): |
| 407 | + from torchtitan.models.qwen3 import qwen3_configs |
| 408 | + from torchtitan.models.qwen3.model import Qwen3Model |
| 409 | + |
| 410 | + self._run_model_test(Qwen3Model, qwen3_configs["debugmodel"]) |
| 411 | + |
| 412 | + def test_qwen3_moe(self): |
| 413 | + from torchtitan.models.qwen3 import qwen3_configs |
| 414 | + from torchtitan.models.qwen3.model import Qwen3Model |
| 415 | + |
| 416 | + self._run_model_test(Qwen3Model, qwen3_configs["debugmodel_moe"]) |
| 417 | + |
| 418 | + def test_deepseek_v3(self): |
| 419 | + from torchtitan.models.deepseek_v3 import deepseekv3_configs |
| 420 | + from torchtitan.models.deepseek_v3.model import DeepSeekV3Model |
| 421 | + |
| 422 | + self._run_model_test(DeepSeekV3Model, deepseekv3_configs["debugmodel"]) |
| 423 | + |
| 424 | + def test_llama4(self): |
| 425 | + from torchtitan.models.llama4 import llama4_configs |
| 426 | + from torchtitan.models.llama4.model import Llama4Model |
| 427 | + |
| 428 | + self._run_model_test( |
| 429 | + Llama4Model, llama4_configs["debugmodel"], use_attn_masks=True |
| 430 | + ) |
| 431 | + |
| 432 | + def test_gpt_oss(self): |
| 433 | + from torch.nn.attention.flex_attention import and_masks |
| 434 | + |
| 435 | + from torchtitan.models.common.attention import ( |
| 436 | + create_attention_mask, |
| 437 | + get_causal_mask_mod, |
| 438 | + get_sliding_window_mask_mod, |
| 439 | + ) |
| 440 | + from torchtitan.models.gpt_oss import gptoss_configs |
| 441 | + from torchtitan.models.gpt_oss.model import GptOssModel |
| 442 | + |
| 443 | + config = gptoss_configs["debugmodel"] |
| 444 | + vocab_size = config.vocab_size |
| 445 | + model_ref = create_model(GptOssModel, config, self.DEVICE, self.DTYPE) |
| 446 | + model_copy = create_model(GptOssModel, config, self.DEVICE, self.DTYPE) |
| 447 | + model_copy.load_state_dict(model_ref.state_dict()) |
| 448 | + tokens = torch.randint( |
| 449 | + 0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE |
| 450 | + ) |
| 451 | + labels = torch.randint( |
| 452 | + 0, vocab_size, (self.BATCH_SIZE, self.SEQ_LEN), device=self.DEVICE |
| 453 | + ) |
| 454 | + causal = get_causal_mask_mod() |
| 455 | + sw_size = config.layer.attention.sliding_window_size |
| 456 | + basic_mask = create_attention_mask(causal, 1, None, self.SEQ_LEN, self.SEQ_LEN) |
| 457 | + sliding_window_mask = create_attention_mask( |
| 458 | + and_masks(causal, get_sliding_window_mask_mod(sw_size)), |
| 459 | + 1, |
| 460 | + None, |
| 461 | + self.SEQ_LEN, |
| 462 | + self.SEQ_LEN, |
| 463 | + ) |
| 464 | + attn_masks = { |
| 465 | + "basic_mask": basic_mask, |
| 466 | + "sliding_window_mask": sliding_window_mask, |
| 467 | + } |
| 468 | + self._run_bitwise_test( |
| 469 | + model_ref, |
| 470 | + model_copy, |
| 471 | + (tokens, attn_masks), |
| 472 | + labels, |
| 473 | + num_steps=self.NUM_STEPS, |
| 474 | + lr=self.LR, |
| 475 | + ) |
| 476 | + |
| 477 | + |
258 | 478 | if __name__ == "__main__": |
259 | 479 | unittest.main() |
0 commit comments