2121LOGGER = logging .getLogger (__name__ )
2222
2323
24+ class KLDivergenceLayer (tf .keras .layers .Layer ):
25+ def call (self , inputs ):
26+ z_log_sigma , z_mean = inputs
27+ kl_loss = - 0.5 * K .mean (1 + z_log_sigma - K .square (z_mean ) - K .exp (z_log_sigma ), axis = - 1 )
28+ self .add_loss (kl_loss )
29+ return inputs
30+
31+
2432class VAE (object ):
2533 """VAE model for time series reconstruction.
2634
@@ -117,7 +125,7 @@ def __init__(self, layers_encoder: list, layers_generator: list, optimizer: str,
117125 self .epochs = epochs
118126 self .batch_size = batch_size
119127 self .optimizer = import_object (optimizer )(learning_rate )
120- self .mse_loss = tf .losses .mse
128+ self .mse_loss = tf .keras . losses .MeanSquaredError ()
121129 self .shuffle = shuffle
122130 self .verbose = verbose
123131 self .hyperparameters = hyperparameters
@@ -170,13 +178,13 @@ def _build_vae(self, **kwargs):
170178 h = self .encoder (x )
171179 z_mean = tf .keras .layers .Dense (self .latent_dim )(h )
172180 z_log_sigma = tf .keras .layers .Dense (self .latent_dim )(h )
181+ KLDivergenceLayer ()([z_log_sigma , z_mean ]) # kl loss
173182 z = tf .keras .layers .Lambda (self ._sampling )([z_mean , z_log_sigma ])
174183
175184 y_ = self .generator (z )
176185
177186 self .vae_model = Model ([x , y ], y_ )
178- self .vae_model .add_loss (self ._vae_loss (y , y_ , z_log_sigma , z_mean ))
179- self .vae_model .compile (optimizer = self .optimizer )
187+ self .vae_model .compile (loss = 'mse' , optimizer = self .optimizer )
180188
181189 def fit (self , X : np .ndarray , y : np .ndarray , ** kwargs ):
182190 """Fit the model.
@@ -201,11 +209,11 @@ def fit(self, X: np.ndarray, y: np.ndarray, **kwargs):
201209 for callback in self .callbacks
202210 ]
203211
204- self .fit_history = self .vae_model .fit ((X , y ),
212+ self .fit_history = self .vae_model .fit ((X , y ), y ,
205213 batch_size = self .batch_size ,
206214 epochs = self .epochs ,
207215 shuffle = self .shuffle ,
208- verbose = self . verbose ,
216+ verbose = True ,
209217 callbacks = callbacks ,
210218 validation_split = self .validation_split ,
211219 )
0 commit comments