Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 70 additions & 20 deletions segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin):
def __new__(cls: Type[T], *args, **kwargs) -> T:
instance = super().__new__(cls, *args, **kwargs)
return instance

def __init__(self):
super().__init__()
self._is_encoder_frozen = False

def initialize(self):
init.initialize_decoder(self.decoder)
Expand Down Expand Up @@ -138,28 +142,74 @@ def load_state_dict(self, state_dict, **kwargs):

return super().load_state_dict(state_dict, **kwargs)

def encoder_freeze(self):
"""
Freeze the encoder parameters and normalization layers.

This method sets ``requires_grad = False`` for all encoder parameters,
preventing them from being updated during backpropagation. In addition,
it switches BatchNorm and InstanceNorm layers into evaluation mode
(``.eval()``), which stops updates to their running statistics
(``running_mean`` and ``running_var``) during training.

**Important:** If you call :meth:`model.train()` after
:meth:`encoder_freeze`, the encoder’s BatchNorm/InstanceNorm layers
will be put back into training mode, and their running statistics
will be updated again. To re-freeze them after switching modes,
call :meth:`encoder_freeze` again.
def train(self, mode: bool = True):
"""Set the module in training mode.

This method behaves like the standard :meth:`torch.nn.Module.train`,
with one exception: if the encoder has been frozen via
:meth:`freeze_encoder`, then its normalization layers are not affected
by this call. In other words, calling ``model.train()`` will not
re-enable updates to frozen encoder normalization layers
(e.g., BatchNorm, InstanceNorm).

To restore the encoder to normal training behavior, use
:meth:`unfreeze_encoder`.

Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.

Returns:
Module: self
"""
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for name, module in self.named_children():
print(name)
# skip encoder encoder if it is frozen
if self._is_encoder_frozen and name == "encoder":
continue
module.train(mode)
return self

def _set_encoder_trainable(self, mode: bool):
for param in self.encoder.parameters():
param.requires_grad = False
param.requires_grad = mode

# Putting norm layers into eval mode stops running stats updates
for module in self.encoder.modules():
# _NormBase is the common base of _InstanceNorm and _BatchNorm classes
# These are the two classes that track running stats
# _NormBase is the common base of classes like _InstanceNorm
# and _BatchNorm that track running stats
if isinstance(module, torch.nn.modules.batchnorm._NormBase):
module.eval()
if mode:
module.train()
else:
# Putting norm layers into eval mode stops running stats updates
module.eval()

self._is_encoder_frozen = not mode

def freeze_encoder(self):
"""
Freeze encoder parameters and disable updates to normalization
layer statistics.

This method:
- Sets ``requires_grad = False`` for all encoder parameters,
preventing them from being updated during backpropagation.
- Puts normalization layers that track running statistics
(e.g., BatchNorm, InstanceNorm) into evaluation mode (``.eval()``),
so their ``running_mean`` and ``running_var`` are no longer updated.
"""
return self._set_encoder_trainable(False)

def unfreeze_encoder(self):
"""
Unfreeze encoder parameters and restore normalization layers to training mode.

This method reverts the effect of :meth:`freeze_encoder`. Specifically:
- Sets ``requires_grad=True`` for all encoder parameters.
- Restores normalization layers (e.g. BatchNorm, InstanceNorm) to training mode,
so their running statistics are updated again.
"""
return self._set_encoder_trainable(True)