File tree Expand file tree Collapse file tree 1 file changed +10
-4
lines changed Expand file tree Collapse file tree 1 file changed +10
-4
lines changed Original file line number Diff line number Diff line change 44
55def test_freeze_and_unfreeze_encoder ():
66 model = smp .Unet (encoder_name = "resnet18" , encoder_weights = None )
7- model . train ()
7+
88 # Initially, encoder params should be trainable
9+ model .train ()
910 assert all (p .requires_grad for p in model .encoder .parameters ())
10- model . freeze_encoder ()
11+
1112 # Check encoder params are frozen
13+ model .freeze_encoder ()
14+
1215 assert all (not p .requires_grad for p in model .encoder .parameters ())
13- # Check normalization layers are in eval mode
1416 for m in model .encoder .modules ():
1517 if isinstance (m , torch .nn .modules .batchnorm ._NormBase ):
1618 assert not m .training
19+
1720 # Call train() and ensure encoder norm layers stay frozen
1821 model .train ()
1922 for m in model .encoder .modules ():
2023 if isinstance (m , torch .nn .modules .batchnorm ._NormBase ):
2124 assert not m .training
22- model . unfreeze_encoder ()
25+
2326 # Params should be trainable again
27+ model .unfreeze_encoder ()
28+
2429 assert all (p .requires_grad for p in model .encoder .parameters ())
2530 # Norm layers should go back to training mode after unfreeze
2631 for m in model .encoder .modules ():
2732 if isinstance (m , torch .nn .modules .batchnorm ._NormBase ):
2833 assert m .training
34+
2935 model .train ()
3036 # Norm layers should have the same training mode after train()
3137 for m in model .encoder .modules ():
You can’t perform that action at this time.
0 commit comments