1+ from collections .abc import Sequence
2+ from typing import Literal
3+
14import torch
25import torch .nn as nn
36import torch .nn .functional as F
@@ -44,7 +47,9 @@ def forward(self, x):
4447
4548
4649class FPABlock (nn .Module ):
47- def __init__ (self , in_channels , out_channels , upscale_mode = "bilinear" ):
50+ def __init__ (
51+ self , in_channels : int , out_channels : int , upscale_mode : str = "bilinear"
52+ ):
4853 super (FPABlock , self ).__init__ ()
4954
5055 self .upscale_mode = upscale_mode
@@ -118,7 +123,9 @@ def forward(self, x):
118123 mid = self .mid (x )
119124 x1 = self .down1 (x )
120125 x2 = self .down2 (x1 )
126+ print (x2 .shape )
121127 x3 = self .down3 (x2 )
128+ print (x3 .shape )
122129 x3 = F .interpolate (x3 , size = (h // 4 , w // 4 ), ** upscale_parameters )
123130
124131 x2 = self .conv2 (x2 )
@@ -176,9 +183,9 @@ def forward(self, x, y):
176183class PANDecoder (nn .Module ):
177184 def __init__ (
178185 self ,
179- encoder_channels ,
180- encoder_depth ,
181- decoder_channels ,
186+ encoder_channels : Sequence [ int ] ,
187+ encoder_depth : Literal [ 3 , 4 , 5 ] ,
188+ decoder_channels : int ,
182189 upscale_mode : str = "bilinear" ,
183190 ):
184191 super ().__init__ ()
@@ -197,11 +204,14 @@ def __init__(
197204 )
198205
199206 for i in range (1 , len (encoder_channels )):
200- self .add_module (f"gau{ len (encoder_channels )- i } " , GAUBlock (
201- in_channels = encoder_channels [i ],
202- out_channels = decoder_channels ,
203- upscale_mode = upscale_mode ,
204- ))
207+ self .add_module (
208+ f"gau{ len (encoder_channels )- i } " ,
209+ GAUBlock (
210+ in_channels = encoder_channels [i ],
211+ out_channels = decoder_channels ,
212+ upscale_mode = upscale_mode ,
213+ ),
214+ )
205215
206216 def forward (self , * features ):
207217 features = features [2 :] # remove first and second skip
0 commit comments