|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | + |
| 5 | +def _get_feature_processing_out_channels(encoder_name: str) -> list[int]: |
| 6 | + """ |
| 7 | + Get the output embedding dimensions for the features after decoder processing |
| 8 | + """ |
| 9 | + |
| 10 | + encoder_name = encoder_name.lower() |
| 11 | + # Output channels for hybrid ViT encoder after feature processing |
| 12 | + if "vit" in encoder_name and "resnet" in encoder_name: |
| 13 | + return [256, 512, 768, 768] |
| 14 | + |
| 15 | + # Output channels for ViT-large,ViT-huge,ViT-giant encoders after feature processing |
| 16 | + if "vit" in encoder_name and any( |
| 17 | + [variant in encoder_name for variant in ["huge", "large", "giant"]] |
| 18 | + ): |
| 19 | + return [256, 512, 1024, 1024] |
| 20 | + |
| 21 | + # Output channels for ViT-base and other encoders after feature processing |
| 22 | + return [96, 192, 384, 768] |
| 23 | + |
| 24 | + |
| 25 | +class Transpose(nn.Module): |
| 26 | + def __init__(self, dim0: int, dim1: int): |
| 27 | + super().__init__() |
| 28 | + self.dim0 = dim0 |
| 29 | + self.dim1 = dim1 |
| 30 | + |
| 31 | + def forward(self, x: torch.Tensor): |
| 32 | + return torch.transpose(x, dim0=self.dim0, dim1=self.dim1) |
| 33 | + |
| 34 | + |
| 35 | +class ProjectionReadout(nn.Module): |
| 36 | + """ |
| 37 | + Concatenates the cls tokens with the features to make use of the global information aggregated in the cls token. |
| 38 | + Projects the combined feature map to the original embedding dimension using a MLP |
| 39 | + """ |
| 40 | + |
| 41 | + def __init__(self, in_features: int, encoder_output_stride: int): |
| 42 | + super().__init__() |
| 43 | + self.project = nn.Sequential( |
| 44 | + nn.Linear(in_features=2 * in_features, out_features=in_features), nn.GELU() |
| 45 | + ) |
| 46 | + |
| 47 | + self.flatten = nn.Flatten(start_dim=2) |
| 48 | + self.transpose = Transpose(dim0=1, dim1=2) |
| 49 | + self.encoder_output_stride = encoder_output_stride |
| 50 | + |
| 51 | + def forward(self, feature: torch.Tensor, cls_token: torch.Tensor): |
| 52 | + batch_size, _, height_dim, width_dim = feature.shape |
| 53 | + feature = self.flatten(feature) |
| 54 | + feature = self.transpose(feature) |
| 55 | + |
| 56 | + cls_token = cls_token.expand_as(feature) |
| 57 | + |
| 58 | + features = torch.cat([feature, cls_token], dim=2) |
| 59 | + features = self.project(features) |
| 60 | + features = self.transpose(features) |
| 61 | + |
| 62 | + features = features.view(batch_size, -1, height_dim, width_dim) |
| 63 | + return features |
| 64 | + |
| 65 | + |
| 66 | +class IgnoreReadout(nn.Module): |
| 67 | + def __init__(self): |
| 68 | + super().__init__() |
| 69 | + |
| 70 | + def forward(self, feature: torch.Tensor, cls_token: torch.Tensor): |
| 71 | + return feature |
| 72 | + |
| 73 | + |
| 74 | +class FeatureProcessBlock(nn.Module): |
| 75 | + """ |
| 76 | + Processes the features such that they have progressively increasing embedding size and progressively decreasing |
| 77 | + spatial dimension |
| 78 | + """ |
| 79 | + |
| 80 | + def __init__( |
| 81 | + self, embed_dim: int, feature_dim: int, out_channel: int, upsample_factor: int |
| 82 | + ): |
| 83 | + super().__init__() |
| 84 | + |
| 85 | + self.project_to_out_channel = nn.Conv2d( |
| 86 | + in_channels=embed_dim, out_channels=out_channel, kernel_size=1 |
| 87 | + ) |
| 88 | + |
| 89 | + if upsample_factor > 1.0: |
| 90 | + self.upsample = nn.ConvTranspose2d( |
| 91 | + in_channels=out_channel, |
| 92 | + out_channels=out_channel, |
| 93 | + kernel_size=int(upsample_factor), |
| 94 | + stride=int(upsample_factor), |
| 95 | + ) |
| 96 | + |
| 97 | + elif upsample_factor == 1.0: |
| 98 | + self.upsample = nn.Identity() |
| 99 | + |
| 100 | + else: |
| 101 | + self.upsample = nn.Conv2d( |
| 102 | + in_channels=out_channel, |
| 103 | + out_channels=out_channel, |
| 104 | + kernel_size=3, |
| 105 | + stride=int(1 / upsample_factor), |
| 106 | + padding=1, |
| 107 | + ) |
| 108 | + |
| 109 | + self.project_to_feature_dim = nn.Conv2d( |
| 110 | + in_channels=out_channel, out_channels=feature_dim, kernel_size=3, padding=1 |
| 111 | + ) |
| 112 | + |
| 113 | + def forward(self, x: torch.Tensor): |
| 114 | + x = self.project_to_out_channel(x) |
| 115 | + x = self.upsample(x) |
| 116 | + x = self.project_to_feature_dim(x) |
| 117 | + |
| 118 | + return x |
| 119 | + |
| 120 | + |
| 121 | +class ResidualConvBlock(nn.Module): |
| 122 | + def __init__(self, feature_dim: int): |
| 123 | + super().__init__() |
| 124 | + self.conv_block = nn.Sequential( |
| 125 | + nn.ReLU(), |
| 126 | + nn.Conv2d( |
| 127 | + in_channels=feature_dim, |
| 128 | + out_channels=feature_dim, |
| 129 | + kernel_size=3, |
| 130 | + padding=1, |
| 131 | + bias=False, |
| 132 | + ), |
| 133 | + nn.BatchNorm2d(num_features=feature_dim), |
| 134 | + nn.ReLU(), |
| 135 | + nn.Conv2d( |
| 136 | + in_channels=feature_dim, |
| 137 | + out_channels=feature_dim, |
| 138 | + kernel_size=3, |
| 139 | + padding=1, |
| 140 | + bias=False, |
| 141 | + ), |
| 142 | + nn.BatchNorm2d(num_features=feature_dim), |
| 143 | + ) |
| 144 | + |
| 145 | + def forward(self, x: torch.Tensor): |
| 146 | + return x + self.conv_block(x) |
| 147 | + |
| 148 | + |
| 149 | +class FusionBlock(nn.Module): |
| 150 | + """ |
| 151 | + Fuses the processed encoder features in a residual manner and upsamples them |
| 152 | + """ |
| 153 | + |
| 154 | + def __init__(self, feature_dim: int): |
| 155 | + super().__init__() |
| 156 | + self.residual_conv_block1 = ResidualConvBlock(feature_dim=feature_dim) |
| 157 | + self.residual_conv_block2 = ResidualConvBlock(feature_dim=feature_dim) |
| 158 | + self.project = nn.Conv2d( |
| 159 | + in_channels=feature_dim, out_channels=feature_dim, kernel_size=1 |
| 160 | + ) |
| 161 | + self.activation = nn.ReLU() |
| 162 | + |
| 163 | + def forward(self, feature: torch.Tensor, preceding_layer_feature: torch.Tensor): |
| 164 | + feature = self.residual_conv_block1(feature) |
| 165 | + |
| 166 | + if preceding_layer_feature is not None: |
| 167 | + feature += preceding_layer_feature |
| 168 | + |
| 169 | + feature = self.residual_conv_block2(feature) |
| 170 | + |
| 171 | + feature = nn.functional.interpolate( |
| 172 | + feature, scale_factor=2, align_corners=True, mode="bilinear" |
| 173 | + ) |
| 174 | + feature = self.project(feature) |
| 175 | + feature = self.activation(feature) |
| 176 | + |
| 177 | + return feature |
| 178 | + |
| 179 | + |
| 180 | +class DPTDecoder(nn.Module): |
| 181 | + """ |
| 182 | + Decoder part for DPT |
| 183 | +
|
| 184 | + Processes the encoder features and class tokens (if encoder has class_tokens) to have spatial downsampling ratios of |
| 185 | + [1/32,1/16,1/8,1/4] relative to the input image spatial dimension. |
| 186 | +
|
| 187 | + The decoder then fuses these features in a residual manner and progressively upsamples them by a factor of 2 so that the |
| 188 | + output has a downsampling ratio of 1/2 relative to the input image spatial dimension |
| 189 | +
|
| 190 | + """ |
| 191 | + |
| 192 | + def __init__( |
| 193 | + self, |
| 194 | + encoder_name: str, |
| 195 | + transformer_embed_dim: int, |
| 196 | + encoder_output_stride: int, |
| 197 | + feature_dim: int = 256, |
| 198 | + encoder_depth: int = 4, |
| 199 | + prefix_token_supported: bool = False, |
| 200 | + ): |
| 201 | + super().__init__() |
| 202 | + |
| 203 | + self.prefix_token_supported = prefix_token_supported |
| 204 | + |
| 205 | + # If encoder has cls token, then concatenate it with the features along the embedding dimension and project it |
| 206 | + # back to the feature_dim dimension. Else, ignore the non-existent cls token |
| 207 | + |
| 208 | + if prefix_token_supported: |
| 209 | + self.readout_blocks = nn.ModuleList( |
| 210 | + [ |
| 211 | + ProjectionReadout( |
| 212 | + in_features=transformer_embed_dim, |
| 213 | + encoder_output_stride=encoder_output_stride, |
| 214 | + ) |
| 215 | + for _ in range(encoder_depth) |
| 216 | + ] |
| 217 | + ) |
| 218 | + else: |
| 219 | + self.readout_blocks = [IgnoreReadout() for _ in range(encoder_depth)] |
| 220 | + |
| 221 | + upsample_factors = [ |
| 222 | + (encoder_output_stride / 2 ** (index + 2)) |
| 223 | + for index in range(0, encoder_depth) |
| 224 | + ] |
| 225 | + feature_processing_out_channels = _get_feature_processing_out_channels( |
| 226 | + encoder_name |
| 227 | + ) |
| 228 | + if encoder_depth < len(feature_processing_out_channels): |
| 229 | + feature_processing_out_channels = feature_processing_out_channels[ |
| 230 | + :encoder_depth |
| 231 | + ] |
| 232 | + |
| 233 | + self.feature_processing_blocks = nn.ModuleList( |
| 234 | + [ |
| 235 | + FeatureProcessBlock( |
| 236 | + transformer_embed_dim, feature_dim, out_channel, upsample_factor |
| 237 | + ) |
| 238 | + for upsample_factor, out_channel in zip( |
| 239 | + upsample_factors, feature_processing_out_channels |
| 240 | + ) |
| 241 | + ] |
| 242 | + ) |
| 243 | + |
| 244 | + self.fusion_blocks = nn.ModuleList( |
| 245 | + [FusionBlock(feature_dim=feature_dim) for _ in range(encoder_depth)] |
| 246 | + ) |
| 247 | + |
| 248 | + def forward( |
| 249 | + self, encoder_output: list[list[torch.Tensor], list[torch.Tensor]] |
| 250 | + ) -> torch.Tensor: |
| 251 | + features, cls_tokens = encoder_output |
| 252 | + processed_features = [] |
| 253 | + |
| 254 | + # Process the encoder features to scale of [1/32,1/16,1/8,1/4] |
| 255 | + for index, (feature, cls_token) in enumerate(zip(features, cls_tokens)): |
| 256 | + readout_feature = self.readout_blocks[index](feature, cls_token) |
| 257 | + processed_feature = self.feature_processing_blocks[index](readout_feature) |
| 258 | + processed_features.append(processed_feature) |
| 259 | + |
| 260 | + preceding_layer_feature = None |
| 261 | + |
| 262 | + # Fusion and progressive upsampling starting from the last processed feature |
| 263 | + processed_features = processed_features[::-1] |
| 264 | + for fusion_block, feature in zip(self.fusion_blocks, processed_features): |
| 265 | + out = fusion_block(feature, preceding_layer_feature) |
| 266 | + preceding_layer_feature = out |
| 267 | + |
| 268 | + return out |
0 commit comments