Skip to content

Commit 2dce31a

Browse files
committed
Fix progress bar
1 parent 1a5e093 commit 2dce31a

File tree

5 files changed

+28
-12
lines changed

5 files changed

+28
-12
lines changed

ctgan/synthesizers/_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,14 @@ def _set_device(enable_gpu, device=None):
5555
def validate_and_set_device(enable_gpu, cuda):
5656
enable_gpu = get_enable_gpu_value(enable_gpu, cuda)
5757
return _set_device(enable_gpu)
58+
59+
60+
def _format_score(score):
61+
"""Format a score as a fixed-length string ``±XX.XX``.
62+
63+
Values are clipped to the range ``[-99.99, +99.99]`` so the result
64+
is always exactly 6 characters.
65+
"""
66+
score = max(-99.99, min(99.99, score))
67+
sign = '+' if score >= 0 else '-'
68+
return f'{sign}{abs(score):05.2f}'

ctgan/synthesizers/ctgan.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ctgan.data_sampler import DataSampler
1313
from ctgan.data_transformer import DataTransformer
1414
from ctgan.errors import InvalidDataError
15-
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
15+
from ctgan.synthesizers._utils import _format_score, _set_device, validate_and_set_device
1616
from ctgan.synthesizers.base import BaseSynthesizer, random_state
1717

1818

@@ -379,8 +379,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
379379

380380
epoch_iterator = tqdm(range(epochs), disable=(not self._verbose))
381381
if self._verbose:
382-
description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})'
383-
epoch_iterator.set_description(description.format(gen=0, dis=0))
382+
description = 'Gen. ({gen}) | Discrim. ({dis})'
383+
epoch_iterator.set_description(
384+
description.format(gen=_format_score(0), dis=_format_score(0))
385+
)
384386

385387
steps_per_epoch = max(len(train_data) // self._batch_size, 1)
386388
for i in epoch_iterator:
@@ -479,7 +481,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
479481

480482
if self._verbose:
481483
epoch_iterator.set_description(
482-
description.format(gen=generator_loss, dis=discriminator_loss)
484+
description.format(
485+
gen=_format_score(generator_loss),
486+
dis=_format_score(discriminator_loss),
487+
)
483488
)
484489

485490
@random_state

ctgan/synthesizers/tvae.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tqdm import tqdm
1111

1212
from ctgan.data_transformer import DataTransformer
13-
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
13+
from ctgan.synthesizers._utils import _format_score, _set_device, validate_and_set_device
1414
from ctgan.synthesizers.base import BaseSynthesizer, random_state
1515

1616

@@ -161,8 +161,8 @@ def fit(self, train_data, discrete_columns=()):
161161
self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
162162
iterator = tqdm(range(self.epochs), disable=(not self.verbose))
163163
if self.verbose:
164-
iterator_description = 'Loss: {loss:.3f}'
165-
iterator.set_description(iterator_description.format(loss=0))
164+
iterator_description = 'Loss: {loss}'
165+
iterator.set_description(iterator_description.format(loss=_format_score(0)))
166166

167167
for i in iterator:
168168
loss_values = []
@@ -205,7 +205,7 @@ def fit(self, train_data, discrete_columns=()):
205205

206206
if self.verbose:
207207
iterator.set_description(
208-
iterator_description.format(loss=loss.detach().cpu().item())
208+
iterator_description.format(loss=_format_score(loss.detach().cpu().item()))
209209
)
210210

211211
@random_state

tests/integration/synthesizer/test_tvae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,5 +120,5 @@ def test_tvae_save(tmpdir, capsys):
120120
assert len(loss_values) == 10
121121
assert set(loss_values.columns) == {'Epoch', 'Batch', 'Loss'}
122122
assert all(loss_values['Batch'] == 0)
123-
last_loss_val = loss_values['Loss'].iloc[-1]
124-
assert f'Loss: {round(last_loss_val, 3):.3f}: 100%' in captured_out
123+
last_loss_val = max(-99.99, min(99.99, loss_values['Loss'].iloc[-1]))
124+
assert f'Loss: {last_loss_val:+06.2f}: 100%' in captured_out

tests/unit/synthesizer/test_tvae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,6 @@ def mock_add(a, b):
6060

6161
# Assert
6262
tqdm_mock.assert_called_once_with(range(epochs), disable=False)
63-
assert iterator_mock.set_description.call_args_list[0] == call('Loss: 0.000')
64-
assert iterator_mock.set_description.call_args_list[1] == call('Loss: 1.235')
63+
assert iterator_mock.set_description.call_args_list[0] == call('Loss: +00.00')
64+
assert iterator_mock.set_description.call_args_list[1] == call('Loss: +01.23')
6565
assert iterator_mock.set_description.call_count == 2

0 commit comments

Comments
 (0)