Skip to content

Commit c3379d5

Browse files
committed
Fix progress bar
1 parent 4f7a0fa commit c3379d5

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

deepecho/models/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@
66
from deepecho.sequences import assemble_sequences
77

88

9+
def _format_score(score):
10+
"""Format a score as a fixed-length string ``±XX.XX``.
11+
12+
Values are clipped to the range ``[-99.99, +99.99]`` so the result
13+
is always exactly 6 characters.
14+
"""
15+
score = max(-99.99, min(99.99, score))
16+
sign = '+' if score >= 0 else '-'
17+
return f'{sign}{abs(score):05.2f}'
18+
19+
920
class DeepEcho:
1021
"""The base class for DeepEcho models."""
1122

deepecho/models/basic_gan.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from tqdm import tqdm
99

10-
from deepecho.models.base import DeepEcho
10+
from deepecho.models.base import DeepEcho, _format_score
1111

1212
LOGGER = logging.getLogger(__name__)
1313

@@ -547,7 +547,10 @@ def fit_sequences(self, sequences, context_types, data_types):
547547
if self._verbose:
548548
d_loss = discriminator_score.item()
549549
g_loss = generator_score.item()
550-
iterator.set_description(f'Epoch {epoch + 1} | D Loss {d_loss} | G Loss {g_loss}')
550+
iterator.set_description(
551+
f'Epoch {epoch + 1} | D Loss {_format_score(d_loss)}'
552+
f' | G Loss {_format_score(g_loss)}'
553+
)
551554

552555
def sample_sequence(self, context, sequence_length=None):
553556
"""Sample a single sequence conditioned on context.

deepecho/models/par.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from tqdm import tqdm
99

10-
from deepecho.models.base import DeepEcho
10+
from deepecho.models.base import DeepEcho, _format_score
1111

1212
LOGGER = logging.getLogger(__name__)
1313

@@ -336,8 +336,8 @@ def fit_sequences(self, sequences, context_types, data_types):
336336

337337
iterator = tqdm(range(self.epochs), disable=(not self.verbose))
338338
if self.verbose:
339-
pbar_description = 'Loss ({loss:.3f})'
340-
iterator.set_description(pbar_description.format(loss=0))
339+
pbar_description = 'Loss ({loss})'
340+
iterator.set_description(pbar_description.format(loss=_format_score(0)))
341341

342342
# Reset loss_values dataframe
343343
self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])
@@ -364,7 +364,7 @@ def fit_sequences(self, sequences, context_types, data_types):
364364
self.loss_values = epoch_loss_df
365365

366366
if self.verbose:
367-
iterator.set_description(pbar_description.format(loss=loss.item()))
367+
iterator.set_description(pbar_description.format(loss=_format_score(loss.item())))
368368

369369
optimizer.step()
370370

0 commit comments

Comments
 (0)