1+ import torch
12import torch .nn as nn
3+ import torch .nn .functional as F
24
35
6+ # Autoencoder Linear
47class Autoencoder (nn .Module ):
58 def __init__ (self , input_dim , encoding_dim ):
69 super (Autoencoder , self ).__init__ ()
@@ -24,3 +27,127 @@ def forward(self, x):
2427 x = self .encoder (x )
2528 x = self .decoder (x )
2629 return x
30+
31+
32+ # Autoencoder Convolucional
33+ class ConvolutionalAutoencoder (nn .Module ):
34+ def __init__ (self ):
35+ super (ConvolutionalAutoencoder , self ).__init__ ()
36+
37+ # Encoder
38+ self .enc1 = nn .Conv2d (3 , 64 , kernel_size = 3 , padding = 1 )
39+ self .enc2 = nn .Conv2d (64 , 32 , kernel_size = 3 , padding = 1 )
40+ self .enc3 = nn .Conv2d (32 , 16 , kernel_size = 3 , padding = 1 )
41+ self .pool = nn .MaxPool2d (2 , 2 , return_indices = True )
42+
43+ # Decoder
44+ self .dec1 = nn .ConvTranspose2d (16 , 32 , kernel_size = 2 , stride = 2 )
45+ self .dec2 = nn .ConvTranspose2d (32 , 64 , kernel_size = 2 , stride = 2 )
46+ self .dec3 = nn .ConvTranspose2d (64 , 3 , kernel_size = 2 , stride = 2 )
47+
48+ def forward (self , x ):
49+ x , idxs1 = self .pool (F .relu (self .enc1 (x )))
50+ x , idxs2 = self .pool (F .relu (self .enc2 (x )))
51+ x , idxs3 = self .pool (F .relu (self .enc3 (x )))
52+
53+ x = F .relu (self .dec1 (x ))
54+ x = F .relu (self .dec2 (x ))
55+ x = torch .sigmoid (self .dec3 (x ))
56+ return x
57+
58+
59+ # Variational Autoencoder
60+ class VariationalAutoencoder (nn .Module ):
61+ def __init__ (self , encoding_dim = 128 ):
62+ super (VariationalAutoencoder , self ).__init__ ()
63+
64+ # Encoder
65+ self .enc1 = nn .Linear (3 * 64 * 64 , 512 )
66+ self .enc2 = nn .Linear (512 , 256 )
67+ self .enc3 = nn .Linear (256 , encoding_dim )
68+
69+ # Latent space
70+ self .fc_mu = nn .Linear (encoding_dim , encoding_dim )
71+ self .fc_log_var = nn .Linear (encoding_dim , encoding_dim )
72+
73+ # Decoder
74+ self .dec1 = nn .Linear (encoding_dim , encoding_dim )
75+ self .dec2 = nn .Linear (encoding_dim , 256 )
76+ self .dec3 = nn .Linear (256 , 512 )
77+ self .dec4 = nn .Linear (512 , 3 * 64 * 64 )
78+
79+ def reparameterize (self , mu , log_var ):
80+ std = torch .exp (0.5 * log_var )
81+ eps = torch .randn_like (std )
82+ return mu + eps * std
83+
84+ def forward (self , x ):
85+ x = F .relu (self .enc1 (x ))
86+ x = F .relu (self .enc2 (x ))
87+ x = F .relu (self .enc3 (x ))
88+
89+ mu = self .fc_mu (x )
90+ log_var = self .fc_log_var (x )
91+ z = self .reparameterize (mu , log_var )
92+
93+ x = F .relu (self .dec1 (z ))
94+ x = F .relu (self .dec2 (x ))
95+ x = F .relu (self .dec3 (x ))
96+ x = torch .sigmoid (self .dec4 (x ))
97+
98+ return x , mu , log_var
99+
100+
101+ # Convolucional Variational Autoencoder
102+ class ConvolutionalVAE (nn .Module ):
103+ def __init__ (self ):
104+ super (ConvolutionalVAE , self ).__init__ ()
105+
106+ # Encoder
107+ self .enc1 = nn .Conv2d (3 , 64 , kernel_size = 3 , padding = 1 )
108+ self .enc2 = nn .Conv2d (64 , 32 , kernel_size = 3 , padding = 1 )
109+ self .enc3 = nn .Conv2d (32 , 16 , kernel_size = 3 , padding = 1 )
110+ self .pool = nn .MaxPool2d (2 , 2 )
111+
112+ self .fc_mu = nn .Linear (16 * 8 * 8 , 128 )
113+ self .fc_log_var = nn .Linear (16 * 8 * 8 , 128 )
114+
115+ # Decoder
116+ self .decoder_input = nn .Linear (128 , 16 * 8 * 8 )
117+ self .dec1 = nn .ConvTranspose2d (16 , 32 , kernel_size = 3 , padding = 1 )
118+ self .dec2 = nn .ConvTranspose2d (32 , 64 , kernel_size = 3 , padding = 1 )
119+ self .dec3 = nn .ConvTranspose2d (64 , 3 , kernel_size = 3 , padding = 1 )
120+
121+ self .upsample = nn .Upsample (scale_factor = 2 , mode = 'bilinear' , align_corners = True )
122+
123+ def reparameterize (self , mu , log_var ):
124+ std = torch .exp (0.5 * log_var )
125+ eps = torch .randn_like (std )
126+ return mu + eps * std
127+
128+ def forward (self , x ):
129+ # Encoding
130+ x = F .relu (self .enc1 (x ))
131+ x = self .pool (x )
132+ x = F .relu (self .enc2 (x ))
133+ x = self .pool (x )
134+ x = F .relu (self .enc3 (x ))
135+ x = self .pool (x )
136+
137+ x = x .view (x .size (0 ), - 1 ) # Flatten
138+
139+ mu = self .fc_mu (x )
140+ log_var = self .fc_log_var (x )
141+ z = self .reparameterize (mu , log_var )
142+
143+ # Decoding
144+ x = self .decoder_input (z )
145+ x = x .view (x .size (0 ), 16 , 8 , 8 ) # Unflatten
146+ x = self .upsample (x )
147+ x = F .relu (self .dec1 (x ))
148+ x = self .upsample (x )
149+ x = F .relu (self .dec2 (x ))
150+ x = self .upsample (x )
151+ x = torch .sigmoid (self .dec3 (x ))
152+
153+ return x , mu , log_var
0 commit comments