Skip to content

Commit 363a361

Browse files
committed
Update timm_universal.py
1 parent 5c42e61 commit 363a361

File tree

1 file changed

+131
-9
lines changed

1 file changed

+131
-9
lines changed

segmentation_models_pytorch/encoders/timm_universal.py

Lines changed: 131 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,60 @@
1+
"""
2+
TimmUniversalEncoder provides a unified feature extraction interface built on the
3+
`timm` library, supporting various backbone architectures, including traditional
4+
CNNs (e.g., ResNet) and models adopting a transformer-like feature hierarchy
5+
(e.g., Swin Transformer, ConvNeXt).
6+
7+
This encoder produces standardized multi-level feature maps, facilitating integration
8+
with semantic segmentation tasks. It allows configuring the number of feature extraction
9+
stages (`depth`) and adjusting `output_stride` when supported.
10+
11+
Key Features:
12+
- Flexible model selection through `timm.create_model`.
13+
- A unified interface that outputs consistent, multi-level features even if the
14+
underlying model differs in its feature hierarchy.
15+
- Automatic alignment: If a model lacks certain early-stage features (for example,
16+
modern architectures that start from a 1/4 scale rather than 1/2 scale), the encoder
17+
inserts dummy features to maintain consistency with traditional CNN structures.
18+
- Easy access to channel information: Use the `out_channels` property to retrieve
19+
the number of channels at each feature stage.
20+
21+
Feature Scale Differences:
22+
- Traditional CNNs (e.g., ResNet) typically provide features at 1/2, 1/4, 1/8, 1/16,
23+
and 1/32 scales.
24+
- Transformer-style or next-generation models (e.g., Swin Transformer, ConvNeXt) often
25+
start from the 1/4 scale (then 1/8, 1/16, 1/32), omitting the initial 1/2 scale
26+
feature. TimmUniversalEncoder compensates for this omission to ensure a unified
27+
multi-stage output.
28+
29+
Notes:
30+
- Not all models support modifying `output_stride` (especially transformer-based or
31+
transformer-like models).
32+
- Certain models (e.g., TResNet, DLA) require special handling to ensure correct
33+
feature indexing.
34+
- Most `timm` models output features in (B, C, H, W) format. However, some
35+
(e.g., MambaOut and certain Swin/SwinV2 variants) use (B, H, W, C) format, which is
36+
currently unsupported.
37+
"""
38+
139
from typing import Any
240

341
import timm
42+
import torch
443
import torch.nn as nn
544

645

746
class TimmUniversalEncoder(nn.Module):
47+
"""
48+
A universal encoder built on the `timm` library, designed to adapt to a wide variety of
49+
model architectures, including both traditional CNNs and those that follow a
50+
transformer-like hierarchy.
51+
52+
Features:
53+
- Supports flexible depth and output stride for feature extraction.
54+
- Automatically adjusts to input/output channel structures based on the model type.
55+
- Compatible with both convolutional and transformer-like encoders.
56+
"""
57+
858
def __init__(
959
self,
1060
name: str,
@@ -14,7 +64,19 @@ def __init__(
1464
output_stride: int = 32,
1565
**kwargs: dict[str, Any],
1666
):
67+
"""
68+
Initialize the encoder.
69+
70+
Args:
71+
name (str): Name of the model to be loaded from the `timm` library.
72+
pretrained (bool): If True, loads pretrained weights.
73+
in_channels (int): Number of input channels (default: 3 for RGB).
74+
depth (int): Number of feature extraction stages (default: 5).
75+
output_stride (int): Desired output stride (default: 32).
76+
**kwargs: Additional keyword arguments for `timm.create_model`.
77+
"""
1778
super().__init__()
79+
1880
common_kwargs = dict(
1981
in_chans=in_channels,
2082
features_only=True,
@@ -23,30 +85,90 @@ def __init__(
2385
out_indices=tuple(range(depth)),
2486
)
2587

26-
# not all models support output stride argument, drop it by default
2788
if output_stride == 32:
2889
common_kwargs.pop("output_stride")
2990

30-
self.model = timm.create_model(
31-
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
32-
)
91+
# Load a preliminary model to determine its feature hierarchy structure.
92+
self.model = timm.create_model(name, features_only=True)
93+
94+
# Determine if this model uses a transformer-like hierarchy (i.e., starting at 1/4 scale)
95+
# rather than a traditional CNN hierarchy (starting at 1/2 scale).
96+
if len(self.model.feature_info.channels()) == 5:
97+
# This indicates a traditional hierarchy: (1/2, 1/4, 1/8, 1/16, 1/32)
98+
self._is_transformer_style = False
99+
else:
100+
# This indicates a transformer-like hierarchy: (1/4, 1/8, 1/16, 1/32)
101+
self._is_transformer_style = True
102+
103+
if self._is_transformer_style:
104+
if "tresnet" in name:
105+
# 'tresnet' models start feature extraction at stage 1,
106+
# so out_indices=(1, 2, 3, 4) for depth=5.
107+
common_kwargs["out_indices"] = tuple(range(1, depth))
108+
else:
109+
# Most transformer-like models use out_indices=(0, 1, 2, 3) for depth=5.
110+
common_kwargs["out_indices"] = tuple(range(depth - 1))
111+
112+
self.model = timm.create_model(
113+
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
114+
)
115+
# Add a dummy output channel (0) to align with traditional encoder structures.
116+
self._out_channels = (
117+
[in_channels] + [0] + self.model.feature_info.channels()
118+
)
119+
else:
120+
if "dla" in name:
121+
# For 'dla' models, out_indices starts at 0 and matches the input size.
122+
kwargs["out_indices"] = tuple(range(1, depth + 1))
123+
124+
self.model = timm.create_model(
125+
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
126+
)
127+
self._out_channels = [in_channels] + self.model.feature_info.channels()
33128

34129
self._in_channels = in_channels
35-
self._out_channels = [in_channels] + self.model.feature_info.channels()
36130
self._depth = depth
37131
self._output_stride = output_stride
38132

39-
def forward(self, x):
133+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
134+
"""
135+
Pass the input through the encoder and return extracted features.
136+
137+
Args:
138+
x (torch.Tensor): Input tensor of shape (B, C, H, W).
139+
140+
Returns:
141+
List[torch.Tensor]: A list of feature maps extracted at various scales.
142+
"""
40143
features = self.model(x)
41-
features = [x] + features
144+
145+
if self._is_transformer_style:
146+
# Models using a transformer-like hierarchy may not generate
147+
# all expected feature maps. Insert a dummy feature map to ensure
148+
# compatibility with decoders expecting a 5-level pyramid.
149+
B, _, H, W = x.shape
150+
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
151+
features = [x] + [dummy] + features
152+
else:
153+
features = [x] + features
154+
42155
return features
43156

44157
@property
45-
def out_channels(self):
158+
def out_channels(self) -> list[int]:
159+
"""
160+
Returns:
161+
List[int]: A list of output channels for each stage of the encoder,
162+
including the input channels at the first stage.
163+
"""
46164
return self._out_channels
47165

48166
@property
49-
def output_stride(self):
167+
def output_stride(self) -> int:
168+
"""
169+
Returns:
170+
int: The effective output stride of the encoder, considering the depth.
171+
"""
50172
return min(self._output_stride, 2**self._depth)
51173

52174

0 commit comments

Comments
 (0)