1
1
"""
2
- TimmUniversalEncoder provides a unified feature extraction interface built on the
3
- `timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
2
+ TimmUniversalEncoder provides a unified feature extraction interface built on the
3
+ `timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
4
4
models (e.g., Swin Transformer, ConvNeXt).
5
5
6
- This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
7
- It allows configuring the number of feature extraction stages (`depth`) and adjusting
6
+ This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
7
+ It allows configuring the number of feature extraction stages (`depth`) and adjusting
8
8
`output_stride` when supported.
9
9
10
10
Key Features:
11
11
- Flexible model selection using `timm.create_model`.
12
- - Unified multi-level output across different model hierarchies.
12
+ - Unified multi-level output across different model hierarchies.
13
13
- Automatic alignment for inconsistent feature scales:
14
- - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
15
- - VGG-style models (include scale-1 features): Align outputs for compatibility.
14
+ - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
15
+ - VGG-style models (include scale-1 features): Align outputs for compatibility.
16
16
- Easy access to feature scale information via the `reduction` property.
17
17
18
18
Feature Scale Differences:
19
- - Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
20
- - Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
19
+ - Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
20
+ - Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
21
21
- VGG-style models: Include scale-1 features (input resolution).
22
22
23
23
Notes:
24
- - `output_stride` is unsupported in some models, especially transformer-based architectures.
25
- - Special handling for models like TResNet and DLA to ensure correct feature indexing.
26
- - VGG-style models use `_is_skip_first ` to align scale-1 features with standard outputs.
24
+ - `output_stride` is unsupported in some models, especially transformer-based architectures.
25
+ - Special handling for models like TResNet and DLA to ensure correct feature indexing.
26
+ - VGG-style models use `_is_vgg_style ` to align scale-1 features with standard outputs.
27
27
"""
28
28
29
29
from typing import Any
35
35
36
36
class TimmUniversalEncoder (nn .Module ):
37
37
"""
38
- A universal encoder leveraging the `timm` library for feature extraction from
38
+ A universal encoder leveraging the `timm` library for feature extraction from
39
39
various model architectures, including traditional-style and transformer-style models.
40
40
41
41
Features:
@@ -92,15 +92,15 @@ def __init__(
92
92
if reduction_scales == [2 ** (i + 2 ) for i in range (encoder_stage )]:
93
93
# Transformer-style downsampling: scales (4, 8, 16, 32)
94
94
self ._is_transformer_style = True
95
- self ._is_skip_first = False
95
+ self ._is_vgg_style = False
96
96
elif reduction_scales == [2 ** (i + 1 ) for i in range (encoder_stage )]:
97
97
# Traditional-style downsampling: scales (2, 4, 8, 16, 32)
98
98
self ._is_transformer_style = False
99
- self ._is_skip_first = False
100
- elif reduction_scales == [2 ** i for i in range (encoder_stage )]:
101
- # Models including scale 1: scales (1, 2, 4, 8, 16, 32)
99
+ self ._is_vgg_style = False
100
+ elif reduction_scales == [2 ** i for i in range (encoder_stage )]:
101
+ # Vgg-style models including scale 1: scales (1, 2, 4, 8, 16, 32)
102
102
self ._is_transformer_style = False
103
- self ._is_skip_first = True
103
+ self ._is_vgg_style = True
104
104
else :
105
105
raise ValueError ("Unsupported model downsampling pattern." )
106
106
@@ -125,14 +125,14 @@ def __init__(
125
125
if "dla" in name :
126
126
# For 'dla' models, out_indices starts at 0 and matches the input size.
127
127
common_kwargs ["out_indices" ] = tuple (range (1 , depth + 1 ))
128
- if self ._is_skip_first :
128
+ if self ._is_vgg_style :
129
129
common_kwargs ["out_indices" ] = tuple (range (depth + 1 ))
130
130
131
131
self .model = timm .create_model (
132
132
name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
133
133
)
134
134
135
- if self ._is_skip_first :
135
+ if self ._is_vgg_style :
136
136
self ._out_channels = self .model .feature_info .channels ()
137
137
else :
138
138
self ._out_channels = [in_channels ] + self .model .feature_info .channels ()
@@ -164,9 +164,9 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
164
164
B , _ , H , W = x .shape
165
165
dummy = torch .empty ([B , 0 , H // 2 , W // 2 ], dtype = x .dtype , device = x .device )
166
166
features = [dummy ] + features
167
-
168
- # Add input tensor as scale 1 feature if `self._is_skip_first ` is False
169
- if not self ._is_skip_first :
167
+
168
+ # Add input tensor as scale 1 feature if `self._is_vgg_style ` is False
169
+ if not self ._is_vgg_style :
170
170
features = [x ] + features
171
171
172
172
return features
0 commit comments