-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Adding DPT #1079
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Adding DPT #1079
Changes from 5 commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
78ba0e8
Initial timm vit encoder commit
vedantdalimkar 2c38de6
Add DPT model and update logic for TimmViTEncoder class
vedantdalimkar 5599409
Removed redudant documentation
vedantdalimkar c47bdfb
Added intitial test and some minor code modifications
vedantdalimkar 71e2acb
Code refactor
vedantdalimkar e85836d
Added weight conversion script
vedantdalimkar 35cb060
Moved conversion script to appropriate location
vedantdalimkar aa84f4e
Added logic in timm table generation for adding ViT encoders for DPT
67c4a75
Ruff formatting
vedantdalimkar 85f22fb
Code revision
vedantdalimkar ef48032
Remove unnecessary comment
vedantdalimkar 28204ad
Simplify ViT encoder
qubvel 1b9a6f6
Refactor ProjectionReadout
qubvel 334cfbb
Refactor modeling DPT
qubvel 7e1ef3b
Support more encoders
qubvel d65c0f7
Refactor a bit conversion, added validation
qubvel 0a62fe0
Fixup
qubvel e3238ae
Split forward for timm_vit
qubvel df4d087
Rename readout, remove feature_dim
qubvel 8bcb0ed
refactor + add transform
qubvel 6ba6746
Fixup
qubvel 8fd8c77
Refine docs a bit
qubvel 9bf1fd2
Refine docs
qubvel 0e9170f
Refine model size a bit and docs
qubvel a0aa5a8
Add to docs
qubvel 6cfd3be
Add note
qubvel d4b162d
Remove txt
qubvel 5fe80a5
Fix doc
qubvel 0a14972
Fix docstring
qubvel 5b28978
Fixing list in activation
qubvel 0ed621c
Fixing list
qubvel 6207310
Fixing list
qubvel 19eeebe
Fixup, fix type hint
qubvel f2e3f89
Merge branch 'main' into pr/vedantdalimkar/1079
qubvel 1257c4b
Add to README
qubvel 21a164a
Add example
qubvel 8d3ed4f
Add decoder_readout according to initial impl
qubvel 4eb6ec3
Tests update
vedantdalimkar 165b9c0
Fix encoder tests
qubvel 5603707
Fix DPT tests
qubvel 9518964
Refactor a bit
qubvel 38cb944
Tests
qubvel 17d3328
Update gen test models
qubvel 83b9655
Revert gitignore
qubvel 343fbe0
Fix test
qubvel File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should remove this file |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
|Encoder |Pretrained weights |Params, M |Script |Compile |Export | | ||
|--------------------------------|:------------------------------:|:------------------------------:|:------------------------------:|:------------------------------:|:------------------------------:| |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .model import DPT | ||
|
||
__all__ = ["DPT"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,267 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
def _get_feature_processing_out_channels(encoder_name: str) -> list[int]: | ||
""" | ||
Get the output embedding dimensions for the features after decoder processing | ||
""" | ||
|
||
encoder_name = encoder_name.lower() | ||
# Output channels for hybrid ViT encoder after feature processing | ||
if "vit" in encoder_name and "resnet" in encoder_name: | ||
return [256, 512, 768, 768] | ||
|
||
# Output channels for ViT-large,ViT-huge,ViT-giant encoders after feature processing | ||
if "vit" in encoder_name and any( | ||
[variant in encoder_name for variant in ["huge", "large", "giant"]] | ||
): | ||
return [256, 512, 1024, 1024] | ||
|
||
# Output channels for ViT-base and other encoders after feature processing | ||
return [96, 192, 384, 768] | ||
|
||
|
||
class Transpose(nn.Module): | ||
def __init__(self, dim0: int, dim1: int): | ||
super().__init__() | ||
self.dim0 = dim0 | ||
self.dim1 = dim1 | ||
|
||
def forward(self, x: torch.Tensor): | ||
return torch.transpose(x, dim0=self.dim0, dim1=self.dim1) | ||
|
||
|
||
class ProjectionReadout(nn.Module): | ||
""" | ||
Concatenates the cls tokens with the features to make use of the global information aggregated in the cls token. | ||
Projects the combined feature map to the original embedding dimension using a MLP | ||
""" | ||
|
||
def __init__(self, in_features: int, encoder_output_stride: int): | ||
super().__init__() | ||
self.project = nn.Sequential( | ||
nn.Linear(in_features=2 * in_features, out_features=in_features), nn.GELU() | ||
) | ||
|
||
self.flatten = nn.Flatten(start_dim=2) | ||
self.transpose = Transpose(dim0=1, dim1=2) | ||
self.encoder_output_stride = encoder_output_stride | ||
|
||
def forward(self, feature: torch.Tensor, cls_token: torch.Tensor): | ||
batch_size, _, height_dim, width_dim = feature.shape | ||
feature = self.flatten(feature) | ||
feature = self.transpose(feature) | ||
|
||
cls_token = cls_token.expand_as(feature) | ||
|
||
features = torch.cat([feature, cls_token], dim=2) | ||
features = self.project(features) | ||
features = self.transpose(features) | ||
|
||
features = features.view(batch_size, -1, height_dim, width_dim) | ||
return features | ||
|
||
|
||
class IgnoreReadout(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, feature: torch.Tensor, cls_token: torch.Tensor): | ||
return feature | ||
|
||
|
||
class FeatureProcessBlock(nn.Module): | ||
""" | ||
Processes the features such that they have progressively increasing embedding size and progressively decreasing | ||
spatial dimension | ||
""" | ||
|
||
def __init__( | ||
self, embed_dim: int, feature_dim: int, out_channel: int, upsample_factor: int | ||
): | ||
super().__init__() | ||
|
||
self.project_to_out_channel = nn.Conv2d( | ||
in_channels=embed_dim, out_channels=out_channel, kernel_size=1 | ||
) | ||
|
||
if upsample_factor > 1.0: | ||
self.upsample = nn.ConvTranspose2d( | ||
in_channels=out_channel, | ||
out_channels=out_channel, | ||
kernel_size=int(upsample_factor), | ||
stride=int(upsample_factor), | ||
) | ||
|
||
elif upsample_factor == 1.0: | ||
self.upsample = nn.Identity() | ||
|
||
else: | ||
self.upsample = nn.Conv2d( | ||
in_channels=out_channel, | ||
out_channels=out_channel, | ||
kernel_size=3, | ||
stride=int(1 / upsample_factor), | ||
padding=1, | ||
) | ||
qubvel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.project_to_feature_dim = nn.Conv2d( | ||
in_channels=out_channel, out_channels=feature_dim, kernel_size=3, padding=1 | ||
) | ||
|
||
def forward(self, x: torch.Tensor): | ||
x = self.project_to_out_channel(x) | ||
x = self.upsample(x) | ||
x = self.project_to_feature_dim(x) | ||
|
||
return x | ||
|
||
|
||
class ResidualConvBlock(nn.Module): | ||
def __init__(self, feature_dim: int): | ||
super().__init__() | ||
self.conv_block = nn.Sequential( | ||
nn.ReLU(), | ||
nn.Conv2d( | ||
in_channels=feature_dim, | ||
out_channels=feature_dim, | ||
kernel_size=3, | ||
padding=1, | ||
bias=False, | ||
), | ||
nn.BatchNorm2d(num_features=feature_dim), | ||
nn.ReLU(), | ||
nn.Conv2d( | ||
in_channels=feature_dim, | ||
out_channels=feature_dim, | ||
kernel_size=3, | ||
padding=1, | ||
bias=False, | ||
), | ||
nn.BatchNorm2d(num_features=feature_dim), | ||
) | ||
|
||
def forward(self, x: torch.Tensor): | ||
return x + self.conv_block(x) | ||
|
||
|
||
class FusionBlock(nn.Module): | ||
""" | ||
Fuses the processed encoder features in a residual manner and upsamples them | ||
""" | ||
|
||
def __init__(self, feature_dim: int): | ||
super().__init__() | ||
self.residual_conv_block1 = ResidualConvBlock(feature_dim=feature_dim) | ||
self.residual_conv_block2 = ResidualConvBlock(feature_dim=feature_dim) | ||
self.project = nn.Conv2d( | ||
in_channels=feature_dim, out_channels=feature_dim, kernel_size=1 | ||
) | ||
self.activation = nn.ReLU() | ||
|
||
def forward(self, feature: torch.Tensor, preceding_layer_feature: torch.Tensor): | ||
feature = self.residual_conv_block1(feature) | ||
|
||
if preceding_layer_feature is not None: | ||
feature += preceding_layer_feature | ||
|
||
feature = self.residual_conv_block2(feature) | ||
|
||
feature = nn.functional.interpolate( | ||
feature, scale_factor=2, align_corners=True, mode="bilinear" | ||
) | ||
feature = self.project(feature) | ||
feature = self.activation(feature) | ||
|
||
return feature | ||
|
||
|
||
class DPTDecoder(nn.Module): | ||
""" | ||
Decoder part for DPT | ||
|
||
Processes the encoder features and class tokens (if encoder has class_tokens) to have spatial downsampling ratios of | ||
[1/32,1/16,1/8,1/4] relative to the input image spatial dimension. | ||
|
||
The decoder then fuses these features in a residual manner and progressively upsamples them by a factor of 2 so that the | ||
output has a downsampling ratio of 1/2 relative to the input image spatial dimension | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
encoder_name: str, | ||
transformer_embed_dim: int, | ||
encoder_output_stride: int, | ||
feature_dim: int = 256, | ||
encoder_depth: int = 4, | ||
cls_token_supported: bool = False, | ||
): | ||
super().__init__() | ||
|
||
self.cls_token_supported = cls_token_supported | ||
|
||
# If encoder has cls token, then concatenate it with the features along the embedding dimension and project it | ||
# back to the feature_dim dimension. Else, ignore the non-existent cls token | ||
|
||
if cls_token_supported: | ||
self.readout_blocks = nn.ModuleList( | ||
[ | ||
ProjectionReadout( | ||
in_features=transformer_embed_dim, | ||
encoder_output_stride=encoder_output_stride, | ||
) | ||
for _ in range(encoder_depth) | ||
] | ||
) | ||
else: | ||
self.readout_blocks = [IgnoreReadout() for _ in range(encoder_depth)] | ||
|
||
upsample_factors = [ | ||
(encoder_output_stride / 2 ** (index + 2)) | ||
for index in range(0, encoder_depth) | ||
] | ||
feature_processing_out_channels = _get_feature_processing_out_channels( | ||
encoder_name | ||
) | ||
if encoder_depth < len(feature_processing_out_channels): | ||
feature_processing_out_channels = feature_processing_out_channels[ | ||
:encoder_depth | ||
] | ||
|
||
self.feature_processing_blocks = nn.ModuleList( | ||
[ | ||
FeatureProcessBlock( | ||
transformer_embed_dim, feature_dim, out_channel, upsample_factor | ||
) | ||
for upsample_factor, out_channel in zip( | ||
upsample_factors, feature_processing_out_channels | ||
) | ||
] | ||
) | ||
|
||
self.fusion_blocks = nn.ModuleList( | ||
[FusionBlock(feature_dim=feature_dim) for _ in range(encoder_depth)] | ||
) | ||
|
||
def forward( | ||
self, features: list[torch.Tensor], cls_tokens: list[torch.Tensor] | ||
) -> torch.Tensor: | ||
processed_features = [] | ||
|
||
# Process the encoder features to scale of [1/32,1/16,1/8,1/4] | ||
for index, (feature, cls_token) in enumerate(zip(features, cls_tokens)): | ||
readout_feature = self.readout_blocks[index](feature, cls_token) | ||
processed_feature = self.feature_processing_blocks[index](readout_feature) | ||
processed_features.append(processed_feature) | ||
|
||
preceding_layer_feature = None | ||
|
||
# Fusion and progressive upsampling starting from the last processed feature | ||
processed_features = processed_features[::-1] | ||
for fusion_block, feature in zip(self.fusion_blocks, processed_features): | ||
out = fusion_block(feature, preceding_layer_feature) | ||
preceding_layer_feature = out | ||
|
||
return out |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.