File tree Expand file tree Collapse file tree 1 file changed +10
-1
lines changed
segmentation_models_pytorch/base Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Original file line number Diff line number Diff line change 55
66
77class SegmentationModel (torch .nn .Module , SMPHubMixin ):
8+ """Base class for all segmentation models."""
9+
10+ # if model supports shape not divisible by 2 ^ n
11+ # set to False
12+ requires_divisible_input_shape = True
13+
814 def initialize (self ):
915 init .initialize_decoder (self .decoder )
1016 init .initialize_head (self .segmentation_head )
1117 if self .classification_head is not None :
1218 init .initialize_head (self .classification_head )
1319
1420 def check_input_shape (self , x ):
21+ """Check if the input shape is divisible by the output stride.
22+ If not, raise a RuntimeError.
23+ """
1524 h , w = x .shape [- 2 :]
1625 output_stride = self .encoder .output_stride
1726 if h % output_stride != 0 or w % output_stride != 0 :
@@ -33,7 +42,7 @@ def check_input_shape(self, x):
3342 def forward (self , x ):
3443 """Sequentially pass `x` trough model`s encoder, decoder and heads"""
3544
36- if not torch .jit .is_tracing ():
45+ if not torch .jit .is_tracing () or self . requires_divisible_input_shape :
3746 self .check_input_shape (x )
3847
3948 features = self .encoder (x )
You can’t perform that action at this time.
0 commit comments