Skip to content

Commit 4d20629

Browse files
ricberqubvel
andauthored
Add encoder_freeze method to SegmentationModel (#1220)
* Add encoder_freeze method * Add encoder freeze/unfreeze methods and override train() * Fix typo * Remove unnecessary print * Refactor _set_encoder_trainable Co-authored-by: Pavel Iakubovskii <[email protected]> * Add tests for encoder freezing * Add assertion on running_var * Refactor test_freeze_and_unfreeze_encoder Co-authored-by: Pavel Iakubovskii <[email protected]> * Refactor test and add call to eval * add example for encoder freezing --------- Co-authored-by: Pavel Iakubovskii <[email protected]>
1 parent e76ed01 commit 4d20629

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed

docs/insights.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,25 @@ Example:
117117
118118
mask.shape, label.shape
119119
# (N, 4, H, W), (N, 4)
120+
121+
4. Freezing and unfreezing the encoder
122+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
123+
124+
Sometimes you may want to freeze the encoder during training, e.g. when using pretrained backbones and only fine-tuning the decoder and segmentation head.
125+
126+
All segmentation models in SMP provide two helper methods:
127+
128+
.. code-block:: python
129+
130+
model = smp.Unet("resnet34", classes=2)
131+
132+
# Freeze encoder: stops gradient updates and freezes normalization layer stats
133+
model.freeze_encoder()
134+
135+
# Unfreeze encoder: re-enables training for encoder parameters and normalization layers
136+
model.unfreeze_encoder()
137+
138+
.. important::
139+
- Freezing sets ``requires_grad = False`` for all encoder parameters.
140+
- Normalization layers that track running statistics (e.g., BatchNorm and InstanceNorm layers) are set to ``.eval()`` mode to prevent updates to ``running_mean`` and ``running_var``.
141+
- If you later call ``model.train()``, frozen encoders will remain frozen until you call ``unfreeze_encoder()``.

segmentation_models_pytorch/base/model.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ def __new__(cls: Type[T], *args, **kwargs) -> T:
2424
instance = super().__new__(cls, *args, **kwargs)
2525
return instance
2626

27+
def __init__(self):
28+
super().__init__()
29+
self._is_encoder_frozen = False
30+
2731
def initialize(self):
2832
init.initialize_decoder(self.decoder)
2933
init.initialize_head(self.segmentation_head)
@@ -137,3 +141,70 @@ def load_state_dict(self, state_dict, **kwargs):
137141
warnings.warn(text, stacklevel=-1)
138142

139143
return super().load_state_dict(state_dict, **kwargs)
144+
145+
def train(self, mode: bool = True):
146+
"""Set the module in training mode.
147+
148+
This method behaves like the standard :meth:`torch.nn.Module.train`,
149+
with one exception: if the encoder has been frozen via
150+
:meth:`freeze_encoder`, then its normalization layers are not affected
151+
by this call. In other words, calling ``model.train()`` will not
152+
re-enable updates to frozen encoder normalization layers
153+
(e.g., BatchNorm, InstanceNorm).
154+
155+
To restore the encoder to normal training behavior, use
156+
:meth:`unfreeze_encoder`.
157+
158+
Args:
159+
mode (bool): whether to set training mode (``True``) or evaluation
160+
mode (``False``). Default: ``True``.
161+
162+
Returns:
163+
Module: self
164+
"""
165+
if not isinstance(mode, bool):
166+
raise ValueError("training mode is expected to be boolean")
167+
self.training = mode
168+
for name, module in self.named_children():
169+
# skip encoder if it is frozen
170+
if self._is_encoder_frozen and name == "encoder":
171+
continue
172+
module.train(mode)
173+
return self
174+
175+
def _set_encoder_trainable(self, mode: bool):
176+
for param in self.encoder.parameters():
177+
param.requires_grad = mode
178+
179+
for module in self.encoder.modules():
180+
# _NormBase is the common base of classes like _InstanceNorm
181+
# and _BatchNorm that track running stats
182+
if isinstance(module, torch.nn.modules.batchnorm._NormBase):
183+
module.train(mode)
184+
185+
self._is_encoder_frozen = not mode
186+
187+
def freeze_encoder(self):
188+
"""
189+
Freeze encoder parameters and disable updates to normalization
190+
layer statistics.
191+
192+
This method:
193+
- Sets ``requires_grad = False`` for all encoder parameters,
194+
preventing them from being updated during backpropagation.
195+
- Puts normalization layers that track running statistics
196+
(e.g., BatchNorm, InstanceNorm) into evaluation mode (``.eval()``),
197+
so their ``running_mean`` and ``running_var`` are no longer updated.
198+
"""
199+
return self._set_encoder_trainable(False)
200+
201+
def unfreeze_encoder(self):
202+
"""
203+
Unfreeze encoder parameters and restore normalization layers to training mode.
204+
205+
This method reverts the effect of :meth:`freeze_encoder`. Specifically:
206+
- Sets ``requires_grad=True`` for all encoder parameters.
207+
- Restores normalization layers (e.g. BatchNorm, InstanceNorm) to training mode,
208+
so their running statistics are updated again.
209+
"""
210+
return self._set_encoder_trainable(True)

tests/base/test_freeze_encoder.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
import segmentation_models_pytorch as smp
3+
4+
5+
def test_freeze_and_unfreeze_encoder():
6+
model = smp.Unet(encoder_name="resnet18", encoder_weights=None)
7+
8+
def assert_encoder_params_trainable(expected: bool):
9+
assert all(p.requires_grad == expected for p in model.encoder.parameters())
10+
11+
def assert_norm_layers_training(expected: bool):
12+
for m in model.encoder.modules():
13+
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
14+
assert m.training == expected
15+
16+
# Initially, encoder params should be trainable
17+
model.train()
18+
assert_encoder_params_trainable(True)
19+
20+
# Freeze encoder
21+
model.freeze_encoder()
22+
assert_encoder_params_trainable(False)
23+
assert_norm_layers_training(False)
24+
25+
# Call train() and ensure encoder norm layers stay frozen
26+
model.train()
27+
assert_norm_layers_training(False)
28+
29+
# Unfreeze encoder
30+
model.unfreeze_encoder()
31+
assert_encoder_params_trainable(True)
32+
assert_norm_layers_training(True)
33+
34+
# Call train() again — should stay trainable
35+
model.train()
36+
assert_norm_layers_training(True)
37+
38+
# Switch to eval, then freeze
39+
model.eval()
40+
model.freeze_encoder()
41+
assert_encoder_params_trainable(False)
42+
assert_norm_layers_training(False)
43+
44+
# Unfreeze again
45+
model.unfreeze_encoder()
46+
assert_encoder_params_trainable(True)
47+
assert_norm_layers_training(True)
48+
49+
50+
def test_freeze_encoder_stops_running_stats():
51+
model = smp.Unet(encoder_name="resnet18", encoder_weights=None)
52+
model.freeze_encoder()
53+
model.train() # overridden train, encoder should remain frozen
54+
bn = None
55+
56+
for m in model.encoder.modules():
57+
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
58+
bn = m
59+
break
60+
61+
assert bn is not None
62+
63+
orig_mean = bn.running_mean.clone()
64+
orig_var = bn.running_var.clone()
65+
66+
x = torch.randn(2, 3, 64, 64)
67+
_ = model(x)
68+
69+
torch.testing.assert_close(orig_mean, bn.running_mean)
70+
torch.testing.assert_close(orig_var, bn.running_var)

0 commit comments

Comments
 (0)