Skip to content

Commit 3ac1ca8

Browse files
authored
Fix UPerNet decoder typo (#945)
* fix typo * fix style issues
1 parent 2be1124 commit 3ac1ca8

File tree

1 file changed

+4
-6
lines changed
  • segmentation_models_pytorch/decoders/upernet

1 file changed

+4
-6
lines changed

segmentation_models_pytorch/decoders/upernet/decoder.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def __init__(
3636
)
3737

3838
def forward(self, x):
39-
_, _, height, weight = x.shape
39+
_, _, height, width = x.shape
4040
out = [x] + [
4141
F.interpolate(
42-
block(x), size=(height, weight), mode="bilinear", align_corners=False
42+
block(x), size=(height, width), mode="bilinear", align_corners=False
4343
)
4444
for block in self.blocks
4545
]
@@ -62,10 +62,8 @@ def __init__(self, skip_channels, pyramid_channels, use_bathcnorm=True):
6262
)
6363

6464
def forward(self, x, skip):
65-
_, channels, height, weight = skip.shape
66-
x = F.interpolate(
67-
x, size=(height, weight), mode="bilinear", align_corners=False
68-
)
65+
_, channels, height, width = skip.shape
66+
x = F.interpolate(x, size=(height, width), mode="bilinear", align_corners=False)
6967
if channels != 0:
7068
skip = self.skip_conv(skip)
7169
x = x + skip

0 commit comments

Comments
 (0)