1+ """
2+ TimmUniversalEncoder provides a unified feature extraction interface built on the
3+ `timm` library, supporting various backbone architectures, including traditional
4+ CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
5+ (e.g., Swin Transformer, ConvNeXt).
6+
7+ This encoder produces standardized multi-level feature maps, facilitating integration
8+ with semantic segmentation tasks. It allows configuring the number of feature extraction
9+ stages (`depth`) and adjusting `output_stride` when supported.
10+
11+ Key Features:
12+ - Flexible model selection through `timm.create_model`.
13+ - A unified interface that outputs consistent, multi-level features even if the
14+ underlying model differs in its feature hierarchy.
15+ - Automatic alignment: If a model lacks certain early-stage features (for example,
16+ modern architectures that start from a 1/4 scale rather than 1/2 scale), the encoder
17+ inserts dummy features to maintain consistency with traditional CNN structures.
18+ - Easy access to channel information: Use the `out_channels` property to retrieve
19+ the number of channels at each feature stage.
20+
21+ Feature Scale Differences:
22+ - Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
23+ and 1/32 scales.
24+ - Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
25+ start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
26+ feature. TimmUniversalEncoder compensates for this omission to ensure a unified
27+ multi-stage output.
28+
29+ Notes:
30+ - Not all models support modifying `output_stride` (especially transformer-based or
31+ transformer-like models).
32+ - Certain models (e.g., TResNet, DLA) require special handling to ensure correct
33+ feature indexing.
34+ - Most `timm` models output features in (B, C, H, W) format. However, some
35+ (e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is
36+ currently unsupported.
37+ """
38+
139from typing import Any
240
341import timm
42+ import torch
443import torch .nn as nn
544
645
746class TimmUniversalEncoder (nn .Module ):
47+ """
48+ A universal encoder built on the `timm` library, designed to adapt to a wide variety of
49+ model architectures, including both traditional CNNs and those that follow a
50+ transformer-like hierarchy.
51+
52+ Features:
53+ - Supports flexible depth and output stride for feature extraction.
54+ - Automatically adjusts to input/output channel structures based on the model type.
55+ - Compatible with both convolutional and transformer-like encoders.
56+ """
57+
858 def __init__ (
959 self ,
1060 name : str ,
@@ -14,7 +64,19 @@ def __init__(
1464 output_stride : int = 32 ,
1565 ** kwargs : dict [str , Any ],
1666 ):
67+ """
68+ Initialize the encoder.
69+
70+ Args:
71+ name (str): Name of the model to be loaded from the `timm` library.
72+ pretrained (bool): If True, loads pretrained weights.
73+ in_channels (int): Number of input channels (default: 3 for RGB).
74+ depth (int): Number of feature extraction stages (default: 5).
75+ output_stride (int): Desired output stride (default: 32).
76+ **kwargs: Additional keyword arguments for `timm.create_model`.
77+ """
1778 super ().__init__ ()
79+
1880 common_kwargs = dict (
1981 in_chans = in_channels ,
2082 features_only = True ,
@@ -23,30 +85,90 @@ def __init__(
2385 out_indices = tuple (range (depth )),
2486 )
2587
26- # not all models support output stride argument, drop it by default
2788 if output_stride == 32 :
2889 common_kwargs .pop ("output_stride" )
2990
30- self .model = timm .create_model (
31- name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
32- )
91+ # Load a preliminary model to determine its feature hierarchy structure.
92+ self .model = timm .create_model (name , features_only = True )
93+
94+ # Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
95+ # rather than a traditional CNN hierarchy (starting at 1/2 scale).
96+ if len (self .model .feature_info .channels ()) == 5 :
97+ # This indicates a traditional hierarchy: (1/2, 1/4, 1/8, 1/16, 1/32)
98+ self ._is_transformer_style = False
99+ else :
100+ # This indicates a transformer-like hierarchy: (1/4, 1/8, 1/16, 1/32)
101+ self ._is_transformer_style = True
102+
103+ if self ._is_transformer_style :
104+ if "tresnet" in name :
105+ # 'tresnet' models start feature extraction at stage 1,
106+ # so out_indices=(1, 2, 3, 4) for depth=5.
107+ common_kwargs ["out_indices" ] = tuple (range (1 , depth ))
108+ else :
109+ # Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5.
110+ common_kwargs ["out_indices" ] = tuple (range (depth - 1 ))
111+
112+ self .model = timm .create_model (
113+ name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
114+ )
115+ # Add a dummy output channel (0) to align with traditional encoder structures.
116+ self ._out_channels = (
117+ [in_channels ] + [0 ] + self .model .feature_info .channels ()
118+ )
119+ else :
120+ if "dla" in name :
121+ # For 'dla' models, out_indices starts at 0 and matches the input size.
122+ kwargs ["out_indices" ] = tuple (range (1 , depth + 1 ))
123+
124+ self .model = timm .create_model (
125+ name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
126+ )
127+ self ._out_channels = [in_channels ] + self .model .feature_info .channels ()
33128
34129 self ._in_channels = in_channels
35- self ._out_channels = [in_channels ] + self .model .feature_info .channels ()
36130 self ._depth = depth
37131 self ._output_stride = output_stride
38132
39- def forward (self , x ):
133+ def forward (self , x : torch .Tensor ) -> list [torch .Tensor ]:
134+ """
135+ Pass the input through the encoder and return extracted features.
136+
137+ Args:
138+ x (torch.Tensor): Input tensor of shape (B, C, H, W).
139+
140+ Returns:
141+ List[torch.Tensor]: A list of feature maps extracted at various scales.
142+ """
40143 features = self .model (x )
41- features = [x ] + features
144+
145+ if self ._is_transformer_style :
146+ # Models using a transformer-like hierarchy may not generate
147+ # all expected feature maps. Insert a dummy feature map to ensure
148+ # compatibility with decoders expecting a 5-level pyramid.
149+ B , _ , H , W = x .shape
150+ dummy = torch .empty ([B , 0 , H // 2 , W // 2 ], dtype = x .dtype , device = x .device )
151+ features = [x ] + [dummy ] + features
152+ else :
153+ features = [x ] + features
154+
42155 return features
43156
44157 @property
45- def out_channels (self ):
158+ def out_channels (self ) -> list [int ]:
159+ """
160+ Returns:
161+ List[int]: A list of output channels for each stage of the encoder,
162+ including the input channels at the first stage.
163+ """
46164 return self ._out_channels
47165
48166 @property
49- def output_stride (self ):
167+ def output_stride (self ) -> int :
168+ """
169+ Returns:
170+ int: The effective output stride of the encoder, considering the depth.
171+ """
50172 return min (self ._output_stride , 2 ** self ._depth )
51173
52174
0 commit comments