Skip to content

Commit c35e567

Browse files
committed
Don't store RNG as class attribute and make methods private
1 parent 8c42ba9 commit c35e567

File tree

2 files changed

+29
-31
lines changed

2 files changed

+29
-31
lines changed

sklearn_extra/kernel_approximation/_fastfood.py

Lines changed: 26 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ def __init__(self,
6767
self.tradeoff_mem_accuracy = tradeoff_mem_accuracy
6868

6969
@staticmethod
70-
def is_number_power_of_two(n):
70+
def _is_number_power_of_two(n):
7171
return n != 0 and ((n & (n - 1)) == 0)
7272

7373
@staticmethod
74-
def enforce_dimensionality_constraints(d, n):
75-
if not (Fastfood.is_number_power_of_two(d)):
74+
def _enforce_dimensionality_constraints(d, n):
75+
if not (Fastfood._is_number_power_of_two(d)):
7676
# find d that fulfills 2^l
7777
d = np.power(2, np.floor(np.log2(d)) + 1)
7878
divisor, remainder = divmod(n, d)
@@ -83,7 +83,7 @@ def enforce_dimensionality_constraints(d, n):
8383
times_to_stack_v = int(divisor+1)
8484
return int(d), int(n), times_to_stack_v
8585

86-
def pad_with_zeros(self, X):
86+
def _pad_with_zeros(self, X):
8787
try:
8888
X_padded = np.pad(
8989
X,
@@ -98,42 +98,42 @@ def pad_with_zeros(self, X):
9898
return X_padded
9999

100100
@staticmethod
101-
def approx_fourier_transformation_multi_dim(result):
101+
def _approx_fourier_transformation_multi_dim(result):
102102
cyfht(result)
103103

104104
@staticmethod
105-
def l2norm_along_axis1(X):
105+
def _l2norm_along_axis1(X):
106106
return np.sqrt(np.einsum('ij,ij->i', X, X))
107107

108-
def uniform_vector(self):
108+
def _uniform_vector(self, rng):
109109
if self.tradeoff_mem_accuracy != 'accuracy':
110-
return self._rng.uniform(0, 2 * np.pi, size=self._n)
110+
return rng.uniform(0, 2 * np.pi, size=self._n)
111111
else:
112112
return None
113113

114-
def apply_approximate_gaussian_matrix(self, B, G, P, X):
114+
def _apply_approximate_gaussian_matrix(self, B, G, P, X):
115115
""" Create mapping of all x_i by applying B, G and P step-wise """
116116
num_examples = X.shape[0]
117117

118118
result = np.multiply(B, X.reshape((1, num_examples, 1, self._d)))
119119
result = result.reshape((num_examples*self._times_to_stack_v, self._d))
120-
Fastfood.approx_fourier_transformation_multi_dim(result)
120+
Fastfood._approx_fourier_transformation_multi_dim(result)
121121
result = result.reshape((num_examples, -1))
122122
np.take(result, P, axis=1, mode='wrap', out=result)
123123
np.multiply(np.ravel(G), result.reshape(num_examples, self._n),
124124
out=result)
125125
result = result.reshape(num_examples*self._times_to_stack_v, self._d)
126-
Fastfood.approx_fourier_transformation_multi_dim(result)
126+
Fastfood._approx_fourier_transformation_multi_dim(result)
127127
return result
128128

129-
def scale_transformed_data(self, S, VX):
129+
def _scale_transformed_data(self, S, VX):
130130
""" Scale mapped data VX to match kernel(e.g. RBF-Kernel) """
131131
VX = VX.reshape(-1, self._times_to_stack_v*self._d)
132132

133133
return (1 / (self.sigma * np.sqrt(self._d)) *
134134
np.multiply(np.ravel(S), VX))
135135

136-
def phi(self, X):
136+
def _phi(self, X):
137137
if self.tradeoff_mem_accuracy == 'accuracy':
138138
return (1 / np.sqrt(X.shape[1])) * \
139139
np.hstack([np.cos(X), np.sin(X)])
@@ -161,27 +161,27 @@ def fit(self, X, y=None):
161161
X = check_array(X, dtype=np.float64)
162162

163163
d_orig = X.shape[1]
164-
self._rng = check_random_state(self.random_state)
164+
rng = check_random_state(self.random_state)
165165

166166
self._d, self._n, self._times_to_stack_v = \
167-
Fastfood.enforce_dimensionality_constraints(d_orig,
168-
self.n_components)
167+
Fastfood._enforce_dimensionality_constraints(d_orig,
168+
self.n_components)
169169
self._number_of_features_to_pad_with_zeros = self._d - d_orig
170170

171-
self._G = self._rng.normal(size=(self._times_to_stack_v, self._d))
172-
self._B = self._rng.choice(
171+
self._G = rng.normal(size=(self._times_to_stack_v, self._d))
172+
self._B = rng.choice(
173173
[-1, 1],
174174
size=(self._times_to_stack_v, self._d),
175175
replace=True)
176-
self._P = np.hstack([(i*self._d)+self._rng.permutation(self._d)
176+
self._P = np.hstack([(i*self._d) + rng.permutation(self._d)
177177
for i in range(self._times_to_stack_v)])
178-
self._S = np.multiply(1 / self.l2norm_along_axis1(self._G)
178+
self._S = np.multiply(1 / self._l2norm_along_axis1(self._G)
179179
.reshape((-1, 1)),
180180
chi.rvs(self._d,
181181
size=(self._times_to_stack_v, self._d),
182182
random_state=self.random_state))
183183

184-
self._U = self.uniform_vector()
184+
self._U = self._uniform_vector(rng)
185185

186186
return self
187187

@@ -199,10 +199,8 @@ def transform(self, X):
199199
X_new : array-like, shape (n_samples, n_components)
200200
"""
201201
X = check_array(X, dtype=np.float64)
202-
X_padded = self.pad_with_zeros(X)
203-
HGPHBX = self.apply_approximate_gaussian_matrix(self._B,
204-
self._G,
205-
self._P,
206-
X_padded)
207-
VX = self.scale_transformed_data(self._S, HGPHBX)
208-
return self.phi(VX)
202+
X_padded = self._pad_with_zeros(X)
203+
HGPHBX = self._apply_approximate_gaussian_matrix(
204+
self._B, self._G, self._P, X_padded)
205+
VX = self._scale_transformed_data(self._S, HGPHBX)
206+
return self._phi(VX)

sklearn_extra/kernel_approximation/test_fastfood.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,14 @@ def logging_histogram_kernel(x, y, log):
206206

207207
def test_enforce_dimensionality_constraint():
208208

209-
for message, input, expected in [
209+
for message, input_, expected in [
210210
('test n is scaled to be a multiple of d', (16, 20), (16, 32, 2)),
211211
('test n equals d', (16, 16), (16, 16, 1)),
212212
('test n becomes power of two', (3, 16), (4, 16, 4)),
213213
('test all', (7, 12), (8, 16, 2)),
214214
]:
215-
d, n = input
216-
output = Fastfood.enforce_dimensionality_constraints(d, n)
215+
d, n = input_
216+
output = Fastfood._enforce_dimensionality_constraints(d, n)
217217
yield assert_equal, expected, output, message
218218

219219

0 commit comments

Comments
 (0)