@@ -467,11 +467,10 @@ def test_sharding_cw(
467
467
data_type : DataType ,
468
468
allow_zero_batch_size : bool ,
469
469
) -> None :
470
- if (
471
- self .device == torch .device ("cpu" )
472
- and kernel_type != EmbeddingComputeKernel .FUSED .value
473
- ):
474
- self .skipTest ("CPU does not support uvm." )
470
+ assume (
471
+ self .device != torch .device ("cpu" )
472
+ or kernel_type == EmbeddingComputeKernel .FUSED .value
473
+ )
475
474
476
475
sharding_type = ShardingType .COLUMN_WISE .value
477
476
assume (
@@ -548,11 +547,10 @@ def test_sharding_twcw(
548
547
variable_batch_size : bool ,
549
548
data_type : DataType ,
550
549
) -> None :
551
- if (
552
- self .device == torch .device ("cpu" )
553
- and kernel_type != EmbeddingComputeKernel .FUSED .value
554
- ):
555
- self .skipTest ("CPU does not support uvm." )
550
+ assume (
551
+ self .device != torch .device ("cpu" )
552
+ or kernel_type == EmbeddingComputeKernel .FUSED .value
553
+ )
556
554
557
555
sharding_type = ShardingType .TABLE_COLUMN_WISE .value
558
556
assume (
@@ -629,11 +627,10 @@ def test_sharding_tw(
629
627
variable_batch_size : bool ,
630
628
data_type : DataType ,
631
629
) -> None :
632
- if (
633
- self .device == torch .device ("cpu" )
634
- and kernel_type != EmbeddingComputeKernel .FUSED .value
635
- ):
636
- self .skipTest ("CPU does not support uvm." )
630
+ assume (
631
+ self .device != torch .device ("cpu" )
632
+ or kernel_type == EmbeddingComputeKernel .FUSED .value
633
+ )
637
634
638
635
sharding_type = ShardingType .TABLE_WISE .value
639
636
assume (
0 commit comments