|
12 | 12 | from ctgan.data_sampler import DataSampler |
13 | 13 | from ctgan.data_transformer import DataTransformer |
14 | 14 | from ctgan.errors import InvalidDataError |
15 | | -from ctgan.synthesizers._utils import _set_device, validate_and_set_device |
| 15 | +from ctgan.synthesizers._utils import _format_score, _set_device, validate_and_set_device |
16 | 16 | from ctgan.synthesizers.base import BaseSynthesizer, random_state |
17 | 17 |
|
18 | 18 |
|
@@ -379,8 +379,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None): |
379 | 379 |
|
380 | 380 | epoch_iterator = tqdm(range(epochs), disable=(not self._verbose)) |
381 | 381 | if self._verbose: |
382 | | - description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})' |
383 | | - epoch_iterator.set_description(description.format(gen=0, dis=0)) |
| 382 | + description = 'Gen. ({gen}) | Discrim. ({dis})' |
| 383 | + epoch_iterator.set_description( |
| 384 | + description.format(gen=_format_score(0), dis=_format_score(0)) |
| 385 | + ) |
384 | 386 |
|
385 | 387 | steps_per_epoch = max(len(train_data) // self._batch_size, 1) |
386 | 388 | for i in epoch_iterator: |
@@ -479,7 +481,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None): |
479 | 481 |
|
480 | 482 | if self._verbose: |
481 | 483 | epoch_iterator.set_description( |
482 | | - description.format(gen=generator_loss, dis=discriminator_loss) |
| 484 | + description.format( |
| 485 | + gen=_format_score(generator_loss), |
| 486 | + dis=_format_score(discriminator_loss), |
| 487 | + ) |
483 | 488 | ) |
484 | 489 |
|
485 | 490 | @random_state |
|
0 commit comments