11"""
2- TimmUniversalEncoder provides a unified feature extraction interface built on the
3- `timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
2+ TimmUniversalEncoder provides a unified feature extraction interface built on the
3+ `timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
44models (e.g., Swin Transformer, ConvNeXt).
55
6- This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
7- It allows configuring the number of feature extraction stages (`depth`) and adjusting
6+ This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
7+ It allows configuring the number of feature extraction stages (`depth`) and adjusting
88`output_stride` when supported.
99
1010Key Features:
1111- Flexible model selection using `timm.create_model`.
12- - Unified multi-level output across different model hierarchies.
12+ - Unified multi-level output across different model hierarchies.
1313- Automatic alignment for inconsistent feature scales:
14- - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
15- - VGG-style models (include scale-1 features): Align outputs for compatibility.
14+ - Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
15+ - VGG-style models (include scale-1 features): Align outputs for compatibility.
1616- Easy access to feature scale information via the `reduction` property.
1717
1818Feature Scale Differences:
19- - Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
20- - Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
19+ - Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
20+ - Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
2121- VGG-style models: Include scale-1 features (input resolution).
2222
2323Notes:
24- - `output_stride` is unsupported in some models, especially transformer-based architectures.
25- - Special handling for models like TResNet and DLA to ensure correct feature indexing.
26- - VGG-style models use `_is_skip_first ` to align scale-1 features with standard outputs.
24+ - `output_stride` is unsupported in some models, especially transformer-based architectures.
25+ - Special handling for models like TResNet and DLA to ensure correct feature indexing.
26+ - VGG-style models use `_is_vgg_style ` to align scale-1 features with standard outputs.
2727"""
2828
2929from typing import Any
3535
3636class TimmUniversalEncoder (nn .Module ):
3737 """
38- A universal encoder leveraging the `timm` library for feature extraction from
38+ A universal encoder leveraging the `timm` library for feature extraction from
3939 various model architectures, including traditional-style and transformer-style models.
4040
4141 Features:
@@ -92,15 +92,15 @@ def __init__(
9292 if reduction_scales == [2 ** (i + 2 ) for i in range (encoder_stage )]:
9393 # Transformer-style downsampling: scales (4, 8, 16, 32)
9494 self ._is_transformer_style = True
95- self ._is_skip_first = False
95+ self ._is_vgg_style = False
9696 elif reduction_scales == [2 ** (i + 1 ) for i in range (encoder_stage )]:
9797 # Traditional-style downsampling: scales (2, 4, 8, 16, 32)
9898 self ._is_transformer_style = False
99- self ._is_skip_first = False
100- elif reduction_scales == [2 ** i for i in range (encoder_stage )]:
101- # Models including scale 1: scales (1, 2, 4, 8, 16, 32)
99+ self ._is_vgg_style = False
100+ elif reduction_scales == [2 ** i for i in range (encoder_stage )]:
101+ # Vgg-style models including scale 1: scales (1, 2, 4, 8, 16, 32)
102102 self ._is_transformer_style = False
103- self ._is_skip_first = True
103+ self ._is_vgg_style = True
104104 else :
105105 raise ValueError ("Unsupported model downsampling pattern." )
106106
@@ -125,14 +125,14 @@ def __init__(
125125 if "dla" in name :
126126 # For 'dla' models, out_indices starts at 0 and matches the input size.
127127 common_kwargs ["out_indices" ] = tuple (range (1 , depth + 1 ))
128- if self ._is_skip_first :
128+ if self ._is_vgg_style :
129129 common_kwargs ["out_indices" ] = tuple (range (depth + 1 ))
130130
131131 self .model = timm .create_model (
132132 name , ** _merge_kwargs_no_duplicates (common_kwargs , kwargs )
133133 )
134134
135- if self ._is_skip_first :
135+ if self ._is_vgg_style :
136136 self ._out_channels = self .model .feature_info .channels ()
137137 else :
138138 self ._out_channels = [in_channels ] + self .model .feature_info .channels ()
@@ -164,9 +164,9 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
164164 B , _ , H , W = x .shape
165165 dummy = torch .empty ([B , 0 , H // 2 , W // 2 ], dtype = x .dtype , device = x .device )
166166 features = [dummy ] + features
167-
168- # Add input tensor as scale 1 feature if `self._is_skip_first ` is False
169- if not self ._is_skip_first :
167+
168+ # Add input tensor as scale 1 feature if `self._is_vgg_style ` is False
169+ if not self ._is_vgg_style :
170170 features = [x ] + features
171171
172172 return features
0 commit comments