11"""
2- TimmUniversalEncoder provides a unified feature extraction interface built on the
3- `timm` library, supporting various backbone architectures, including traditional
4- CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
5- (e.g., Swin Transformer, ConvNeXt).
2+ TimmUniversalEncoder provides a unified feature extraction interface built on the
3+ `timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
4+ models (e.g., Swin Transformer, ConvNeXt).
65
7- This encoder produces standardized multi-level feature maps, facilitating integration
8- with semantic segmentation tasks. It allows configuring the number of feature extraction
9- stages (`depth`) and adjusting `output_stride` when supported.
6+ This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
7+ It allows configuring the number of feature extraction stages (`depth`) and adjusting
8+ `output_stride` when supported.
109
1110Key Features:
12- - Flexible model selection through `timm.create_model`.
13- - A unified interface that outputs consistent, multi-level features even if the
14- underlying model differs in its feature hierarchy.
15- - Automatic alignment: If a model lacks certain early-stage features (for example,
16- modern architectures that start from a 1/4 scale rather than 1/2 scale), the encoder
17- inserts dummy features to maintain consistency with traditional CNN structures.
18- - Easy access to channel information: Use the `out_channels` property to retrieve
19- the number of channels at each feature stage.
11+ - Flexible model selection using `timm.create_model`.
12+ - Unified multi-level output across different model hierarchies.
13+ - Automatic alignment for inconsistent feature scales:
14+ - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
15+ - VGG-style models (include scale-1 features): Align outputs for compatibility.
16+ - Easy access to feature scale information via the `reduction` property.
2017
2118Feature Scale Differences:
22- - Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
23- and 1/32 scales.
24- - Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
25- start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
26- feature. TimmUniversalEncoder compensates for this omission to ensure a unified
27- multi-stage output.
19+ - Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
20+ - Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
21+ - VGG-style models: Include scale-1 features (input resolution).
2822
2923Notes:
30- - Not all models support modifying `output_stride` (especially transformer-based or
31- transformer-like models).
32- - Certain models (e.g., TResNet, DLA) require special handling to ensure correct
33- feature indexing.
24+ - `output_stride` is unsupported in some models, especially transformer-based architectures.
25+ - Special handling for models like TResNet and DLA to ensure correct feature indexing.
26+ - VGG-style models use `_is_skip_first` to align scale-1 features with standard outputs.
3427"""
3528
3629from typing import Any
4235
4336class TimmUniversalEncoder (nn .Module ):
4437 """
45- A universal encoder built on the `timm` library, designed to adapt to a wide variety of
46- model architectures, including both traditional CNNs and those that follow a
47- transformer-like hierarchy.
38+ A universal encoder leveraging the `timm` library for feature extraction from
39+ various model architectures, including traditional-style and transformer-style models.
4840
4941 Features:
50- - Supports flexible depth and output stride for feature extraction .
51- - Automatically adjusts to input/output channel structures based on the model type .
52- - Compatible with both convolutional and transformer-like encoders .
42+ - Supports configurable depth and output stride.
43+ - Ensures consistent multi-level feature extraction across diverse models .
44+ - Compatible with convolutional and transformer-like backbones .
5345 """
5446
5547 def __init__ (
@@ -65,15 +57,16 @@ def __init__(
6557 Initialize the encoder.
6658
6759 Args:
68- name (str): Name of the model to be loaded from the `timm` library .
69- pretrained (bool): If True, loads pretrained weights.
60+ name (str): Model name to load from `timm`.
61+ pretrained (bool): Load pretrained weights (default: True) .
7062 in_channels (int): Number of input channels (default: 3 for RGB).
71- depth (int): Number of feature extraction stages (default: 5).
63+ depth (int): Number of feature stages to extract (default: 5).
7264 output_stride (int): Desired output stride (default: 32).
73- **kwargs: Additional keyword arguments for `timm.create_model`.
65+ **kwargs: Additional arguments passed to `timm.create_model`.
7466 """
7567 super ().__init__ ()
7668
69+ # Default model configuration for feature extraction
7770 common_kwargs = dict (
7871 in_chans = in_channels ,
7972 features_only = True ,
@@ -82,24 +75,37 @@ def __init__(
8275 out_indices = tuple (range (depth )),
8376 )
8477
85- # not all models support output stride argument, drop it by default
78+ # Not all models support output stride argument, drop it by default
8679 if output_stride == 32 :
8780 common_kwargs .pop ("output_stride" )
8881
89- # Load a preliminary model to determine its feature hierarchy structure.
82+ # Load a temporary model to analyze its feature hierarchy
9083 self .model = timm .create_model (name , features_only = True )
9184
92- # Check if the model's output is in channel-last format (B, H, W, C).
85+ # Check if model output is in channel-last format (NHWC)
9386 self ._is_channel_last = getattr (self .model , "output_fmt" , None ) == "NHWC"
9487
95- # Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
96- # rather than a traditional CNN hierarchy (starting at 1/2 scale).
97- if len (self .model .feature_info .channels ()) == 5 :
88+ # Determine the model's downsampling pattern and set hierarchy flags
89+ encoder_stage = len (self .model .feature_info .reduction ())
90+ reduction_scales = self .model .feature_info .reduction ()
91+
92+ if reduction_scales == [2 ** (i + 2 ) for i in range (encoder_stage )]:
93+ # Transformer-style downsampling: scales (4, 8, 16, 32)
94+ self ._is_transformer_style = True
95+ self ._is_skip_first = False
96+ elif reduction_scales == [2 ** (i + 1 ) for i in range (encoder_stage )]:
97+ # Traditional-style downsampling: scales (2, 4, 8, 16, 32)
9898 self ._is_transformer_style = False
99+ self ._is_skip_first = False
100+ elif reduction_scales == [2 ** i for i in range (encoder_stage )]:
101+ # Models including scale 1: scales (1, 2, 4, 8, 16, 32)
102+ self ._is_transformer_style = False
103+ self ._is_skip_first = True
99104 else :
100- self . _is_transformer_style = True
105+ raise ValueError ( "Unsupported model downsampling pattern." )
101106
102107 if self ._is_transformer_style :
108+ # Transformer-like models (start at scale 4)
103109 if "tresnet" in name :
104110 # 'tresnet' models start feature extraction at stage 1,
105111 # so out_indices=(1, 2, 3, 4) for depth=5.
@@ -119,65 +125,84 @@ def __init__(
119125 if "dla" in name :
120126 # For 'dla' models, out_indices starts at 0 and matches the input size.
121127 common_kwargs ["out_indices" ] = tuple (range (1 , depth + 1 ))
128+ if self ._is_skip_first :
129+ common_kwargs ["out_indices" ] = tuple (range (depth + 1 ))
122130
123131 self .model = timm .create_model (
124132 name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
125133 )
126- self ._out_channels = [in_channels ] + self .model .feature_info .channels ()
134+
135+ if self ._is_skip_first :
136+ self ._out_channels = self .model .feature_info .channels ()
137+ else :
138+ self ._out_channels = [in_channels ] + self .model .feature_info .channels ()
127139
128140 self ._in_channels = in_channels
129141 self ._depth = depth
130142 self ._output_stride = output_stride
131143
132144 def forward (self , x : torch .Tensor ) -> list [torch .Tensor ]:
133145 """
134- Pass the input through the encoder and return extracted features.
146+ Forward pass to extract multi-stage features.
135147
136148 Args:
137149 x (torch.Tensor): Input tensor of shape (B, C, H, W).
138150
139151 Returns:
140- list[torch.Tensor]: A list of feature maps extracted at various scales.
152+ list[torch.Tensor]: List of feature maps at different scales.
141153 """
142154 features = self .model (x )
143155
156+ # Convert NHWC to NCHW if needed
144157 if self ._is_channel_last :
145- # Convert to channel-first (B, C, H, W).
146158 features = [
147159 feature .permute (0 , 3 , 1 , 2 ).contiguous () for feature in features
148160 ]
149161
162+ # Add dummy feature for scale 1/2 if missing (transformer-style models)
150163 if self ._is_transformer_style :
151- # Models using a transformer-like hierarchy may not generate
152- # all expected feature maps. Insert a dummy feature map to ensure
153- # compatibility with decoders expecting a 5-level pyramid.
154164 B , _ , H , W = x .shape
155165 dummy = torch .empty ([B , 0 , H // 2 , W // 2 ], dtype = x .dtype , device = x .device )
156- features = [x ] + [dummy ] + features
157- else :
166+ features = [dummy ] + features
167+
168+ # Add input tensor as scale 1 feature if `self._is_skip_first` is False
169+ if not self ._is_skip_first :
158170 features = [x ] + features
159171
160172 return features
161173
162174 @property
163175 def out_channels (self ) -> list [int ]:
164176 """
177+ Returns the number of output channels for each feature stage.
178+
165179 Returns:
166- list[int]: A list of output channels for each stage of the encoder,
167- including the input channels at the first stage.
180+ list[int]: A list of channel dimensions at each scale.
168181 """
169182 return self ._out_channels
170183
171184 @property
172185 def output_stride (self ) -> int :
173186 """
187+ Returns the effective output stride based on the model depth.
188+
174189 Returns:
175- int: The effective output stride of the encoder, considering the depth .
190+ int: The effective output stride.
176191 """
177192 return min (self ._output_stride , 2 ** self ._depth )
178193
179194
180195def _merge_kwargs_no_duplicates (a : dict [str , Any ], b : dict [str , Any ]) -> dict [str , Any ]:
196+ """
197+ Merge two dictionaries, ensuring no duplicate keys exist.
198+
199+ Args:
200+ a (dict): Base dictionary.
201+ b (dict): Additional parameters to merge.
202+
203+ Returns:
204+ dict: A merged dictionary.
205+ """
181206 duplicates = a .keys () & b .keys ()
182207 if duplicates :
183208 raise ValueError (f"'{ duplicates } ' already specified internally" )
0 commit comments