1- import torch
21import torch .nn as nn
3- import torch .nn .functional as F
42
53
6- # Autoencoder Linear
74class Autoencoder (nn .Module ):
85 def __init__ (self , input_dim , encoding_dim ):
6+ print ('***** Autoencoder input_dim:' , input_dim )
97 super (Autoencoder , self ).__init__ ()
108 self .encoder = nn .Sequential (
11- nn .Linear (input_dim , 128 ),
9+ nn .Linear (input_dim , 1024 ),
10+ nn .ReLU (),
11+ nn .Linear (1024 , 512 ),
12+ nn .ReLU (),
13+ nn .Linear (512 , 256 ),
14+ nn .ReLU (),
15+ nn .Linear (256 , 128 ),
1216 nn .ReLU (),
1317 nn .Linear (128 , 64 ),
1418 nn .ReLU (),
@@ -19,135 +23,17 @@ def __init__(self, input_dim, encoding_dim):
1923 nn .ReLU (),
2024 nn .Linear (64 , 128 ),
2125 nn .ReLU (),
22- nn .Linear (128 , input_dim ),
26+ nn .Linear (128 , 256 ),
27+ nn .ReLU (),
28+ nn .Linear (256 , 512 ),
29+ nn .ReLU (),
30+ nn .Linear (512 , 1024 ),
31+ nn .ReLU (),
32+ nn .Linear (1024 , input_dim ),
2333 nn .Sigmoid ()
2434 )
2535
2636 def forward (self , x ):
2737 x = self .encoder (x )
2838 x = self .decoder (x )
2939 return x
30-
31-
32- # Autoencoder Convolucional
33- class ConvolutionalAutoencoder (nn .Module ):
34- def __init__ (self ):
35- super (ConvolutionalAutoencoder , self ).__init__ ()
36-
37- # Encoder
38- self .enc1 = nn .Conv2d (3 , 64 , kernel_size = 3 , padding = 1 )
39- self .enc2 = nn .Conv2d (64 , 32 , kernel_size = 3 , padding = 1 )
40- self .enc3 = nn .Conv2d (32 , 16 , kernel_size = 3 , padding = 1 )
41- self .pool = nn .MaxPool2d (2 , 2 , return_indices = True )
42-
43- # Decoder
44- self .dec1 = nn .ConvTranspose2d (16 , 32 , kernel_size = 2 , stride = 2 )
45- self .dec2 = nn .ConvTranspose2d (32 , 64 , kernel_size = 2 , stride = 2 )
46- self .dec3 = nn .ConvTranspose2d (64 , 3 , kernel_size = 2 , stride = 2 )
47-
48- def forward (self , x ):
49- x , idxs1 = self .pool (F .relu (self .enc1 (x )))
50- x , idxs2 = self .pool (F .relu (self .enc2 (x )))
51- x , idxs3 = self .pool (F .relu (self .enc3 (x )))
52-
53- x = F .relu (self .dec1 (x ))
54- x = F .relu (self .dec2 (x ))
55- x = torch .sigmoid (self .dec3 (x ))
56- return x
57-
58-
59- # Variational Autoencoder
60- class VariationalAutoencoder (nn .Module ):
61- def __init__ (self , encoding_dim = 128 ):
62- super (VariationalAutoencoder , self ).__init__ ()
63-
64- # Encoder
65- self .enc1 = nn .Linear (3 * 64 * 64 , 512 )
66- self .enc2 = nn .Linear (512 , 256 )
67- self .enc3 = nn .Linear (256 , encoding_dim )
68-
69- # Latent space
70- self .fc_mu = nn .Linear (encoding_dim , encoding_dim )
71- self .fc_log_var = nn .Linear (encoding_dim , encoding_dim )
72-
73- # Decoder
74- self .dec1 = nn .Linear (encoding_dim , encoding_dim )
75- self .dec2 = nn .Linear (encoding_dim , 256 )
76- self .dec3 = nn .Linear (256 , 512 )
77- self .dec4 = nn .Linear (512 , 3 * 64 * 64 )
78-
79- def reparameterize (self , mu , log_var ):
80- std = torch .exp (0.5 * log_var )
81- eps = torch .randn_like (std )
82- return mu + eps * std
83-
84- def forward (self , x ):
85- x = F .relu (self .enc1 (x ))
86- x = F .relu (self .enc2 (x ))
87- x = F .relu (self .enc3 (x ))
88-
89- mu = self .fc_mu (x )
90- log_var = self .fc_log_var (x )
91- z = self .reparameterize (mu , log_var )
92-
93- x = F .relu (self .dec1 (z ))
94- x = F .relu (self .dec2 (x ))
95- x = F .relu (self .dec3 (x ))
96- x = torch .sigmoid (self .dec4 (x ))
97-
98- return x , mu , log_var
99-
100-
101- # Convolucional Variational Autoencoder
102- class ConvolutionalVAE (nn .Module ):
103- def __init__ (self ):
104- super (ConvolutionalVAE , self ).__init__ ()
105-
106- # Encoder
107- self .enc1 = nn .Conv2d (3 , 64 , kernel_size = 3 , padding = 1 )
108- self .enc2 = nn .Conv2d (64 , 32 , kernel_size = 3 , padding = 1 )
109- self .enc3 = nn .Conv2d (32 , 16 , kernel_size = 3 , padding = 1 )
110- self .pool = nn .MaxPool2d (2 , 2 )
111-
112- self .fc_mu = nn .Linear (16 * 8 * 8 , 128 )
113- self .fc_log_var = nn .Linear (16 * 8 * 8 , 128 )
114-
115- # Decoder
116- self .decoder_input = nn .Linear (128 , 16 * 8 * 8 )
117- self .dec1 = nn .ConvTranspose2d (16 , 32 , kernel_size = 3 , padding = 1 )
118- self .dec2 = nn .ConvTranspose2d (32 , 64 , kernel_size = 3 , padding = 1 )
119- self .dec3 = nn .ConvTranspose2d (64 , 3 , kernel_size = 3 , padding = 1 )
120-
121- self .upsample = nn .Upsample (scale_factor = 2 , mode = 'bilinear' , align_corners = True )
122-
123- def reparameterize (self , mu , log_var ):
124- std = torch .exp (0.5 * log_var )
125- eps = torch .randn_like (std )
126- return mu + eps * std
127-
128- def forward (self , x ):
129- # Encoding
130- x = F .relu (self .enc1 (x ))
131- x = self .pool (x )
132- x = F .relu (self .enc2 (x ))
133- x = self .pool (x )
134- x = F .relu (self .enc3 (x ))
135- x = self .pool (x )
136-
137- x = x .view (x .size (0 ), - 1 ) # Flatten
138-
139- mu = self .fc_mu (x )
140- log_var = self .fc_log_var (x )
141- z = self .reparameterize (mu , log_var )
142-
143- # Decoding
144- x = self .decoder_input (z )
145- x = x .view (x .size (0 ), 16 , 8 , 8 ) # Unflatten
146- x = self .upsample (x )
147- x = F .relu (self .dec1 (x ))
148- x = self .upsample (x )
149- x = F .relu (self .dec2 (x ))
150- x = self .upsample (x )
151- x = torch .sigmoid (self .dec3 (x ))
152-
153- return x , mu , log_var
0 commit comments