Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,29 @@ def load_state_dict(self, state_dict, **kwargs):
warnings.warn(text, stacklevel=-1)

return super().load_state_dict(state_dict, **kwargs)

def encoder_freeze(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's name it freeze_encoder, what do you think?

Copy link
Contributor Author

@ricber ricber Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I kept the name I found in the repo, but freeze_encoder would be better ;) Addressed in 92b7e1c

"""
Freeze the encoder parameters and normalization layers.

This method sets ``requires_grad = False`` for all encoder parameters,
preventing them from being updated during backpropagation. In addition,
it switches BatchNorm and InstanceNorm layers into evaluation mode
(``.eval()``), which stops updates to their running statistics
(``running_mean`` and ``running_var``) during training.

**Important:** If you call :meth:`model.train()` after
:meth:`encoder_freeze`, the encoder’s BatchNorm/InstanceNorm layers
will be put back into training mode, and their running statistics
will be updated again. To re-freeze them after switching modes,
call :meth:`encoder_freeze` again.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make the user's experience better, we can override train method for the model similar to the original one

    def train(self, mode: bool = True) -> Self:
        if not isinstance(mode, bool):
            raise ValueError("training mode is expected to be boolean")
        self.training = mode
        for name, module in self.named_children():
            # skip encoder encoder if it is frozen
            if self._is_encoder_frozen and name == "encoder":
                continue
            module.train(mode)
        return self

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 92b7e1c

"""
for param in self.encoder.parameters():
param.requires_grad = False

# Putting norm layers into eval mode stops running stats updates
for module in self.encoder.modules():
# _NormBase is the common base of _InstanceNorm and _BatchNorm classes
# These are the two classes that track running stats
if isinstance(module, torch.nn.modules.batchnorm._NormBase):
module.eval()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would also need a method to unfreeze the encoder, which will revert these ops, it would better to define one common method

def _set_encoder_trainable(..

def freeze_encoder(...
    return self._set_encoder_trainable(False)
    
def unfreeze_encoder(...
    return self._set_encoder_trainable(True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 92b7e1c