Skip to content

Commit 340e677

Browse files
committed
Revert asymmetric padding. Raise error if block size is not even.
1 parent bcd7fbb commit 340e677

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

torchvision/ops/drop_block.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,16 @@ def drop_block2d(
3636

3737
N, C, H, W = input.size()
3838
block_size = min(block_size, W, H)
39+
if block_size % 2 == 0:
40+
raise ValueError(f"block size should be odd. Got {block_size} which is even.")
41+
3942
# compute the gamma of Bernoulli distribution
4043
gamma = (p * H * W) / ((block_size**2) * ((H - block_size + 1) * (W - block_size + 1)))
4144
noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device)
4245
noise.bernoulli_(gamma)
4346

4447
noise = F.pad(noise, [block_size // 2] * 4, value=0)
45-
left_pad = right_pad = block_size // 2
46-
if left_pad > 0 and block_size % 2 == 0:
47-
left_pad -= 1
48-
noise = F.pad(noise, pad=(left_pad, right_pad, left_pad, right_pad))
49-
noise = F.max_pool2d(noise, stride=1, kernel_size=block_size)
48+
noise = F.max_pool2d(noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2)
5049
noise = 1 - noise
5150
normalize_scale = noise.numel() / (eps + noise.sum())
5251
if inplace:
@@ -86,6 +85,9 @@ def drop_block3d(
8685

8786
N, C, D, H, W = input.size()
8887
block_size = min(block_size, D, H, W)
88+
if block_size % 2 == 0:
89+
raise ValueError(f"block size should be odd. Got {block_size} which is even.")
90+
8991
# compute the gamma of Bernoulli distribution
9092
gamma = (p * D * H * W) / ((block_size**3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1)))
9193
noise = torch.empty(

0 commit comments

Comments
 (0)