Skip to content

Commit 058d7f6

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Fix skipped tests in test_model_parallel.py (#3203)
Summary: Pull Request resolved: #3203 Fix skipped tests. See more context in D78355780. https://www.internalfb.com/intern/test/844425010918803?ref_report_id=0 Reviewed By: jeffkbkim Differential Revision: D78461387 fbshipit-source-id: cfb043f4c15583006ce0dd0d3ea2038d4547be4c
1 parent a2fdb42 commit 058d7f6

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

torchrec/distributed/test_utils/test_model_parallel.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -467,11 +467,10 @@ def test_sharding_cw(
467467
data_type: DataType,
468468
allow_zero_batch_size: bool,
469469
) -> None:
470-
if (
471-
self.device == torch.device("cpu")
472-
and kernel_type != EmbeddingComputeKernel.FUSED.value
473-
):
474-
self.skipTest("CPU does not support uvm.")
470+
assume(
471+
self.device != torch.device("cpu")
472+
or kernel_type == EmbeddingComputeKernel.FUSED.value
473+
)
475474

476475
sharding_type = ShardingType.COLUMN_WISE.value
477476
assume(
@@ -548,11 +547,10 @@ def test_sharding_twcw(
548547
variable_batch_size: bool,
549548
data_type: DataType,
550549
) -> None:
551-
if (
552-
self.device == torch.device("cpu")
553-
and kernel_type != EmbeddingComputeKernel.FUSED.value
554-
):
555-
self.skipTest("CPU does not support uvm.")
550+
assume(
551+
self.device != torch.device("cpu")
552+
or kernel_type == EmbeddingComputeKernel.FUSED.value
553+
)
556554

557555
sharding_type = ShardingType.TABLE_COLUMN_WISE.value
558556
assume(
@@ -629,11 +627,10 @@ def test_sharding_tw(
629627
variable_batch_size: bool,
630628
data_type: DataType,
631629
) -> None:
632-
if (
633-
self.device == torch.device("cpu")
634-
and kernel_type != EmbeddingComputeKernel.FUSED.value
635-
):
636-
self.skipTest("CPU does not support uvm.")
630+
assume(
631+
self.device != torch.device("cpu")
632+
or kernel_type == EmbeddingComputeKernel.FUSED.value
633+
)
637634

638635
sharding_type = ShardingType.TABLE_WISE.value
639636
assume(

0 commit comments

Comments
 (0)