Skip to content

Commit 492d117

Browse files
committed
updated vae
1 parent 277ad9b commit 492d117

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

orion/primitives/vae.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121
LOGGER = logging.getLogger(__name__)
2222

2323

24+
class KLDivergenceLayer(tf.keras.layers.Layer):
25+
def call(self, inputs):
26+
z_log_sigma, z_mean = inputs
27+
kl_loss = -0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
28+
self.add_loss(kl_loss)
29+
return inputs
30+
31+
2432
class VAE(object):
2533
"""VAE model for time series reconstruction.
2634
@@ -117,7 +125,7 @@ def __init__(self, layers_encoder: list, layers_generator: list, optimizer: str,
117125
self.epochs = epochs
118126
self.batch_size = batch_size
119127
self.optimizer = import_object(optimizer)(learning_rate)
120-
self.mse_loss = tf.losses.mse
128+
self.mse_loss = tf.keras.losses.MeanSquaredError()
121129
self.shuffle = shuffle
122130
self.verbose = verbose
123131
self.hyperparameters = hyperparameters
@@ -170,13 +178,13 @@ def _build_vae(self, **kwargs):
170178
h = self.encoder(x)
171179
z_mean = tf.keras.layers.Dense(self.latent_dim)(h)
172180
z_log_sigma = tf.keras.layers.Dense(self.latent_dim)(h)
181+
KLDivergenceLayer()([z_log_sigma, z_mean]) # kl loss
173182
z = tf.keras.layers.Lambda(self._sampling)([z_mean, z_log_sigma])
174183

175184
y_ = self.generator(z)
176185

177186
self.vae_model = Model([x, y], y_)
178-
self.vae_model.add_loss(self._vae_loss(y, y_, z_log_sigma, z_mean))
179-
self.vae_model.compile(optimizer=self.optimizer)
187+
self.vae_model.compile(loss='mse', optimizer=self.optimizer)
180188

181189
def fit(self, X: np.ndarray, y: np.ndarray, **kwargs):
182190
"""Fit the model.
@@ -201,11 +209,11 @@ def fit(self, X: np.ndarray, y: np.ndarray, **kwargs):
201209
for callback in self.callbacks
202210
]
203211

204-
self.fit_history = self.vae_model.fit((X, y),
212+
self.fit_history = self.vae_model.fit((X, y), y,
205213
batch_size=self.batch_size,
206214
epochs=self.epochs,
207215
shuffle=self.shuffle,
208-
verbose=self.verbose,
216+
verbose=True,
209217
callbacks=callbacks,
210218
validation_split=self.validation_split,
211219
)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
install_requires = [
20+
'Keras>=3,<4',
2021
'tensorflow>=2.16.1,<2.20',
2122
'numpy>=1.23.5,<2',
2223
'pandas>=1.4.0,<3',

0 commit comments

Comments
 (0)