Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ checks = [
]
exclude = [ # don't report on checks for these
'\.__init__$',
'\.__iter__$',
'\.__repr__$',
'\.__str__$',
]
Expand Down
2 changes: 1 addition & 1 deletion src/data_morph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@
and `this slide deck <https://stefaniemolin.com/data-morph-talk/#/>`_.
"""

__version__ = '0.3.1'
__version__ = '0.4.0.dev0'
MAIN_DIR = __name__
11 changes: 11 additions & 0 deletions src/data_morph/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ def generate_parser() -> argparse.ArgumentParser:
'iterations (see ``--iterations``).'
),
)
morph_config_group.add_argument(
'--with-median',
default=False,
action='store_true',
help=(
'Whether to require the median to be preserved. Note that this will be a '
'little slower.'
),
)

file_group = parser.add_argument_group(
'Output File Configuration',
Expand Down Expand Up @@ -285,6 +294,7 @@ def _morph(
forward_only_animation=args.forward_only,
num_frames=100,
in_notebook=False,
with_median=args.with_median,
)

_ = morpher.morph(
Expand Down Expand Up @@ -399,6 +409,7 @@ def _serialize(args: argparse.Namespace, target_shapes: Sequence[str]) -> None:
forward_only_animation=args.forward_only,
num_frames=100,
in_notebook=False,
with_median=args.with_median,
)

for target_shape in target_shapes:
Expand Down
57 changes: 40 additions & 17 deletions src/data_morph/data/stats.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,59 @@
"""Utility functions for calculating summary statistics."""

from collections import namedtuple
from __future__ import annotations

import pandas as pd
from typing import TYPE_CHECKING, NamedTuple

SummaryStatistics = namedtuple(
'SummaryStatistics', ['x_mean', 'y_mean', 'x_stdev', 'y_stdev', 'correlation']
)
SummaryStatistics.__doc__ = (
'Named tuple containing the summary statistics for plotting/analysis.'
)
if TYPE_CHECKING:
from collections.abc import Generator

import pandas as pd

def get_summary_statistics(data: pd.DataFrame) -> SummaryStatistics:

class SummaryStatistics(NamedTuple):
"""Named tuple containing the summary statistics for plotting/analysis."""

x_mean: float
y_mean: float

x_stdev: float
y_stdev: float

correlation: float

x_median: float | None
y_median: float | None

def __iter__(self) -> Generator[float, None, None]:
for statistic in self._fields:
if (value := getattr(self, statistic)) is not None:
yield value


def get_summary_statistics(data: pd.DataFrame, with_median: bool) -> SummaryStatistics:
"""
Calculate the summary statistics for the given set of points.

Parameters
----------
data : pandas.DataFrame
A dataset with columns x and y.
A dataset with columns ``x`` and ``y``.
with_median : bool
Whether to include the median of ``x`` and ``y``.

Returns
-------
SummaryStatistics
Named tuple consisting of mean and standard deviations of x and y,
along with the Pearson correlation coefficient between the two.
Named tuple consisting of mean and standard deviations of ``x`` and ``y``,
along with the Pearson correlation coefficient between the two, and optionally,
the median of ``x`` and ``y``.
"""
return SummaryStatistics(
data.x.mean(),
data.y.mean(),
data.x.std(),
data.y.std(),
data.corr().x.y,
x_mean=data.x.mean(),
y_mean=data.y.mean(),
x_stdev=data.x.std(),
y_stdev=data.y.std(),
correlation=data.corr().x.y,
x_median=data.x.median() if with_median else None,
y_median=data.y.median() if with_median else None,
)
14 changes: 13 additions & 1 deletion src/data_morph/morpher.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class DataMorpher:
forward_only_animation : bool, default ``False``
Whether to generate the animation in the forward direction only.
By default, the animation will play forward and then reverse.
with_median : bool, default ``False``
Whether to preserve the median in addition to the other summary statistics.
Note that this will be a little slower.
"""

def __init__(
Expand All @@ -74,6 +77,7 @@ def __init__(
num_frames: int = 100,
keep_frames: bool = False,
forward_only_animation: bool = False,
with_median: bool = False,
) -> None:
self._rng = np.random.default_rng(seed)

Expand Down Expand Up @@ -129,6 +133,8 @@ def __init__(

self._ProgressTracker = partial(DataMorphProgress, not self._in_notebook)

self._with_median = with_median

def _select_frames(
self, iterations: int, ease_in: bool, ease_out: bool, freeze_for: int
) -> list:
Expand Down Expand Up @@ -222,6 +228,7 @@ def _record_frames(
decimals=self.decimals,
x_bounds=bounds.x_bounds,
y_bounds=bounds.y_bounds,
with_median=self._with_median,
dpi=150,
)
if (
Expand Down Expand Up @@ -253,7 +260,12 @@ def _is_close_enough(self, df1: pd.DataFrame, df2: pd.DataFrame) -> bool:
np.subtract(
*(
np.floor(
np.array(get_summary_statistics(data)) * 10**self.decimals
np.array(
get_summary_statistics(
data, with_median=self._with_median
)
)
* 10**self.decimals
)
for data in [df1, df2]
)
Expand Down
61 changes: 47 additions & 14 deletions src/data_morph/plotting/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
import pandas as pd
from matplotlib.axes import Axes

_STATISTIC_DISPLAY_NAME_MAPPING: dict[str, str] = {
'x_mean': 'X Mean',
'y_mean': 'Y Mean',
'x_stdev': 'X SD',
'y_stdev': 'Y SD',
'x_median': 'X Med.',
'y_median': 'Y Med.',
'correlation': 'Corr.',
}


@plot_with_custom_style
def plot(
Expand All @@ -28,6 +38,7 @@ def plot(
y_bounds: Iterable[Number],
save_to: str | Path,
decimals: int,
with_median: bool,
**save_kwds: Any, # noqa: ANN401
) -> Axes | None:
"""
Expand All @@ -43,6 +54,8 @@ def plot(
Path to save the plot frame to.
decimals : int
The number of integers to highlight as preserved.
with_median : bool
Whether to include the median.
**save_kwds
Additional keyword arguments that will be passed down to
:meth:`matplotlib.figure.Figure.savefig`.
Expand All @@ -64,10 +77,24 @@ def plot(
ax.xaxis.set_major_formatter(tick_formatter)
ax.yaxis.set_major_formatter(tick_formatter)

res = get_summary_statistics(data)
res = get_summary_statistics(data, with_median=with_median)

if with_median:
fields = (
'x_mean',
'x_median',
'x_stdev',
'y_mean',
'y_median',
'y_stdev',
'correlation',
)
locs = [0.9, 0.78, 0.66, 0.5, 0.38, 0.26, 0.1]
else:
fields = ('x_mean', 'y_mean', 'x_stdev', 'y_stdev', 'correlation')
locs = np.linspace(0.8, 0.2, num=len(fields))

labels = ('X Mean', 'Y Mean', 'X SD', 'Y SD', 'Corr.')
locs = np.linspace(0.8, 0.2, num=len(labels))
labels = [_STATISTIC_DISPLAY_NAME_MAPPING[field] for field in fields]
max_label_length = max([len(label) for label in labels])
max_stat = int(np.log10(np.max(np.abs(res)))) + 1
mean_x_digits, mean_y_digits = (
Expand Down Expand Up @@ -95,17 +122,23 @@ def plot(
transform=ax.transAxes,
va='center',
)
for label, loc, stat in zip(labels[:-1], locs, res):
add_stat_text(loc, formatter(label, stat), alpha=0.3)
add_stat_text(loc, formatter(label, stat)[:-stat_clip])

correlation_str = corr_formatter(labels[-1], res.correlation)
for alpha, text in zip([0.3, 1], [correlation_str, correlation_str[:-stat_clip]]):
add_stat_text(
locs[-1],
text,
alpha=alpha,
)
for loc, field in zip(locs, fields):
label = _STATISTIC_DISPLAY_NAME_MAPPING[field]
stat = getattr(res, field)

if field == 'correlation':
correlation_str = corr_formatter(labels[-1], res.correlation)
for alpha, text in zip(
[0.3, 1], [correlation_str, correlation_str[:-stat_clip]]
):
add_stat_text(
locs[-1],
text,
alpha=alpha,
)
else:
add_stat_text(loc, formatter(label, stat), alpha=0.3)
add_stat_text(loc, formatter(label, stat)[:-stat_clip])

if not save_to:
return ax
Expand Down
13 changes: 11 additions & 2 deletions tests/data/test_stats.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
"""Test the stats module."""

import pytest

from data_morph.data.loader import DataLoader
from data_morph.data.stats import get_summary_statistics


def test_stats():
@pytest.mark.parametrize('with_median', [True, False])
def test_stats(with_median):
"""Test that summary statistics tuple is correct."""

data = DataLoader.load_dataset('dino').data

stats = get_summary_statistics(data)
stats = get_summary_statistics(data, with_median)

assert stats.x_mean == data.x.mean()
assert stats.y_mean == data.y.mean()
assert stats.x_stdev == data.x.std()
assert stats.y_stdev == data.y.std()
assert stats.correlation == data.corr().x.y

if with_median:
assert stats.x_median == data.x.median()
assert stats.y_median == data.y.median()
else:
assert stats.x_median is stats.y_median is None
1 change: 1 addition & 0 deletions tests/plotting/test_animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_frame_stitching(sample_data, tmp_path, forward_only):
y_bounds=bounds,
save_to=(tmp_path / f'{start_shape}-to-{target_shape}-{frame}.png'),
decimals=2,
with_median=False,
)

duration_multipliers = [0, 0, 0, 0, 1, 1, *frame_numbers[2:], frame_numbers[-1]]
Expand Down
24 changes: 21 additions & 3 deletions tests/plotting/test_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@
pytestmark = pytest.mark.plotting


@pytest.mark.parametrize('file_path', ['test_plot.png', None])
def test_plot(sample_data, tmp_path, file_path):
@pytest.mark.parametrize(
('file_path', 'with_median'),
[
('test_plot.png', False),
(None, True),
(None, False),
],
)
def test_plot(sample_data, tmp_path, file_path, with_median):
"""Test static plot creation."""
bounds = (-5.0, 105.0)
if file_path:
Expand All @@ -20,12 +27,18 @@ def test_plot(sample_data, tmp_path, file_path):
y_bounds=bounds,
save_to=save_to,
decimals=2,
with_median=with_median,
)
assert save_to.is_file()

else:
ax = plot(
data=sample_data, x_bounds=bounds, y_bounds=bounds, save_to=None, decimals=2
data=sample_data,
x_bounds=bounds,
y_bounds=bounds,
save_to=None,
decimals=2,
with_median=with_median,
)

# confirm that the stylesheet was used
Expand All @@ -34,3 +47,8 @@ def test_plot(sample_data, tmp_path, file_path):
# confirm that bounds are correct
assert ax.get_xlim() == bounds
assert ax.get_ylim() == bounds

# confirm that the right number of stats was drawn
expected_stats = 7 if with_median else 5
expected_texts = 2 * expected_stats # label and the number
assert len(ax.texts) == expected_texts
2 changes: 2 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def test_cli_one_shape(start_shape, flag, mocker, tmp_path):
'forward_only_animation': flag,
'num_frames': 100,
'in_notebook': False,
'with_median': flag,
}
morph_args = {
'start_shape_name': start_shape,
Expand All @@ -153,6 +154,7 @@ def test_cli_one_shape(start_shape, flag, mocker, tmp_path):
'--write-data' if init_args['write_data'] else '',
'--keep-frames' if init_args['keep_frames'] else '',
'--forward-only' if init_args['forward_only_animation'] else '',
'--with-median' if init_args['with_median'] else '',
f'--shake={morph_args["min_shake"]}' if morph_args['min_shake'] else '',
f'--freeze={morph_args["freeze"]}' if morph_args['freeze'] else '',
'--ease-in' if morph_args['ease_in'] else '',
Expand Down