Skip to content

Commit ebdec3d

Browse files
Add option to preserve the median when morphing (#188)
- Added `--with-median` to CLI and `with_median` to morpher and utility functions to have the option of preserving the median (optional because it is slower and more restrictive) - Restyled plot when median is present to show groups of x and y so it is easier to read --------- Co-authored-by: Stefanie Molin <[email protected]>
1 parent 18ae0bb commit ebdec3d

File tree

10 files changed

+148
-38
lines changed

10 files changed

+148
-38
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ checks = [
203203
]
204204
exclude = [ # don't report on checks for these
205205
'\.__init__$',
206+
'\.__iter__$',
206207
'\.__repr__$',
207208
'\.__str__$',
208209
]

src/data_morph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@
2828
and `this slide deck <https://stefaniemolin.com/data-morph-talk/#/>`_.
2929
"""
3030

31-
__version__ = '0.3.1'
31+
__version__ = '0.4.0.dev0'
3232
MAIN_DIR = __name__

src/data_morph/cli.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,15 @@ def generate_parser() -> argparse.ArgumentParser:
154154
'iterations (see ``--iterations``).'
155155
),
156156
)
157+
morph_config_group.add_argument(
158+
'--with-median',
159+
default=False,
160+
action='store_true',
161+
help=(
162+
'Whether to require the median to be preserved. Note that this will be a '
163+
'little slower.'
164+
),
165+
)
157166

158167
file_group = parser.add_argument_group(
159168
'Output File Configuration',
@@ -285,6 +294,7 @@ def _morph(
285294
forward_only_animation=args.forward_only,
286295
num_frames=100,
287296
in_notebook=False,
297+
with_median=args.with_median,
288298
)
289299

290300
_ = morpher.morph(
@@ -399,6 +409,7 @@ def _serialize(args: argparse.Namespace, target_shapes: Sequence[str]) -> None:
399409
forward_only_animation=args.forward_only,
400410
num_frames=100,
401411
in_notebook=False,
412+
with_median=args.with_median,
402413
)
403414

404415
for target_shape in target_shapes:

src/data_morph/data/stats.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,59 @@
11
"""Utility functions for calculating summary statistics."""
22

3-
from collections import namedtuple
3+
from __future__ import annotations
44

5-
import pandas as pd
5+
from typing import TYPE_CHECKING, NamedTuple
66

7-
SummaryStatistics = namedtuple(
8-
'SummaryStatistics', ['x_mean', 'y_mean', 'x_stdev', 'y_stdev', 'correlation']
9-
)
10-
SummaryStatistics.__doc__ = (
11-
'Named tuple containing the summary statistics for plotting/analysis.'
12-
)
7+
if TYPE_CHECKING:
8+
from collections.abc import Generator
139

10+
import pandas as pd
1411

15-
def get_summary_statistics(data: pd.DataFrame) -> SummaryStatistics:
12+
13+
class SummaryStatistics(NamedTuple):
14+
"""Named tuple containing the summary statistics for plotting/analysis."""
15+
16+
x_mean: float
17+
y_mean: float
18+
19+
x_stdev: float
20+
y_stdev: float
21+
22+
correlation: float
23+
24+
x_median: float | None
25+
y_median: float | None
26+
27+
def __iter__(self) -> Generator[float, None, None]:
28+
for statistic in self._fields:
29+
if (value := getattr(self, statistic)) is not None:
30+
yield value
31+
32+
33+
def get_summary_statistics(data: pd.DataFrame, with_median: bool) -> SummaryStatistics:
1634
"""
1735
Calculate the summary statistics for the given set of points.
1836
1937
Parameters
2038
----------
2139
data : pandas.DataFrame
22-
A dataset with columns x and y.
40+
A dataset with columns ``x`` and ``y``.
41+
with_median : bool
42+
Whether to include the median of ``x`` and ``y``.
2343
2444
Returns
2545
-------
2646
SummaryStatistics
27-
Named tuple consisting of mean and standard deviations of x and y,
28-
along with the Pearson correlation coefficient between the two.
47+
Named tuple consisting of mean and standard deviations of ``x`` and ``y``,
48+
along with the Pearson correlation coefficient between the two, and optionally,
49+
the median of ``x`` and ``y``.
2950
"""
3051
return SummaryStatistics(
31-
data.x.mean(),
32-
data.y.mean(),
33-
data.x.std(),
34-
data.y.std(),
35-
data.corr().x.y,
52+
x_mean=data.x.mean(),
53+
y_mean=data.y.mean(),
54+
x_stdev=data.x.std(),
55+
y_stdev=data.y.std(),
56+
correlation=data.corr().x.y,
57+
x_median=data.x.median() if with_median else None,
58+
y_median=data.y.median() if with_median else None,
3659
)

src/data_morph/morpher.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class DataMorpher:
6060
forward_only_animation : bool, default ``False``
6161
Whether to generate the animation in the forward direction only.
6262
By default, the animation will play forward and then reverse.
63+
with_median : bool, default ``False``
64+
Whether to preserve the median in addition to the other summary statistics.
65+
Note that this will be a little slower.
6366
"""
6467

6568
def __init__(
@@ -74,6 +77,7 @@ def __init__(
7477
num_frames: int = 100,
7578
keep_frames: bool = False,
7679
forward_only_animation: bool = False,
80+
with_median: bool = False,
7781
) -> None:
7882
self._rng = np.random.default_rng(seed)
7983

@@ -129,6 +133,8 @@ def __init__(
129133

130134
self._ProgressTracker = partial(DataMorphProgress, not self._in_notebook)
131135

136+
self._with_median = with_median
137+
132138
def _select_frames(
133139
self, iterations: int, ease_in: bool, ease_out: bool, freeze_for: int
134140
) -> list:
@@ -222,6 +228,7 @@ def _record_frames(
222228
decimals=self.decimals,
223229
x_bounds=bounds.x_bounds,
224230
y_bounds=bounds.y_bounds,
231+
with_median=self._with_median,
225232
dpi=150,
226233
)
227234
if (
@@ -253,7 +260,12 @@ def _is_close_enough(self, df1: pd.DataFrame, df2: pd.DataFrame) -> bool:
253260
np.subtract(
254261
*(
255262
np.floor(
256-
np.array(get_summary_statistics(data)) * 10**self.decimals
263+
np.array(
264+
get_summary_statistics(
265+
data, with_median=self._with_median
266+
)
267+
)
268+
* 10**self.decimals
257269
)
258270
for data in [df1, df2]
259271
)

src/data_morph/plotting/static.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@
2020
import pandas as pd
2121
from matplotlib.axes import Axes
2222

23+
_STATISTIC_DISPLAY_NAME_MAPPING: dict[str, str] = {
24+
'x_mean': 'X Mean',
25+
'y_mean': 'Y Mean',
26+
'x_stdev': 'X SD',
27+
'y_stdev': 'Y SD',
28+
'x_median': 'X Med.',
29+
'y_median': 'Y Med.',
30+
'correlation': 'Corr.',
31+
}
32+
2333

2434
@plot_with_custom_style
2535
def plot(
@@ -28,6 +38,7 @@ def plot(
2838
y_bounds: Iterable[Number],
2939
save_to: str | Path,
3040
decimals: int,
41+
with_median: bool,
3142
**save_kwds: Any, # noqa: ANN401
3243
) -> Axes | None:
3344
"""
@@ -43,6 +54,8 @@ def plot(
4354
Path to save the plot frame to.
4455
decimals : int
4556
The number of integers to highlight as preserved.
57+
with_median : bool
58+
Whether to include the median.
4659
**save_kwds
4760
Additional keyword arguments that will be passed down to
4861
:meth:`matplotlib.figure.Figure.savefig`.
@@ -64,10 +77,24 @@ def plot(
6477
ax.xaxis.set_major_formatter(tick_formatter)
6578
ax.yaxis.set_major_formatter(tick_formatter)
6679

67-
res = get_summary_statistics(data)
80+
res = get_summary_statistics(data, with_median=with_median)
81+
82+
if with_median:
83+
fields = (
84+
'x_mean',
85+
'x_median',
86+
'x_stdev',
87+
'y_mean',
88+
'y_median',
89+
'y_stdev',
90+
'correlation',
91+
)
92+
locs = [0.9, 0.78, 0.66, 0.5, 0.38, 0.26, 0.1]
93+
else:
94+
fields = ('x_mean', 'y_mean', 'x_stdev', 'y_stdev', 'correlation')
95+
locs = np.linspace(0.8, 0.2, num=len(fields))
6896

69-
labels = ('X Mean', 'Y Mean', 'X SD', 'Y SD', 'Corr.')
70-
locs = np.linspace(0.8, 0.2, num=len(labels))
97+
labels = [_STATISTIC_DISPLAY_NAME_MAPPING[field] for field in fields]
7198
max_label_length = max([len(label) for label in labels])
7299
max_stat = int(np.log10(np.max(np.abs(res)))) + 1
73100
mean_x_digits, mean_y_digits = (
@@ -95,17 +122,23 @@ def plot(
95122
transform=ax.transAxes,
96123
va='center',
97124
)
98-
for label, loc, stat in zip(labels[:-1], locs, res):
99-
add_stat_text(loc, formatter(label, stat), alpha=0.3)
100-
add_stat_text(loc, formatter(label, stat)[:-stat_clip])
101-
102-
correlation_str = corr_formatter(labels[-1], res.correlation)
103-
for alpha, text in zip([0.3, 1], [correlation_str, correlation_str[:-stat_clip]]):
104-
add_stat_text(
105-
locs[-1],
106-
text,
107-
alpha=alpha,
108-
)
125+
for loc, field in zip(locs, fields):
126+
label = _STATISTIC_DISPLAY_NAME_MAPPING[field]
127+
stat = getattr(res, field)
128+
129+
if field == 'correlation':
130+
correlation_str = corr_formatter(labels[-1], res.correlation)
131+
for alpha, text in zip(
132+
[0.3, 1], [correlation_str, correlation_str[:-stat_clip]]
133+
):
134+
add_stat_text(
135+
locs[-1],
136+
text,
137+
alpha=alpha,
138+
)
139+
else:
140+
add_stat_text(loc, formatter(label, stat), alpha=0.3)
141+
add_stat_text(loc, formatter(label, stat)[:-stat_clip])
109142

110143
if not save_to:
111144
return ax

tests/data/test_stats.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
"""Test the stats module."""
22

3+
import pytest
4+
35
from data_morph.data.loader import DataLoader
46
from data_morph.data.stats import get_summary_statistics
57

68

7-
def test_stats():
9+
@pytest.mark.parametrize('with_median', [True, False])
10+
def test_stats(with_median):
811
"""Test that summary statistics tuple is correct."""
912

1013
data = DataLoader.load_dataset('dino').data
1114

12-
stats = get_summary_statistics(data)
15+
stats = get_summary_statistics(data, with_median)
1316

1417
assert stats.x_mean == data.x.mean()
1518
assert stats.y_mean == data.y.mean()
1619
assert stats.x_stdev == data.x.std()
1720
assert stats.y_stdev == data.y.std()
1821
assert stats.correlation == data.corr().x.y
22+
23+
if with_median:
24+
assert stats.x_median == data.x.median()
25+
assert stats.y_median == data.y.median()
26+
else:
27+
assert stats.x_median is stats.y_median is None

tests/plotting/test_animation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def test_frame_stitching(sample_data, tmp_path, forward_only):
2929
y_bounds=bounds,
3030
save_to=(tmp_path / f'{start_shape}-to-{target_shape}-{frame}.png'),
3131
decimals=2,
32+
with_median=False,
3233
)
3334

3435
duration_multipliers = [0, 0, 0, 0, 1, 1, *frame_numbers[2:], frame_numbers[-1]]

tests/plotting/test_static.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,15 @@
77
pytestmark = pytest.mark.plotting
88

99

10-
@pytest.mark.parametrize('file_path', ['test_plot.png', None])
11-
def test_plot(sample_data, tmp_path, file_path):
10+
@pytest.mark.parametrize(
11+
('file_path', 'with_median'),
12+
[
13+
('test_plot.png', False),
14+
(None, True),
15+
(None, False),
16+
],
17+
)
18+
def test_plot(sample_data, tmp_path, file_path, with_median):
1219
"""Test static plot creation."""
1320
bounds = (-5.0, 105.0)
1421
if file_path:
@@ -20,12 +27,18 @@ def test_plot(sample_data, tmp_path, file_path):
2027
y_bounds=bounds,
2128
save_to=save_to,
2229
decimals=2,
30+
with_median=with_median,
2331
)
2432
assert save_to.is_file()
2533

2634
else:
2735
ax = plot(
28-
data=sample_data, x_bounds=bounds, y_bounds=bounds, save_to=None, decimals=2
36+
data=sample_data,
37+
x_bounds=bounds,
38+
y_bounds=bounds,
39+
save_to=None,
40+
decimals=2,
41+
with_median=with_median,
2942
)
3043

3144
# confirm that the stylesheet was used
@@ -34,3 +47,8 @@ def test_plot(sample_data, tmp_path, file_path):
3447
# confirm that bounds are correct
3548
assert ax.get_xlim() == bounds
3649
assert ax.get_ylim() == bounds
50+
51+
# confirm that the right number of stats was drawn
52+
expected_stats = 7 if with_median else 5
53+
expected_texts = 2 * expected_stats # label and the number
54+
assert len(ax.texts) == expected_texts

tests/test_cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ def test_cli_one_shape(start_shape, flag, mocker, tmp_path):
127127
'forward_only_animation': flag,
128128
'num_frames': 100,
129129
'in_notebook': False,
130+
'with_median': flag,
130131
}
131132
morph_args = {
132133
'start_shape_name': start_shape,
@@ -153,6 +154,7 @@ def test_cli_one_shape(start_shape, flag, mocker, tmp_path):
153154
'--write-data' if init_args['write_data'] else '',
154155
'--keep-frames' if init_args['keep_frames'] else '',
155156
'--forward-only' if init_args['forward_only_animation'] else '',
157+
'--with-median' if init_args['with_median'] else '',
156158
f'--shake={morph_args["min_shake"]}' if morph_args['min_shake'] else '',
157159
f'--freeze={morph_args["freeze"]}' if morph_args['freeze'] else '',
158160
'--ease-in' if morph_args['ease_in'] else '',

0 commit comments

Comments
 (0)