Skip to content

Commit eadfe1f

Browse files
committed
Docs + flag for anyres
1 parent 4d7fed0 commit eadfe1f

File tree

1 file changed

+10
-1
lines changed
  • segmentation_models_pytorch/base

1 file changed

+10
-1
lines changed

segmentation_models_pytorch/base/model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,22 @@
55

66

77
class SegmentationModel(torch.nn.Module, SMPHubMixin):
8+
"""Base class for all segmentation models."""
9+
10+
# if model supports shape not divisible by 2 ^ n
11+
# set to False
12+
requires_divisible_input_shape = True
13+
814
def initialize(self):
915
init.initialize_decoder(self.decoder)
1016
init.initialize_head(self.segmentation_head)
1117
if self.classification_head is not None:
1218
init.initialize_head(self.classification_head)
1319

1420
def check_input_shape(self, x):
21+
"""Check if the input shape is divisible by the output stride.
22+
If not, raise a RuntimeError.
23+
"""
1524
h, w = x.shape[-2:]
1625
output_stride = self.encoder.output_stride
1726
if h % output_stride != 0 or w % output_stride != 0:
@@ -33,7 +42,7 @@ def check_input_shape(self, x):
3342
def forward(self, x):
3443
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""
3544

36-
if not torch.jit.is_tracing():
45+
if not torch.jit.is_tracing() or self.requires_divisible_input_shape:
3746
self.check_input_shape(x)
3847

3948
features = self.encoder(x)

0 commit comments

Comments
 (0)