diff --git a/docs/insights.rst b/docs/insights.rst index ad5355b9..e5848ef6 100644 --- a/docs/insights.rst +++ b/docs/insights.rst @@ -117,3 +117,25 @@ Example: mask.shape, label.shape # (N, 4, H, W), (N, 4) + +4. Freezing and unfreezing the encoder +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Sometimes you may want to freeze the encoder during training, e.g. when using pretrained backbones and only fine-tuning the decoder and segmentation head. + +All segmentation models in SMP provide two helper methods: + +.. code-block:: python + + model = smp.Unet("resnet34", classes=2) + + # Freeze encoder: stops gradient updates and freezes normalization layer stats + model.freeze_encoder() + + # Unfreeze encoder: re-enables training for encoder parameters and normalization layers + model.unfreeze_encoder() + +.. important:: + - Freezing sets ``requires_grad = False`` for all encoder parameters. + - Normalization layers that track running statistics (e.g., BatchNorm and InstanceNorm layers) are set to ``.eval()`` mode to prevent updates to ``running_mean`` and ``running_var``. + - If you later call ``model.train()``, frozen encoders will remain frozen until you call ``unfreeze_encoder()``. diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 71322cf0..aa01d7b9 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -24,6 +24,10 @@ def __new__(cls: Type[T], *args, **kwargs) -> T: instance = super().__new__(cls, *args, **kwargs) return instance + def __init__(self): + super().__init__() + self._is_encoder_frozen = False + def initialize(self): init.initialize_decoder(self.decoder) init.initialize_head(self.segmentation_head) @@ -137,3 +141,70 @@ def load_state_dict(self, state_dict, **kwargs): warnings.warn(text, stacklevel=-1) return super().load_state_dict(state_dict, **kwargs) + + def train(self, mode: bool = True): + """Set the module in training mode. + + This method behaves like the standard :meth:`torch.nn.Module.train`, + with one exception: if the encoder has been frozen via + :meth:`freeze_encoder`, then its normalization layers are not affected + by this call. In other words, calling ``model.train()`` will not + re-enable updates to frozen encoder normalization layers + (e.g., BatchNorm, InstanceNorm). + + To restore the encoder to normal training behavior, use + :meth:`unfreeze_encoder`. + + Args: + mode (bool): whether to set training mode (``True``) or evaluation + mode (``False``). Default: ``True``. + + Returns: + Module: self + """ + if not isinstance(mode, bool): + raise ValueError("training mode is expected to be boolean") + self.training = mode + for name, module in self.named_children(): + # skip encoder if it is frozen + if self._is_encoder_frozen and name == "encoder": + continue + module.train(mode) + return self + + def _set_encoder_trainable(self, mode: bool): + for param in self.encoder.parameters(): + param.requires_grad = mode + + for module in self.encoder.modules(): + # _NormBase is the common base of classes like _InstanceNorm + # and _BatchNorm that track running stats + if isinstance(module, torch.nn.modules.batchnorm._NormBase): + module.train(mode) + + self._is_encoder_frozen = not mode + + def freeze_encoder(self): + """ + Freeze encoder parameters and disable updates to normalization + layer statistics. + + This method: + - Sets ``requires_grad = False`` for all encoder parameters, + preventing them from being updated during backpropagation. + - Puts normalization layers that track running statistics + (e.g., BatchNorm, InstanceNorm) into evaluation mode (``.eval()``), + so their ``running_mean`` and ``running_var`` are no longer updated. + """ + return self._set_encoder_trainable(False) + + def unfreeze_encoder(self): + """ + Unfreeze encoder parameters and restore normalization layers to training mode. + + This method reverts the effect of :meth:`freeze_encoder`. Specifically: + - Sets ``requires_grad=True`` for all encoder parameters. + - Restores normalization layers (e.g. BatchNorm, InstanceNorm) to training mode, + so their running statistics are updated again. + """ + return self._set_encoder_trainable(True) diff --git a/tests/base/test_freeze_encoder.py b/tests/base/test_freeze_encoder.py new file mode 100644 index 00000000..c4e082b4 --- /dev/null +++ b/tests/base/test_freeze_encoder.py @@ -0,0 +1,70 @@ +import torch +import segmentation_models_pytorch as smp + + +def test_freeze_and_unfreeze_encoder(): + model = smp.Unet(encoder_name="resnet18", encoder_weights=None) + + def assert_encoder_params_trainable(expected: bool): + assert all(p.requires_grad == expected for p in model.encoder.parameters()) + + def assert_norm_layers_training(expected: bool): + for m in model.encoder.modules(): + if isinstance(m, torch.nn.modules.batchnorm._NormBase): + assert m.training == expected + + # Initially, encoder params should be trainable + model.train() + assert_encoder_params_trainable(True) + + # Freeze encoder + model.freeze_encoder() + assert_encoder_params_trainable(False) + assert_norm_layers_training(False) + + # Call train() and ensure encoder norm layers stay frozen + model.train() + assert_norm_layers_training(False) + + # Unfreeze encoder + model.unfreeze_encoder() + assert_encoder_params_trainable(True) + assert_norm_layers_training(True) + + # Call train() again — should stay trainable + model.train() + assert_norm_layers_training(True) + + # Switch to eval, then freeze + model.eval() + model.freeze_encoder() + assert_encoder_params_trainable(False) + assert_norm_layers_training(False) + + # Unfreeze again + model.unfreeze_encoder() + assert_encoder_params_trainable(True) + assert_norm_layers_training(True) + + +def test_freeze_encoder_stops_running_stats(): + model = smp.Unet(encoder_name="resnet18", encoder_weights=None) + model.freeze_encoder() + model.train() # overridden train, encoder should remain frozen + bn = None + + for m in model.encoder.modules(): + if isinstance(m, torch.nn.modules.batchnorm._NormBase): + bn = m + break + + assert bn is not None + + orig_mean = bn.running_mean.clone() + orig_var = bn.running_var.clone() + + x = torch.randn(2, 3, 64, 64) + _ = model(x) + + torch.testing.assert_close(orig_mean, bn.running_mean) + torch.testing.assert_close(orig_var, bn.running_var)