1010import logging
1111import math
1212from collections import OrderedDict
13- from typing import List
13+ from typing import List , Callable
1414
1515from timm import create_model
1616from timm .models .layers import create_conv2d , drop_path , create_pool2d , Swish , get_act_layer
@@ -440,6 +440,17 @@ def _init_weight_alt(m, n='', ):
440440 m .bias .data .zero_ ()
441441
442442
443+ def get_feature_info (backbone ):
444+ if isinstance (backbone .feature_info , Callable ):
445+ # old accessor for timm versions <= 0.1.30, efficientnet and mobilenetv3 and related nets only
446+ feature_info = [dict (num_chs = f ['num_chs' ], reduction = f ['reduction' ])
447+ for i , f in enumerate (backbone .feature_info ())]
448+ else :
449+ # new feature info accessor, timm >= 0.2, all models supported
450+ feature_info = backbone .feature_info .get_dicts (keys = ['num_chs' , 'reduction' ])
451+ return feature_info
452+
453+
443454class EfficientDet (nn .Module ):
444455
445456 def __init__ (self , config , norm_kwargs = None , pretrained_backbone = True , alternate_init = False ):
@@ -448,8 +459,7 @@ def __init__(self, config, norm_kwargs=None, pretrained_backbone=True, alternate
448459 self .backbone = create_model (
449460 config .backbone_name , features_only = True , out_indices = (2 , 3 , 4 ),
450461 pretrained = pretrained_backbone , ** config .backbone_args )
451- feature_info = [dict (num_chs = f ['num_chs' ], reduction = f ['reduction' ])
452- for i , f in enumerate (self .backbone .feature_info ())]
462+ feature_info = get_feature_info (self .backbone )
453463
454464 act_layer = get_act_layer (config .act_type )
455465 self .fpn = BiFpn (config , feature_info , norm_kwargs = norm_kwargs , act_layer = act_layer )
0 commit comments