1
1
"""
2
2
TimmUniversalEncoder provides a unified feature extraction interface built on the
3
3
`timm` library, supporting various backbone architectures, including traditional
4
- CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
4
+ CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
5
5
(e.g., Swin Transformer, ConvNeXt).
6
6
7
7
This encoder produces standardized multi-level feature maps, facilitating integration
22
22
- Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
23
23
and 1/32 scales.
24
24
- Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
25
- start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
25
+ start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
26
26
feature. TimmUniversalEncoder compensates for this omission to ensure a unified
27
27
multi-stage output.
28
28
29
29
Notes:
30
- - Not all models support modifying `output_stride` (especially transformer-based or
30
+ - Not all models support modifying `output_stride` (especially transformer-based or
31
31
transformer-like models).
32
32
- Certain models (e.g., TResNet, DLA) require special handling to ensure correct
33
33
feature indexing.
34
- - Most `timm` models output features in (B, C, H, W) format. However, some
34
+ - Most `timm` models output features in (B, C, H, W) format. However, some
35
35
(e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is
36
36
currently unsupported.
37
37
"""
46
46
class TimmUniversalEncoder (nn .Module ):
47
47
"""
48
48
A universal encoder built on the `timm` library, designed to adapt to a wide variety of
49
- model architectures, including both traditional CNNs and those that follow a
49
+ model architectures, including both traditional CNNs and those that follow a
50
50
transformer-like hierarchy.
51
51
52
52
Features:
@@ -94,10 +94,8 @@ def __init__(
94
94
# Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
95
95
# rather than a traditional CNN hierarchy (starting at 1/2 scale).
96
96
if len (self .model .feature_info .channels ()) == 5 :
97
- # This indicates a traditional hierarchy: (1/2, 1/4, 1/8, 1/16, 1/32)
98
97
self ._is_transformer_style = False
99
98
else :
100
- # This indicates a transformer-like hierarchy: (1/4, 1/8, 1/16, 1/32)
101
99
self ._is_transformer_style = True
102
100
103
101
if self ._is_transformer_style :
@@ -138,7 +136,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
138
136
x (torch.Tensor): Input tensor of shape (B, C, H, W).
139
137
140
138
Returns:
141
- List [torch.Tensor]: A list of feature maps extracted at various scales.
139
+ list [torch.Tensor]: A list of feature maps extracted at various scales.
142
140
"""
143
141
features = self .model (x )
144
142
@@ -158,7 +156,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
158
156
def out_channels (self ) -> list [int ]:
159
157
"""
160
158
Returns:
161
- List [int]: A list of output channels for each stage of the encoder,
159
+ list [int]: A list of output channels for each stage of the encoder,
162
160
including the input channels at the first stage.
163
161
"""
164
162
return self ._out_channels
0 commit comments