Skip to content

Commit 343fbe0

Browse files
committed
Fix test
1 parent 83b9655 commit 343fbe0

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

tests/models/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def test_in_channels_and_depth_and_out_classes(
9999
if self.model_type in ["unet", "unetplusplus", "manet"]:
100100
kwargs = {"decoder_channels": self.decoder_channels[:depth]}
101101

102+
if self.model_type == "dpt":
103+
kwargs = {"decoder_intermediate_channels": self.decoder_channels[:depth]}
104+
102105
model = (
103106
smp.create_model(
104107
arch=self.model_type,

tests/models/test_dpt.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import inspect
23
import torch
34
import segmentation_models_pytorch as smp
45

@@ -22,6 +23,11 @@ class TestDPTModel(base.BaseModelTester):
2223

2324
compile_dynamic = False
2425

26+
@property
27+
def decoder_channels(self):
28+
signature = inspect.signature(self.model_class)
29+
return signature.parameters["decoder_intermediate_channels"].default
30+
2531
@property
2632
def hub_checkpoint(self):
2733
return "smp-test-models/dpt-tu-test_vit"

0 commit comments

Comments
 (0)