diff --git a/pyproject.toml b/pyproject.toml index 432001f5..379da0be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -203,6 +203,7 @@ checks = [ ] exclude = [ # don't report on checks for these '\.__init__$', + '\.__iter__$', '\.__repr__$', '\.__str__$', ] diff --git a/src/data_morph/__init__.py b/src/data_morph/__init__.py index 036bcbd5..6eae26ef 100644 --- a/src/data_morph/__init__.py +++ b/src/data_morph/__init__.py @@ -28,5 +28,5 @@ and `this slide deck `_. """ -__version__ = '0.3.1' +__version__ = '0.4.0.dev0' MAIN_DIR = __name__ diff --git a/src/data_morph/cli.py b/src/data_morph/cli.py index ab38bf89..8f7cc817 100644 --- a/src/data_morph/cli.py +++ b/src/data_morph/cli.py @@ -154,6 +154,15 @@ def generate_parser() -> argparse.ArgumentParser: 'iterations (see ``--iterations``).' ), ) + morph_config_group.add_argument( + '--with-median', + default=False, + action='store_true', + help=( + 'Whether to require the median to be preserved. Note that this will be a ' + 'little slower.' + ), + ) file_group = parser.add_argument_group( 'Output File Configuration', @@ -285,6 +294,7 @@ def _morph( forward_only_animation=args.forward_only, num_frames=100, in_notebook=False, + with_median=args.with_median, ) _ = morpher.morph( @@ -399,6 +409,7 @@ def _serialize(args: argparse.Namespace, target_shapes: Sequence[str]) -> None: forward_only_animation=args.forward_only, num_frames=100, in_notebook=False, + with_median=args.with_median, ) for target_shape in target_shapes: diff --git a/src/data_morph/data/stats.py b/src/data_morph/data/stats.py index 3a35003c..f78b202d 100644 --- a/src/data_morph/data/stats.py +++ b/src/data_morph/data/stats.py @@ -1,36 +1,59 @@ """Utility functions for calculating summary statistics.""" -from collections import namedtuple +from __future__ import annotations -import pandas as pd +from typing import TYPE_CHECKING, NamedTuple -SummaryStatistics = namedtuple( - 'SummaryStatistics', ['x_mean', 'y_mean', 'x_stdev', 'y_stdev', 'correlation'] -) -SummaryStatistics.__doc__ = ( - 'Named tuple containing the summary statistics for plotting/analysis.' -) +if TYPE_CHECKING: + from collections.abc import Generator + import pandas as pd -def get_summary_statistics(data: pd.DataFrame) -> SummaryStatistics: + +class SummaryStatistics(NamedTuple): + """Named tuple containing the summary statistics for plotting/analysis.""" + + x_mean: float + y_mean: float + + x_stdev: float + y_stdev: float + + correlation: float + + x_median: float | None + y_median: float | None + + def __iter__(self) -> Generator[float, None, None]: + for statistic in self._fields: + if (value := getattr(self, statistic)) is not None: + yield value + + +def get_summary_statistics(data: pd.DataFrame, with_median: bool) -> SummaryStatistics: """ Calculate the summary statistics for the given set of points. Parameters ---------- data : pandas.DataFrame - A dataset with columns x and y. + A dataset with columns ``x`` and ``y``. + with_median : bool + Whether to include the median of ``x`` and ``y``. Returns ------- SummaryStatistics - Named tuple consisting of mean and standard deviations of x and y, - along with the Pearson correlation coefficient between the two. + Named tuple consisting of mean and standard deviations of ``x`` and ``y``, + along with the Pearson correlation coefficient between the two, and optionally, + the median of ``x`` and ``y``. """ return SummaryStatistics( - data.x.mean(), - data.y.mean(), - data.x.std(), - data.y.std(), - data.corr().x.y, + x_mean=data.x.mean(), + y_mean=data.y.mean(), + x_stdev=data.x.std(), + y_stdev=data.y.std(), + correlation=data.corr().x.y, + x_median=data.x.median() if with_median else None, + y_median=data.y.median() if with_median else None, ) diff --git a/src/data_morph/morpher.py b/src/data_morph/morpher.py index c7a5483f..a9509f5a 100644 --- a/src/data_morph/morpher.py +++ b/src/data_morph/morpher.py @@ -60,6 +60,9 @@ class DataMorpher: forward_only_animation : bool, default ``False`` Whether to generate the animation in the forward direction only. By default, the animation will play forward and then reverse. + with_median : bool, default ``False`` + Whether to preserve the median in addition to the other summary statistics. + Note that this will be a little slower. """ def __init__( @@ -74,6 +77,7 @@ def __init__( num_frames: int = 100, keep_frames: bool = False, forward_only_animation: bool = False, + with_median: bool = False, ) -> None: self._rng = np.random.default_rng(seed) @@ -129,6 +133,8 @@ def __init__( self._ProgressTracker = partial(DataMorphProgress, not self._in_notebook) + self._with_median = with_median + def _select_frames( self, iterations: int, ease_in: bool, ease_out: bool, freeze_for: int ) -> list: @@ -222,6 +228,7 @@ def _record_frames( decimals=self.decimals, x_bounds=bounds.x_bounds, y_bounds=bounds.y_bounds, + with_median=self._with_median, dpi=150, ) if ( @@ -253,7 +260,12 @@ def _is_close_enough(self, df1: pd.DataFrame, df2: pd.DataFrame) -> bool: np.subtract( *( np.floor( - np.array(get_summary_statistics(data)) * 10**self.decimals + np.array( + get_summary_statistics( + data, with_median=self._with_median + ) + ) + * 10**self.decimals ) for data in [df1, df2] ) diff --git a/src/data_morph/plotting/static.py b/src/data_morph/plotting/static.py index 20ba7d17..f937367f 100644 --- a/src/data_morph/plotting/static.py +++ b/src/data_morph/plotting/static.py @@ -20,6 +20,16 @@ import pandas as pd from matplotlib.axes import Axes +_STATISTIC_DISPLAY_NAME_MAPPING: dict[str, str] = { + 'x_mean': 'X Mean', + 'y_mean': 'Y Mean', + 'x_stdev': 'X SD', + 'y_stdev': 'Y SD', + 'x_median': 'X Med.', + 'y_median': 'Y Med.', + 'correlation': 'Corr.', +} + @plot_with_custom_style def plot( @@ -28,6 +38,7 @@ def plot( y_bounds: Iterable[Number], save_to: str | Path, decimals: int, + with_median: bool, **save_kwds: Any, # noqa: ANN401 ) -> Axes | None: """ @@ -43,6 +54,8 @@ def plot( Path to save the plot frame to. decimals : int The number of integers to highlight as preserved. + with_median : bool + Whether to include the median. **save_kwds Additional keyword arguments that will be passed down to :meth:`matplotlib.figure.Figure.savefig`. @@ -64,10 +77,24 @@ def plot( ax.xaxis.set_major_formatter(tick_formatter) ax.yaxis.set_major_formatter(tick_formatter) - res = get_summary_statistics(data) + res = get_summary_statistics(data, with_median=with_median) + + if with_median: + fields = ( + 'x_mean', + 'x_median', + 'x_stdev', + 'y_mean', + 'y_median', + 'y_stdev', + 'correlation', + ) + locs = [0.9, 0.78, 0.66, 0.5, 0.38, 0.26, 0.1] + else: + fields = ('x_mean', 'y_mean', 'x_stdev', 'y_stdev', 'correlation') + locs = np.linspace(0.8, 0.2, num=len(fields)) - labels = ('X Mean', 'Y Mean', 'X SD', 'Y SD', 'Corr.') - locs = np.linspace(0.8, 0.2, num=len(labels)) + labels = [_STATISTIC_DISPLAY_NAME_MAPPING[field] for field in fields] max_label_length = max([len(label) for label in labels]) max_stat = int(np.log10(np.max(np.abs(res)))) + 1 mean_x_digits, mean_y_digits = ( @@ -95,17 +122,23 @@ def plot( transform=ax.transAxes, va='center', ) - for label, loc, stat in zip(labels[:-1], locs, res): - add_stat_text(loc, formatter(label, stat), alpha=0.3) - add_stat_text(loc, formatter(label, stat)[:-stat_clip]) - - correlation_str = corr_formatter(labels[-1], res.correlation) - for alpha, text in zip([0.3, 1], [correlation_str, correlation_str[:-stat_clip]]): - add_stat_text( - locs[-1], - text, - alpha=alpha, - ) + for loc, field in zip(locs, fields): + label = _STATISTIC_DISPLAY_NAME_MAPPING[field] + stat = getattr(res, field) + + if field == 'correlation': + correlation_str = corr_formatter(labels[-1], res.correlation) + for alpha, text in zip( + [0.3, 1], [correlation_str, correlation_str[:-stat_clip]] + ): + add_stat_text( + locs[-1], + text, + alpha=alpha, + ) + else: + add_stat_text(loc, formatter(label, stat), alpha=0.3) + add_stat_text(loc, formatter(label, stat)[:-stat_clip]) if not save_to: return ax diff --git a/tests/data/test_stats.py b/tests/data/test_stats.py index d10857c3..75f25d2a 100644 --- a/tests/data/test_stats.py +++ b/tests/data/test_stats.py @@ -1,18 +1,27 @@ """Test the stats module.""" +import pytest + from data_morph.data.loader import DataLoader from data_morph.data.stats import get_summary_statistics -def test_stats(): +@pytest.mark.parametrize('with_median', [True, False]) +def test_stats(with_median): """Test that summary statistics tuple is correct.""" data = DataLoader.load_dataset('dino').data - stats = get_summary_statistics(data) + stats = get_summary_statistics(data, with_median) assert stats.x_mean == data.x.mean() assert stats.y_mean == data.y.mean() assert stats.x_stdev == data.x.std() assert stats.y_stdev == data.y.std() assert stats.correlation == data.corr().x.y + + if with_median: + assert stats.x_median == data.x.median() + assert stats.y_median == data.y.median() + else: + assert stats.x_median is stats.y_median is None diff --git a/tests/plotting/test_animation.py b/tests/plotting/test_animation.py index bb873062..347094a1 100644 --- a/tests/plotting/test_animation.py +++ b/tests/plotting/test_animation.py @@ -29,6 +29,7 @@ def test_frame_stitching(sample_data, tmp_path, forward_only): y_bounds=bounds, save_to=(tmp_path / f'{start_shape}-to-{target_shape}-{frame}.png'), decimals=2, + with_median=False, ) duration_multipliers = [0, 0, 0, 0, 1, 1, *frame_numbers[2:], frame_numbers[-1]] diff --git a/tests/plotting/test_static.py b/tests/plotting/test_static.py index d7a203ba..9b1064e5 100644 --- a/tests/plotting/test_static.py +++ b/tests/plotting/test_static.py @@ -7,8 +7,15 @@ pytestmark = pytest.mark.plotting -@pytest.mark.parametrize('file_path', ['test_plot.png', None]) -def test_plot(sample_data, tmp_path, file_path): +@pytest.mark.parametrize( + ('file_path', 'with_median'), + [ + ('test_plot.png', False), + (None, True), + (None, False), + ], +) +def test_plot(sample_data, tmp_path, file_path, with_median): """Test static plot creation.""" bounds = (-5.0, 105.0) if file_path: @@ -20,12 +27,18 @@ def test_plot(sample_data, tmp_path, file_path): y_bounds=bounds, save_to=save_to, decimals=2, + with_median=with_median, ) assert save_to.is_file() else: ax = plot( - data=sample_data, x_bounds=bounds, y_bounds=bounds, save_to=None, decimals=2 + data=sample_data, + x_bounds=bounds, + y_bounds=bounds, + save_to=None, + decimals=2, + with_median=with_median, ) # confirm that the stylesheet was used @@ -34,3 +47,8 @@ def test_plot(sample_data, tmp_path, file_path): # confirm that bounds are correct assert ax.get_xlim() == bounds assert ax.get_ylim() == bounds + + # confirm that the right number of stats was drawn + expected_stats = 7 if with_median else 5 + expected_texts = 2 * expected_stats # label and the number + assert len(ax.texts) == expected_texts diff --git a/tests/test_cli.py b/tests/test_cli.py index 9e6edc9b..3a8bf28d 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -127,6 +127,7 @@ def test_cli_one_shape(start_shape, flag, mocker, tmp_path): 'forward_only_animation': flag, 'num_frames': 100, 'in_notebook': False, + 'with_median': flag, } morph_args = { 'start_shape_name': start_shape, @@ -153,6 +154,7 @@ def test_cli_one_shape(start_shape, flag, mocker, tmp_path): '--write-data' if init_args['write_data'] else '', '--keep-frames' if init_args['keep_frames'] else '', '--forward-only' if init_args['forward_only_animation'] else '', + '--with-median' if init_args['with_median'] else '', f'--shake={morph_args["min_shake"]}' if morph_args['min_shake'] else '', f'--freeze={morph_args["freeze"]}' if morph_args['freeze'] else '', '--ease-in' if morph_args['ease_in'] else '',