Skip to content

Commit 7a1ac16

Browse files
committed
Fix tests
1 parent 44009a7 commit 7a1ac16

File tree

4 files changed

+45
-19
lines changed

4 files changed

+45
-19
lines changed

setup.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22
"""A template for scikit-learn compatible packages."""
33

44
import codecs
5+
import warnings
56
import os
67

7-
from setuptools import find_packages, setup
8+
from setuptools import find_packages, setup, Extension
9+
10+
import numpy as np
11+
12+
from Cython.Build import cythonize
13+
from Cython.Distutils import build_ext
814

915
# get __version__ from _version.py
1016
ver_file = os.path.join('sklearn_extra', '_version.py')
@@ -20,7 +26,7 @@
2026
URL = 'https://github.com/scikit-learn-contrib/scikit-learn-extra'
2127
LICENSE = 'new BSD'
2228
DOWNLOAD_URL = 'https://github.com/scikit-learn-contrib/scikit-learn-extra'
23-
VERSION = __version__
29+
VERSION = __version__ # noqa
2430
INSTALL_REQUIRES = ['numpy', 'scipy', 'scikit-learn']
2531
CLASSIFIERS = ['Intended Audience :: Science/Research',
2632
'Intended Audience :: Developers',
@@ -48,6 +54,21 @@
4854
]
4955
}
5056

57+
args = {
58+
"ext_modules": cythonize(
59+
[
60+
Extension(
61+
"sklearn_extra.utils._cyfht",
62+
["sklearn_extra/utils/_cyfht.pyx"],
63+
include_dirs=[np.get_include()]
64+
)
65+
],
66+
compiler_directives={"language_level": "3str"},
67+
),
68+
"cmdclass": dict(build_ext=build_ext),
69+
}
70+
71+
5172
setup(name=DISTNAME,
5273
maintainer=MAINTAINER,
5374
maintainer_email=MAINTAINER_EMAIL,
@@ -61,4 +82,5 @@
6182
classifiers=CLASSIFIERS,
6283
packages=find_packages(),
6384
install_requires=INSTALL_REQUIRES,
64-
extras_require=EXTRAS_REQUIRE)
85+
extras_require=EXTRAS_REQUIRE,
86+
**args)

sklearn_extra/kernel_approximation/_fastfood.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111

1212
import numpy as np
1313
import scipy.sparse as sp
14-
from sklearn.utils.cyfht import fht2 as cyfht
1514
from scipy.linalg import svd
1615
from scipy.stats import chi
17-
from sklearn.utils.random import choice
1816

19-
from .base import BaseEstimator
20-
from .base import TransformerMixin
21-
from .utils import check_array, check_random_state, as_float_array
22-
from .utils.extmath import safe_sparse_dot
23-
from .metrics.pairwise import pairwise_kernels
17+
from sklearn.base import BaseEstimator
18+
from sklearn.base import TransformerMixin
19+
from sklearn.utils import check_array, check_random_state, as_float_array
20+
from sklearn.utils.extmath import safe_sparse_dot
21+
from sklearn.metrics.pairwise import pairwise_kernels
22+
23+
from ..utils._cyfht import fht2 as cyfht
2424

2525

2626
class RBFSampler(BaseEstimator, TransformerMixin):
@@ -659,10 +659,10 @@ def fit(self, X, y=None):
659659
self.number_of_features_to_pad_with_zeros = self.d - d_orig
660660

661661
self.G = self.rng.normal(size=(self.times_to_stack_v, self.d))
662-
self.B = choice([-1, 1],
663-
size=(self.times_to_stack_v, self.d),
664-
replace=True,
665-
random_state=self.random_state)
662+
self.B = self.rng.choice(
663+
[-1, 1],
664+
size=(self.times_to_stack_v, self.d),
665+
replace=True)
666666
self.P = np.hstack([(i*self.d)+self.rng.permutation(self.d)
667667
for i in range(self.times_to_stack_v)])
668668
self.S = np.multiply(1 / self.l2norm_along_axis1(self.G)

sklearn_extra/kernel_approximation/test_fastfood.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from sklearn.kernel_approximation import Nystroem
1414
from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel
1515

16+
from sklearn_extra.kernel_approximation import Fastfood
17+
1618

1719
# generate data
1820
rng = np.random.RandomState(0)

sklearn_extra/utils/tests/test_fht.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
22
import numpy.testing as npt
3-
import nose.tools as nt
43
from scipy.linalg import hadamard
5-
from sklearn.utils.cyfht import fht as cyfht
6-
from sklearn.utils.cyfht import fht2 as cyfht2
4+
5+
from sklearn.utils.testing import assert_raises
6+
7+
from sklearn_extra.utils._cyfht import fht as cyfht
8+
from sklearn_extra.utils._cyfht import fht2 as cyfht2
79

810

911
def test_wikipedia_example():
@@ -34,5 +36,5 @@ def test_numerical_fuzzing_fht2():
3436

3537

3638
def test_exception_when_input_not_power_two():
37-
nt.assert_raises(ValueError, cyfht, np.zeros(9, dtype=np.float64))
38-
nt.assert_raises(ValueError, cyfht2, np.zeros((2, 9), dtype=np.float64))
39+
assert_raises(ValueError, cyfht, np.zeros(9, dtype=np.float64))
40+
assert_raises(ValueError, cyfht2, np.zeros((2, 9), dtype=np.float64))

0 commit comments

Comments
 (0)