Skip to content

Commit 2c38de6

Browse files
Add DPT model and update logic for TimmViTEncoder class
1 parent 78ba0e8 commit 2c38de6

File tree

6 files changed

+487
-31
lines changed

6 files changed

+487
-31
lines changed

encoders_table.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
|Encoder |Pretrained weights |Params, M |Script |Compile |Export |
2+
|--------------------------------|:------------------------------:|:------------------------------:|:------------------------------:|:------------------------------:|:------------------------------:|
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .model import DPT
2+
3+
__all__ = ["DPT"]
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
def _get_feature_processing_out_channels(encoder_name: str) -> list[int]:
6+
"""
7+
Get the output embedding dimensions for the features after decoder processing
8+
"""
9+
10+
encoder_name = encoder_name.lower()
11+
# Output channels for hybrid ViT encoder after feature processing
12+
if "vit" in encoder_name and "resnet" in encoder_name:
13+
return [256, 512, 768, 768]
14+
15+
# Output channels for ViT-large,ViT-huge,ViT-giant encoders after feature processing
16+
if "vit" in encoder_name and any(
17+
[variant in encoder_name for variant in ["huge", "large", "giant"]]
18+
):
19+
return [256, 512, 1024, 1024]
20+
21+
# Output channels for ViT-base and other encoders after feature processing
22+
return [96, 192, 384, 768]
23+
24+
25+
class Transpose(nn.Module):
26+
def __init__(self, dim0: int, dim1: int):
27+
super().__init__()
28+
self.dim0 = dim0
29+
self.dim1 = dim1
30+
31+
def forward(self, x: torch.Tensor):
32+
return torch.transpose(x, dim0=self.dim0, dim1=self.dim1)
33+
34+
35+
class ProjectionReadout(nn.Module):
36+
"""
37+
Concatenates the cls tokens with the features to make use of the global information aggregated in the cls token.
38+
Projects the combined feature map to the original embedding dimension using a MLP
39+
"""
40+
41+
def __init__(self, in_features: int, encoder_output_stride: int):
42+
super().__init__()
43+
self.project = nn.Sequential(
44+
nn.Linear(in_features=2 * in_features, out_features=in_features), nn.GELU()
45+
)
46+
47+
self.flatten = nn.Flatten(start_dim=2)
48+
self.transpose = Transpose(dim0=1, dim1=2)
49+
self.encoder_output_stride = encoder_output_stride
50+
51+
def forward(self, feature: torch.Tensor, cls_token: torch.Tensor):
52+
batch_size, _, height_dim, width_dim = feature.shape
53+
feature = self.flatten(feature)
54+
feature = self.transpose(feature)
55+
56+
cls_token = cls_token.expand_as(feature)
57+
58+
features = torch.cat([feature, cls_token], dim=2)
59+
features = self.project(features)
60+
features = self.transpose(features)
61+
62+
features = features.view(batch_size, -1, height_dim, width_dim)
63+
return features
64+
65+
66+
class IgnoreReadout(nn.Module):
67+
def __init__(self):
68+
super().__init__()
69+
70+
def forward(self, feature: torch.Tensor, cls_token: torch.Tensor):
71+
return feature
72+
73+
74+
class FeatureProcessBlock(nn.Module):
75+
"""
76+
Processes the features such that they have progressively increasing embedding size and progressively decreasing
77+
spatial dimension
78+
"""
79+
80+
def __init__(
81+
self, embed_dim: int, feature_dim: int, out_channel: int, upsample_factor: int
82+
):
83+
super().__init__()
84+
85+
self.project_to_out_channel = nn.Conv2d(
86+
in_channels=embed_dim, out_channels=out_channel, kernel_size=1
87+
)
88+
89+
if upsample_factor > 1.0:
90+
self.upsample = nn.ConvTranspose2d(
91+
in_channels=out_channel,
92+
out_channels=out_channel,
93+
kernel_size=int(upsample_factor),
94+
stride=int(upsample_factor),
95+
)
96+
97+
elif upsample_factor == 1.0:
98+
self.upsample = nn.Identity()
99+
100+
else:
101+
self.upsample = nn.Conv2d(
102+
in_channels=out_channel,
103+
out_channels=out_channel,
104+
kernel_size=3,
105+
stride=int(1 / upsample_factor),
106+
padding=1,
107+
)
108+
109+
self.project_to_feature_dim = nn.Conv2d(
110+
in_channels=out_channel, out_channels=feature_dim, kernel_size=3, padding=1
111+
)
112+
113+
def forward(self, x: torch.Tensor):
114+
x = self.project_to_out_channel(x)
115+
x = self.upsample(x)
116+
x = self.project_to_feature_dim(x)
117+
118+
return x
119+
120+
121+
class ResidualConvBlock(nn.Module):
122+
def __init__(self, feature_dim: int):
123+
super().__init__()
124+
self.conv_block = nn.Sequential(
125+
nn.ReLU(),
126+
nn.Conv2d(
127+
in_channels=feature_dim,
128+
out_channels=feature_dim,
129+
kernel_size=3,
130+
padding=1,
131+
bias=False,
132+
),
133+
nn.BatchNorm2d(num_features=feature_dim),
134+
nn.ReLU(),
135+
nn.Conv2d(
136+
in_channels=feature_dim,
137+
out_channels=feature_dim,
138+
kernel_size=3,
139+
padding=1,
140+
bias=False,
141+
),
142+
nn.BatchNorm2d(num_features=feature_dim),
143+
)
144+
145+
def forward(self, x: torch.Tensor):
146+
return x + self.conv_block(x)
147+
148+
149+
class FusionBlock(nn.Module):
150+
"""
151+
Fuses the processed encoder features in a residual manner and upsamples them
152+
"""
153+
154+
def __init__(self, feature_dim: int):
155+
super().__init__()
156+
self.residual_conv_block1 = ResidualConvBlock(feature_dim=feature_dim)
157+
self.residual_conv_block2 = ResidualConvBlock(feature_dim=feature_dim)
158+
self.project = nn.Conv2d(
159+
in_channels=feature_dim, out_channels=feature_dim, kernel_size=1
160+
)
161+
self.activation = nn.ReLU()
162+
163+
def forward(self, feature: torch.Tensor, preceding_layer_feature: torch.Tensor):
164+
feature = self.residual_conv_block1(feature)
165+
166+
if preceding_layer_feature is not None:
167+
feature += preceding_layer_feature
168+
169+
feature = self.residual_conv_block2(feature)
170+
171+
feature = nn.functional.interpolate(
172+
feature, scale_factor=2, align_corners=True, mode="bilinear"
173+
)
174+
feature = self.project(feature)
175+
feature = self.activation(feature)
176+
177+
return feature
178+
179+
180+
class DPTDecoder(nn.Module):
181+
"""
182+
Decoder part for DPT
183+
184+
Processes the encoder features and class tokens (if encoder has class_tokens) to have spatial downsampling ratios of
185+
[1/32,1/16,1/8,1/4] relative to the input image spatial dimension.
186+
187+
The decoder then fuses these features in a residual manner and progressively upsamples them by a factor of 2 so that the
188+
output has a downsampling ratio of 1/2 relative to the input image spatial dimension
189+
190+
"""
191+
192+
def __init__(
193+
self,
194+
encoder_name: str,
195+
transformer_embed_dim: int,
196+
encoder_output_stride: int,
197+
feature_dim: int = 256,
198+
encoder_depth: int = 4,
199+
prefix_token_supported: bool = False,
200+
):
201+
super().__init__()
202+
203+
self.prefix_token_supported = prefix_token_supported
204+
205+
# If encoder has cls token, then concatenate it with the features along the embedding dimension and project it
206+
# back to the feature_dim dimension. Else, ignore the non-existent cls token
207+
208+
if prefix_token_supported:
209+
self.readout_blocks = nn.ModuleList(
210+
[
211+
ProjectionReadout(
212+
in_features=transformer_embed_dim,
213+
encoder_output_stride=encoder_output_stride,
214+
)
215+
for _ in range(encoder_depth)
216+
]
217+
)
218+
else:
219+
self.readout_blocks = [IgnoreReadout() for _ in range(encoder_depth)]
220+
221+
upsample_factors = [
222+
(encoder_output_stride / 2 ** (index + 2))
223+
for index in range(0, encoder_depth)
224+
]
225+
feature_processing_out_channels = _get_feature_processing_out_channels(
226+
encoder_name
227+
)
228+
if encoder_depth < len(feature_processing_out_channels):
229+
feature_processing_out_channels = feature_processing_out_channels[
230+
:encoder_depth
231+
]
232+
233+
self.feature_processing_blocks = nn.ModuleList(
234+
[
235+
FeatureProcessBlock(
236+
transformer_embed_dim, feature_dim, out_channel, upsample_factor
237+
)
238+
for upsample_factor, out_channel in zip(
239+
upsample_factors, feature_processing_out_channels
240+
)
241+
]
242+
)
243+
244+
self.fusion_blocks = nn.ModuleList(
245+
[FusionBlock(feature_dim=feature_dim) for _ in range(encoder_depth)]
246+
)
247+
248+
def forward(
249+
self, encoder_output: list[list[torch.Tensor], list[torch.Tensor]]
250+
) -> torch.Tensor:
251+
features, cls_tokens = encoder_output
252+
processed_features = []
253+
254+
# Process the encoder features to scale of [1/32,1/16,1/8,1/4]
255+
for index, (feature, cls_token) in enumerate(zip(features, cls_tokens)):
256+
readout_feature = self.readout_blocks[index](feature, cls_token)
257+
processed_feature = self.feature_processing_blocks[index](readout_feature)
258+
processed_features.append(processed_feature)
259+
260+
preceding_layer_feature = None
261+
262+
# Fusion and progressive upsampling starting from the last processed feature
263+
processed_features = processed_features[::-1]
264+
for fusion_block, feature in zip(self.fusion_blocks, processed_features):
265+
out = fusion_block(feature, preceding_layer_feature)
266+
preceding_layer_feature = out
267+
268+
return out
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from typing import Any, Optional, Union, Callable
2+
3+
from segmentation_models_pytorch.base import (
4+
ClassificationHead,
5+
SegmentationHead,
6+
SegmentationModel,
7+
)
8+
from segmentation_models_pytorch.encoders import get_encoder
9+
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
10+
from .decoder import DPTDecoder
11+
12+
13+
class DPT(SegmentationModel):
14+
"""
15+
DPT is a dense prediction architecture that leverages vision transformers in place of convolutional networks as
16+
a backbone for dense prediction tasks
17+
18+
It assembles tokens from various stages of the vision transformer into image-like representations at various resolutions
19+
and progressively combines them into full-resolution predictions using a convolutional decoder.
20+
21+
The transformer backbone processes representations at a constant and relatively high resolution and has a global receptive
22+
field at every stage. These properties allow the dense vision transformer to provide finer-grained and more globally coherent
23+
predictions when compared to fully-convolutional networks
24+
25+
Args:
26+
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
27+
to extract features of different spatial resolution
28+
encoder_depth: A number of stages used in encoder in range [1,4]. Each stage generate features
29+
smaller by a factor equal to the ViT model patch_size in spatial dimensions.
30+
Default is 4
31+
encoder_weights: One of **None** (random initialization), or other pretrained weights (see table with
32+
available weights for each encoder_name)
33+
feature_dim : The latent dimension to which the encoder features will be projected to.
34+
in_channels: Number of input channels for the model, default is 3 (RGB images)
35+
classes: Number of classes for output mask (or you can think as a number of channels of output mask)
36+
activation: An activation function to apply after the final convolution layer.
37+
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
38+
**callable** and **None**.
39+
Default is **None**
40+
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
41+
on top of encoder if **aux_params** is not **None** (default). Supported params:
42+
- classes (int): A number of classes
43+
- pooling (str): One of "max", "avg". Default is "avg"
44+
- dropout (float): Dropout factor in [0, 1)
45+
- activation (str): An activation function to apply "sigmoid"/"softmax"
46+
(could be **None** to return logits)
47+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with
48+
``None`` values are pruned before passing.
49+
allow_downsampling : Allow ViT encoder to have progressive downsampling. Set to False for DPT as the architecture
50+
requires all encoder feature outputs to have the same spatial shape.
51+
allow_output_stride_not_power_of_two : Allow ViT encoders with output_stride not being a power of 2. This
52+
is set False for DPT as the architecture requires the encoder output features to have an output stride of
53+
[1/32,1/16,1/8,1/4]
54+
55+
Returns:
56+
``torch.nn.Module``: DPT
57+
58+
59+
"""
60+
61+
@supports_config_loading
62+
def __init__(
63+
self,
64+
encoder_name: str = "tu-vit_base_patch8_224",
65+
encoder_depth: int = 4,
66+
encoder_weights: Optional[str] = None,
67+
feature_dim: int = 256,
68+
in_channels: int = 3,
69+
classes: int = 1,
70+
activation: Optional[Union[str, Callable]] = None,
71+
aux_params: Optional[dict] = None,
72+
**kwargs: dict[str, Any],
73+
):
74+
super().__init__()
75+
76+
self.encoder = get_encoder(
77+
encoder_name,
78+
in_channels=in_channels,
79+
depth=encoder_depth,
80+
weights=encoder_weights,
81+
use_vit_encoder=True,
82+
allow_downsampling=False,
83+
allow_output_stride_not_power_of_two=False,
84+
**kwargs,
85+
)
86+
87+
transformer_embed_dim = self.encoder.embed_dim
88+
encoder_output_stride = self.encoder.output_stride
89+
cls_token_supported = self.encoder.prefix_token_supported
90+
91+
self.decoder = DPTDecoder(
92+
encoder_name=encoder_name,
93+
transformer_embed_dim=transformer_embed_dim,
94+
feature_dim=feature_dim,
95+
encoder_depth=encoder_depth,
96+
encoder_output_stride=encoder_output_stride,
97+
prefix_token_supported=cls_token_supported,
98+
)
99+
100+
self.segmentation_head = SegmentationHead(
101+
in_channels=feature_dim,
102+
out_channels=classes,
103+
activation=activation,
104+
kernel_size=1,
105+
upsampling=2,
106+
)
107+
108+
if aux_params is not None:
109+
self.classification_head = ClassificationHead(
110+
in_channels=self.encoder.out_channels[-1], **aux_params
111+
)
112+
else:
113+
self.classification_head = None
114+
115+
self.name = "dpt-{}".format(encoder_name)
116+
self.initialize()

0 commit comments

Comments
 (0)