1111import math
1212import torch
1313import torch .nn as nn
14+ import torch .nn .functional as F
1415from functools import partial
16+ from typing import Dict , Sequence , List
1517
1618from timm .layers import DropPath , to_2tuple , trunc_normal_
1719
1820
1921class LayerNorm (nn .LayerNorm ):
20- def forward (self , x ) :
22+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
2123 if x .ndim == 4 :
22- B , C , H , W = x .shape
23- x = x .view (B , C , - 1 ).transpose (1 , 2 )
24- x = super (). forward ( x )
25- x = x .transpose (1 , 2 ).view (B , C , H , W )
24+ batch_size , channels , height , width = x .shape
25+ x = x .view (batch_size , channels , - 1 ).transpose (1 , 2 )
26+ x = F . layer_norm ( x , self . normalized_shape , self . weight , self . bias , self . eps )
27+ x = x .transpose (1 , 2 ).view (batch_size , channels , height , width )
2628 else :
27- x = super (). forward ( x )
29+ x = F . layer_norm ( x , self . normalized_shape , self . weight , self . bias , self . eps )
2830 return x
2931
3032
@@ -60,9 +62,9 @@ def _init_weights(self, m):
6062 if m .bias is not None :
6163 m .bias .data .zero_ ()
6264
63- def forward (self , x , H , W ) :
65+ def forward (self , x : torch . Tensor , height : int , width : int ) -> torch . Tensor :
6466 x = self .fc1 (x )
65- x = self .dwconv (x , H , W )
67+ x = self .dwconv (x , height , width )
6668 x = self .act (x )
6769 x = self .drop (x )
6870 x = self .fc2 (x )
@@ -101,6 +103,10 @@ def __init__(
101103 if sr_ratio > 1 :
102104 self .sr = nn .Conv2d (dim , dim , kernel_size = sr_ratio , stride = sr_ratio )
103105 self .norm = LayerNorm (dim )
106+ else :
107+ # for torchscript compatibility
108+ self .sr = nn .Identity ()
109+ self .norm = nn .Identity ()
104110
105111 self .apply (self ._init_weights )
106112
@@ -119,27 +125,27 @@ def _init_weights(self, m):
119125 if m .bias is not None :
120126 m .bias .data .zero_ ()
121127
122- def forward (self , x , H , W ) :
123- B , N , C = x .shape
128+ def forward (self , x : torch . Tensor , height : int , width : int ) -> torch . Tensor :
129+ batch_size , N , C = x .shape
124130 q = (
125131 self .q (x )
126- .reshape (B , N , self .num_heads , C // self .num_heads )
132+ .reshape (batch_size , N , self .num_heads , C // self .num_heads )
127133 .permute (0 , 2 , 1 , 3 )
128134 )
129135
130136 if self .sr_ratio > 1 :
131- x_ = x .permute (0 , 2 , 1 ).reshape (B , C , H , W )
132- x_ = self .sr (x_ ).reshape (B , C , - 1 ).permute (0 , 2 , 1 )
137+ x_ = x .permute (0 , 2 , 1 ).reshape (batch_size , C , height , width )
138+ x_ = self .sr (x_ ).reshape (batch_size , C , - 1 ).permute (0 , 2 , 1 )
133139 x_ = self .norm (x_ )
134140 kv = (
135141 self .kv (x_ )
136- .reshape (B , - 1 , 2 , self .num_heads , C // self .num_heads )
142+ .reshape (batch_size , - 1 , 2 , self .num_heads , C // self .num_heads )
137143 .permute (2 , 0 , 3 , 1 , 4 )
138144 )
139145 else :
140146 kv = (
141147 self .kv (x )
142- .reshape (B , - 1 , 2 , self .num_heads , C // self .num_heads )
148+ .reshape (batch_size , - 1 , 2 , self .num_heads , C // self .num_heads )
143149 .permute (2 , 0 , 3 , 1 , 4 )
144150 )
145151 k , v = kv [0 ], kv [1 ]
@@ -148,7 +154,7 @@ def forward(self, x, H, W):
148154 attn = attn .softmax (dim = - 1 )
149155 attn = self .attn_drop (attn )
150156
151- x = (attn @ v ).transpose (1 , 2 ).reshape (B , N , C )
157+ x = (attn @ v ).transpose (1 , 2 ).reshape (batch_size , N , C )
152158 x = self .proj (x )
153159 x = self .proj_drop (x )
154160
@@ -209,12 +215,12 @@ def _init_weights(self, m):
209215 if m .bias is not None :
210216 m .bias .data .zero_ ()
211217
212- def forward (self , x ) :
213- B , _ , H , W = x .shape
218+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
219+ batch_size , _ , height , width = x .shape
214220 x = x .flatten (2 ).transpose (1 , 2 )
215- x = x + self .drop_path (self .attn (self .norm1 (x ), H , W ))
216- x = x + self .drop_path (self .mlp (self .norm2 (x ), H , W ))
217- x = x .transpose (1 , 2 ).view (B , - 1 , H , W )
221+ x = x + self .drop_path (self .attn (self .norm1 (x ), height , width ))
222+ x = x + self .drop_path (self .mlp (self .norm2 (x ), height , width ))
223+ x = x .transpose (1 , 2 ).view (batch_size , - 1 , height , width )
218224 return x
219225
220226
@@ -256,7 +262,7 @@ def _init_weights(self, m):
256262 if m .bias is not None :
257263 m .bias .data .zero_ ()
258264
259- def forward (self , x ) :
265+ def forward (self , x : torch . Tensor ) -> torch . Tensor :
260266 x = self .proj (x )
261267 x = self .norm (x )
262268 return x
@@ -462,7 +468,7 @@ def reset_classifier(self, num_classes, global_pool=""):
462468 nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
463469 )
464470
465- def forward_features (self , x ) :
471+ def forward_features (self , x : torch . Tensor ) -> List [ torch . Tensor ] :
466472 outs = []
467473
468474 # stage 1
@@ -491,21 +497,21 @@ def forward_features(self, x):
491497
492498 return outs
493499
494- def forward (self , x ) :
495- x = self .forward_features (x )
500+ def forward (self , x : torch . Tensor ) -> List [ torch . Tensor ] :
501+ features = self .forward_features (x )
496502 # x = self.head(x)
497503
498- return x
504+ return features
499505
500506
501507class DWConv (nn .Module ):
502508 def __init__ (self , dim = 768 ):
503509 super (DWConv , self ).__init__ ()
504510 self .dwconv = nn .Conv2d (dim , dim , 3 , 1 , 1 , bias = True , groups = dim )
505511
506- def forward (self , x , H , W ) :
507- B , _ , C = x .shape
508- x = x .transpose (1 , 2 ).view (B , C , H , W )
512+ def forward (self , x : torch . Tensor , height : int , width : int ) -> torch . Tensor :
513+ batch_size , _ , channels = x .shape
514+ x = x .transpose (1 , 2 ).view (batch_size , channels , height , width )
509515 x = self .dwconv (x )
510516 x = x .flatten (2 ).transpose (1 , 2 )
511517
@@ -516,7 +522,6 @@ def forward(self, x, H, W):
516522# End of NVIDIA code
517523# ---------------------------------------------------------------
518524
519- from typing import Dict , Sequence , List # noqa E402
520525from ._base import EncoderMixin # noqa E402
521526
522527
0 commit comments