@@ -36,17 +36,16 @@ def drop_block2d(
3636
3737    N , C , H , W  =  input .size ()
3838    block_size  =  min (block_size , W , H )
39+     if  block_size  %  2  ==  0 :
40+         raise  ValueError (f"block size should be odd. Got { block_size }   which is even." )
41+ 
3942    # compute the gamma of Bernoulli distribution 
4043    gamma  =  (p  *  H  *  W ) /  ((block_size ** 2 ) *  ((H  -  block_size  +  1 ) *  (W  -  block_size  +  1 )))
4144    noise  =  torch .empty ((N , C , H  -  block_size  +  1 , W  -  block_size  +  1 ), dtype = input .dtype , device = input .device )
4245    noise .bernoulli_ (gamma )
4346
4447    noise  =  F .pad (noise , [block_size  //  2 ] *  4 , value = 0 )
45-     left_pad  =  right_pad  =  block_size  //  2 
46-     if  left_pad  >  0  and  block_size  %  2  ==  0 :
47-         left_pad  -=  1 
48-     noise  =  F .pad (noise , pad = (left_pad , right_pad , left_pad , right_pad ))
49-     noise  =  F .max_pool2d (noise , stride = 1 , kernel_size = block_size )
48+     noise  =  F .max_pool2d (noise , stride = (1 , 1 ), kernel_size = (block_size , block_size ), padding = block_size  //  2 )
5049    noise  =  1  -  noise 
5150    normalize_scale  =  noise .numel () /  (eps  +  noise .sum ())
5251    if  inplace :
@@ -86,6 +85,9 @@ def drop_block3d(
8685
8786    N , C , D , H , W  =  input .size ()
8887    block_size  =  min (block_size , D , H , W )
88+     if  block_size  %  2  ==  0 :
89+         raise  ValueError (f"block size should be odd. Got { block_size }   which is even." )
90+ 
8991    # compute the gamma of Bernoulli distribution 
9092    gamma  =  (p  *  D  *  H  *  W ) /  ((block_size ** 3 ) *  ((D  -  block_size  +  1 ) *  (H  -  block_size  +  1 ) *  (W  -  block_size  +  1 )))
9193    noise  =  torch .empty (
0 commit comments