Skip to content

Commit d8ea35f

Browse files
committed
Update timm_universal.py
1. rename temporary model 2. create temporary model on meta device to speed up
1 parent 330e6e5 commit d8ea35f

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

segmentation_models_pytorch/encoders/timm_universal.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,18 @@ def __init__(
8080
common_kwargs.pop("output_stride")
8181

8282
# Load a temporary model to analyze its feature hierarchy
83-
self.model = timm.create_model(name, features_only=True)
83+
try:
84+
with torch.device("meta"):
85+
tmp_model = timm.create_model(name, features_only=True)
86+
except Exception:
87+
tmp_model = timm.create_model(name, features_only=True)
8488

8589
# Check if model output is in channel-last format (NHWC)
86-
self._is_channel_last = getattr(self.model, "output_fmt", None) == "NHWC"
90+
self._is_channel_last = getattr(tmp_model, "output_fmt", None) == "NHWC"
8791

8892
# Determine the model's downsampling pattern and set hierarchy flags
89-
encoder_stage = len(self.model.feature_info.reduction())
90-
reduction_scales = self.model.feature_info.reduction()
93+
encoder_stage = len(tmp_model.feature_info.reduction())
94+
reduction_scales = tmp_model.feature_info.reduction()
9195

9296
if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]:
9397
# Transformer-style downsampling: scales (4, 8, 16, 32)

0 commit comments

Comments
 (0)