Skip to content

Commit e8852c9

Browse files
committed
Fix deprecation tests
1 parent 846e112 commit e8852c9

File tree

2 files changed

+27
-15
lines changed

2 files changed

+27
-15
lines changed

tests/encoders/test_batchnorm_deprecation.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from segmentation_models_pytorch import create_model
5+
import segmentation_models_pytorch as smp
66
from tests.utils import check_two_models_strictly_equal
77

88

@@ -11,18 +11,21 @@
1111
def test_seg_models_before_after_use_norm(model_name, decoder_option):
1212
torch.manual_seed(42)
1313
with pytest.warns(DeprecationWarning):
14-
model_decoder_batchnorm = create_model(
15-
model_name, "mobilenet_v2", None, decoder_use_batchnorm=decoder_option
14+
model_decoder_batchnorm = smp.create_model(
15+
model_name,
16+
"mobilenet_v2",
17+
encoder_weights=None,
18+
decoder_use_batchnorm=decoder_option,
1619
)
17-
torch.manual_seed(42)
18-
model_decoder_norm = create_model(
20+
model_decoder_norm = smp.create_model(
1921
model_name,
2022
"mobilenet_v2",
21-
None,
22-
decoder_use_batchnorm=None,
23+
encoder_weights=None,
2324
decoder_use_norm=decoder_option,
2425
)
2526

27+
model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict())
28+
2629
check_two_models_strictly_equal(
2730
model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224)
2831
)
@@ -32,17 +35,19 @@ def test_seg_models_before_after_use_norm(model_name, decoder_option):
3235
def test_pspnet_before_after_use_norm(decoder_option):
3336
torch.manual_seed(42)
3437
with pytest.warns(DeprecationWarning):
35-
model_decoder_batchnorm = create_model(
36-
"pspnet", "mobilenet_v2", None, psp_use_batchnorm=decoder_option
38+
model_decoder_batchnorm = smp.create_model(
39+
"pspnet",
40+
"mobilenet_v2",
41+
encoder_weights=None,
42+
psp_use_batchnorm=decoder_option,
3743
)
38-
torch.manual_seed(42)
39-
model_decoder_norm = create_model(
44+
model_decoder_norm = smp.create_model(
4045
"pspnet",
4146
"mobilenet_v2",
42-
None,
43-
psp_use_batchnorm=None,
47+
encoder_weights=None,
4448
decoder_use_norm=decoder_option,
4549
)
50+
model_decoder_norm.load_state_dict(model_decoder_batchnorm.state_dict())
4651

4752
check_two_models_strictly_equal(
4853
model_decoder_batchnorm, model_decoder_norm, torch.rand(1, 3, 224, 224)

tests/utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,14 @@ def check_two_models_strictly_equal(
6767
model_a.state_dict().items(), model_b.state_dict().items()
6868
):
6969
assert k1 == k2, f"Key mismatch: {k1} != {k2}"
70-
assert (v1 == v2).all(), f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}"
70+
torch.testing.assert_close(
71+
v1, v2, msg=f"Tensor mismatch at key '{k1}':\n{v1} !=\n{v2}"
72+
)
7173

74+
model_a.eval()
75+
model_b.eval()
7276
with torch.inference_mode():
73-
assert (model_a(input_data) == model_b(input_data)).all()
77+
output_a = model_a(input_data)
78+
output_b = model_b(input_data)
79+
80+
torch.testing.assert_close(output_a, output_b)

0 commit comments

Comments
 (0)