Skip to content

Commit d6135a0

Browse files
committed
numpy 2; update imports
1 parent bad9d62 commit d6135a0

File tree

18 files changed

+51
-58
lines changed

18 files changed

+51
-58
lines changed

multi_mst/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .k_mst import kMST, KMST
2+
from .noisy_mst import noisyMST, NoisyMST
3+
from .k_mst_descent import kMSTDescent, KMSTDescent
4+
5+
__all__ = ["kMST", "KMST", "noisyMST", "NoisyMST", "kMSTDescent", "KMSTDescent"]

multi_mst/k_mst/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
import numpy as np
21
from .api import KMST, kMST
32

4-
# Force JIT compilation on import
5-
random_state = np.random.RandomState(42)
6-
random_data = random_state.random(size=(50, 3))
7-
KMST().fit(random_data)
8-
93
__all__ = ["KMST", "kMST"]

multi_mst/k_mst/api.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,10 @@ def kMST(data, num_neighbors=3, min_samples=1, epsilon=None, umap_kwargs=None):
7777
)
7878
with warn.catch_warnings():
7979
warn.filterwarnings(
80-
"ignore", category=UserWarning, module="umap.umap_", lineno=2010
80+
"ignore",
81+
category=UserWarning,
82+
module="umap.umap_",
83+
message=".*is not an NNDescent object.*",
8184
)
8285
umap = UMAP(
8386
n_neighbors=mst_indices.shape[1],
@@ -133,7 +136,9 @@ class KMST(BaseEstimator):
133136
missing values. Use the graph_ and embedding_ attributes instead!
134137
"""
135138

136-
def __init__(self, *, num_neighbors=3, min_samples=1, epsilon=None, umap_kwargs=None):
139+
def __init__(
140+
self, *, num_neighbors=3, min_samples=1, epsilon=None, umap_kwargs=None
141+
):
137142
self.num_neighbors = num_neighbors
138143
self.min_samples = min_samples
139144
self.epsilon = epsilon
@@ -173,9 +178,7 @@ def fit(self, X, y=None, **fit_params):
173178
clean_data = X
174179

175180
kwargs = self.get_params()
176-
self.mst_indices_, self.mst_distances_, self._umap = kMST(
177-
clean_data, **kwargs
178-
)
181+
self.mst_indices_, self.mst_distances_, self._umap = kMST(clean_data, **kwargs)
179182
self.graph_ = self._umap.graph_.copy()
180183
self.embedding_ = (
181184
self._umap.embedding_.copy() if hasattr(self._umap, "embedding_") else None
Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
import numpy as np
21
from .api import KMSTDescent, kMSTDescent
32

4-
# Force JIT compilation on import
5-
random_state = np.random.RandomState(42)
6-
random_data = random_state.random(size=(50, 3))
7-
KMSTDescent().fit(random_data)
8-
93
__all__ = ["KMSTDescent", "kMSTDescent"]

multi_mst/k_mst_descent/api.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def kMSTDescent(
103103
The number of neighbors for computing the mutual reachability distance.
104104
Value must be lower or equal to the number of neighbors. `epsilon`
105105
operates on the mutual reachability distance, so always allows the
106-
nearest `min_samples` points. Acts as UMAP's `local connnectivity`
106+
nearest `min_samples` points. Acts as UMAP's `local connectivity`
107107
parameter. Default is 1.
108108
epsilon: float, optional
109109
A fraction of the initial MST edge distance to act as upper distance
@@ -114,11 +114,11 @@ def kMSTDescent(
114114
umap_kwargs: dict
115115
Additional keyword arguments passed to UMAP.
116116
nn_kwargs: dict
117-
Additional keyword arguments passsed to NNDescent.
117+
Additional keyword arguments passed to NNDescent.
118118
n_jobs : int, optional
119119
The number of threads to use for the computation. -1 means using all
120120
threads.
121-
121+
122122
Returns
123123
-------
124124
mst_indices_: numpy.ndarray, shape (n_samples, num_found_neighbors)
@@ -155,7 +155,10 @@ def kMSTDescent(
155155
)
156156
with warn.catch_warnings():
157157
warn.filterwarnings(
158-
"ignore", category=UserWarning, module="umap.umap_", lineno=2010
158+
"ignore",
159+
category=UserWarning,
160+
module="umap.umap_",
161+
message=".*is not an NNDescent object.*",
159162
)
160163
umap = UMAP(
161164
n_neighbors=mst_indices.shape[1],
@@ -236,7 +239,7 @@ class KMSTDescent(BaseEstimator):
236239
Additional keyword arguments passsed to NNDescent.
237240
n_jobs : int, optional
238241
The number of threads to use for the computation. -1 means using all
239-
threads.
242+
threads.
240243
241244
Attributes
242245
----------
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .api import (
22
KMSTDescentLogRecall as _KMSTDescentLogRecall,
3-
kMSTDescentLogRecall as _kMSTDescentLogRecall
3+
kMSTDescentLogRecall as _kMSTDescentLogRecall,
44
)
55

66
__all__ = ["_KMSTDescentLogRecall", "_kMSTDescentLogRecall"]

multi_mst/k_mst_descent_recall/api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,10 @@ def kMSTDescentLogRecall(
142142
]
143143
with warn.catch_warnings():
144144
warn.filterwarnings(
145-
"ignore", category=UserWarning, module="umap.umap_", lineno=2010
145+
"ignore",
146+
category=UserWarning,
147+
module="umap.umap_",
148+
message=".*is not an NNDescent object.*",
146149
)
147150
umap = UMAP(
148151
n_neighbors=mst_indices.shape[1],

multi_mst/noisy_mst/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
import numpy as np
21
from .api import NoisyMST, noisyMST
32

4-
# Force JIT compilation on import
5-
random_state = np.random.RandomState(42)
6-
random_data = random_state.random(size=(50, 3))
7-
NoisyMST().fit(random_data)
8-
93
__all__ = ["NoisyMST", "noisyMST"]

multi_mst/noisy_mst/api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,10 @@ def noisyMST(data, num_trees=3, noise_fraction=0.1, min_samples=1, umap_kwargs=N
7676
)
7777
with warn.catch_warnings():
7878
warn.filterwarnings(
79-
"ignore", category=UserWarning, module="umap.umap_", lineno=2010
79+
"ignore",
80+
category=UserWarning,
81+
module="umap.umap_",
82+
message=".*is not an NNDescent object.*",
8083
)
8184
umap = UMAP(
8285
n_neighbors=mst_indices.shape[1],

notebooks/Benchmark MNIST.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
3030
"\n",
3131
"from umap import UMAP\n",
32-
"from multi_mst.k_mst import KMST\n",
33-
"from multi_mst.k_mst_descent import KMSTDescent"
32+
"from multi_mst import KMST, KMSTDescent"
3433
]
3534
},
3635
{
@@ -393,7 +392,7 @@
393392
"name": "python",
394393
"nbconvert_exporter": "python",
395394
"pygments_lexer": "ipython3",
396-
"version": "3.9.19"
395+
"version": "3.10.15"
397396
}
398397
},
399398
"nbformat": 4,

0 commit comments

Comments
 (0)