Skip to content

Commit eaae17c

Browse files
authored
Sklearn 1.6 compat (#70)
1 parent 93c1f19 commit eaae17c

File tree

7 files changed

+61
-37
lines changed

7 files changed

+61
-37
lines changed

pyproject.toml

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ urls.Documentation = "https://sklearn-ann.readthedocs.io/"
1111
dynamic = ["version", "readme"]
1212
requires-python = "<3.13,>=3.9" # enforced by scipy
1313
dependencies = [
14-
"scikit-learn>=0.24.0",
14+
"scikit-learn>=1.6.0",
1515
"scipy>=1.11.1,<2.0.0",
1616
]
1717

@@ -23,7 +23,7 @@ tests = [
2323
docs = [
2424
"sphinx>=7",
2525
"sphinx-gallery>=0.8.2",
26-
"sphinx-book-theme>=1.1.0rc1",
26+
"sphinx-book-theme>=1.1.0",
2727
"sphinx-issues>=1.2.0",
2828
"numpydoc>=1.1.0",
2929
"matplotlib>=3.3.3",
@@ -37,6 +37,7 @@ faiss = [
3737
]
3838
pynndescent = [
3939
"pynndescent>=0.5.1,<1.0.0",
40+
"numba>=0.52",
4041
]
4142
nmslib = [
4243
"nmslib>=2.1.1,<3.0.0 ; python_version < '3.11'",
@@ -84,16 +85,14 @@ ignore = [
8485
[tool.ruff.lint.isort]
8586
known-first-party = ["sklearn_ann"]
8687

87-
[tool.hatch.envs.default]
88-
features = [
89-
"tests",
90-
"docs",
91-
"annlibs",
92-
]
88+
[tool.hatch.envs.docs]
89+
installer = "uv"
90+
features = ["docs", "annlibs"]
91+
scripts.build = "sphinx-build -M html docs docs/_build"
9392

94-
[tool.hatch.envs.default.scripts]
95-
test = "pytest {args:tests}"
96-
build-docs = "sphinx-build -M html docs docs/_build"
93+
[tool.hatch.envs.hatch-test]
94+
default-args = []
95+
features = ["tests", "annlibs"]
9796

9897
[tool.hatch.build.targets.wheel]
9998
packages = ["src/sklearn_ann"]

src/sklearn_ann/cluster/rnn_dbscan.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from collections import deque
2+
from typing import cast
23

34
import numpy as np
45
from sklearn.base import BaseEstimator, ClusterMixin
56
from sklearn.neighbors import KNeighborsTransformer
7+
from sklearn.utils import Tags
8+
from sklearn.utils.validation import validate_data
69

710
from ..utils import get_sparse_row
811

@@ -143,7 +146,7 @@ def __init__(
143146
self.keep_knns = keep_knns
144147

145148
def fit(self, X, y=None):
146-
X = self._validate_data(X, accept_sparse="csr")
149+
X = validate_data(self, X, accept_sparse="csr")
147150
if self.input_guarantee == "none":
148151
algorithm = KNeighborsTransformer(n_neighbors=self.n_neighbors)
149152
X = algorithm.fit_transform(X)
@@ -181,6 +184,11 @@ def drop_knns(self):
181184
del self.knns_
182185
del self.rev_knns_
183186

187+
def __sklearn_tags__(self) -> Tags:
188+
tags = cast(Tags, super().__sklearn_tags__())
189+
tags.input_tags.sparse = True
190+
return tags
191+
184192

185193
def simple_rnn_dbscan_pipeline(
186194
neighbor_transformer, n_neighbors, n_jobs=None, keep_knns=None, **kwargs

src/sklearn_ann/kneighbors/annoy.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import numpy as np
33
from scipy.sparse import csr_matrix
44
from sklearn.base import BaseEstimator, TransformerMixin
5+
from sklearn.utils import Tags, TargetTags, TransformerTags
6+
from sklearn.utils.validation import validate_data
57

68
from ..utils import TransformerChecksMixin
79

@@ -16,7 +18,7 @@ def __init__(self, n_neighbors=5, *, metric="euclidean", n_trees=10, search_k=-1
1618
self.metric = metric
1719

1820
def fit(self, X, y=None):
19-
X = self._validate_data(X)
21+
X = validate_data(self, X)
2022
self.n_samples_fit_ = X.shape[0]
2123
metric = self.metric if self.metric != "sqeuclidean" else "euclidean"
2224
self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=metric)
@@ -68,8 +70,9 @@ def _transform(self, X):
6870

6971
return kneighbors_graph
7072

71-
def _more_tags(self):
72-
return {
73-
"_xfail_checks": {"check_estimators_pickle": "Cannot pickle AnnoyIndex"},
74-
"requires_y": False,
75-
}
73+
def __sklearn_tags__(self) -> Tags:
74+
return Tags(
75+
estimator_type="transformer",
76+
target_tags=TargetTags(required=False),
77+
transformer_tags=TransformerTags(),
78+
)

src/sklearn_ann/kneighbors/faiss.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from joblib import cpu_count
99
from scipy.sparse import csr_matrix
1010
from sklearn.base import BaseEstimator, TransformerMixin
11+
from sklearn.utils import Tags, TargetTags, TransformerTags
12+
from sklearn.utils.validation import validate_data
1113

1214
from ..utils import TransformerChecksMixin, postprocess_knn_csr
1315

@@ -85,7 +87,7 @@ def _metric_info(self):
8587

8688
def fit(self, X, y=None):
8789
normalize = self._metric_info.get("normalize", False)
88-
X = self._validate_data(X, dtype=np.float32, copy=normalize)
90+
X = validate_data(self, X, dtype=np.float32, copy=normalize)
8991
self.n_samples_fit_ = X.shape[0]
9092
if self.n_jobs == -1:
9193
n_jobs = cpu_count()
@@ -157,14 +159,11 @@ def _transform(self, X):
157159
def fit_transform(self, X, y=None):
158160
return self.fit(X, y=y)._transform(X=None)
159161

160-
def _more_tags(self):
161-
return {
162-
"_xfail_checks": {
163-
"check_estimators_pickle": "Cannot pickle FAISS index",
164-
"check_methods_subset_invariance": "Unable to reset FAISS internal RNG",
165-
},
166-
"requires_y": False,
167-
"preserves_dtype": [np.float32],
162+
def __sklearn_tags__(self) -> Tags:
163+
return Tags(
164+
estimator_type="transformer",
165+
target_tags=TargetTags(required=False),
166+
transformer_tags=TransformerTags(preserves_dtype=[np.float32]),
168167
# Could be made deterministic *if* we could reset FAISS's internal RNG
169-
"non_deterministic": True,
170-
}
168+
non_deterministic=True,
169+
)

src/sklearn_ann/kneighbors/nmslib.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import numpy as np
33
from scipy.sparse import csr_matrix
44
from sklearn.base import BaseEstimator, TransformerMixin
5+
from sklearn.utils import Tags, TransformerTags
6+
from sklearn.utils.validation import validate_data
57

68
from ..utils import TransformerChecksMixin, check_metric
79

@@ -28,7 +30,7 @@ def __init__(
2830
self.n_jobs = n_jobs
2931

3032
def fit(self, X, y=None):
31-
X = self._validate_data(X)
33+
X = validate_data(self, X)
3234
self.n_samples_fit_ = X.shape[0]
3335

3436
check_metric(self.metric, METRIC_MAP)
@@ -62,8 +64,8 @@ def transform(self, X):
6264

6365
return kneighbors_graph
6466

65-
def _more_tags(self):
66-
return {
67-
"_xfail_checks": {"check_estimators_pickle": "Cannot pickle NMSLib index"},
68-
"preserves_dtype": [np.float32],
69-
}
67+
def __sklearn_tags__(self) -> Tags:
68+
return Tags(
69+
estimator_type="transformer",
70+
transformer_tags=TransformerTags(preserves_dtype=[np.float32]),
71+
)

src/sklearn_ann/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from scipy.sparse import csr_matrix
3+
from sklearn.utils.validation import validate_data
34

45

56
def check_metric(metric, metrics):
@@ -90,6 +91,6 @@ class TransformerChecksMixin:
9091
def _transform_checks(self, X, *fitted_props, **check_params):
9192
from sklearn.utils.validation import check_is_fitted
9293

93-
X = self._validate_data(X, reset=False, **check_params)
94+
X = validate_data(self, X, reset=False, **check_params)
9495
check_is_fitted(self, *fitted_props)
9596
return X

tests/test_common.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@
3131
pytest.param(KDTreeTransformer),
3232
]
3333

34+
PER_ESTIMATOR_XFAIL_CHECKS = {
35+
AnnoyTransformer: dict(check_estimators_pickle="Cannot pickle AnnoyIndex"),
36+
FAISSTransformer: dict(
37+
check_estimators_pickle="Cannot pickle FAISS index",
38+
check_methods_subset_invariance="Unable to reset FAISS internal RNG",
39+
),
40+
NMSlibTransformer: dict(check_estimators_pickle="Cannot pickle NMSLib index"),
41+
}
42+
3443

3544
def add_mark(param, mark):
3645
return pytest.param(*param.values, marks=[*param.marks, mark], id=param.id)
@@ -51,7 +60,10 @@ def add_mark(param, mark):
5160
],
5261
)
5362
def test_all_estimators(Estimator):
54-
check_estimator(Estimator())
63+
check_estimator(
64+
Estimator(),
65+
expected_failed_checks=PER_ESTIMATOR_XFAIL_CHECKS.get(Estimator, {}),
66+
)
5567

5668

5769
# The following critera are from:

0 commit comments

Comments
 (0)