Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions deepecho/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from deepecho.sequences import assemble_sequences


def _format_score(score):
"""Format a score as a fixed-length string ``±XX.XX``.

Values are clipped to the range ``[-99.99, +99.99]`` so the result
is always exactly 6 characters.
"""
score = max(-99.99, min(99.99, score))
sign = '+' if score >= 0 else '-'
return f'{sign}{abs(score):05.2f}'


class DeepEcho:
"""The base class for DeepEcho models."""

Expand Down
7 changes: 5 additions & 2 deletions deepecho/models/basic_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from tqdm import tqdm

from deepecho.models.base import DeepEcho
from deepecho.models.base import DeepEcho, _format_score

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -547,7 +547,10 @@ def fit_sequences(self, sequences, context_types, data_types):
if self._verbose:
d_loss = discriminator_score.item()
g_loss = generator_score.item()
iterator.set_description(f'Epoch {epoch + 1} | D Loss {d_loss} | G Loss {g_loss}')
iterator.set_description(
f'Epoch {epoch + 1} | D Loss {_format_score(d_loss)}'
f' | G Loss {_format_score(g_loss)}'
)

def sample_sequence(self, context, sequence_length=None):
"""Sample a single sequence conditioned on context.
Expand Down
8 changes: 4 additions & 4 deletions deepecho/models/par.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from tqdm import tqdm

from deepecho.models.base import DeepEcho
from deepecho.models.base import DeepEcho, _format_score

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -336,8 +336,8 @@ def fit_sequences(self, sequences, context_types, data_types):

iterator = tqdm(range(self.epochs), disable=(not self.verbose))
if self.verbose:
pbar_description = 'Loss ({loss:.3f})'
iterator.set_description(pbar_description.format(loss=0))
pbar_description = 'Loss ({loss})'
iterator.set_description(pbar_description.format(loss=_format_score(0)))

# Reset loss_values dataframe
self.loss_values = pd.DataFrame(columns=['Epoch', 'Loss'])
Expand All @@ -364,7 +364,7 @@ def fit_sequences(self, sequences, context_types, data_types):
self.loss_values = epoch_loss_df

if self.verbose:
iterator.set_description(pbar_description.format(loss=loss.item()))
iterator.set_description(pbar_description.format(loss=_format_score(loss.item())))

optimizer.step()

Expand Down
26 changes: 26 additions & 0 deletions tests/unit/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""Unit tests for the ``base`` module."""

import pytest

from deepecho.models.base import _format_score


@pytest.mark.parametrize(
'score, expected',
[
(0, '+00.00'),
(1.233434, '+01.23'),
(-0.93, '-00.93'),
(0.01, '+00.01'),
(-1.21, '-01.21'),
(99.99, '+99.99'),
(-99.99, '-99.99'),
(150, '+99.99'),
(-200, '-99.99'),
],
)
def test__format_score(score, expected):
"""Test the ``_format_score`` method."""
result = _format_score(score)
assert result == expected
assert len(result) == 6
Loading