@@ -8,22 +8,30 @@ def __init__(self):
88 super (ConvolutionalAutoencoder , self ).__init__ ()
99
1010 # Encoder
11- self .enc1 = nn .Conv2d (3 , 64 , kernel_size = 3 , padding = 1 )
12- self .enc2 = nn .Conv2d (64 , 32 , kernel_size = 3 , padding = 1 )
13- self .enc3 = nn .Conv2d (32 , 16 , kernel_size = 3 , padding = 1 )
11+ self .enc0 = nn .Conv2d (3 , 256 , kernel_size = 3 , padding = 1 )
12+ self .enc1 = nn .Conv2d (256 , 128 , kernel_size = 3 , padding = 1 )
13+ self .enc2 = nn .Conv2d (128 , 64 , kernel_size = 3 , padding = 1 )
14+ self .enc3 = nn .Conv2d (64 , 32 , kernel_size = 3 , padding = 1 )
15+ self .enc4 = nn .Conv2d (32 , 16 , kernel_size = 3 , padding = 1 )
1416 self .pool = nn .MaxPool2d (2 , 2 , return_indices = True )
1517
1618 # Decoder
17- self .dec1 = nn .ConvTranspose2d (16 , 32 , kernel_size = 2 , stride = 2 )
18- self .dec2 = nn .ConvTranspose2d (32 , 64 , kernel_size = 2 , stride = 2 )
19- self .dec3 = nn .ConvTranspose2d (64 , 3 , kernel_size = 2 , stride = 2 )
19+ self .dec0 = nn .ConvTranspose2d (16 , 32 , kernel_size = 2 , stride = 2 )
20+ self .dec1 = nn .ConvTranspose2d (32 , 64 , kernel_size = 2 , stride = 2 )
21+ self .dec2 = nn .ConvTranspose2d (64 , 128 , kernel_size = 2 , stride = 2 )
22+ self .dec3 = nn .ConvTranspose2d (128 , 256 , kernel_size = 2 , stride = 2 )
23+ self .dec4 = nn .ConvTranspose2d (256 , 3 , kernel_size = 2 , stride = 2 )
2024
2125 def forward (self , x ):
26+ x , _ = self .pool (F .relu (self .enc0 (x )))
2227 x , _ = self .pool (F .relu (self .enc1 (x )))
2328 x , _ = self .pool (F .relu (self .enc2 (x )))
2429 x , _ = self .pool (F .relu (self .enc3 (x )))
30+ x , _ = self .pool (F .relu (self .enc4 (x )))
2531
32+ x = F .relu (self .dec0 (x ))
2633 x = F .relu (self .dec1 (x ))
2734 x = F .relu (self .dec2 (x ))
28- x = torch .sigmoid (self .dec3 (x ))
35+ x = F .relu (self .dec3 (x ))
36+ x = torch .sigmoid (self .dec4 (x ))
2937 return x
0 commit comments