Skip to content

Commit e32de47

Browse files
ricberqubvel
andauthored
Refactor test_freeze_and_unfreeze_encoder
Co-authored-by: Pavel Iakubovskii <[email protected]>
1 parent c9e6e0e commit e32de47

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tests/base/test_freeze_encoder.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,34 @@
44

55
def test_freeze_and_unfreeze_encoder():
66
model = smp.Unet(encoder_name="resnet18", encoder_weights=None)
7-
model.train()
7+
88
# Initially, encoder params should be trainable
9+
model.train()
910
assert all(p.requires_grad for p in model.encoder.parameters())
10-
model.freeze_encoder()
11+
1112
# Check encoder params are frozen
13+
model.freeze_encoder()
14+
1215
assert all(not p.requires_grad for p in model.encoder.parameters())
13-
# Check normalization layers are in eval mode
1416
for m in model.encoder.modules():
1517
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
1618
assert not m.training
19+
1720
# Call train() and ensure encoder norm layers stay frozen
1821
model.train()
1922
for m in model.encoder.modules():
2023
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
2124
assert not m.training
22-
model.unfreeze_encoder()
25+
2326
# Params should be trainable again
27+
model.unfreeze_encoder()
28+
2429
assert all(p.requires_grad for p in model.encoder.parameters())
2530
# Norm layers should go back to training mode after unfreeze
2631
for m in model.encoder.modules():
2732
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
2833
assert m.training
34+
2935
model.train()
3036
# Norm layers should have the same training mode after train()
3137
for m in model.encoder.modules():

0 commit comments

Comments
 (0)