File tree Expand file tree Collapse file tree 1 file changed +8
-4
lines changed
segmentation_models_pytorch/encoders Expand file tree Collapse file tree 1 file changed +8
-4
lines changed Original file line number Diff line number Diff line change @@ -80,14 +80,18 @@ def __init__(
8080 common_kwargs .pop ("output_stride" )
8181
8282 # Load a temporary model to analyze its feature hierarchy
83- self .model = timm .create_model (name , features_only = True )
83+ try :
84+ with torch .device ("meta" ):
85+ tmp_model = timm .create_model (name , features_only = True )
86+ except Exception :
87+ tmp_model = timm .create_model (name , features_only = True )
8488
8589 # Check if model output is in channel-last format (NHWC)
86- self ._is_channel_last = getattr (self . model , "output_fmt" , None ) == "NHWC"
90+ self ._is_channel_last = getattr (tmp_model , "output_fmt" , None ) == "NHWC"
8791
8892 # Determine the model's downsampling pattern and set hierarchy flags
89- encoder_stage = len (self . model .feature_info .reduction ())
90- reduction_scales = self . model .feature_info .reduction ()
93+ encoder_stage = len (tmp_model .feature_info .reduction ())
94+ reduction_scales = tmp_model .feature_info .reduction ()
9195
9296 if reduction_scales == [2 ** (i + 2 ) for i in range (encoder_stage )]:
9397 # Transformer-style downsampling: scales (4, 8, 16, 32)
You can’t perform that action at this time.
0 commit comments