Skip to content

Commit c5e9bbc

Browse files
committed
Remove tests copied from scikit-learn
1 parent c35e567 commit c5e9bbc

File tree

2 files changed

+22
-209
lines changed

2 files changed

+22
-209
lines changed

sklearn_extra/kernel_approximation/_fastfood.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def fit(self, X, y=None):
179179
.reshape((-1, 1)),
180180
chi.rvs(self._d,
181181
size=(self._times_to_stack_v, self._d),
182-
random_state=self.random_state))
182+
random_state=rng))
183183

184184
self._U = self._uniform_vector(rng)
185185

sklearn_extra/kernel_approximation/test_fastfood.py

Lines changed: 21 additions & 208 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
1+
import pytest
12
import numpy as np
2-
from scipy.sparse import csr_matrix
33

4-
from sklearn.utils.testing import assert_array_equal, assert_equal
5-
from sklearn.utils.testing import assert_not_equal
6-
from sklearn.utils.testing import assert_array_almost_equal, assert_raises
7-
8-
9-
from sklearn.metrics.pairwise import kernel_metrics
10-
from sklearn.kernel_approximation import RBFSampler
11-
from sklearn.kernel_approximation import AdditiveChi2Sampler
12-
from sklearn.kernel_approximation import SkewedChi2Sampler
13-
from sklearn.kernel_approximation import Nystroem
14-
from sklearn.metrics.pairwise import polynomial_kernel, rbf_kernel
4+
from sklearn.utils.testing import assert_equal
5+
from sklearn.utils.testing import assert_array_almost_equal
6+
from sklearn.metrics.pairwise import rbf_kernel
157

168
from sklearn_extra.kernel_approximation import Fastfood
179

@@ -24,200 +16,17 @@
2416
Y /= Y.sum(axis=1)[:, np.newaxis]
2517

2618

27-
def test_additive_chi2_sampler():
28-
"""test that AdditiveChi2Sampler approximates kernel on random data"""
29-
30-
# compute exact kernel
31-
# appreviations for easier formular
32-
X_ = X[:, np.newaxis, :]
33-
Y_ = Y[np.newaxis, :, :]
34-
35-
large_kernel = 2 * X_ * Y_ / (X_ + Y_)
36-
37-
# reduce to n_samples_x x n_samples_y by summing over features
38-
kernel = (large_kernel.sum(axis=2))
39-
40-
# approximate kernel mapping
41-
transform = AdditiveChi2Sampler(sample_steps=3)
42-
X_trans = transform.fit_transform(X)
43-
Y_trans = transform.transform(Y)
44-
45-
kernel_approx = np.dot(X_trans, Y_trans.T)
46-
47-
assert_array_almost_equal(kernel, kernel_approx, 1)
48-
49-
X_sp_trans = transform.fit_transform(csr_matrix(X))
50-
Y_sp_trans = transform.transform(csr_matrix(Y))
51-
52-
assert_array_equal(X_trans, X_sp_trans.A)
53-
assert_array_equal(Y_trans, Y_sp_trans.A)
54-
55-
# test error is raised on negative input
56-
Y_neg = Y.copy()
57-
Y_neg[0, 0] = -1
58-
assert_raises(ValueError, transform.transform, Y_neg)
59-
60-
# test error on invalid sample_steps
61-
transform = AdditiveChi2Sampler(sample_steps=4)
62-
assert_raises(ValueError, transform.fit, X)
63-
64-
# test that the sample interval is set correctly
65-
sample_steps_available = [1, 2, 3]
66-
for sample_steps in sample_steps_available:
67-
68-
# test that the sample_interval is initialized correctly
69-
transform = AdditiveChi2Sampler(sample_steps=sample_steps)
70-
assert_equal(transform.sample_interval, None)
71-
72-
# test that the sample_interval is changed in the fit method
73-
transform.fit(X)
74-
assert_not_equal(transform.sample_interval_, None)
75-
76-
# test that the sample_interval is set correctly
77-
sample_interval = 0.3
78-
transform = AdditiveChi2Sampler(sample_steps=4,
79-
sample_interval=sample_interval)
80-
assert_equal(transform.sample_interval, sample_interval)
81-
transform.fit(X)
82-
assert_equal(transform.sample_interval_, sample_interval)
83-
84-
85-
def test_skewed_chi2_sampler():
86-
"""test that RBFSampler approximates kernel on random data"""
87-
88-
# compute exact kernel
89-
c = 0.03
90-
# appreviations for easier formular
91-
X_c = (X + c)[:, np.newaxis, :]
92-
Y_c = (Y + c)[np.newaxis, :, :]
93-
94-
# we do it in log-space in the hope that it's more stable
95-
# this array is n_samples_x x n_samples_y big x n_features
96-
log_kernel = ((np.log(X_c) / 2.) + (np.log(Y_c) / 2.) + np.log(2.) -
97-
np.log(X_c + Y_c))
98-
# reduce to n_samples_x x n_samples_y by summing over features in log-space
99-
kernel = np.exp(log_kernel.sum(axis=2))
100-
101-
# approximate kernel mapping
102-
transform = SkewedChi2Sampler(skewedness=c, n_components=1000,
103-
random_state=42)
104-
X_trans = transform.fit_transform(X)
105-
Y_trans = transform.transform(Y)
106-
107-
kernel_approx = np.dot(X_trans, Y_trans.T)
108-
assert_array_almost_equal(kernel, kernel_approx, 1)
109-
110-
# test error is raised on negative input
111-
Y_neg = Y.copy()
112-
Y_neg[0, 0] = -1
113-
assert_raises(ValueError, transform.transform, Y_neg)
114-
115-
116-
def test_rbf_sampler():
117-
"""test that RBFSampler approximates kernel on random data"""
118-
# compute exact kernel
119-
gamma = 10.
120-
kernel = rbf_kernel(X, Y, gamma=gamma)
121-
122-
# approximate kernel mapping
123-
rbf_transform = RBFSampler(gamma=gamma, n_components=1000, random_state=42)
124-
X_trans = rbf_transform.fit_transform(X)
125-
Y_trans = rbf_transform.transform(Y)
126-
kernel_approx = np.dot(X_trans, Y_trans.T)
127-
128-
129-
assert_array_almost_equal(kernel, kernel_approx, 1)
130-
131-
132-
def test_input_validation():
133-
"""Regression test: kernel approx. transformers should work on lists
134-
135-
No assertions; the old versions would simply crash
136-
"""
137-
X = [[1, 2], [3, 4], [5, 6]]
138-
AdditiveChi2Sampler().fit(X).transform(X)
139-
SkewedChi2Sampler().fit(X).transform(X)
140-
RBFSampler().fit(X).transform(X)
141-
142-
X = csr_matrix(X)
143-
RBFSampler().fit(X).transform(X)
144-
145-
146-
def test_nystroem_approximation():
147-
# some basic tests
148-
rnd = np.random.RandomState(0)
149-
X = rnd.uniform(size=(10, 4))
150-
151-
# With n_components = n_samples this is exact
152-
X_transformed = Nystroem(n_components=X.shape[0]).fit_transform(X)
153-
K = rbf_kernel(X)
154-
assert_array_almost_equal(np.dot(X_transformed, X_transformed.T), K)
155-
156-
trans = Nystroem(n_components=2, random_state=rnd)
157-
X_transformed = trans.fit(X).transform(X)
158-
assert_equal(X_transformed.shape, (X.shape[0], 2))
159-
160-
# test callable kernel
161-
linear_kernel = lambda X, Y: np.dot(X, Y.T)
162-
trans = Nystroem(n_components=2, kernel=linear_kernel, random_state=rnd)
163-
X_transformed = trans.fit(X).transform(X)
164-
assert_equal(X_transformed.shape, (X.shape[0], 2))
165-
166-
# test that available kernels fit and transform
167-
kernels_available = kernel_metrics()
168-
for kern in kernels_available:
169-
trans = Nystroem(n_components=2, kernel=kern, random_state=rnd)
170-
X_transformed = trans.fit(X).transform(X)
171-
assert_equal(X_transformed.shape, (X.shape[0], 2))
172-
173-
174-
def test_nystroem_poly_kernel_params():
175-
"""Non-regression: Nystroem should pass other parameters beside gamma."""
176-
rnd = np.random.RandomState(37)
177-
X = rnd.uniform(size=(10, 4))
178-
179-
K = polynomial_kernel(X, degree=3.1, coef0=.1)
180-
nystroem = Nystroem(kernel="polynomial", n_components=X.shape[0],
181-
degree=3.1, coef0=.1)
182-
X_transformed = nystroem.fit_transform(X)
183-
assert_array_almost_equal(np.dot(X_transformed, X_transformed.T), K)
184-
185-
186-
def test_nystroem_callable():
187-
"""Test Nystroem on a callable."""
188-
rnd = np.random.RandomState(42)
189-
n_samples = 10
190-
X = rnd.uniform(size=(n_samples, 4))
191-
192-
def logging_histogram_kernel(x, y, log):
193-
"""Histogram kernel that writes to a log."""
194-
log.append(1)
195-
return np.minimum(x, y).sum()
196-
197-
kernel_log = []
198-
X = list(X) # test input validation
199-
Nystroem(kernel=logging_histogram_kernel,
200-
n_components=(n_samples - 1),
201-
kernel_params={'log': kernel_log}).fit(X)
202-
assert_equal(len(kernel_log), n_samples * (n_samples - 1) / 2)
203-
204-
# Fastfood
205-
206-
207-
def test_enforce_dimensionality_constraint():
208-
209-
for message, input_, expected in [
210-
('test n is scaled to be a multiple of d', (16, 20), (16, 32, 2)),
211-
('test n equals d', (16, 16), (16, 16, 1)),
212-
('test n becomes power of two', (3, 16), (4, 16, 4)),
213-
('test all', (7, 12), (8, 16, 2)),
214-
]:
215-
d, n = input_
216-
output = Fastfood._enforce_dimensionality_constraints(d, n)
217-
yield assert_equal, expected, output, message
218-
219-
220-
# Performance Analysis
19+
@pytest.mark.parametrize(
20+
"message, input_, expected",
21+
[('test n is scaled to be a multiple of d', (16, 20), (16, 32, 2)),
22+
('test n equals d', (16, 16), (16, 16, 1)),
23+
('test n becomes power of two', (3, 16), (4, 16, 4)),
24+
('test all', (7, 12), (8, 16, 2)),
25+
])
26+
def test_fastfood_enforce_dimensionality_constraint(message, input_, expected):
27+
d, n = input_
28+
output = Fastfood._enforce_dimensionality_constraints(d, n)
29+
assert_equal(expected, output, message)
22130

22231

22332
def test_fastfood():
@@ -257,7 +66,9 @@ def test_fastfood():
25766
#
25867
# fastfood_start = datetime.datetime.utcnow()
25968
# # Fastfood: approximate kernel mapping
260-
# rbf_transform = Fastfood(sigma=sigma, n_components=number_of_features_to_generate, tradeoff_less_mem_or_higher_accuracy='accuracy', random_state=42)
69+
# rbf_transform = Fastfood(
70+
# sigma=sigma, n_components=number_of_features_to_generate,
71+
# tradeoff_less_mem_or_higher_accuracy='accuracy', random_state=42)
26172
# _ = rbf_transform.fit_transform(X)
26273
# fastfood_end = datetime.datetime.utcnow()
26374
# fastfood_spent_time =fastfood_end- fastfood_start
@@ -266,7 +77,9 @@ def test_fastfood():
26677
#
26778
# fastfood_mem_start = datetime.datetime.utcnow()
26879
# # Fastfood: approximate kernel mapping
269-
# rbf_transform = Fastfood(sigma=sigma, n_components=number_of_features_to_generate, tradeoff_less_mem_or_higher_accuracy='mem', random_state=42)
80+
# rbf_transform = Fastfood(
81+
# sigma=sigma, n_components=number_of_features_to_generate,
82+
# tradeoff_less_mem_or_higher_accuracy='mem', random_state=42)
27083
# _ = rbf_transform.fit_transform(X)
27184
# fastfood_mem_end = datetime.datetime.utcnow()
27285
# fastfood_mem_spent_time = fastfood_mem_end- fastfood_mem_start

0 commit comments

Comments
 (0)