-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add encoder_freeze method to SegmentationModel #1220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
fd6a900
92b7e1c
32f3c07
a76967c
9f51206
a699d67
c9e6e0e
e32de47
77fd490
ffa3c5a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -137,3 +137,29 @@ def load_state_dict(self, state_dict, **kwargs): | |
warnings.warn(text, stacklevel=-1) | ||
|
||
return super().load_state_dict(state_dict, **kwargs) | ||
|
||
def encoder_freeze(self): | ||
""" | ||
Freeze the encoder parameters and normalization layers. | ||
|
||
This method sets ``requires_grad = False`` for all encoder parameters, | ||
preventing them from being updated during backpropagation. In addition, | ||
it switches BatchNorm and InstanceNorm layers into evaluation mode | ||
(``.eval()``), which stops updates to their running statistics | ||
(``running_mean`` and ``running_var``) during training. | ||
|
||
**Important:** If you call :meth:`model.train()` after | ||
:meth:`encoder_freeze`, the encoder’s BatchNorm/InstanceNorm layers | ||
will be put back into training mode, and their running statistics | ||
will be updated again. To re-freeze them after switching modes, | ||
call :meth:`encoder_freeze` again. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make the user's experience better, we can override def train(self, mode: bool = True) -> Self:
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
self.training = mode
for name, module in self.named_children():
# skip encoder encoder if it is frozen
if self._is_encoder_frozen and name == "encoder":
continue
module.train(mode)
return self There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in 92b7e1c |
||
""" | ||
for param in self.encoder.parameters(): | ||
param.requires_grad = False | ||
|
||
# Putting norm layers into eval mode stops running stats updates | ||
for module in self.encoder.modules(): | ||
# _NormBase is the common base of _InstanceNorm and _BatchNorm classes | ||
# These are the two classes that track running stats | ||
if isinstance(module, torch.nn.modules.batchnorm._NormBase): | ||
module.eval() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We would also need a method to unfreeze the encoder, which will revert these ops, it would better to define one common method def _set_encoder_trainable(..
def freeze_encoder(...
return self._set_encoder_trainable(False)
def unfreeze_encoder(...
return self._set_encoder_trainable(True) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in 92b7e1c |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's name it
freeze_encoder
, what do you think?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, I kept the name I found in the repo, but freeze_encoder would be better ;) Addressed in 92b7e1c