Skip to content

Commit 36d05f2

Browse files
committed
random state in compute_landmarks
1 parent c93396a commit 36d05f2

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# 1.6.1
2+
3+
- forward randome state to k-means to compute landmarks
4+
15
# v1.6.0
26

37
- use [PyNNDescent](https://github.com/lmcinnes/pynndescent) for faster nearest neighbor distance computation

mellon/parameters.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def compute_gp_type(n_landmarks, rank, n_samples):
240240
return GaussianProcessType.SPARSE_NYSTROEM
241241

242242

243-
def compute_landmarks(x, gp_type=None, n_landmarks=DEFAULT_N_LANDMARKS):
243+
def compute_landmarks(x, gp_type=None, n_landmarks=DEFAULT_N_LANDMARKS, random_state=DEFAULT_RANDOM_SEED):
244244
R"""
245245
Computes the landmark points as k-means centroids.
246246
@@ -260,6 +260,9 @@ def compute_landmarks(x, gp_type=None, n_landmarks=DEFAULT_N_LANDMARKS):
260260
The desired number of landmark points. If less than 2 or greater
261261
than the number of data points, the function will return None.
262262
Defaults to DEFAULT_N_LANDMARKS.
263+
random_state : int, optional
264+
Random seed for the k-means algorithm to ensure reproducible landmark selection.
265+
Defaults to DEFAULT_RANDOM_SEED (42).
263266
264267
Returns
265268
-------
@@ -284,12 +287,12 @@ def compute_landmarks(x, gp_type=None, n_landmarks=DEFAULT_N_LANDMARKS):
284287
logger.warning(message)
285288
return x
286289
return None
287-
logger.info(f"Computing {n_landmarks:,} landmarks with k-means clustering.")
288-
return k_means(x, n_landmarks, n_init=1)[0]
290+
logger.info(f"Computing {n_landmarks:,} landmarks with k-means clustering (random_state={random_state}).")
291+
return k_means(x, n_landmarks, n_init=1, random_state=random_state)[0]
289292

290293

291294
def compute_landmarks_rescale_time(
292-
x, ls, ls_time, times=None, n_landmarks=DEFAULT_N_LANDMARKS
295+
x, ls, ls_time, times=None, n_landmarks=DEFAULT_N_LANDMARKS, random_state=DEFAULT_RANDOM_SEED
293296
):
294297
R"""
295298
Computes landmark points for time-rescaled input data using k-means centroids.
@@ -315,6 +318,9 @@ def compute_landmarks_rescale_time(
315318
Shape must be either (n_samples,) or (n_samples, 1).
316319
n_landmarks : int, optional
317320
The desired number of landmark points. Defaults to DEFAULT_N_LANDMARKS.
321+
random_state : int, optional
322+
Random seed for the k-means algorithm to ensure reproducible landmark selection.
323+
Defaults to DEFAULT_RANDOM_SEED (42).
318324
319325
Returns
320326
-------
@@ -333,7 +339,7 @@ def compute_landmarks_rescale_time(
333339
x = validate_time_x(x, times)
334340
time_factor = ls / ls_time
335341
x = x.at[:, -1].set(x[:, -1] * time_factor)
336-
landmarks = compute_landmarks(x, n_landmarks=n_landmarks)
342+
landmarks = compute_landmarks(x, n_landmarks=n_landmarks, random_state=random_state)
337343
if landmarks is not None:
338344
try:
339345
landmarks = landmarks.at[:, -1].set(landmarks[:, -1] / time_factor)

0 commit comments

Comments
 (0)