Skip to content

Commit 78ba0e8

Browse files
Initial timm vit encoder commit
1 parent f4c73c6 commit 78ba0e8

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .mobileone import mobileone_encoders
2525

2626
from .timm_universal import TimmUniversalEncoder
27+
from .timm_vit import TimmViTEncoder
2728

2829
from ._preprocessing import preprocess_input
2930
from ._legacy_pretrained_settings import pretrained_settings
@@ -81,8 +82,20 @@ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **
8182
if "mobilenetv3" in name:
8283
name = name.replace("tu-", "tu-tf_")
8384

85+
use_vit_encoder = kwargs.pop("use_vit_encoder",False)
8486
if name.startswith("tu-"):
8587
name = name[3:]
88+
89+
if use_vit_encoder:
90+
encoder = TimmViTEncoder(
91+
name = name,
92+
in_channels = in_channels,
93+
depth = depth,
94+
pretrained = weights is not None,
95+
**kwargs
96+
)
97+
return encoder
98+
8699
encoder = TimmUniversalEncoder(
87100
name=name,
88101
in_channels=in_channels,
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
"""
2+
TimmUniversalEncoder provides a unified feature extraction interface built on the
3+
`timm` library, supporting both traditional-style (e.g., ResNet) and transformer-style
4+
models (e.g., Swin Transformer, ConvNeXt).
5+
6+
This encoder produces consistent multi-level feature maps for semantic segmentation tasks.
7+
It allows configuring the number of feature extraction stages (`depth`) and adjusting
8+
`output_stride` when supported.
9+
10+
Key Features:
11+
- Flexible model selection using `timm.create_model`.
12+
- Unified multi-level output across different model hierarchies.
13+
- Automatic alignment for inconsistent feature scales:
14+
- Transformer-style models (start at 1/4 scale): Insert dummy features for 1/2 scale.
15+
- VGG-style models (include scale-1 features): Align outputs for compatibility.
16+
- Easy access to feature scale information via the `reduction` property.
17+
18+
Feature Scale Differences:
19+
- Traditional-style models (e.g., ResNet): Scales at 1/2, 1/4, 1/8, 1/16, 1/32.
20+
- Transformer-style models (e.g., Swin Transformer): Start at 1/4 scale, skip 1/2 scale.
21+
- VGG-style models: Include scale-1 features (input resolution).
22+
23+
Notes:
24+
- `output_stride` is unsupported in some models, especially transformer-based architectures.
25+
- Special handling for models like TResNet and DLA to ensure correct feature indexing.
26+
- VGG-style models use `_is_vgg_style` to align scale-1 features with standard outputs.
27+
"""
28+
29+
from typing import Any, Optional
30+
31+
import timm
32+
import torch
33+
import torch.nn as nn
34+
35+
36+
class TimmViTEncoder(nn.Module):
37+
"""
38+
TODO
39+
"""
40+
41+
_is_torch_scriptable = True
42+
_is_torch_exportable = True
43+
_is_torch_compilable = True
44+
45+
def __init__(
46+
self,
47+
name: str,
48+
pretrained: bool = True,
49+
in_channels: int = 3,
50+
depth: int = 4,
51+
out_indices : Optional[list[int]] = None,
52+
**kwargs: dict[str, Any],
53+
):
54+
"""
55+
Initialize the encoder.
56+
57+
Args:
58+
name (str): Model name to load from `timm`.
59+
pretrained (bool): Load pretrained weights (default: True).
60+
in_channels (int): Number of input channels (default: 3 for RGB).
61+
depth (int): Number of feature stages to extract (default: 5).
62+
**kwargs: Additional arguments passed to `timm.create_model`.
63+
"""
64+
# At the moment we do not support models with more than 4 stages,
65+
# but can be reconfigured in the future.
66+
if depth > 4 or depth < 1:
67+
raise ValueError(
68+
f"{self.__class__.__name__} depth should be in range [1, 4], got {depth}"
69+
)
70+
71+
super().__init__()
72+
self.name = name
73+
74+
# Default model configuration for feature extraction
75+
common_kwargs = dict(
76+
in_chans=in_channels,
77+
features_only=True,
78+
pretrained=pretrained,
79+
out_indices=tuple(range(depth)),
80+
)
81+
82+
# Load a temporary model to analyze its feature hierarchy
83+
try:
84+
with torch.device("meta"):
85+
tmp_model = timm.create_model(name, features_only=True)
86+
except Exception:
87+
tmp_model = timm.create_model(name, features_only=True)
88+
89+
# Check if model output is in channel-last format (NHWC)
90+
self._is_channel_last = getattr(tmp_model, "output_fmt", None) == "NHWC"
91+
92+
# Determine the model's downsampling pattern and set hierarchy flags
93+
reduction_scales = list(tmp_model.feature_info.reduction())
94+
output_stride = reduction_scales[0]
95+
96+
# Need model to output ViT style features with no downsampling
97+
if len(set(reduction_scales)) != 1:
98+
raise ValueError("Unsupported model downsampling pattern.")
99+
100+
num_blocks = len(tmp_model.blocks)
101+
if out_indices is None:
102+
out_indices = [int(index * (num_blocks / 4)) - 1 for index in range(1,depth+1)]
103+
104+
# Model with 24 blocks should use features from layers [5,12,18,24]
105+
if num_blocks == 24:
106+
out_indices[0] -= 1
107+
108+
109+
common_kwargs['out_indices'] = out_indices
110+
self.model = timm.create_model(
111+
name, **_merge_kwargs_no_duplicates(common_kwargs, kwargs)
112+
)
113+
114+
self._out_channels = self.model.feature_info.channels()
115+
self._in_channels = in_channels
116+
self._depth = depth
117+
self._output_stride = output_stride
118+
119+
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
120+
"""
121+
Forward pass to extract multi-stage features.
122+
123+
Args:
124+
x (torch.Tensor): Input tensor of shape (B, C, H, W).
125+
126+
Returns:
127+
list[torch.Tensor]: List of feature maps at different scales.
128+
"""
129+
features = self.model(x)
130+
131+
# Convert NHWC to NCHW if needed
132+
if self._is_channel_last:
133+
features = [
134+
feature.permute(0, 3, 1, 2).contiguous() for feature in features
135+
]
136+
137+
return features
138+
139+
@property
140+
def out_channels(self) -> list[int]:
141+
"""
142+
Returns the number of output channels for each feature stage.
143+
144+
Returns:
145+
list[int]: A list of channel dimensions at each scale.
146+
"""
147+
return self._out_channels
148+
149+
@property
150+
def output_stride(self) -> int:
151+
"""
152+
Returns the effective output stride based on the model depth.
153+
154+
Returns:
155+
int: The effective output stride.
156+
"""
157+
return self._output_stride
158+
159+
def load_state_dict(self, state_dict, **kwargs):
160+
# for compatibility of weights for
161+
# timm- ported encoders with TimmUniversalEncoder
162+
patterns = ["regnet", "res2", "resnest", "mobilenetv3", "gernet"]
163+
164+
is_deprecated_encoder = any(
165+
self.name.startswith(pattern) for pattern in patterns
166+
)
167+
168+
if is_deprecated_encoder:
169+
keys = list(state_dict.keys())
170+
for key in keys:
171+
new_key = key
172+
if not key.startswith("model."):
173+
new_key = "model." + key
174+
if "gernet" in self.name:
175+
new_key = new_key.replace(".stages.", ".stages_")
176+
state_dict[new_key] = state_dict.pop(key)
177+
178+
return super().load_state_dict(state_dict, **kwargs)
179+
180+
181+
def _merge_kwargs_no_duplicates(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
182+
"""
183+
Merge two dictionaries, ensuring no duplicate keys exist.
184+
185+
Args:
186+
a (dict): Base dictionary.
187+
b (dict): Additional parameters to merge.
188+
189+
Returns:
190+
dict: A merged dictionary.
191+
"""
192+
duplicates = a.keys() & b.keys()
193+
if duplicates:
194+
raise ValueError(f"'{duplicates}' already specified internally")
195+
196+
return a | b

0 commit comments

Comments
 (0)