From fd6a900c340b7d5a441ae96443268d69a3a74104 Mon Sep 17 00:00:00 2001 From: Riccardo Bertoglio Date: Wed, 20 Aug 2025 14:59:33 +0200 Subject: [PATCH 01/10] Add encoder_freeze method --- segmentation_models_pytorch/base/model.py | 26 +++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 71322cf0..72ec1b50 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -137,3 +137,29 @@ def load_state_dict(self, state_dict, **kwargs): warnings.warn(text, stacklevel=-1) return super().load_state_dict(state_dict, **kwargs) + + def encoder_freeze(self): + """ + Freeze the encoder parameters and normalization layers. + + This method sets ``requires_grad = False`` for all encoder parameters, + preventing them from being updated during backpropagation. In addition, + it switches BatchNorm and InstanceNorm layers into evaluation mode + (``.eval()``), which stops updates to their running statistics + (``running_mean`` and ``running_var``) during training. + + **Important:** If you call :meth:`model.train()` after + :meth:`encoder_freeze`, the encoder’s BatchNorm/InstanceNorm layers + will be put back into training mode, and their running statistics + will be updated again. To re-freeze them after switching modes, + call :meth:`encoder_freeze` again. + """ + for param in self.encoder.parameters(): + param.requires_grad = False + + # Putting norm layers into eval mode stops running stats updates + for module in self.encoder.modules(): + # _NormBase is the common base of _InstanceNorm and _BatchNorm classes + # These are the two classes that track running stats + if isinstance(module, torch.nn.modules.batchnorm._NormBase): + module.eval() From 92b7e1c1cbf54d3c454e61f460b067d4cbb6bf87 Mon Sep 17 00:00:00 2001 From: Riccardo Bertoglio Date: Thu, 21 Aug 2025 16:15:54 +0200 Subject: [PATCH 02/10] Add encoder freeze/unfreeze methods and override train() --- segmentation_models_pytorch/base/model.py | 90 ++++++++++++++++++----- 1 file changed, 70 insertions(+), 20 deletions(-) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 72ec1b50..240691b7 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -23,6 +23,10 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin): def __new__(cls: Type[T], *args, **kwargs) -> T: instance = super().__new__(cls, *args, **kwargs) return instance + + def __init__(self): + super().__init__() + self._is_encoder_frozen = False def initialize(self): init.initialize_decoder(self.decoder) @@ -138,28 +142,74 @@ def load_state_dict(self, state_dict, **kwargs): return super().load_state_dict(state_dict, **kwargs) - def encoder_freeze(self): - """ - Freeze the encoder parameters and normalization layers. - - This method sets ``requires_grad = False`` for all encoder parameters, - preventing them from being updated during backpropagation. In addition, - it switches BatchNorm and InstanceNorm layers into evaluation mode - (``.eval()``), which stops updates to their running statistics - (``running_mean`` and ``running_var``) during training. - - **Important:** If you call :meth:`model.train()` after - :meth:`encoder_freeze`, the encoder’s BatchNorm/InstanceNorm layers - will be put back into training mode, and their running statistics - will be updated again. To re-freeze them after switching modes, - call :meth:`encoder_freeze` again. + def train(self, mode: bool = True): + """Set the module in training mode. + + This method behaves like the standard :meth:`torch.nn.Module.train`, + with one exception: if the encoder has been frozen via + :meth:`freeze_encoder`, then its normalization layers are not affected + by this call. In other words, calling ``model.train()`` will not + re-enable updates to frozen encoder normalization layers + (e.g., BatchNorm, InstanceNorm). + + To restore the encoder to normal training behavior, use + :meth:`unfreeze_encoder`. + + Args: + mode (bool): whether to set training mode (``True``) or evaluation + mode (``False``). Default: ``True``. + + Returns: + Module: self """ + if not isinstance(mode, bool): + raise ValueError("training mode is expected to be boolean") + self.training = mode + for name, module in self.named_children(): + print(name) + # skip encoder encoder if it is frozen + if self._is_encoder_frozen and name == "encoder": + continue + module.train(mode) + return self + + def _set_encoder_trainable(self, mode: bool): for param in self.encoder.parameters(): - param.requires_grad = False + param.requires_grad = mode - # Putting norm layers into eval mode stops running stats updates for module in self.encoder.modules(): - # _NormBase is the common base of _InstanceNorm and _BatchNorm classes - # These are the two classes that track running stats + # _NormBase is the common base of classes like _InstanceNorm + # and _BatchNorm that track running stats if isinstance(module, torch.nn.modules.batchnorm._NormBase): - module.eval() + if mode: + module.train() + else: + # Putting norm layers into eval mode stops running stats updates + module.eval() + + self._is_encoder_frozen = not mode + + def freeze_encoder(self): + """ + Freeze encoder parameters and disable updates to normalization + layer statistics. + + This method: + - Sets ``requires_grad = False`` for all encoder parameters, + preventing them from being updated during backpropagation. + - Puts normalization layers that track running statistics + (e.g., BatchNorm, InstanceNorm) into evaluation mode (``.eval()``), + so their ``running_mean`` and ``running_var`` are no longer updated. + """ + return self._set_encoder_trainable(False) + + def unfreeze_encoder(self): + """ + Unfreeze encoder parameters and restore normalization layers to training mode. + + This method reverts the effect of :meth:`freeze_encoder`. Specifically: + - Sets ``requires_grad=True`` for all encoder parameters. + - Restores normalization layers (e.g. BatchNorm, InstanceNorm) to training mode, + so their running statistics are updated again. + """ + return self._set_encoder_trainable(True) From 32f3c07d5a7f83fa5bc80c209431cdf0f2d94641 Mon Sep 17 00:00:00 2001 From: Riccardo Bertoglio Date: Thu, 21 Aug 2025 16:27:11 +0200 Subject: [PATCH 03/10] Fix typo --- segmentation_models_pytorch/base/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 240691b7..7b99d314 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -167,7 +167,7 @@ def train(self, mode: bool = True): self.training = mode for name, module in self.named_children(): print(name) - # skip encoder encoder if it is frozen + # skip encoder if it is frozen if self._is_encoder_frozen and name == "encoder": continue module.train(mode) From a76967c7f14db1df8cb944d511f812c5d0f6b695 Mon Sep 17 00:00:00 2001 From: Riccardo Bertoglio Date: Thu, 21 Aug 2025 16:34:15 +0200 Subject: [PATCH 04/10] Remove unnecessary print --- segmentation_models_pytorch/base/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 7b99d314..3edb513a 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -23,7 +23,7 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin): def __new__(cls: Type[T], *args, **kwargs) -> T: instance = super().__new__(cls, *args, **kwargs) return instance - + def __init__(self): super().__init__() self._is_encoder_frozen = False @@ -166,13 +166,12 @@ def train(self, mode: bool = True): raise ValueError("training mode is expected to be boolean") self.training = mode for name, module in self.named_children(): - print(name) # skip encoder if it is frozen if self._is_encoder_frozen and name == "encoder": continue module.train(mode) return self - + def _set_encoder_trainable(self, mode: bool): for param in self.encoder.parameters(): param.requires_grad = mode From 9f51206e749cdb4307c18ffce010abfdbfb21abf Mon Sep 17 00:00:00 2001 From: Riccardo Bertoglio <18362950+ricber@users.noreply.github.com> Date: Thu, 21 Aug 2025 16:39:52 +0200 Subject: [PATCH 05/10] Refactor _set_encoder_trainable Co-authored-by: Pavel Iakubovskii --- segmentation_models_pytorch/base/model.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 3edb513a..aa01d7b9 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -180,11 +180,7 @@ def _set_encoder_trainable(self, mode: bool): # _NormBase is the common base of classes like _InstanceNorm # and _BatchNorm that track running stats if isinstance(module, torch.nn.modules.batchnorm._NormBase): - if mode: - module.train() - else: - # Putting norm layers into eval mode stops running stats updates - module.eval() + module.train(mode) self._is_encoder_frozen = not mode From a699d67f42d035bacda020e69252ab9c3c466cb1 Mon Sep 17 00:00:00 2001 From: Riccardo Bertoglio Date: Thu, 21 Aug 2025 17:12:05 +0200 Subject: [PATCH 06/10] Add tests for encoder freezing --- tests/base/test_freeze_encoder.py | 54 +++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/base/test_freeze_encoder.py diff --git a/tests/base/test_freeze_encoder.py b/tests/base/test_freeze_encoder.py new file mode 100644 index 00000000..4876ec08 --- /dev/null +++ b/tests/base/test_freeze_encoder.py @@ -0,0 +1,54 @@ +import torch +import segmentation_models_pytorch as smp + + +def test_freeze_and_unfreeze_encoder(): + model = smp.Unet(encoder_name="resnet18", encoder_weights=None) + model.train() + # Initially, encoder params should be trainable + assert all(p.requires_grad for p in model.encoder.parameters()) + model.freeze_encoder() + # Check encoder params are frozen + assert all(not p.requires_grad for p in model.encoder.parameters()) + # Check normalization layers are in eval mode + for m in model.encoder.modules(): + if isinstance(m, torch.nn.modules.batchnorm._NormBase): + assert not m.training + # Call train() and ensure encoder norm layers stay frozen + model.train() + for m in model.encoder.modules(): + if isinstance(m, torch.nn.modules.batchnorm._NormBase): + assert not m.training + model.unfreeze_encoder() + # Params should be trainable again + assert all(p.requires_grad for p in model.encoder.parameters()) + # Norm layers should go back to training mode after unfreeze + for m in model.encoder.modules(): + if isinstance(m, torch.nn.modules.batchnorm._NormBase): + assert m.training + model.train() + # Norm layers should have the same training mode after train() + for m in model.encoder.modules(): + if isinstance(m, torch.nn.modules.batchnorm._NormBase): + assert m.training + + +def test_freeze_encoder_stops_running_stats(): + model = smp.Unet(encoder_name="resnet18", encoder_weights=None) + model.freeze_encoder() + model.train() # overridden train, encoder should remain frozen + bn = None + + for m in model.encoder.modules(): + if isinstance(m, torch.nn.modules.batchnorm._NormBase): + bn = m + break + + assert bn is not None + + orig_mean = bn.running_mean.clone() + + x = torch.randn(2, 3, 64, 64) + _ = model(x) + + torch.testing.assert_close(orig_mean, bn.running_mean) From c9e6e0e8a4155df9ef081bcf38e4b5a00dd92e51 Mon Sep 17 00:00:00 2001 From: Riccardo Bertoglio Date: Thu, 21 Aug 2025 17:15:38 +0200 Subject: [PATCH 07/10] Add assertion on running_var --- tests/base/test_freeze_encoder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/base/test_freeze_encoder.py b/tests/base/test_freeze_encoder.py index 4876ec08..abda1399 100644 --- a/tests/base/test_freeze_encoder.py +++ b/tests/base/test_freeze_encoder.py @@ -47,8 +47,10 @@ def test_freeze_encoder_stops_running_stats(): assert bn is not None orig_mean = bn.running_mean.clone() + orig_var = bn.running_var.clone() x = torch.randn(2, 3, 64, 64) _ = model(x) torch.testing.assert_close(orig_mean, bn.running_mean) + torch.testing.assert_close(orig_var, bn.running_var) From e32de477c3620b91c62743707edb479edd652c09 Mon Sep 17 00:00:00 2001 From: Riccardo Bertoglio <18362950+ricber@users.noreply.github.com> Date: Thu, 21 Aug 2025 20:34:56 +0200 Subject: [PATCH 08/10] Refactor test_freeze_and_unfreeze_encoder Co-authored-by: Pavel Iakubovskii --- tests/base/test_freeze_encoder.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/base/test_freeze_encoder.py b/tests/base/test_freeze_encoder.py index abda1399..a27e6a54 100644 --- a/tests/base/test_freeze_encoder.py +++ b/tests/base/test_freeze_encoder.py @@ -4,28 +4,34 @@ def test_freeze_and_unfreeze_encoder(): model = smp.Unet(encoder_name="resnet18", encoder_weights=None) - model.train() + # Initially, encoder params should be trainable + model.train() assert all(p.requires_grad for p in model.encoder.parameters()) - model.freeze_encoder() + # Check encoder params are frozen + model.freeze_encoder() + assert all(not p.requires_grad for p in model.encoder.parameters()) - # Check normalization layers are in eval mode for m in model.encoder.modules(): if isinstance(m, torch.nn.modules.batchnorm._NormBase): assert not m.training + # Call train() and ensure encoder norm layers stay frozen model.train() for m in model.encoder.modules(): if isinstance(m, torch.nn.modules.batchnorm._NormBase): assert not m.training - model.unfreeze_encoder() + # Params should be trainable again + model.unfreeze_encoder() + assert all(p.requires_grad for p in model.encoder.parameters()) # Norm layers should go back to training mode after unfreeze for m in model.encoder.modules(): if isinstance(m, torch.nn.modules.batchnorm._NormBase): assert m.training + model.train() # Norm layers should have the same training mode after train() for m in model.encoder.modules(): From 77fd4906ca9ee51fcbbf753f1ec61647d70eb638 Mon Sep 17 00:00:00 2001 From: Riccardo Bertoglio Date: Thu, 21 Aug 2025 20:59:28 +0200 Subject: [PATCH 09/10] Refactor test and add call to eval --- tests/base/test_freeze_encoder.py | 58 ++++++++++++++++++------------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/tests/base/test_freeze_encoder.py b/tests/base/test_freeze_encoder.py index a27e6a54..c4e082b4 100644 --- a/tests/base/test_freeze_encoder.py +++ b/tests/base/test_freeze_encoder.py @@ -4,39 +4,47 @@ def test_freeze_and_unfreeze_encoder(): model = smp.Unet(encoder_name="resnet18", encoder_weights=None) - + + def assert_encoder_params_trainable(expected: bool): + assert all(p.requires_grad == expected for p in model.encoder.parameters()) + + def assert_norm_layers_training(expected: bool): + for m in model.encoder.modules(): + if isinstance(m, torch.nn.modules.batchnorm._NormBase): + assert m.training == expected + # Initially, encoder params should be trainable model.train() - assert all(p.requires_grad for p in model.encoder.parameters()) - - # Check encoder params are frozen + assert_encoder_params_trainable(True) + + # Freeze encoder model.freeze_encoder() - - assert all(not p.requires_grad for p in model.encoder.parameters()) - for m in model.encoder.modules(): - if isinstance(m, torch.nn.modules.batchnorm._NormBase): - assert not m.training + assert_encoder_params_trainable(False) + assert_norm_layers_training(False) # Call train() and ensure encoder norm layers stay frozen model.train() - for m in model.encoder.modules(): - if isinstance(m, torch.nn.modules.batchnorm._NormBase): - assert not m.training - - # Params should be trainable again + assert_norm_layers_training(False) + + # Unfreeze encoder model.unfreeze_encoder() - - assert all(p.requires_grad for p in model.encoder.parameters()) - # Norm layers should go back to training mode after unfreeze - for m in model.encoder.modules(): - if isinstance(m, torch.nn.modules.batchnorm._NormBase): - assert m.training - + assert_encoder_params_trainable(True) + assert_norm_layers_training(True) + + # Call train() again — should stay trainable model.train() - # Norm layers should have the same training mode after train() - for m in model.encoder.modules(): - if isinstance(m, torch.nn.modules.batchnorm._NormBase): - assert m.training + assert_norm_layers_training(True) + + # Switch to eval, then freeze + model.eval() + model.freeze_encoder() + assert_encoder_params_trainable(False) + assert_norm_layers_training(False) + + # Unfreeze again + model.unfreeze_encoder() + assert_encoder_params_trainable(True) + assert_norm_layers_training(True) def test_freeze_encoder_stops_running_stats(): From ffa3c5af900db5909f811be022a4f1b27578a099 Mon Sep 17 00:00:00 2001 From: Riccardo Bertoglio Date: Thu, 21 Aug 2025 21:13:50 +0200 Subject: [PATCH 10/10] add example for encoder freezing --- docs/insights.rst | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/insights.rst b/docs/insights.rst index ad5355b9..e5848ef6 100644 --- a/docs/insights.rst +++ b/docs/insights.rst @@ -117,3 +117,25 @@ Example: mask.shape, label.shape # (N, 4, H, W), (N, 4) + +4. Freezing and unfreezing the encoder +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Sometimes you may want to freeze the encoder during training, e.g. when using pretrained backbones and only fine-tuning the decoder and segmentation head. + +All segmentation models in SMP provide two helper methods: + +.. code-block:: python + + model = smp.Unet("resnet34", classes=2) + + # Freeze encoder: stops gradient updates and freezes normalization layer stats + model.freeze_encoder() + + # Unfreeze encoder: re-enables training for encoder parameters and normalization layers + model.unfreeze_encoder() + +.. important:: + - Freezing sets ``requires_grad = False`` for all encoder parameters. + - Normalization layers that track running statistics (e.g., BatchNorm and InstanceNorm layers) are set to ``.eval()`` mode to prevent updates to ``running_mean`` and ``running_var``. + - If you later call ``model.train()``, frozen encoders will remain frozen until you call ``unfreeze_encoder()``.