1
+ """
2
+ TimmUniversalEncoder provides a unified feature extraction interface built on the
3
+ `timm` library, supporting various backbone architectures, including traditional
4
+ CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
5
+ (e.g., Swin Transformer, ConvNeXt).
6
+
7
+ This encoder produces standardized multi-level feature maps, facilitating integration
8
+ with semantic segmentation tasks. It allows configuring the number of feature extraction
9
+ stages (`depth`) and adjusting `output_stride` when supported.
10
+
11
+ Key Features:
12
+ - Flexible model selection through `timm.create_model`.
13
+ - A unified interface that outputs consistent, multi-level features even if the
14
+ underlying model differs in its feature hierarchy.
15
+ - Automatic alignment: If a model lacks certain early-stage features (for example,
16
+ modern architectures that start from a 1/4 scale rather than 1/2 scale), the encoder
17
+ inserts dummy features to maintain consistency with traditional CNN structures.
18
+ - Easy access to channel information: Use the `out_channels` property to retrieve
19
+ the number of channels at each feature stage.
20
+
21
+ Feature Scale Differences:
22
+ - Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
23
+ and 1/32 scales.
24
+ - Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
25
+ start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
26
+ feature. TimmUniversalEncoder compensates for this omission to ensure a unified
27
+ multi-stage output.
28
+
29
+ Notes:
30
+ - Not all models support modifying `output_stride` (especially transformer-based or
31
+ transformer-like models).
32
+ - Certain models (e.g., TResNet, DLA) require special handling to ensure correct
33
+ feature indexing.
34
+ - Most `timm` models output features in (B, C, H, W) format. However, some
35
+ (e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is
36
+ currently unsupported.
37
+ """
38
+
1
39
from typing import Any
2
40
3
41
import timm
42
+ import torch
4
43
import torch .nn as nn
5
44
6
45
7
46
class TimmUniversalEncoder (nn .Module ):
47
+ """
48
+ A universal encoder built on the `timm` library, designed to adapt to a wide variety of
49
+ model architectures, including both traditional CNNs and those that follow a
50
+ transformer-like hierarchy.
51
+
52
+ Features:
53
+ - Supports flexible depth and output stride for feature extraction.
54
+ - Automatically adjusts to input/output channel structures based on the model type.
55
+ - Compatible with both convolutional and transformer-like encoders.
56
+ """
57
+
8
58
def __init__ (
9
59
self ,
10
60
name : str ,
@@ -14,7 +64,19 @@ def __init__(
14
64
output_stride : int = 32 ,
15
65
** kwargs : dict [str , Any ],
16
66
):
67
+ """
68
+ Initialize the encoder.
69
+
70
+ Args:
71
+ name (str): Name of the model to be loaded from the `timm` library.
72
+ pretrained (bool): If True, loads pretrained weights.
73
+ in_channels (int): Number of input channels (default: 3 for RGB).
74
+ depth (int): Number of feature extraction stages (default: 5).
75
+ output_stride (int): Desired output stride (default: 32).
76
+ **kwargs: Additional keyword arguments for `timm.create_model`.
77
+ """
17
78
super ().__init__ ()
79
+
18
80
common_kwargs = dict (
19
81
in_chans = in_channels ,
20
82
features_only = True ,
@@ -23,30 +85,90 @@ def __init__(
23
85
out_indices = tuple (range (depth )),
24
86
)
25
87
26
- # not all models support output stride argument, drop it by default
27
88
if output_stride == 32 :
28
89
common_kwargs .pop ("output_stride" )
29
90
30
- self .model = timm .create_model (
31
- name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
32
- )
91
+ # Load a preliminary model to determine its feature hierarchy structure.
92
+ self .model = timm .create_model (name , features_only = True )
93
+
94
+ # Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
95
+ # rather than a traditional CNN hierarchy (starting at 1/2 scale).
96
+ if len (self .model .feature_info .channels ()) == 5 :
97
+ # This indicates a traditional hierarchy: (1/2, 1/4, 1/8, 1/16, 1/32)
98
+ self ._is_transformer_style = False
99
+ else :
100
+ # This indicates a transformer-like hierarchy: (1/4, 1/8, 1/16, 1/32)
101
+ self ._is_transformer_style = True
102
+
103
+ if self ._is_transformer_style :
104
+ if "tresnet" in name :
105
+ # 'tresnet' models start feature extraction at stage 1,
106
+ # so out_indices=(1, 2, 3, 4) for depth=5.
107
+ common_kwargs ["out_indices" ] = tuple (range (1 , depth ))
108
+ else :
109
+ # Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5.
110
+ common_kwargs ["out_indices" ] = tuple (range (depth - 1 ))
111
+
112
+ self .model = timm .create_model (
113
+ name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
114
+ )
115
+ # Add a dummy output channel (0) to align with traditional encoder structures.
116
+ self ._out_channels = (
117
+ [in_channels ] + [0 ] + self .model .feature_info .channels ()
118
+ )
119
+ else :
120
+ if "dla" in name :
121
+ # For 'dla' models, out_indices starts at 0 and matches the input size.
122
+ kwargs ["out_indices" ] = tuple (range (1 , depth + 1 ))
123
+
124
+ self .model = timm .create_model (
125
+ name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
126
+ )
127
+ self ._out_channels = [in_channels ] + self .model .feature_info .channels ()
33
128
34
129
self ._in_channels = in_channels
35
- self ._out_channels = [in_channels ] + self .model .feature_info .channels ()
36
130
self ._depth = depth
37
131
self ._output_stride = output_stride
38
132
39
- def forward (self , x ):
133
+ def forward (self , x : torch .Tensor ) -> list [torch .Tensor ]:
134
+ """
135
+ Pass the input through the encoder and return extracted features.
136
+
137
+ Args:
138
+ x (torch.Tensor): Input tensor of shape (B, C, H, W).
139
+
140
+ Returns:
141
+ List[torch.Tensor]: A list of feature maps extracted at various scales.
142
+ """
40
143
features = self .model (x )
41
- features = [x ] + features
144
+
145
+ if self ._is_transformer_style :
146
+ # Models using a transformer-like hierarchy may not generate
147
+ # all expected feature maps. Insert a dummy feature map to ensure
148
+ # compatibility with decoders expecting a 5-level pyramid.
149
+ B , _ , H , W = x .shape
150
+ dummy = torch .empty ([B , 0 , H // 2 , W // 2 ], dtype = x .dtype , device = x .device )
151
+ features = [x ] + [dummy ] + features
152
+ else :
153
+ features = [x ] + features
154
+
42
155
return features
43
156
44
157
@property
45
- def out_channels (self ):
158
+ def out_channels (self ) -> list [int ]:
159
+ """
160
+ Returns:
161
+ List[int]: A list of output channels for each stage of the encoder,
162
+ including the input channels at the first stage.
163
+ """
46
164
return self ._out_channels
47
165
48
166
@property
49
- def output_stride (self ):
167
+ def output_stride (self ) -> int :
168
+ """
169
+ Returns:
170
+ int: The effective output stride of the encoder, considering the depth.
171
+ """
50
172
return min (self ._output_stride , 2 ** self ._depth )
51
173
52
174
0 commit comments