Skip to content

Commit 9802f8a

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Fix skipped tests in MultiRankDMPDynamicShardingTest
Summary: Fix skipped tests. See more context in D78355780. https://www.internalfb.com/intern/test/562950182314458?ref_report_id=0 Differential Revision: D78583353
1 parent 411f71b commit 9802f8a

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

torchrec/distributed/tests/test_dynamic_sharding.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -583,11 +583,10 @@ def test_sharding(
583583
"""
584584
Tests resharding from DMP module interface, rather than EBC level.
585585
"""
586-
if (
587-
self.device == torch.device("cpu")
588-
and kernel_type != EmbeddingComputeKernel.FUSED.value
589-
):
590-
self.skipTest("CPU does not support uvm.")
586+
assume(
587+
self.device != torch.device("cpu")
588+
or kernel_type == EmbeddingComputeKernel.FUSED.value
589+
)
591590

592591
assume(
593592
sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value

0 commit comments

Comments
 (0)