11"""
22TimmUniversalEncoder provides a unified feature extraction interface built on the
33`timm` library, supporting various backbone architectures, including traditional
4- CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
4+ CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
55(e.g., Swin Transformer, ConvNeXt).
66
77This encoder produces standardized multi-level feature maps, facilitating integration
2222- Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
2323 and 1/32 scales.
2424- Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
25- start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
25+ start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
2626 feature. TimmUniversalEncoder compensates for this omission to ensure a unified
2727 multi-stage output.
2828
2929Notes:
30- - Not all models support modifying `output_stride` (especially transformer-based or
30+ - Not all models support modifying `output_stride` (especially transformer-based or
3131 transformer-like models).
3232- Certain models (e.g., TResNet, DLA) require special handling to ensure correct
3333 feature indexing.
34- - Most `timm` models output features in (B, C, H, W) format. However, some
34+ - Most `timm` models output features in (B, C, H, W) format. However, some
3535 (e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is
3636 currently unsupported.
3737"""
4646class TimmUniversalEncoder (nn .Module ):
4747 """
4848 A universal encoder built on the `timm` library, designed to adapt to a wide variety of
49- model architectures, including both traditional CNNs and those that follow a
49+ model architectures, including both traditional CNNs and those that follow a
5050 transformer-like hierarchy.
5151
5252 Features:
@@ -94,10 +94,8 @@ def __init__(
9494 # Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
9595 # rather than a traditional CNN hierarchy (starting at 1/2 scale).
9696 if len (self .model .feature_info .channels ()) == 5 :
97- # This indicates a traditional hierarchy: (1/2, 1/4, 1/8, 1/16, 1/32)
9897 self ._is_transformer_style = False
9998 else :
100- # This indicates a transformer-like hierarchy: (1/4, 1/8, 1/16, 1/32)
10199 self ._is_transformer_style = True
102100
103101 if self ._is_transformer_style :
@@ -138,7 +136,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
138136 x (torch.Tensor): Input tensor of shape (B, C, H, W).
139137
140138 Returns:
141- List [torch.Tensor]: A list of feature maps extracted at various scales.
139+ list [torch.Tensor]: A list of feature maps extracted at various scales.
142140 """
143141 features = self .model (x )
144142
@@ -158,7 +156,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
158156 def out_channels (self ) -> list [int ]:
159157 """
160158 Returns:
161- List [int]: A list of output channels for each stage of the encoder,
159+ list [int]: A list of output channels for each stage of the encoder,
162160 including the input channels at the first stage.
163161 """
164162 return self ._out_channels
0 commit comments