@@ -2718,3 +2718,144 @@ def test_errors(self):
2718
2718
2719
2719
with pytest .raises (ValueError , match = "Padding mode should be either" ):
2720
2720
transforms .RandomCrop ([10 , 12 ], padding = 1 , padding_mode = "abc" )
2721
+
2722
+
2723
+ class TestErase :
2724
+ INPUT_SIZE = (17 , 11 )
2725
+ FUNCTIONAL_KWARGS = dict (
2726
+ zip ("ijhwv" , [2 , 2 , 10 , 8 , torch .tensor (0.0 , dtype = torch .float32 , device = "cpu" ).reshape (- 1 , 1 , 1 )])
2727
+ )
2728
+
2729
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
2730
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
2731
+ def test_kernel_image (self , dtype , device ):
2732
+ check_kernel (F .erase_image , make_image (self .INPUT_SIZE , dtype = dtype , device = device ), ** self .FUNCTIONAL_KWARGS )
2733
+
2734
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
2735
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
2736
+ def test_kernel_image_inplace (self , dtype , device ):
2737
+ input = make_image (self .INPUT_SIZE , dtype = dtype , device = device )
2738
+ input_version = input ._version
2739
+
2740
+ output_out_of_place = F .erase_image (input , ** self .FUNCTIONAL_KWARGS )
2741
+ assert output_out_of_place .data_ptr () != input .data_ptr ()
2742
+ assert output_out_of_place is not input
2743
+
2744
+ output_inplace = F .erase_image (input , ** self .FUNCTIONAL_KWARGS , inplace = True )
2745
+ assert output_inplace .data_ptr () == input .data_ptr ()
2746
+ assert output_inplace ._version > input_version
2747
+ assert output_inplace is input
2748
+
2749
+ assert_equal (output_inplace , output_out_of_place )
2750
+
2751
+ def test_kernel_video (self ):
2752
+ check_kernel (F .erase_video , make_video (self .INPUT_SIZE ), ** self .FUNCTIONAL_KWARGS )
2753
+
2754
+ @pytest .mark .parametrize (
2755
+ "make_input" ,
2756
+ [make_image_tensor , make_image_pil , make_image , make_video ],
2757
+ )
2758
+ def test_functional (self , make_input ):
2759
+ check_functional (F .erase , make_input (), ** self .FUNCTIONAL_KWARGS )
2760
+
2761
+ @pytest .mark .parametrize (
2762
+ ("kernel" , "input_type" ),
2763
+ [
2764
+ (F .erase_image , torch .Tensor ),
2765
+ (F ._erase_image_pil , PIL .Image .Image ),
2766
+ (F .erase_image , tv_tensors .Image ),
2767
+ (F .erase_video , tv_tensors .Video ),
2768
+ ],
2769
+ )
2770
+ def test_functional_signature (self , kernel , input_type ):
2771
+ check_functional_kernel_signature_match (F .erase , kernel = kernel , input_type = input_type )
2772
+
2773
+ @pytest .mark .parametrize (
2774
+ "make_input" ,
2775
+ [make_image_tensor , make_image_pil , make_image , make_video ],
2776
+ )
2777
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
2778
+ def test_transform (self , make_input , device ):
2779
+ check_transform (transforms .RandomErasing (p = 1 ), make_input (device = device ))
2780
+
2781
+ def _reference_erase_image (self , image , * , i , j , h , w , v ):
2782
+ mask = torch .zeros_like (image , dtype = torch .bool )
2783
+ mask [..., i : i + h , j : j + w ] = True
2784
+
2785
+ # The broadcasting and type casting logic is handled automagically in the kernel through indexing
2786
+ value = torch .broadcast_to (v , (* image .shape [:- 2 ], h , w )).to (image )
2787
+
2788
+ erased_image = torch .empty_like (image )
2789
+ erased_image [mask ] = value .flatten ()
2790
+ erased_image [~ mask ] = image [~ mask ]
2791
+
2792
+ return erased_image
2793
+
2794
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
2795
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
2796
+ def test_functional_image_correctness (self , dtype , device ):
2797
+ image = make_image (dtype = dtype , device = device )
2798
+
2799
+ actual = F .erase (image , ** self .FUNCTIONAL_KWARGS )
2800
+ expected = self ._reference_erase_image (image , ** self .FUNCTIONAL_KWARGS )
2801
+
2802
+ assert_equal (actual , expected )
2803
+
2804
+ @param_value_parametrization (
2805
+ scale = [(0.1 , 0.2 ), [0.0 , 1.0 ]],
2806
+ ratio = [(0.3 , 0.7 ), [0.1 , 5.0 ]],
2807
+ value = [0 , 0.5 , (0 , 1 , 0 ), [- 0.2 , 0.0 , 1.3 ], "random" ],
2808
+ )
2809
+ @pytest .mark .parametrize ("dtype" , [torch .float32 , torch .uint8 ])
2810
+ @pytest .mark .parametrize ("device" , cpu_and_cuda ())
2811
+ @pytest .mark .parametrize ("seed" , list (range (5 )))
2812
+ def test_transform_image_correctness (self , param , value , dtype , device , seed ):
2813
+ transform = transforms .RandomErasing (** {param : value }, p = 1 )
2814
+
2815
+ image = make_image (dtype = dtype , device = device )
2816
+
2817
+ with freeze_rng_state ():
2818
+ torch .manual_seed (seed )
2819
+ # This emulates the random apply check that happens before _get_params is called
2820
+ torch .rand (1 )
2821
+ params = transform ._get_params ([image ])
2822
+
2823
+ torch .manual_seed (seed )
2824
+ actual = transform (image )
2825
+
2826
+ expected = self ._reference_erase_image (image , ** params )
2827
+
2828
+ assert_equal (actual , expected )
2829
+
2830
+ def test_transform_errors (self ):
2831
+ with pytest .raises (TypeError , match = "Argument value should be either a number or str or a sequence" ):
2832
+ transforms .RandomErasing (value = {})
2833
+
2834
+ with pytest .raises (ValueError , match = "If value is str, it should be 'random'" ):
2835
+ transforms .RandomErasing (value = "abc" )
2836
+
2837
+ with pytest .raises (TypeError , match = "Scale should be a sequence" ):
2838
+ transforms .RandomErasing (scale = 123 )
2839
+
2840
+ with pytest .raises (TypeError , match = "Ratio should be a sequence" ):
2841
+ transforms .RandomErasing (ratio = 123 )
2842
+
2843
+ with pytest .raises (ValueError , match = "Scale should be between 0 and 1" ):
2844
+ transforms .RandomErasing (scale = [- 1 , 2 ])
2845
+
2846
+ transform = transforms .RandomErasing (value = [1 , 2 , 3 , 4 ])
2847
+
2848
+ with pytest .raises (ValueError , match = "If value is a sequence, it should have either a single value" ):
2849
+ transform ._get_params ([make_image ()])
2850
+
2851
+ @pytest .mark .parametrize ("make_input" , [make_bounding_boxes , make_detection_mask ])
2852
+ def test_transform_passthrough (self , make_input ):
2853
+ transform = transforms .RandomErasing (p = 1 )
2854
+
2855
+ input = make_input (self .INPUT_SIZE )
2856
+
2857
+ with pytest .warns (UserWarning , match = "currently passing through inputs of type" ):
2858
+ # RandomErasing requires an image or video to be present
2859
+ _ , output = transform (make_image (self .INPUT_SIZE ), input )
2860
+
2861
+ assert output is input
0 commit comments