diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index eddf15faa..f421ae1bb 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -583,11 +583,10 @@ def test_sharding( """ Tests resharding from DMP module interface, rather than EBC level. """ - if ( - self.device == torch.device("cpu") - and kernel_type != EmbeddingComputeKernel.FUSED.value - ): - self.skipTest("CPU does not support uvm.") + assume( + self.device != torch.device("cpu") + or kernel_type == EmbeddingComputeKernel.FUSED.value + ) assume( sharder_type == SharderType.EMBEDDING_BAG_COLLECTION.value