diff --git a/pyproject.toml b/pyproject.toml index 691d172..37086a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ urls.Documentation = "https://sklearn-ann.readthedocs.io/" dynamic = ["version", "readme"] requires-python = "<3.13,>=3.9" # enforced by scipy dependencies = [ - "scikit-learn>=0.24.0", + "scikit-learn>=1.6.0", "scipy>=1.11.1,<2.0.0", ] @@ -23,7 +23,7 @@ tests = [ docs = [ "sphinx>=7", "sphinx-gallery>=0.8.2", - "sphinx-book-theme>=1.1.0rc1", + "sphinx-book-theme>=1.1.0", "sphinx-issues>=1.2.0", "numpydoc>=1.1.0", "matplotlib>=3.3.3", @@ -37,6 +37,7 @@ faiss = [ ] pynndescent = [ "pynndescent>=0.5.1,<1.0.0", + "numba>=0.52", ] nmslib = [ "nmslib>=2.1.1,<3.0.0 ; python_version < '3.11'", @@ -84,16 +85,14 @@ ignore = [ [tool.ruff.lint.isort] known-first-party = ["sklearn_ann"] -[tool.hatch.envs.default] -features = [ - "tests", - "docs", - "annlibs", -] +[tool.hatch.envs.docs] +installer = "uv" +features = ["docs", "annlibs"] +scripts.build = "sphinx-build -M html docs docs/_build" -[tool.hatch.envs.default.scripts] -test = "pytest {args:tests}" -build-docs = "sphinx-build -M html docs docs/_build" +[tool.hatch.envs.hatch-test] +default-args = [] +features = ["tests", "annlibs"] [tool.hatch.build.targets.wheel] packages = ["src/sklearn_ann"] diff --git a/src/sklearn_ann/cluster/rnn_dbscan.py b/src/sklearn_ann/cluster/rnn_dbscan.py index dc6a398..0ce7e3b 100644 --- a/src/sklearn_ann/cluster/rnn_dbscan.py +++ b/src/sklearn_ann/cluster/rnn_dbscan.py @@ -1,8 +1,11 @@ from collections import deque +from typing import cast import numpy as np from sklearn.base import BaseEstimator, ClusterMixin from sklearn.neighbors import KNeighborsTransformer +from sklearn.utils import Tags +from sklearn.utils.validation import validate_data from ..utils import get_sparse_row @@ -143,7 +146,7 @@ def __init__( self.keep_knns = keep_knns def fit(self, X, y=None): - X = self._validate_data(X, accept_sparse="csr") + X = validate_data(self, X, accept_sparse="csr") if self.input_guarantee == "none": algorithm = KNeighborsTransformer(n_neighbors=self.n_neighbors) X = algorithm.fit_transform(X) @@ -181,6 +184,11 @@ def drop_knns(self): del self.knns_ del self.rev_knns_ + def __sklearn_tags__(self) -> Tags: + tags = cast(Tags, super().__sklearn_tags__()) + tags.input_tags.sparse = True + return tags + def simple_rnn_dbscan_pipeline( neighbor_transformer, n_neighbors, n_jobs=None, keep_knns=None, **kwargs diff --git a/src/sklearn_ann/kneighbors/annoy.py b/src/sklearn_ann/kneighbors/annoy.py index efa95a1..67505eb 100644 --- a/src/sklearn_ann/kneighbors/annoy.py +++ b/src/sklearn_ann/kneighbors/annoy.py @@ -2,6 +2,8 @@ import numpy as np from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils import Tags, TargetTags, TransformerTags +from sklearn.utils.validation import validate_data from ..utils import TransformerChecksMixin @@ -16,7 +18,7 @@ def __init__(self, n_neighbors=5, *, metric="euclidean", n_trees=10, search_k=-1 self.metric = metric def fit(self, X, y=None): - X = self._validate_data(X) + X = validate_data(self, X) self.n_samples_fit_ = X.shape[0] metric = self.metric if self.metric != "sqeuclidean" else "euclidean" self.annoy_ = annoy.AnnoyIndex(X.shape[1], metric=metric) @@ -68,8 +70,9 @@ def _transform(self, X): return kneighbors_graph - def _more_tags(self): - return { - "_xfail_checks": {"check_estimators_pickle": "Cannot pickle AnnoyIndex"}, - "requires_y": False, - } + def __sklearn_tags__(self) -> Tags: + return Tags( + estimator_type="transformer", + target_tags=TargetTags(required=False), + transformer_tags=TransformerTags(), + ) diff --git a/src/sklearn_ann/kneighbors/faiss.py b/src/sklearn_ann/kneighbors/faiss.py index 5e8dd2a..c349132 100644 --- a/src/sklearn_ann/kneighbors/faiss.py +++ b/src/sklearn_ann/kneighbors/faiss.py @@ -8,6 +8,8 @@ from joblib import cpu_count from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils import Tags, TargetTags, TransformerTags +from sklearn.utils.validation import validate_data from ..utils import TransformerChecksMixin, postprocess_knn_csr @@ -85,7 +87,7 @@ def _metric_info(self): def fit(self, X, y=None): normalize = self._metric_info.get("normalize", False) - X = self._validate_data(X, dtype=np.float32, copy=normalize) + X = validate_data(self, X, dtype=np.float32, copy=normalize) self.n_samples_fit_ = X.shape[0] if self.n_jobs == -1: n_jobs = cpu_count() @@ -157,14 +159,11 @@ def _transform(self, X): def fit_transform(self, X, y=None): return self.fit(X, y=y)._transform(X=None) - def _more_tags(self): - return { - "_xfail_checks": { - "check_estimators_pickle": "Cannot pickle FAISS index", - "check_methods_subset_invariance": "Unable to reset FAISS internal RNG", - }, - "requires_y": False, - "preserves_dtype": [np.float32], + def __sklearn_tags__(self) -> Tags: + return Tags( + estimator_type="transformer", + target_tags=TargetTags(required=False), + transformer_tags=TransformerTags(preserves_dtype=[np.float32]), # Could be made deterministic *if* we could reset FAISS's internal RNG - "non_deterministic": True, - } + non_deterministic=True, + ) diff --git a/src/sklearn_ann/kneighbors/nmslib.py b/src/sklearn_ann/kneighbors/nmslib.py index 314d0c4..c10fb43 100644 --- a/src/sklearn_ann/kneighbors/nmslib.py +++ b/src/sklearn_ann/kneighbors/nmslib.py @@ -2,6 +2,8 @@ import numpy as np from scipy.sparse import csr_matrix from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.utils import Tags, TransformerTags +from sklearn.utils.validation import validate_data from ..utils import TransformerChecksMixin, check_metric @@ -28,7 +30,7 @@ def __init__( self.n_jobs = n_jobs def fit(self, X, y=None): - X = self._validate_data(X) + X = validate_data(self, X) self.n_samples_fit_ = X.shape[0] check_metric(self.metric, METRIC_MAP) @@ -62,8 +64,8 @@ def transform(self, X): return kneighbors_graph - def _more_tags(self): - return { - "_xfail_checks": {"check_estimators_pickle": "Cannot pickle NMSLib index"}, - "preserves_dtype": [np.float32], - } + def __sklearn_tags__(self) -> Tags: + return Tags( + estimator_type="transformer", + transformer_tags=TransformerTags(preserves_dtype=[np.float32]), + ) diff --git a/src/sklearn_ann/utils.py b/src/sklearn_ann/utils.py index 38284f9..e275e27 100644 --- a/src/sklearn_ann/utils.py +++ b/src/sklearn_ann/utils.py @@ -1,5 +1,6 @@ import numpy as np from scipy.sparse import csr_matrix +from sklearn.utils.validation import validate_data def check_metric(metric, metrics): @@ -90,6 +91,6 @@ class TransformerChecksMixin: def _transform_checks(self, X, *fitted_props, **check_params): from sklearn.utils.validation import check_is_fitted - X = self._validate_data(X, reset=False, **check_params) + X = validate_data(self, X, reset=False, **check_params) check_is_fitted(self, *fitted_props) return X diff --git a/tests/test_common.py b/tests/test_common.py index d493a9d..65e453c 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -31,6 +31,15 @@ pytest.param(KDTreeTransformer), ] +PER_ESTIMATOR_XFAIL_CHECKS = { + AnnoyTransformer: dict(check_estimators_pickle="Cannot pickle AnnoyIndex"), + FAISSTransformer: dict( + check_estimators_pickle="Cannot pickle FAISS index", + check_methods_subset_invariance="Unable to reset FAISS internal RNG", + ), + NMSlibTransformer: dict(check_estimators_pickle="Cannot pickle NMSLib index"), +} + def add_mark(param, mark): return pytest.param(*param.values, marks=[*param.marks, mark], id=param.id) @@ -51,7 +60,10 @@ def add_mark(param, mark): ], ) def test_all_estimators(Estimator): - check_estimator(Estimator()) + check_estimator( + Estimator(), + expected_failed_checks=PER_ESTIMATOR_XFAIL_CHECKS.get(Estimator, {}), + ) # The following critera are from: