@@ -309,7 +309,6 @@ def is_inplace(op, variant):
309
309
skip ('pca_lowrank' , '' ), # fails on cuda, runs okay on cpu
310
310
skip ('svd_lowrank' , '' ), # fails on cuda, runs okay on cpu
311
311
skip ('nn.functional.dropout2d' , '' ), # fails on cuda, runs okay on cpu
312
- xfail ('__getitem__' , device_type = 'cuda' ),
313
312
}
314
313
315
314
@@ -318,18 +317,6 @@ class TestOperators(TestCase):
318
317
@skipOps ('TestOperators' , 'test_grad' , vjp_fail .union ({
319
318
skip ('nn.functional.fractional_max_pool2d' ), # fails on cuda, runs okay on cpu
320
319
skip ('nn.functional.fractional_max_pool3d' ), # fails on cuda, runs okay on cpu
321
- xfail ('__getitem__' , 'functorch' , device_type = 'cuda' ),
322
- xfail ('_masked.amax' , device_type = 'cuda' ),
323
- xfail ('_masked.amin' , device_type = 'cuda' ),
324
- xfail ('_masked.log_softmax' , device_type = 'cuda' ),
325
- xfail ('_masked.mean' , device_type = 'cuda' ),
326
- xfail ('_masked.norm' , device_type = 'cuda' ),
327
- xfail ('_masked.prod' , device_type = 'cuda' ),
328
- xfail ('_masked.softmax' , device_type = 'cuda' ),
329
- xfail ('_masked.softmin' , device_type = 'cuda' ),
330
- xfail ('_masked.std' , device_type = 'cuda' ),
331
- xfail ('_masked.sum' , device_type = 'cuda' ),
332
- xfail ('_masked.var' , device_type = 'cuda' ),
333
320
}))
334
321
@opsToleranceOverride ('TestOperators' , 'test_grad' , (
335
322
tol1 ('nn.functional.binary_cross_entropy_with_logits' ,
@@ -409,16 +396,6 @@ def wrapped_fn(*args, **kwargs):
409
396
skip ('nn.functional.max_unpool1d' ), # fails everywhere except on mac
410
397
skip ('nn.functional.max_unpool2d' ), # fails everywhere except on windows
411
398
xfail ('nn.functional.max_unpool3d' ),
412
- xfail ('__getitem__' , device_type = 'cuda' ),
413
- xfail ('_masked.log_softmax' , device_type = 'cuda' ),
414
- xfail ('_masked.mean' , device_type = 'cuda' ),
415
- xfail ('_masked.norm' , device_type = 'cuda' ),
416
- xfail ('_masked.prod' , device_type = 'cuda' ),
417
- xfail ('_masked.softmax' , device_type = 'cuda' ),
418
- xfail ('_masked.softmin' , device_type = 'cuda' ),
419
- xfail ('_masked.std' , device_type = 'cuda' ),
420
- xfail ('_masked.sum' , device_type = 'cuda' ),
421
- xfail ('_masked.var' , device_type = 'cuda' ),
422
399
}))
423
400
@opsToleranceOverride ('TestOperators' , 'test_jvp' , (
424
401
tol1 ('nn.functional.conv_transpose3d' ,
@@ -466,19 +443,6 @@ def test_jvp(self, device, dtype, op):
466
443
xfail ('nn.functional.dropout2d' , '' ),
467
444
xfail ('nn.functional.feature_alpha_dropout' , 'without_train' ),
468
445
xfail ('svd_lowrank' , '' ),
469
-
470
- xfail ('__getitem__' , 'functorch' , device_type = 'cuda' ),
471
- xfail ('_masked.amax' , device_type = 'cuda' ),
472
- xfail ('_masked.amin' , device_type = 'cuda' ),
473
- xfail ('_masked.log_softmax' , device_type = 'cuda' ),
474
- xfail ('_masked.mean' , device_type = 'cuda' ),
475
- xfail ('_masked.norm' , device_type = 'cuda' ),
476
- xfail ('_masked.prod' , device_type = 'cuda' ),
477
- xfail ('_masked.softmax' , device_type = 'cuda' ),
478
- xfail ('_masked.softmin' , device_type = 'cuda' ),
479
- xfail ('_masked.std' , device_type = 'cuda' ),
480
- xfail ('_masked.sum' , device_type = 'cuda' ),
481
- xfail ('_masked.var' , device_type = 'cuda' ),
482
446
}))
483
447
@opsToleranceOverride ('TestOperators' , 'test_vjp' , (
484
448
tol1 ('nn.functional.conv_transpose3d' ,
@@ -524,19 +488,6 @@ def _test(_op):
524
488
skip ('nn.functional.fractional_max_pool2d' ), # randomness
525
489
skip ('nn.functional.fractional_max_pool3d' ), # randomness
526
490
xfail ('nn.functional.binary_cross_entropy' ), # testing problem
527
-
528
- xfail ('__getitem__' , 'functorch' , device_type = 'cuda' ),
529
- xfail ('_masked.amax' , device_type = 'cuda' ),
530
- xfail ('_masked.amin' , device_type = 'cuda' ),
531
- xfail ('_masked.log_softmax' , device_type = 'cuda' ),
532
- xfail ('_masked.mean' , device_type = 'cuda' ),
533
- xfail ('_masked.norm' , device_type = 'cuda' ),
534
- xfail ('_masked.prod' , device_type = 'cuda' ),
535
- xfail ('_masked.softmax' , device_type = 'cuda' ),
536
- xfail ('_masked.softmin' , device_type = 'cuda' ),
537
- xfail ('_masked.std' , device_type = 'cuda' ),
538
- xfail ('_masked.sum' , device_type = 'cuda' ),
539
- xfail ('_masked.var' , device_type = 'cuda' ),
540
491
}))
541
492
@opsToleranceOverride ('TestOperators' , 'test_vjpvjp' , (
542
493
tol1 ('nn.functional.conv_transpose3d' ,
@@ -672,19 +623,6 @@ def vjp_of_vjp(*args_and_cotangents):
672
623
# NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format
673
624
xfail ('nn.functional.max_unpool2d' ),
674
625
xfail ('nn.functional.max_unpool2d' , 'grad' ),
675
-
676
- xfail ('__getitem__' , 'functorch' , device_type = 'cuda' ),
677
- xfail ('_masked.amax' , device_type = 'cuda' ),
678
- xfail ('_masked.amin' , device_type = 'cuda' ),
679
- xfail ('_masked.log_softmax' , device_type = 'cuda' ),
680
- xfail ('_masked.mean' , device_type = 'cuda' ),
681
- xfail ('_masked.norm' , device_type = 'cuda' ),
682
- xfail ('_masked.prod' , device_type = 'cuda' ),
683
- xfail ('_masked.softmax' , device_type = 'cuda' ),
684
- xfail ('_masked.softmin' , device_type = 'cuda' ),
685
- xfail ('_masked.std' , device_type = 'cuda' ),
686
- xfail ('_masked.sum' , device_type = 'cuda' ),
687
- xfail ('_masked.var' , device_type = 'cuda' ),
688
626
})
689
627
690
628
@ops (functorch_lagging_op_db + additional_op_db , allowed_dtypes = (torch .float ,))
@@ -774,18 +712,8 @@ def test_vmapvjp(self, device, dtype, op):
774
712
xfail ('nn.functional.max_unpool2d' ),
775
713
xfail ('nn.functional.max_unpool3d' ),
776
714
777
- xfail ('__getitem__' , device_type = 'cuda' ),
778
- xfail ('_masked.amax' , device_type = 'cuda' ),
779
- xfail ('_masked.amin' , device_type = 'cuda' ),
780
- xfail ('_masked.log_softmax' , device_type = 'cuda' ),
781
- xfail ('_masked.mean' , device_type = 'cuda' ),
782
- xfail ('_masked.norm' , device_type = 'cuda' ),
783
- xfail ('_masked.prod' , device_type = 'cuda' ),
784
- xfail ('_masked.softmax' , device_type = 'cuda' ),
785
- xfail ('_masked.softmin' , device_type = 'cuda' ),
786
- xfail ('_masked.std' , device_type = 'cuda' ),
787
- xfail ('_masked.sum' , device_type = 'cuda' ),
788
- xfail ('_masked.var' , device_type = 'cuda' ),
715
+ xfail ('nn.functional.embedding' ), # embedding_renorm_ does not support fwd AD
716
+ xfail ('put' ), # calls put_ during vmap with only vmaps over other, not self
789
717
})
790
718
def test_vmapjvp (self , device , dtype , op ):
791
719
if is_inplace (op , op .get_op ()):
@@ -820,15 +748,13 @@ def test_vmapjvp(self, device, dtype, op):
820
748
821
749
# The following are bugs that we should fix
822
750
skip ('nn.functional.max_pool1d' ), # fails on cpu, runs on cuda
823
- xfail ('_masked.mean' , device_type = 'cuda' ),
824
- xfail ('_masked.prod' , device_type = 'cuda' ),
825
751
xfail ('nn.functional.batch_norm' , device_type = 'cuda' ),
826
752
xfail ('nn.functional.batch_norm' , 'without_cudnn' , device_type = 'cuda' ),
827
753
xfail ('nn.functional.hinge_embedding_loss' , device_type = 'cuda' ),
754
+ xfail ('_masked.mean' ),
755
+ xfail ('_masked.prod' ),
828
756
829
757
# Causing issues with multiple cpu levels of forward mode AD
830
- xfail ('_masked.mean' , device_type = 'cpu' ),
831
- xfail ('_masked.prod' , device_type = 'cpu' ),
832
758
xfail ('nn.functional.batch_norm' , device_type = 'cpu' ),
833
759
xfail ('nn.functional.hinge_embedding_loss' , device_type = 'cpu' ),
834
760
@@ -863,18 +789,9 @@ def test_vmapjvp(self, device, dtype, op):
863
789
xfail ('nn.functional.max_unpool2d' ),
864
790
xfail ('nn.functional.max_unpool3d' ),
865
791
866
- xfail ('__getitem__' , device_type = 'cuda' ),
867
- xfail ('_masked.amax' , device_type = 'cuda' ),
868
- xfail ('_masked.amin' , device_type = 'cuda' ),
869
- xfail ('_masked.log_softmax' , device_type = 'cuda' ),
870
- xfail ('_masked.mean' , device_type = 'cuda' ),
871
- xfail ('_masked.norm' , device_type = 'cuda' ),
872
- xfail ('_masked.prod' , device_type = 'cuda' ),
873
- xfail ('_masked.softmax' , device_type = 'cuda' ),
874
- xfail ('_masked.softmin' , device_type = 'cuda' ),
875
- xfail ('_masked.std' , device_type = 'cuda' ),
876
- xfail ('_masked.sum' , device_type = 'cuda' ),
877
- xfail ('_masked.var' , device_type = 'cuda' ),
792
+ xfail ('nn.functional.embedding' ), # embedding_renorm_ does not support fwd AD
793
+ xfail ('put' ), # calls put_ during vmap with only vmaps over other, not self
794
+ xfail ('nn.functional.prelu' ), # Call Tensor.as_strided
878
795
}
879
796
880
797
@ops (functorch_lagging_op_db , allowed_dtypes = (torch .float ,))
@@ -962,6 +879,7 @@ def test_vmapjvpall(self, device, dtype, op):
962
879
xfail ('nn.functional.max_unpool1d' , 'grad' ),
963
880
xfail ('lu_unpack' ),
964
881
xfail ('nn.functional.glu' ),
882
+ xfail ('nn.functional.bilinear' ), # trilinear doesn't have batching rule
965
883
}))
966
884
@toleranceOverride ({torch .float32 : tol (atol = 1e-04 , rtol = 1e-04 )})
967
885
def test_vmapjvpall_has_batch_rule (self , device , dtype , op ):
@@ -1222,11 +1140,9 @@ def test_vjpvmap(self, device, dtype, op):
1222
1140
xfail ('nansum' , '' ),
1223
1141
xfail ('nn.functional.batch_norm' , '' ),
1224
1142
xfail ('nn.functional.batch_norm' , 'without_cudnn' , device_type = 'cuda' ),
1225
- xfail ('nn.functional.bilinear' , '' ),
1226
1143
xfail ('nn.functional.embedding' , '' ),
1227
1144
xfail ('nn.functional.embedding' , 'functorch' ),
1228
1145
xfail ('nn.functional.embedding_bag' , '' ),
1229
- xfail ('nn.functional.glu' , '' ),
1230
1146
xfail ('nn.functional.grid_sample' , '' ),
1231
1147
xfail ('nn.functional.hardsigmoid' , '' ),
1232
1148
xfail ('nn.functional.hardswish' , '' ),
@@ -1239,11 +1155,9 @@ def test_vjpvmap(self, device, dtype, op):
1239
1155
xfail ('nn.functional.softmin' , '' ),
1240
1156
xfail ('nn.functional.softmin' , 'with_dtype' ),
1241
1157
xfail ('nn.functional.softplus' , '' ),
1242
- xfail ('put' , '' ),
1243
1158
xfail ('renorm' , '' ),
1244
1159
xfail ('std_mean' , '' ),
1245
1160
xfail ('symeig' , '' ),
1246
- xfail ('take' , '' ),
1247
1161
xfail ('var_mean' , '' ),
1248
1162
xfail ('nn.functional.feature_alpha_dropout' , 'with_train' ),
1249
1163
xfail ('nn.functional.kl_div' , '' ),
@@ -1264,18 +1178,6 @@ def test_vjpvmap(self, device, dtype, op):
1264
1178
xfail ('scatter_reduce' , 'prod' ),
1265
1179
skip ('linalg.householder_product' , '' , device_type = 'cuda' ), # flaky, I'm not sure why
1266
1180
xfail ('nn.functional.binary_cross_entropy_with_logits' ),
1267
- xfail ('__getitem__' , 'functorch' , device_type = 'cuda' ),
1268
- xfail ('_masked.amax' , device_type = 'cuda' ),
1269
- xfail ('_masked.amin' , device_type = 'cuda' ),
1270
- xfail ('_masked.log_softmax' , device_type = 'cuda' ),
1271
- xfail ('_masked.mean' , device_type = 'cuda' ),
1272
- xfail ('_masked.norm' , device_type = 'cuda' ),
1273
- xfail ('_masked.prod' , device_type = 'cuda' ),
1274
- xfail ('_masked.softmax' , device_type = 'cuda' ),
1275
- xfail ('_masked.softmin' , device_type = 'cuda' ),
1276
- xfail ('_masked.std' , device_type = 'cuda' ),
1277
- xfail ('_masked.sum' , device_type = 'cuda' ),
1278
- xfail ('_masked.var' , device_type = 'cuda' ),
1279
1181
}))
1280
1182
def test_jvpvjp (self , device , dtype , op ):
1281
1183
if not op .supports_autograd :
0 commit comments