@@ -113,8 +113,8 @@ def __init__(
113
113
)
114
114
115
115
def forward (self , * features ):
116
- # Resize all FPN features to the size of the largest feature
117
- target_size = features [ 0 ]. shape [ 2 : ]
116
+ output_size = features [ 0 ]. shape [ 2 :]
117
+ target_size = [ size // 4 for size in output_size ]
118
118
119
119
features = features [1 :] # remove first skip with same spatial resolution
120
120
features = features [::- 1 ] # reverse channels to start from head of encoder
@@ -126,6 +126,7 @@ def forward(self, *features):
126
126
fpn_feature = stage (fpn_features [- 1 ], feature )
127
127
fpn_features .append (fpn_feature )
128
128
129
+ # Resize all FPN features to 1/4 of the original resolution.
129
130
resized_fpn_features = []
130
131
for feature in fpn_features :
131
132
resized_feature = F .interpolate (
@@ -134,5 +135,8 @@ def forward(self, *features):
134
135
resized_fpn_features .append (resized_feature )
135
136
136
137
output = self .fpn_bottleneck (torch .cat (resized_fpn_features , dim = 1 ))
138
+ output = F .interpolate (
139
+ output , size = output_size , mode = "bilinear" , align_corners = False
140
+ )
137
141
138
142
return output
0 commit comments