Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def __new__(cls: Type[T], *args, **kwargs) -> T:
instance = super().__new__(cls, *args, **kwargs)
return instance

def __init__(self):
super().__init__()
self._is_encoder_frozen = False

def initialize(self):
init.initialize_decoder(self.decoder)
init.initialize_head(self.segmentation_head)
Expand Down Expand Up @@ -137,3 +141,70 @@ def load_state_dict(self, state_dict, **kwargs):
warnings.warn(text, stacklevel=-1)

return super().load_state_dict(state_dict, **kwargs)

def train(self, mode: bool = True):
"""Set the module in training mode.

This method behaves like the standard :meth:`torch.nn.Module.train`,
with one exception: if the encoder has been frozen via
:meth:`freeze_encoder`, then its normalization layers are not affected
by this call. In other words, calling ``model.train()`` will not
re-enable updates to frozen encoder normalization layers
(e.g., BatchNorm, InstanceNorm).

To restore the encoder to normal training behavior, use
:meth:`unfreeze_encoder`.

Args:
mode (bool): whether to set training mode (``True``) or evaluation
mode (``False``). Default: ``True``.

Returns:
Module: self
"""
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for name, module in self.named_children():
# skip encoder if it is frozen
if self._is_encoder_frozen and name == "encoder":
continue
module.train(mode)
return self

def _set_encoder_trainable(self, mode: bool):
for param in self.encoder.parameters():
param.requires_grad = mode

for module in self.encoder.modules():
# _NormBase is the common base of classes like _InstanceNorm
# and _BatchNorm that track running stats
if isinstance(module, torch.nn.modules.batchnorm._NormBase):
module.train(mode)

self._is_encoder_frozen = not mode

def freeze_encoder(self):
"""
Freeze encoder parameters and disable updates to normalization
layer statistics.

This method:
- Sets ``requires_grad = False`` for all encoder parameters,
preventing them from being updated during backpropagation.
- Puts normalization layers that track running statistics
(e.g., BatchNorm, InstanceNorm) into evaluation mode (``.eval()``),
so their ``running_mean`` and ``running_var`` are no longer updated.
"""
return self._set_encoder_trainable(False)

def unfreeze_encoder(self):
"""
Unfreeze encoder parameters and restore normalization layers to training mode.

This method reverts the effect of :meth:`freeze_encoder`. Specifically:
- Sets ``requires_grad=True`` for all encoder parameters.
- Restores normalization layers (e.g. BatchNorm, InstanceNorm) to training mode,
so their running statistics are updated again.
"""
return self._set_encoder_trainable(True)
56 changes: 56 additions & 0 deletions tests/base/test_freeze_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch
import segmentation_models_pytorch as smp


def test_freeze_and_unfreeze_encoder():
model = smp.Unet(encoder_name="resnet18", encoder_weights=None)
model.train()
# Initially, encoder params should be trainable
assert all(p.requires_grad for p in model.encoder.parameters())
model.freeze_encoder()
# Check encoder params are frozen
assert all(not p.requires_grad for p in model.encoder.parameters())
# Check normalization layers are in eval mode
for m in model.encoder.modules():
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
assert not m.training
# Call train() and ensure encoder norm layers stay frozen
model.train()
for m in model.encoder.modules():
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
assert not m.training
model.unfreeze_encoder()
# Params should be trainable again
assert all(p.requires_grad for p in model.encoder.parameters())
# Norm layers should go back to training mode after unfreeze
for m in model.encoder.modules():
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
assert m.training
model.train()
# Norm layers should have the same training mode after train()
for m in model.encoder.modules():
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
assert m.training


def test_freeze_encoder_stops_running_stats():
model = smp.Unet(encoder_name="resnet18", encoder_weights=None)
model.freeze_encoder()
model.train() # overridden train, encoder should remain frozen
bn = None

for m in model.encoder.modules():
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
bn = m
break

assert bn is not None

orig_mean = bn.running_mean.clone()
orig_var = bn.running_var.clone()

x = torch.randn(2, 3, 64, 64)
_ = model(x)

torch.testing.assert_close(orig_mean, bn.running_mean)
torch.testing.assert_close(orig_var, bn.running_var)