@@ -1597,7 +1597,7 @@ def _get_partial_mp_func(mp_func, dask_client=None, device_id=None):
15971597 return partial_mp_func
15981598
15991599
1600- def compare_parameters (norm , non_norm , exclude = None , translate = None ):
1600+ def compare_parameters (norm , non_norm , exclude = None ):
16011601 """
16021602 Compare if the parameters in `norm` and `non_norm` are the same
16031603
@@ -1612,7 +1612,7 @@ def compare_parameters(norm, non_norm, exclude=None, translate=None):
16121612 z-normalized function (or class)
16131613
16141614 exclude : list
1615- A list of parameters to exclude
1615+ A list of parameters to exclude for the comparison
16161616
16171617 Returns
16181618 -------
@@ -1634,12 +1634,14 @@ def compare_parameters(norm, non_norm, exclude=None, translate=None):
16341634 if not is_same_params :
16351635 if exclude is not None :
16361636 logger .warning (f"Excluding `{ exclude } ` parameters, " )
1637- logger .warning (f"`{ norm } ` and `{ non_norm } ` have different parameters." )
1637+ logger .warning (f"`{ norm } `: ({ norm_params } ) and " )
1638+ logger .warning (f"`{ non_norm } `: ({ non_norm_params } ) " )
1639+ logger .warning ("have different parameters." )
16381640
16391641 return is_same_params
16401642
16411643
1642- def non_normalized (non_norm ):
1644+ def non_normalized (non_norm , exclude = None , replace = None ):
16431645 """
16441646 Decorator for swapping a z-normalized function (or class) for its complementary
16451647 non-normalized function (or class) as defined by `non_norm`. This requires that
@@ -1655,25 +1657,37 @@ def non_normalized(non_norm):
16551657 The non-normalized function (or class) that is complementary to the
16561658 z-normalized function (or class)
16571659
1660+ exclude : list, default None
1661+ A list of function (or class) parameter names to exclude when comparing the
1662+ function (or class) signatures
1663+
1664+ replace : dict, default None
1665+ A dictionary of function (or class) parameter key-value pairs. Each key that
1666+ is found as a parameter name in the `norm` function (or class) will be replaced
1667+ by its corresponding or complementary parameter name in the `non_norm` function
1668+ (or class).
1669+
16581670 Returns
16591671 -------
16601672 outer_wrapper : object
16611673 The desired z-normalized/non-normalized function (or class)
16621674 """
1675+ if exclude is None :
1676+ exclude = ["normalize" ]
16631677
16641678 @functools .wraps (non_norm )
16651679 def outer_wrapper (norm ):
16661680 @functools .wraps (norm )
16671681 def inner_wrapper (* args , ** kwargs ):
1668- exclude = ["normalize" , "pre_scrump" , "pre_scraamp" ]
16691682 is_same_params = compare_parameters (norm , non_norm , exclude = exclude )
1670-
16711683 if not is_same_params or kwargs .get ("normalize" , True ):
16721684 return norm (* args , ** kwargs )
16731685 else :
16741686 kwargs = {k : v for k , v in kwargs .items () if k != "normalize" }
1675- if "pre_scrump" in kwargs .keys ():
1676- kwargs ["pre_scraamp" ] = kwargs .pop ("pre_scrump" )
1687+ if replace is not None :
1688+ for k , v in replace .items ():
1689+ if k in kwargs .keys ():
1690+ kwargs [v ] = kwargs .pop (k )
16771691 return non_norm (* args , ** kwargs )
16781692
16791693 return inner_wrapper
0 commit comments