Skip to content

Commit 9f9cfdd

Browse files
committed
change loss name
1 parent f96f21e commit 9f9cfdd

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

orion/primitives/vae.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
LOGGER = logging.getLogger(__name__)
2222

2323

24-
class KLDivergenceLayer(tf.keras.layers.Layer):
24+
class KLDivergenceLoss(tf.keras.layers.Layer):
2525
def call(self, inputs):
2626
z_log_sigma, z_mean = inputs
2727
kl_loss = -0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
@@ -178,7 +178,7 @@ def _build_vae(self, **kwargs):
178178
h = self.encoder(x)
179179
z_mean = tf.keras.layers.Dense(self.latent_dim)(h)
180180
z_log_sigma = tf.keras.layers.Dense(self.latent_dim)(h)
181-
KLDivergenceLayer()([z_log_sigma, z_mean]) # kl loss
181+
KLDivergenceLoss()([z_log_sigma, z_mean]) # kl loss
182182
z = tf.keras.layers.Lambda(self._sampling)([z_mean, z_log_sigma])
183183

184184
y_ = self.generator(z)
@@ -213,7 +213,7 @@ def fit(self, X: np.ndarray, y: np.ndarray, **kwargs):
213213
batch_size=self.batch_size,
214214
epochs=self.epochs,
215215
shuffle=self.shuffle,
216-
verbose=True,
216+
verbose=self.verbose,
217217
callbacks=callbacks,
218218
validation_split=self.validation_split,
219219
)

0 commit comments

Comments
 (0)