Skip to content

Commit a699d67

Browse files
committed
Add tests for encoder freezing
1 parent 9f51206 commit a699d67

File tree

1 file changed

+54
-0
lines changed

1 file changed

+54
-0
lines changed

tests/base/test_freeze_encoder.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
import segmentation_models_pytorch as smp
3+
4+
5+
def test_freeze_and_unfreeze_encoder():
6+
model = smp.Unet(encoder_name="resnet18", encoder_weights=None)
7+
model.train()
8+
# Initially, encoder params should be trainable
9+
assert all(p.requires_grad for p in model.encoder.parameters())
10+
model.freeze_encoder()
11+
# Check encoder params are frozen
12+
assert all(not p.requires_grad for p in model.encoder.parameters())
13+
# Check normalization layers are in eval mode
14+
for m in model.encoder.modules():
15+
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
16+
assert not m.training
17+
# Call train() and ensure encoder norm layers stay frozen
18+
model.train()
19+
for m in model.encoder.modules():
20+
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
21+
assert not m.training
22+
model.unfreeze_encoder()
23+
# Params should be trainable again
24+
assert all(p.requires_grad for p in model.encoder.parameters())
25+
# Norm layers should go back to training mode after unfreeze
26+
for m in model.encoder.modules():
27+
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
28+
assert m.training
29+
model.train()
30+
# Norm layers should have the same training mode after train()
31+
for m in model.encoder.modules():
32+
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
33+
assert m.training
34+
35+
36+
def test_freeze_encoder_stops_running_stats():
37+
model = smp.Unet(encoder_name="resnet18", encoder_weights=None)
38+
model.freeze_encoder()
39+
model.train() # overridden train, encoder should remain frozen
40+
bn = None
41+
42+
for m in model.encoder.modules():
43+
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
44+
bn = m
45+
break
46+
47+
assert bn is not None
48+
49+
orig_mean = bn.running_mean.clone()
50+
51+
x = torch.randn(2, 3, 64, 64)
52+
_ = model(x)
53+
54+
torch.testing.assert_close(orig_mean, bn.running_mean)

0 commit comments

Comments
 (0)