Skip to content

Commit b260114

Browse files
committed
Refactored non-normalized decorator
1 parent 7cb6fbc commit b260114

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

stumpy/core.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1597,7 +1597,7 @@ def _get_partial_mp_func(mp_func, dask_client=None, device_id=None):
15971597
return partial_mp_func
15981598

15991599

1600-
def compare_parameters(norm, non_norm, exclude=None, translate=None):
1600+
def compare_parameters(norm, non_norm, exclude=None):
16011601
"""
16021602
Compare if the parameters in `norm` and `non_norm` are the same
16031603
@@ -1612,7 +1612,7 @@ def compare_parameters(norm, non_norm, exclude=None, translate=None):
16121612
z-normalized function (or class)
16131613
16141614
exclude : list
1615-
A list of parameters to exclude
1615+
A list of parameters to exclude for the comparison
16161616
16171617
Returns
16181618
-------
@@ -1634,12 +1634,14 @@ def compare_parameters(norm, non_norm, exclude=None, translate=None):
16341634
if not is_same_params:
16351635
if exclude is not None:
16361636
logger.warning(f"Excluding `{exclude}` parameters, ")
1637-
logger.warning(f"`{norm}` and `{non_norm}` have different parameters.")
1637+
logger.warning(f"`{norm}`: ({norm_params}) and ")
1638+
logger.warning(f"`{non_norm}`: ({non_norm_params}) ")
1639+
logger.warning("have different parameters.")
16381640

16391641
return is_same_params
16401642

16411643

1642-
def non_normalized(non_norm):
1644+
def non_normalized(non_norm, exclude=None, replace=None):
16431645
"""
16441646
Decorator for swapping a z-normalized function (or class) for its complementary
16451647
non-normalized function (or class) as defined by `non_norm`. This requires that
@@ -1655,25 +1657,37 @@ def non_normalized(non_norm):
16551657
The non-normalized function (or class) that is complementary to the
16561658
z-normalized function (or class)
16571659
1660+
exclude : list, default None
1661+
A list of function (or class) parameter names to exclude when comparing the
1662+
function (or class) signatures
1663+
1664+
replace : dict, default None
1665+
A dictionary of function (or class) parameter key-value pairs. Each key that
1666+
is found as a parameter name in the `norm` function (or class) will be replaced
1667+
by its corresponding or complementary parameter name in the `non_norm` function
1668+
(or class).
1669+
16581670
Returns
16591671
-------
16601672
outer_wrapper : object
16611673
The desired z-normalized/non-normalized function (or class)
16621674
"""
1675+
if exclude is None:
1676+
exclude = ["normalize"]
16631677

16641678
@functools.wraps(non_norm)
16651679
def outer_wrapper(norm):
16661680
@functools.wraps(norm)
16671681
def inner_wrapper(*args, **kwargs):
1668-
exclude = ["normalize", "pre_scrump", "pre_scraamp"]
16691682
is_same_params = compare_parameters(norm, non_norm, exclude=exclude)
1670-
16711683
if not is_same_params or kwargs.get("normalize", True):
16721684
return norm(*args, **kwargs)
16731685
else:
16741686
kwargs = {k: v for k, v in kwargs.items() if k != "normalize"}
1675-
if "pre_scrump" in kwargs.keys():
1676-
kwargs["pre_scraamp"] = kwargs.pop("pre_scrump")
1687+
if replace is not None:
1688+
for k, v in replace.items():
1689+
if k in kwargs.keys():
1690+
kwargs[v] = kwargs.pop(k)
16771691
return non_norm(*args, **kwargs)
16781692

16791693
return inner_wrapper

stumpy/scrump.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,11 @@ def prescrump(T_A, m, T_B=None, s=None, normalize=True):
245245
return P, I
246246

247247

248-
@core.non_normalized(scraamp.scraamp)
248+
@core.non_normalized(
249+
scraamp.scraamp,
250+
exclude=["normalize", "pre_scrump", "pre_scraamp"],
251+
replace={"pre_scrump": "pre_scraamp"},
252+
)
249253
class scrump(object):
250254
"""
251255
Compute an approximate z-normalized matrix profile

0 commit comments

Comments
 (0)