Skip to content

Commit 66984f1

Browse files
authored
In verbose mode, make the prefix of the progress bar a fixed-length (#494)
1 parent 1a5e093 commit 66984f1

File tree

7 files changed

+70
-13
lines changed

7 files changed

+70
-13
lines changed

ctgan/synthesizers/_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,14 @@ def _set_device(enable_gpu, device=None):
5555
def validate_and_set_device(enable_gpu, cuda):
5656
enable_gpu = get_enable_gpu_value(enable_gpu, cuda)
5757
return _set_device(enable_gpu)
58+
59+
60+
def _format_score(score):
61+
"""Format a score as a fixed-length string ``±XX.XX``.
62+
63+
Values are clipped to the range ``[-99.99, +99.99]`` so the result
64+
is always exactly 6 characters.
65+
"""
66+
score = max(-99.99, min(99.99, score))
67+
sign = '+' if score >= 0 else '-'
68+
return f'{sign}{abs(score):05.2f}'

ctgan/synthesizers/ctgan.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ctgan.data_sampler import DataSampler
1313
from ctgan.data_transformer import DataTransformer
1414
from ctgan.errors import InvalidDataError
15-
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
15+
from ctgan.synthesizers._utils import _format_score, _set_device, validate_and_set_device
1616
from ctgan.synthesizers.base import BaseSynthesizer, random_state
1717

1818

@@ -379,8 +379,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
379379

380380
epoch_iterator = tqdm(range(epochs), disable=(not self._verbose))
381381
if self._verbose:
382-
description = 'Gen. ({gen:.2f}) | Discrim. ({dis:.2f})'
383-
epoch_iterator.set_description(description.format(gen=0, dis=0))
382+
description = 'Gen. ({gen}) | Discrim. ({dis})'
383+
epoch_iterator.set_description(
384+
description.format(gen=_format_score(0), dis=_format_score(0))
385+
)
384386

385387
steps_per_epoch = max(len(train_data) // self._batch_size, 1)
386388
for i in epoch_iterator:
@@ -479,7 +481,10 @@ def fit(self, train_data, discrete_columns=(), epochs=None):
479481

480482
if self._verbose:
481483
epoch_iterator.set_description(
482-
description.format(gen=generator_loss, dis=discriminator_loss)
484+
description.format(
485+
gen=_format_score(generator_loss),
486+
dis=_format_score(discriminator_loss),
487+
)
483488
)
484489

485490
@random_state

ctgan/synthesizers/tvae.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from tqdm import tqdm
1111

1212
from ctgan.data_transformer import DataTransformer
13-
from ctgan.synthesizers._utils import _set_device, validate_and_set_device
13+
from ctgan.synthesizers._utils import _format_score, _set_device, validate_and_set_device
1414
from ctgan.synthesizers.base import BaseSynthesizer, random_state
1515

1616

@@ -161,8 +161,8 @@ def fit(self, train_data, discrete_columns=()):
161161
self.loss_values = pd.DataFrame(columns=['Epoch', 'Batch', 'Loss'])
162162
iterator = tqdm(range(self.epochs), disable=(not self.verbose))
163163
if self.verbose:
164-
iterator_description = 'Loss: {loss:.3f}'
165-
iterator.set_description(iterator_description.format(loss=0))
164+
iterator_description = 'Loss: {loss}'
165+
iterator.set_description(iterator_description.format(loss=_format_score(0)))
166166

167167
for i in iterator:
168168
loss_values = []
@@ -205,7 +205,7 @@ def fit(self, train_data, discrete_columns=()):
205205

206206
if self.verbose:
207207
iterator.set_description(
208-
iterator_description.format(loss=loss.detach().cpu().item())
208+
iterator_description.format(loss=_format_score(loss.detach().cpu().item()))
209209
)
210210

211211
@random_state

tests/integration/synthesizer/test_tvae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,5 +120,5 @@ def test_tvae_save(tmpdir, capsys):
120120
assert len(loss_values) == 10
121121
assert set(loss_values.columns) == {'Epoch', 'Batch', 'Loss'}
122122
assert all(loss_values['Batch'] == 0)
123-
last_loss_val = loss_values['Loss'].iloc[-1]
124-
assert f'Loss: {round(last_loss_val, 3):.3f}: 100%' in captured_out
123+
last_loss_val = max(-99.99, min(99.99, loss_values['Loss'].iloc[-1]))
124+
assert f'Loss: {last_loss_val:+06.2f}: 100%' in captured_out

tests/unit/synthesizer/test__utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55
import pytest
66
import torch
77

8-
from ctgan.synthesizers._utils import _set_device, get_enable_gpu_value, validate_and_set_device
8+
from ctgan.synthesizers._utils import (
9+
_format_score,
10+
_set_device,
11+
get_enable_gpu_value,
12+
validate_and_set_device,
13+
)
914

1015

1116
def test__validate_gpu_parameter():
@@ -61,6 +66,27 @@ def test__set_device():
6166
assert device_4 == torch.device('cpu')
6267

6368

69+
@pytest.mark.parametrize(
70+
'score, expected',
71+
[
72+
(0, '+00.00'),
73+
(1.233434, '+01.23'),
74+
(-0.93, '-00.93'),
75+
(0.01, '+00.01'),
76+
(-1.21, '-01.21'),
77+
(99.99, '+99.99'),
78+
(-99.99, '-99.99'),
79+
(150, '+99.99'),
80+
(-200, '-99.99'),
81+
],
82+
)
83+
def test__format_score(score, expected):
84+
"""Test the ``_format_score`` method."""
85+
result = _format_score(score)
86+
assert result == expected
87+
assert len(result) == 6
88+
89+
6490
@patch('ctgan.synthesizers._utils._set_device')
6591
@patch('ctgan.synthesizers._utils.get_enable_gpu_value')
6692
def test_validate_and_set_device(mock_validate, mock_set_device):

tests/unit/synthesizer/test_ctgan.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,21 @@ def test__cond_loss(self):
286286

287287
assert (result - expected).abs() < 1e-3
288288

289+
@patch('ctgan.synthesizers.ctgan._format_score')
290+
def test_fit_verbose_calls_format_score(self, format_score_mock):
291+
"""Test that ``_format_score`` is called during verbose fitting."""
292+
# Setup
293+
format_score_mock.side_effect = lambda x: f'+{abs(x):05.2f}'
294+
data = pd.DataFrame({'col1': [0, 1, 2, 3, 4], 'col2': ['a', 'b', 'c', 'a', 'b']})
295+
296+
# Run
297+
ctgan = CTGAN(epochs=1, verbose=True)
298+
ctgan.fit(data, discrete_columns=['col2'])
299+
300+
# Assert
301+
assert format_score_mock.call_count == 4
302+
format_score_mock.assert_any_call(0)
303+
289304
def test__validate_discrete_columns(self):
290305
"""Test `_validate_discrete_columns` if the discrete column doesn't exist.
291306

tests/unit/synthesizer/test_tvae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,6 @@ def mock_add(a, b):
6060

6161
# Assert
6262
tqdm_mock.assert_called_once_with(range(epochs), disable=False)
63-
assert iterator_mock.set_description.call_args_list[0] == call('Loss: 0.000')
64-
assert iterator_mock.set_description.call_args_list[1] == call('Loss: 1.235')
63+
assert iterator_mock.set_description.call_args_list[0] == call('Loss: +00.00')
64+
assert iterator_mock.set_description.call_args_list[1] == call('Loss: +01.23')
6565
assert iterator_mock.set_description.call_count == 2

0 commit comments

Comments
 (0)