Skip to content

Commit 77fd490

Browse files
committed
Refactor test and add call to eval
1 parent e32de47 commit 77fd490

File tree

1 file changed

+33
-25
lines changed

1 file changed

+33
-25
lines changed

tests/base/test_freeze_encoder.py

Lines changed: 33 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,39 +4,47 @@
44

55
def test_freeze_and_unfreeze_encoder():
66
model = smp.Unet(encoder_name="resnet18", encoder_weights=None)
7-
7+
8+
def assert_encoder_params_trainable(expected: bool):
9+
assert all(p.requires_grad == expected for p in model.encoder.parameters())
10+
11+
def assert_norm_layers_training(expected: bool):
12+
for m in model.encoder.modules():
13+
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
14+
assert m.training == expected
15+
816
# Initially, encoder params should be trainable
917
model.train()
10-
assert all(p.requires_grad for p in model.encoder.parameters())
11-
12-
# Check encoder params are frozen
18+
assert_encoder_params_trainable(True)
19+
20+
# Freeze encoder
1321
model.freeze_encoder()
14-
15-
assert all(not p.requires_grad for p in model.encoder.parameters())
16-
for m in model.encoder.modules():
17-
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
18-
assert not m.training
22+
assert_encoder_params_trainable(False)
23+
assert_norm_layers_training(False)
1924

2025
# Call train() and ensure encoder norm layers stay frozen
2126
model.train()
22-
for m in model.encoder.modules():
23-
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
24-
assert not m.training
25-
26-
# Params should be trainable again
27+
assert_norm_layers_training(False)
28+
29+
# Unfreeze encoder
2730
model.unfreeze_encoder()
28-
29-
assert all(p.requires_grad for p in model.encoder.parameters())
30-
# Norm layers should go back to training mode after unfreeze
31-
for m in model.encoder.modules():
32-
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
33-
assert m.training
34-
31+
assert_encoder_params_trainable(True)
32+
assert_norm_layers_training(True)
33+
34+
# Call train() again — should stay trainable
3535
model.train()
36-
# Norm layers should have the same training mode after train()
37-
for m in model.encoder.modules():
38-
if isinstance(m, torch.nn.modules.batchnorm._NormBase):
39-
assert m.training
36+
assert_norm_layers_training(True)
37+
38+
# Switch to eval, then freeze
39+
model.eval()
40+
model.freeze_encoder()
41+
assert_encoder_params_trainable(False)
42+
assert_norm_layers_training(False)
43+
44+
# Unfreeze again
45+
model.unfreeze_encoder()
46+
assert_encoder_params_trainable(True)
47+
assert_norm_layers_training(True)
4048

4149

4250
def test_freeze_encoder_stops_running_stats():

0 commit comments

Comments
 (0)