Skip to content

Commit 2da1d51

Browse files
committed
no warning for normalization if d_method was "fractal"
1 parent 95d76d4 commit 2da1d51

File tree

5 files changed

+110
-13
lines changed

5 files changed

+110
-13
lines changed

CHANGELOG.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
# v1.6.1
2+
3+
- improve normalization warning to only show when appropriate
4+
- introduce "manual" d_method for explicitly provided d values
5+
- include detailed information in normalization warnings and messages
6+
- store d and d_method in predictors for better warning behavior
7+
- automatically set d_method to "manual" when d is explicitly provided
8+
19
# v1.6.0
210

311
- use [PyNNDescent](https://github.com/lmcinnes/pynndescent) for faster nearest neighbor distance computation

mellon/base_predictor.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,27 @@ class Predictor(ABC):
9292
n_input_features : int
9393
The number of features/dimensions of the cell-state representation the predictor was
9494
trained on. This is used for validation of input data.
95+
96+
d : int or None
97+
The intrinsic dimensionality of the data used to create this predictor.
98+
Only stored to provide appropriate warnings for normalization.
99+
100+
d_method : str or None
101+
The method used to compute the intrinsic dimensionality of the data.
102+
Only stored to provide appropriate warnings for normalization.
95103
"""
96104

97105
# number of features of input data (x.shape[1]) to be specified in __init__
98106
n_input_features: int
99107

100108
# number of observations trained on (x.shape[0]) to be specified in __init__
101109
n_obs: int
110+
111+
# intrinsic dimensionality of the data, stored only for warning purposes
112+
d: int = None
113+
114+
# method used to compute the intrinsic dimensionality, stored only for warning purposes
115+
d_method: str = None
102116

103117
# a set of attribute names that should be saved to reconstruct the object
104118
_state_variables: Union[Set, List]
@@ -210,9 +224,30 @@ def mean(self, x, normalize=False):
210224
)
211225
logger.error(message)
212226
raise ValueError(message)
213-
logger.warning(
214-
'The normalization is only effective if the density was trained with d_method="fractal".'
215-
)
227+
# Check conditions for warning about normalization
228+
if self.d_method == "fractal":
229+
# No warning needed for fractal method
230+
pass
231+
elif self.d_method == "manual":
232+
# For manual d, show info message
233+
logger.info(
234+
f"Using normalization with manually set d={self.d}. "
235+
"Note: Normalization is most effective when d approximates the intrinsic dimensionality of the data."
236+
)
237+
elif self.d_method is None and isinstance(self.d, (int, float)) and float(self.d).is_integer():
238+
# For None d_method with integer d, show warning
239+
logger.warning(
240+
f"The normalization is only effective if d approximates the intrinsic dimensionality. "
241+
f"Current values: d_method={self.d_method}, d={self.d}. "
242+
f'Consider using d_method="fractal" for more accurate results.'
243+
)
244+
elif self.d_method == "embedding":
245+
# For embedding method, show warning
246+
logger.warning(
247+
f"The normalization is only effective if d approximates the intrinsic dimensionality. "
248+
f"Current values: d_method={self.d_method}, d={self.d}. "
249+
f'Consider using d_method="fractal" for more accurate results.'
250+
)
216251
return self._mean(x) - log(self.n_obs)
217252
else:
218253
return self._mean(x)
@@ -396,6 +431,8 @@ def __getstate__(self):
396431
{
397432
"n_input_features": self.n_input_features,
398433
"n_obs": self.n_obs,
434+
"d": self.d,
435+
"d_method": self.d_method,
399436
"_state_variables": self._state_variables,
400437
}
401438
)
@@ -755,9 +792,30 @@ def mean(self, Xnew, time=None, normalize=False):
755792
)
756793
logger.error(message)
757794
raise ValueError(message)
758-
logger.warning(
759-
'The normalization is only effective if the density was trained with d_method="fractal".'
760-
)
795+
# Check conditions for warning about normalization
796+
if self.d_method == "fractal":
797+
# No warning needed for fractal method
798+
pass
799+
elif self.d_method == "manual":
800+
# For manual d, show info message
801+
logger.info(
802+
f"Using normalization with manually set d={self.d}. "
803+
"Note: Normalization is most effective when d approximates the intrinsic dimensionality of the data."
804+
)
805+
elif self.d_method is None and isinstance(self.d, (int, float)) and float(self.d).is_integer():
806+
# For None d_method with integer d, show warning
807+
logger.warning(
808+
f"The normalization is only effective if d approximates the intrinsic dimensionality. "
809+
f"Current values: d_method={self.d_method}, d={self.d}. "
810+
f'Consider using d_method="fractal" for more accurate results.'
811+
)
812+
elif self.d_method == "embedding":
813+
# For embedding method, show warning
814+
logger.warning(
815+
f"The normalization is only effective if d approximates the intrinsic dimensionality. "
816+
f"Current values: d_method={self.d_method}, d={self.d}. "
817+
f'Consider using d_method="fractal" for more accurate results.'
818+
)
761819
return self._mean(Xnew) - log(self.n_obs)
762820
else:
763821
return self._mean(Xnew)

mellon/density_estimator.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,14 @@ def __init__(
224224
jit=jit,
225225
check_rank=check_rank,
226226
)
227-
self.d_method = validate_string(
228-
d_method, "d_method", choices={"fractal", "embedding"}
229-
)
227+
# If d is explicitly provided, set d_method to "manual"
228+
if d is not None:
229+
self.d_method = "manual"
230+
logger.info(f"Explicitly provided d={d}, setting d_method to 'manual'.")
231+
else:
232+
self.d_method = validate_string(
233+
d_method, "d_method", choices={"fractal", "embedding", "manual"}
234+
)
230235
self.transform = None
231236
self.loss_func = None
232237
self.opt_state = None
@@ -305,7 +310,12 @@ def _compute_d(self):
305310
if self.d_method == "fractal":
306311
d = compute_d_factal(x)
307312
logger.info(f"Using d={d}.")
313+
elif self.d_method == "manual":
314+
# For manual method, d is already set, so we don't need to compute it
315+
d = self.d
316+
logger.info(f"Using manually set d={d}.")
308317
else:
318+
# embedding method uses the number of dimensions
309319
d = compute_d(x)
310320
logger.info(
311321
f"Using embedding dimensionality d={d}. "
@@ -382,6 +392,10 @@ def _set_log_density_func(self):
382392
y_is_mean=True,
383393
with_uncertainty=with_uncertainty,
384394
)
395+
log_density_func.n_obs = self.x.shape[0]
396+
# Store d and d_method for warning purposes
397+
log_density_func.d = self.d
398+
log_density_func.d_method = self.d_method
385399
self.log_density_func = log_density_func
386400

387401
def prepare_inference(self, x):

mellon/time_sensitive_density_estimator.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,9 +288,14 @@ def __init__(
288288
if not isinstance(density_estimator_kwargs, dict):
289289
raise ValueError("density_estimator_kwargs needs to be a dictionary.")
290290
self.density_estimator_kwargs = density_estimator_kwargs
291-
self.d_method = validate_string(
292-
d_method, "d_method", choices={"fractal", "embedding"}
293-
)
291+
# If d is explicitly provided, set d_method to "manual"
292+
if d is not None:
293+
self.d_method = "manual"
294+
logger.info(f"Explicitly provided d={d}, setting d_method to 'manual'.")
295+
else:
296+
self.d_method = validate_string(
297+
d_method, "d_method", choices={"fractal", "embedding", "manual"}
298+
)
294299
self.ls_time = validate_positive_float(ls_time, "ls_time", optional=True)
295300
self.ls_time_factor = validate_positive_float(ls_time_factor, "ls_time_factor")
296301
self._save_intermediate_ls_times = _save_intermediate_ls_times
@@ -418,8 +423,17 @@ def _compute_d(self):
418423
logger.warning("Using EXPERIMENTAL fractal dimensionality selection.")
419424
d = compute_d_factal(x)
420425
logger.info(f"Using d={d}.")
426+
elif self.d_method == "manual":
427+
# For manual method, d is already set, so we don't need to compute it
428+
d = self.d
429+
logger.info(f"Using manually set d={d}.")
421430
else:
431+
# embedding method uses the number of dimensions
422432
d = compute_d(x)
433+
logger.info(
434+
f"Using embedding dimensionality d={d}. "
435+
'Use d_method="fractal" to enable effective density normalization.'
436+
)
423437
if d > 50:
424438
message = f"""The detected dimensionality of the data is over 50,
425439
which is likely to cause numerical instability issues.
@@ -580,6 +594,9 @@ def _set_log_density_func(self):
580594
with_uncertainty=with_uncertainty,
581595
)
582596
log_density_func.n_obs = compute_average_cell_count(x, normalize)
597+
# Store d and d_method for warning purposes
598+
log_density_func.d = self.d
599+
log_density_func.d_method = self.d_method
583600
self.log_density_func = log_density_func
584601

585602
def prepare_inference(self, x, times=None):

mellon/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Version information."""
22

3-
__version__ = "1.6.0"
3+
__version__ = "1.6.1"

0 commit comments

Comments
 (0)