Skip to content

Commit 8b0fece

Browse files
committed
Update timm_universal.py
1 parent 51a4d7b commit 8b0fece

File tree

1 file changed

+79
-54
lines changed

1 file changed

+79
-54
lines changed

segmentation_models_pytorch/encoders/timm_universal.py

Lines changed: 79 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,29 @@
11
"""
2-
TimmUniversalEncoder provides a unified feature extraction interface built on the
3-
`timm` library, supporting various backbone architectures, including traditional
4-
CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
5-
(e.g., Swin Transformer, ConvNeXt).
2+
TimmUniversalEncoder provides a unified feature extraction interface built on the
3+
`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
4+
models (e.g., Swin Transformer, ConvNeXt).
65
7-
This encoder produces standardized multi-level feature maps, facilitating integration
8-
with semantic segmentation tasks. It allows configuring the number of feature extraction
9-
stages (`depth`) and adjusting `output_stride` when supported.
6+
This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
7+
It allows configuring the number of feature extraction stages (`depth`) and adjusting
8+
`output_stride` when supported.
109
1110
Key Features:
12-
- Flexible model selection through `timm.create_model`.
13-
- A unified interface that outputs consistent, multi-level features even if the
14-
underlying model differs in its feature hierarchy.
15-
- Automatic alignment: If a model lacks certain early-stage features (for example,
16-
modern architectures that start from a 1/4 scale rather than 1/2 scale), the encoder
17-
inserts dummy features to maintain consistency with traditional CNN structures.
18-
- Easy access to channel information: Use the `out_channels` property to retrieve
19-
the number of channels at each feature stage.
11+
- Flexible model selection using `timm.create_model`.
12+
- Unified multi-level output across different model hierarchies.
13+
- Automatic alignment for inconsistent feature scales:
14+
- Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
15+
- VGG-style models (include scale-1 features): Align outputs for compatibility.
16+
- Easy access to feature scale information via the `reduction` property.
2017
2118
Feature Scale Differences:
22-
- Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
23-
and 1/32 scales.
24-
- Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
25-
start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
26-
feature. TimmUniversalEncoder compensates for this omission to ensure a unified
27-
multi-stage output.
19+
- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
20+
- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
21+
- VGG-style models: Include scale-1 features (input resolution).
2822
2923
Notes:
30-
- Not all models support modifying `output_stride` (especially transformer-based or
31-
transformer-like models).
32-
- Certain models (e.g., TResNet, DLA) require special handling to ensure correct
33-
feature indexing.
24+
- `output_stride` is unsupported in some models, especially transformer-based architectures.
25+
- Special handling for models like TResNet and DLA to ensure correct feature indexing.
26+
- VGG-style models use `_is_skip_first` to align scale-1 features with standard outputs.
3427
"""
3528

3629
from typing import Any
@@ -42,14 +35,13 @@
4235

4336
class TimmUniversalEncoder(nn.Module):
4437
"""
45-
A universal encoder built on the `timm` library, designed to adapt to a wide variety of
46-
model architectures, including both traditional CNNs and those that follow a
47-
transformer-like hierarchy.
38+
A universal encoder leveraging the `timm` library for feature extraction from
39+
various model architectures, including traditional-style and transformer-style models.
4840
4941
Features:
50-
- Supports flexible depth and output stride for feature extraction.
51-
- Automatically adjusts to input/output channel structures based on the model type.
52-
- Compatible with both convolutional and transformer-like encoders.
42+
- Supports configurable depth and output stride.
43+
- Ensures consistent multi-level feature extraction across diverse models.
44+
- Compatible with convolutional and transformer-like backbones.
5345
"""
5446

5547
def __init__(
@@ -65,15 +57,16 @@ def __init__(
6557
Initialize the encoder.
6658
6759
Args:
68-
name (str): Name of the model to be loaded from the `timm` library.
69-
pretrained (bool): If True, loads pretrained weights.
60+
name (str): Model name to load from `timm`.
61+
pretrained (bool): Load pretrained weights (default: True).
7062
in_channels (int): Number of input channels (default: 3 for RGB).
71-
depth (int): Number of feature extraction stages (default: 5).
63+
depth (int): Number of feature stages to extract (default: 5).
7264
output_stride (int): Desired output stride (default: 32).
73-
**kwargs: Additional keyword arguments for `timm.create_model`.
65+
**kwargs: Additional arguments passed to `timm.create_model`.
7466
"""
7567
super().__init__()
7668

69+
# Default model configuration for feature extraction
7770
common_kwargs = dict(
7871
in_chans=in_channels,
7972
features_only=True,
@@ -82,24 +75,37 @@ def __init__(
8275
out_indices=tuple(range(depth)),
8376
)
8477

85-
# not all models support output stride argument, drop it by default
78+
# Not all models support output stride argument, drop it by default
8679
if output_stride == 32:
8780
common_kwargs.pop("output_stride")
8881

89-
# Load a preliminary model to determine its feature hierarchy structure.
82+
# Load a temporary model to analyze its feature hierarchy
9083
self.model = timm.create_model(name, features_only=True)
9184

92-
# Check if the model's output is in channel-last format (B, H, W, C).
85+
# Check if model output is in channel-last format (NHWC)
9386
self._is_channel_last = getattr(self.model, "output_fmt", None) == "NHWC"
9487

95-
# Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
96-
# rather than a traditional CNN hierarchy (starting at 1/2 scale).
97-
if len(self.model.feature_info.channels()) == 5:
88+
# Determine the model's downsampling pattern and set hierarchy flags
89+
encoder_stage = len(self.model.feature_info.reduction())
90+
reduction_scales = self.model.feature_info.reduction()
91+
92+
if reduction_scales == [2 ** (i + 2) for i in range(encoder_stage)]:
93+
# Transformer-style downsampling: scales (4, 8, 16, 32)
94+
self._is_transformer_style = True
95+
self._is_skip_first = False
96+
elif reduction_scales == [2 ** (i + 1) for i in range(encoder_stage)]:
97+
# Traditional-style downsampling: scales (2, 4, 8, 16, 32)
9898
self._is_transformer_style = False
99+
self._is_skip_first = False
100+
elif reduction_scales == [2 ** i for i in range(encoder_stage)]:
101+
# Models including scale 1: scales (1, 2, 4, 8, 16, 32)
102+
self._is_transformer_style = False
103+
self._is_skip_first = True
99104
else:
100-
self._is_transformer_style = True
105+
raise ValueError("Unsupported model downsampling pattern.")
101106

102107
if self._is_transformer_style:
108+
# Transformer-like models (start at scale 4)
103109
if "tresnet" in name:
104110
# 'tresnet' models start feature extraction at stage 1,
105111
# so out_indices=(1, 2, 3, 4) for depth=5.
@@ -119,65 +125,84 @@ def __init__(
119125
if "dla" in name:
120126
# For 'dla' models, out_indices starts at 0 and matches the input size.
121127
common_kwargs["out_indices"] = tuple(range(1, depth + 1))
128+
if self._is_skip_first:
129+
common_kwargs["out_indices"] = tuple(range(depth + 1))
122130

123131
self.model = timm.create_model(
124132
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
125133
)
126-
self._out_channels = [in_channels] + self.model.feature_info.channels()
134+
135+
if self._is_skip_first:
136+
self._out_channels = self.model.feature_info.channels()
137+
else:
138+
self._out_channels = [in_channels] + self.model.feature_info.channels()
127139

128140
self._in_channels = in_channels
129141
self._depth = depth
130142
self._output_stride = output_stride
131143

132144
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
133145
"""
134-
Pass the input through the encoder and return extracted features.
146+
Forward pass to extract multi-stage features.
135147
136148
Args:
137149
x (torch.Tensor): Input tensor of shape (B, C, H, W).
138150
139151
Returns:
140-
list[torch.Tensor]: A list of feature maps extracted at various scales.
152+
list[torch.Tensor]: List of feature maps at different scales.
141153
"""
142154
features = self.model(x)
143155

156+
# Convert NHWC to NCHW if needed
144157
if self._is_channel_last:
145-
# Convert to channel-first (B, C, H, W).
146158
features = [
147159
feature.permute(0, 3, 1, 2).contiguous() for feature in features
148160
]
149161

162+
# Add dummy feature for scale 1/2 if missing (transformer-style models)
150163
if self._is_transformer_style:
151-
# Models using a transformer-like hierarchy may not generate
152-
# all expected feature maps. Insert a dummy feature map to ensure
153-
# compatibility with decoders expecting a 5-level pyramid.
154164
B, _, H, W = x.shape
155165
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
156-
features = [x] + [dummy] + features
157-
else:
166+
features = [dummy] + features
167+
168+
# Add input tensor as scale 1 feature if `self._is_skip_first` is False
169+
if not self._is_skip_first:
158170
features = [x] + features
159171

160172
return features
161173

162174
@property
163175
def out_channels(self) -> list[int]:
164176
"""
177+
Returns the number of output channels for each feature stage.
178+
165179
Returns:
166-
list[int]: A list of output channels for each stage of the encoder,
167-
including the input channels at the first stage.
180+
list[int]: A list of channel dimensions at each scale.
168181
"""
169182
return self._out_channels
170183

171184
@property
172185
def output_stride(self) -> int:
173186
"""
187+
Returns the effective output stride based on the model depth.
188+
174189
Returns:
175-
int: The effective output stride of the encoder, considering the depth.
190+
int: The effective output stride.
176191
"""
177192
return min(self._output_stride, 2**self._depth)
178193

179194

180195
def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
196+
"""
197+
Merge two dictionaries, ensuring no duplicate keys exist.
198+
199+
Args:
200+
a (dict): Base dictionary.
201+
b (dict): Additional parameters to merge.
202+
203+
Returns:
204+
dict: A merged dictionary.
205+
"""
181206
duplicates = a.keys() & b.keys()
182207
if duplicates:
183208
raise ValueError(f"'{duplicates}' already specified internally")

0 commit comments

Comments
 (0)