Skip to content

Commit 9f51206

Browse files
ricberqubvel
andauthored
Refactor _set_encoder_trainable
Co-authored-by: Pavel Iakubovskii <[email protected]>
1 parent a76967c commit 9f51206

File tree

1 file changed

+1
-5
lines changed
  • segmentation_models_pytorch/base

1 file changed

+1
-5
lines changed

segmentation_models_pytorch/base/model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -180,11 +180,7 @@ def _set_encoder_trainable(self, mode: bool):
180180
# _NormBase is the common base of classes like _InstanceNorm
181181
# and _BatchNorm that track running stats
182182
if isinstance(module, torch.nn.modules.batchnorm._NormBase):
183-
if mode:
184-
module.train()
185-
else:
186-
# Putting norm layers into eval mode stops running stats updates
187-
module.eval()
183+
module.train(mode)
188184

189185
self._is_encoder_frozen = not mode
190186

0 commit comments

Comments
 (0)