diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index a09b4f3d4f5c..c4320b2ab10e 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -2761,7 +2761,8 @@ def halfrange(self, halfrange): self.vmax = self.vcenter + abs(halfrange) -def make_norm_from_scale(scale_cls, base_norm_cls=None, *, init=None): +def make_norm_from_scale(scale_cls, base_norm_cls=None, *, init=None, + norm_before_trf=False): """ Decorator for building a `.Normalize` subclass from a `~.scale.ScaleBase` subclass. @@ -2793,7 +2794,8 @@ class norm_cls(Normalize): """ if base_norm_cls is None: - return functools.partial(make_norm_from_scale, scale_cls, init=init) + return functools.partial(make_norm_from_scale, scale_cls, init=init, + norm_before_trf=norm_before_trf) if isinstance(scale_cls, functools.partial): scale_args = scale_cls.args @@ -2807,13 +2809,13 @@ def init(vmin=None, vmax=None, clip=False): pass return _make_norm_from_scale( scale_cls, scale_args, scale_kwargs_items, - base_norm_cls, inspect.signature(init)) + base_norm_cls, inspect.signature(init), norm_before_trf) @functools.cache def _make_norm_from_scale( scale_cls, scale_args, scale_kwargs_items, - base_norm_cls, bound_init_signature, + base_norm_cls, bound_init_signature, norm_before_trf ): """ Helper for `make_norm_from_scale`. @@ -2845,7 +2847,7 @@ def __reduce__(self): pass return (_picklable_norm_constructor, (scale_cls, scale_args, scale_kwargs_items, - base_norm_cls, bound_init_signature), + base_norm_cls, bound_init_signature, norm_before_trf), vars(self)) def __init__(self, *args, **kwargs): @@ -2874,6 +2876,14 @@ def __call__(self, value, clip=None): clip = self.clip if clip: value = np.clip(value, self.vmin, self.vmax) + + if norm_before_trf: + value -= self.vmin + value /= (self.vmax - self.vmin) + t_value = self._trf.transform(value).reshape(np.shape(value)) + t_value = np.ma.masked_invalid(t_value, copy=False) + return t_value[0] if is_scalar else t_value + t_value = self._trf.transform(value).reshape(np.shape(value)) t_vmin, t_vmax = self._trf.transform([self.vmin, self.vmax]) if not np.isfinite([t_vmin, t_vmax]).all(): @@ -2888,10 +2898,17 @@ def inverse(self, value): raise ValueError("Not invertible until scaled") if self.vmin > self.vmax: raise ValueError("vmin must be less or equal to vmax") + value, is_scalar = self.process_value(value) + + if norm_before_trf: + value = self._trf.inverted().transform(value).reshape(np.shape(value)) + rescaled = value * (self.vmax - self.vmin) + rescaled += self.vmin + return rescaled[0] if is_scalar else rescaled + t_vmin, t_vmax = self._trf.transform([self.vmin, self.vmax]) if not np.isfinite([t_vmin, t_vmax]).all(): raise ValueError("Invalid vmin or vmax") - value, is_scalar = self.process_value(value) rescaled = value * (t_vmax - t_vmin) rescaled += t_vmin value = (self._trf @@ -3040,6 +3057,10 @@ def linear_width(self, value): self._scale.linear_width = value +@make_norm_from_scale( + scale.PowerScale, + init=lambda gamma=0.5, vmin=None, vmax=None, clip=False: None, + norm_before_trf=True) class PowerNorm(Normalize): r""" Linearly map a given value to the 0-1 range and then apply @@ -3076,56 +3097,13 @@ class PowerNorm(Normalize): For input values below *vmin*, gamma is set to one. """ - def __init__(self, gamma, vmin=None, vmax=None, clip=False): - super().__init__(vmin, vmax, clip) - self.gamma = gamma - - def __call__(self, value, clip=None): - if clip is None: - clip = self.clip - - result, is_scalar = self.process_value(value) - - self.autoscale_None(result) - gamma = self.gamma - vmin, vmax = self.vmin, self.vmax - if vmin > vmax: - raise ValueError("minvalue must be less than or equal to maxvalue") - elif vmin == vmax: - result.fill(0) - else: - if clip: - mask = np.ma.getmask(result) - result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax), - mask=mask) - resdat = result.data - resdat -= vmin - resdat /= (vmax - vmin) - resdat[resdat > 0] = np.power(resdat[resdat > 0], gamma) - - result = np.ma.array(resdat, mask=result.mask, copy=False) - if is_scalar: - result = result[0] - return result - - def inverse(self, value): - if not self.scaled(): - raise ValueError("Not invertible until scaled") - - result, is_scalar = self.process_value(value) - - gamma = self.gamma - vmin, vmax = self.vmin, self.vmax - - resdat = result.data - resdat[resdat > 0] = np.power(resdat[resdat > 0], 1 / gamma) - resdat *= (vmax - vmin) - resdat += vmin + @property + def gamma(self): + return self._scale.gamma - result = np.ma.array(resdat, mask=result.mask, copy=False) - if is_scalar: - result = result[0] - return result + @gamma.setter + def gamma(self, value): + self._scale.gamma = value class BoundaryNorm(Normalize): diff --git a/lib/matplotlib/colors.pyi b/lib/matplotlib/colors.pyi index cdc6e5e7d89f..4f02540e8b7d 100644 --- a/lib/matplotlib/colors.pyi +++ b/lib/matplotlib/colors.pyi @@ -334,14 +334,16 @@ def make_norm_from_scale( scale_cls: type[scale.ScaleBase], base_norm_cls: type[Normalize], *, - init: Callable | None = ... + init: Callable | None = ..., + norm_before_trf: bool = ..., ) -> type[Normalize]: ... @overload def make_norm_from_scale( scale_cls: type[scale.ScaleBase], base_norm_cls: None = ..., *, - init: Callable | None = ... + init: Callable | None = ..., + norm_before_trf: bool = ..., ) -> Callable[[type[Normalize]], type[Normalize]]: ... class FuncNorm(Normalize): @@ -384,14 +386,17 @@ class AsinhNorm(Normalize): def linear_width(self, value: float) -> None: ... class PowerNorm(Normalize): - gamma: float def __init__( self, - gamma: float, + gamma: float = ..., vmin: float | None = ..., vmax: float | None = ..., clip: bool = ..., ) -> None: ... + @property + def gamma(self) -> float: ... + @gamma.setter + def gamma(self, value: float) -> None: ... class BoundaryNorm(Normalize): boundaries: np.ndarray diff --git a/lib/matplotlib/scale.py b/lib/matplotlib/scale.py index 4517b8946b03..36d485f98b3f 100644 --- a/lib/matplotlib/scale.py +++ b/lib/matplotlib/scale.py @@ -17,6 +17,7 @@ "log" `LogScale` `LogTransform` `InvertedLogTransform` "logit" `LogitScale` `LogitTransform` `LogisticTransform` "symlog" `SymmetricalLogScale` `SymmetricalLogTransform` `InvertedSymmetricalLogTransform` +"power" `PowerScale` `PowerTransform` `InvertedPowerTransform` ============= ===================== ================================ ================================= A user will often only use the scale name, e.g. when setting the scale through @@ -265,6 +266,111 @@ def set_default_locators_and_formatters(self, axis): axis.set_minor_locator(NullLocator()) +class PowerTransform(Transform): + """ + A simple power transformation used by `.PowerScale`. + + This transformation applies a power-law scaling to positive values, while + nonpositive values remain unchanged. + """ + input_dims = output_dims = 1 + + def __init__(self, gamma): + """ + Parameters + ---------- + gamma : float + Power law exponent. + """ + super().__init__() + self.gamma = gamma + + def __str__(self): + return "{}(gamma={})".format( + type(self).__name__, self.gamma) + + def transform_non_affine(self, a): + with np.errstate(divide="ignore", invalid="ignore"): + mask = np.ma.getmask(a) + d = np.asarray(a.data) + out = np.where(d > 0, np.power(d, self.gamma), d) + mout = np.ma.masked_array(out, mask=mask) + return mout + + def inverted(self): + return InvertedPowerTransform(self.gamma) + + +class InvertedPowerTransform(Transform): + """ + The inverse of the `.PowerTransform`. + + This transformation applies an inverse power-law scaling to positive values, + while nonpositive values remain unchanged. + """ + input_dims = output_dims = 1 + + def __init__(self, gamma): + """ + Parameters + ---------- + gamma : float + Power law exponent. + """ + super().__init__() + if gamma == 0: + raise ValueError('gamma cannot be 0') + self.gamma = gamma + + def transform_non_affine(self, a): + with np.errstate(divide="ignore", invalid="ignore"): + input_mask = np.ma.getmask(a) + d = np.asarray(a.data) + out = np.where(d > 0, np.power(d, 1. / self.gamma), d) + mout = np.ma.array(out, mask=input_mask) + return mout + + def inverted(self): + return PowerTransform(self.gamma) + + +class PowerScale(ScaleBase): + """ + A standard power scale + """ + name = 'power' + + @_make_axis_parameter_optional + def __init__(self, axis=None, *, gamma=0.5): + """ + Parameters + ---------- + axis : `~matplotlib.axis.Axis` + The axis for the scale. + gamma : float, default: 0.5 + Power law exponent. + """ + self._transform = PowerTransform(gamma) + + gamma = property(lambda self: self._transform.gamma) + + def get_transform(self): + """Return the `.PowerTransform` associated with this scale.""" + return self._transform + + def set_default_locators_and_formatters(self, axis): + # docstring inherited + axis.set_major_locator(AutoLocator()) + axis.set_major_formatter(ScalarFormatter()) + axis.set_minor_formatter(NullFormatter()) + # update the minor locator for x and y axis based on rcParams + if (axis.axis_name == 'x' and mpl.rcParams['xtick.minor.visible'] or + axis.axis_name == 'y' and mpl.rcParams['ytick.minor.visible']): + axis.set_minor_locator(AutoMinorLocator()) + else: + axis.set_minor_locator(NullLocator()) + + class LogTransform(Transform): input_dims = output_dims = 1 @@ -762,6 +868,7 @@ def limit_range_for_scale(self, vmin, vmax, minpos): 'logit': LogitScale, 'function': FuncScale, 'functionlog': FuncScaleLog, + 'power': PowerScale, } diff --git a/lib/matplotlib/scale.pyi b/lib/matplotlib/scale.pyi index ba9f269b8c78..82dc14e4702a 100644 --- a/lib/matplotlib/scale.pyi +++ b/lib/matplotlib/scale.pyi @@ -40,6 +40,29 @@ class FuncScale(ScaleBase): ], ) -> None: ... +class PowerTransform(Transform): + def __init__(self, gamma: float) -> None: ... + def __str__(self) -> str: ... + def transform_non_affine(self, a: ArrayLike) -> ArrayLike: ... + def inverted(self) -> InvertedPowerTransform: ... + +class InvertedPowerTransform(Transform): + def __init__(self, gamma: float) -> None: ... + def transform_non_affine(self, a: ArrayLike) -> ArrayLike: ... + def inverted(self) -> PowerTransform: ... + +class PowerScale(ScaleBase): + name: str + def __init__( + self, + axis: Axis | None = ..., + *, + gamma: float = ..., + ) -> None: ... + @property + def gamma(self) -> float: ... + def get_transform(self) -> PowerTransform: ... + class LogTransform(Transform): input_dims: int output_dims: int