File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed
segmentation_models_pytorch/encoders Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -80,14 +80,18 @@ def __init__(
80
80
common_kwargs .pop ("output_stride" )
81
81
82
82
# Load a temporary model to analyze its feature hierarchy
83
- self .model = timm .create_model (name , features_only = True )
83
+ try :
84
+ with torch .device ("meta" ):
85
+ tmp_model = timm .create_model (name , features_only = True )
86
+ except Exception :
87
+ tmp_model = timm .create_model (name , features_only = True )
84
88
85
89
# Check if model output is in channel-last format (NHWC)
86
- self ._is_channel_last = getattr (self . model , "output_fmt" , None ) == "NHWC"
90
+ self ._is_channel_last = getattr (tmp_model , "output_fmt" , None ) == "NHWC"
87
91
88
92
# Determine the model's downsampling pattern and set hierarchy flags
89
- encoder_stage = len (self . model .feature_info .reduction ())
90
- reduction_scales = self . model .feature_info .reduction ()
93
+ encoder_stage = len (tmp_model .feature_info .reduction ())
94
+ reduction_scales = tmp_model .feature_info .reduction ()
91
95
92
96
if reduction_scales == [2 ** (i + 2 ) for i in range (encoder_stage )]:
93
97
# Transformer-style downsampling: scales (4, 8, 16, 32)
You can’t perform that action at this time.
0 commit comments