Skip to content

Commit 3f917d4

Browse files
committed
fix: more refactor
1 parent 62c8ddd commit 3f917d4

File tree

5 files changed

+302
-196
lines changed

5 files changed

+302
-196
lines changed

onedal/neighbors/neighbors.py

Lines changed: 28 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -65,18 +65,9 @@ def _get_onedal_params(self, X, y=None, n_neighbors=None):
6565
try:
6666
fptype = X.dtype
6767
except AttributeError:
68-
# For pandas DataFrames or other types without dtype attribute
69-
import numpy as np
70-
7168
fptype = np.float64
7269

73-
# _fit_method should be set by sklearnex level before calling oneDAL
74-
if not hasattr(self, "_fit_method") or self._fit_method is None:
75-
raise ValueError(
76-
"_fit_method must be set by sklearnex level before calling oneDAL. "
77-
"This indicates improper usage - oneDAL neighbors should not be called directly."
78-
)
79-
70+
# _fit_method should be validated at sklearnex level before calling oneDAL
8071
return {
8172
"fptype": fptype,
8273
"vote_weights": "uniform" if weights == "uniform" else "distance",
@@ -109,77 +100,35 @@ def __init__(
109100
self.metric_params = metric_params
110101

111102
def _fit(self, X, y):
103+
# Basic initialization - all validation and preprocessing should be done at sklearnex level
112104
self._onedal_model = None
113105
self._tree = None
114-
self._shape = None
115-
self.classes_ = None
116-
self.effective_metric_ = getattr(self, "effective_metric_", self.metric)
117-
self.effective_metric_params_ = getattr(
118-
self, "effective_metric_params_", self.metric_params
119-
)
120-
121-
_, xp, _ = _get_sycl_namespace(X)
122-
if y is not None or self.requires_y:
123-
shape = getattr(y, "shape", None)
124-
self._shape = shape if shape is not None else y.shape
125-
126-
if _is_classifier(self):
127-
if y.ndim == 1 or y.ndim == 2 and y.shape[1] == 1:
128-
self.outputs_2d_ = False
129-
y = y.reshape((-1, 1))
130-
else:
131-
self.outputs_2d_ = True
132-
133-
self.classes_ = []
134-
self._y = np.empty(y.shape, dtype=int)
135-
for k in range(self._y.shape[1]):
136-
classes, self._y[:, k] = np.unique(y[:, k], return_inverse=True)
137-
self.classes_.append(classes)
138-
139-
if not self.outputs_2d_:
140-
self.classes_ = self.classes_[0]
141-
self._y = self._y.ravel()
142-
else:
143-
self._y = y
144106

107+
# Set basic fitted attributes
145108
self.n_samples_fit_ = X.shape[0]
146109
self.n_features_in_ = X.shape[1]
147110
self._fit_X = X
148111

112+
# Prepare y for oneDAL (classification/regression handling done at sklearnex level)
149113
_fit_y = None
150114
queue = QM.get_global_queue()
151115
gpu_device = queue is not None and queue.sycl_device.is_gpu
152116

153117
if _is_classifier(self) or (_is_regressor(self) and gpu_device):
154118
_fit_y = y.astype(X.dtype).reshape((-1, 1)) if y is not None else None
119+
155120
result = self._onedal_fit(X, _fit_y)
156-
157-
if y is not None and _is_regressor(self):
158-
self._y = y if self._shape is None else xp.reshape(y, self._shape)
159-
160121
self._onedal_model = result
161-
result = self
162-
163-
return result
122+
123+
return self
164124

165125
def _kneighbors(self, X=None, n_neighbors=None, return_distance=True):
166126
_check_is_fitted(self)
167127

128+
# All validation and preprocessing should be done at sklearnex level
168129
if n_neighbors is None:
169130
n_neighbors = self.n_neighbors
170131

171-
if X is not None:
172-
query_is_train = False
173-
else:
174-
query_is_train = True
175-
X = self._fit_X
176-
# Include an extra neighbor to account for the sample itself being
177-
# returned, which is removed later
178-
n_neighbors += 1
179-
180-
n_samples_fit = self.n_samples_fit_
181-
182-
chunked_results = None
183132
# Use the fit method determined at sklearnex level
184133
method = getattr(self, "_fit_method", "brute")
185134

@@ -188,52 +137,18 @@ def _kneighbors(self, X=None, n_neighbors=None, return_distance=True):
188137
distances = from_table(prediction_results.distances)
189138
indices = from_table(prediction_results.indices)
190139

140+
# Sort results for kd_tree method
191141
if method == "kd_tree":
192142
for i in range(distances.shape[0]):
193143
seq = distances[i].argsort()
194144
indices[i] = indices[i][seq]
195145
distances[i] = distances[i][seq]
196146

147+
# Return raw results - all post-processing done at sklearnex level
197148
if return_distance:
198-
results = distances, indices
199-
else:
200-
results = indices
201-
202-
if chunked_results is not None:
203-
if return_distance:
204-
neigh_dist, neigh_ind = zip(*chunked_results)
205-
results = np.vstack(neigh_dist), np.vstack(neigh_ind)
206-
else:
207-
results = np.vstack(chunked_results)
208-
209-
if not query_is_train:
210-
return results
211-
212-
# If the query data is the same as the indexed data, we would like
213-
# to ignore the first nearest neighbor of every sample, i.e
214-
# the sample itself.
215-
if return_distance:
216-
neigh_dist, neigh_ind = results
149+
return distances, indices
217150
else:
218-
neigh_ind = results
219-
220-
n_queries, _ = X.shape
221-
sample_range = np.arange(n_queries)[:, None]
222-
sample_mask = neigh_ind != sample_range
223-
224-
# Corner case: When the number of duplicates are more
225-
# than the number of neighbors, the first NN will not
226-
# be the sample, but a duplicate.
227-
# In that case mask the first duplicate.
228-
dup_gr_nbrs = np.all(sample_mask, axis=1)
229-
sample_mask[:, 0][dup_gr_nbrs] = False
230-
231-
neigh_ind = np.reshape(neigh_ind[sample_mask], (n_queries, n_neighbors - 1))
232-
233-
if return_distance:
234-
neigh_dist = np.reshape(neigh_dist[sample_mask], (n_queries, n_neighbors - 1))
235-
return neigh_dist, neigh_ind
236-
return neigh_ind
151+
return indices
237152

238153

239154
class KNeighborsClassifier(NeighborsBase, ClassifierMixin):
@@ -303,40 +218,11 @@ def predict(self, X, queue=None):
303218

304219
@supports_queue
305220
def predict_proba(self, X, queue=None):
306-
neigh_dist, neigh_ind = self.kneighbors(X, queue=queue)
307-
308-
classes_ = self.classes_
309-
_y = self._y
310-
if not self.outputs_2d_:
311-
_y = self._y.reshape((-1, 1))
312-
classes_ = [self.classes_]
313-
314-
n_queries = _num_samples(X)
315-
316-
# Use uniform weights for now - weights calculation should be done at sklearnex level
317-
weights = np.ones_like(neigh_ind)
318-
319-
all_rows = np.arange(n_queries)
320-
probabilities = []
321-
for k, classes_k in enumerate(classes_):
322-
pred_labels = _y[:, k][neigh_ind]
323-
proba_k = np.zeros((n_queries, classes_k.size))
324-
325-
# a simple ':' index doesn't work right
326-
for i, idx in enumerate(pred_labels.T): # loop is O(n_neighbors)
327-
proba_k[all_rows, idx] += weights[:, i]
328-
329-
# normalize 'votes' into real [0,1] probabilities
330-
normalizer = proba_k.sum(axis=1)[:, np.newaxis]
331-
normalizer[normalizer == 0.0] = 1.0
332-
proba_k /= normalizer
333-
334-
probabilities.append(proba_k)
335-
336-
if not self.outputs_2d_:
337-
probabilities = probabilities[0]
338-
339-
return probabilities
221+
# This method should not be called directly - weights processing moved to sklearnex level
222+
raise NotImplementedError(
223+
"predict_proba weights processing moved to sklearnex level. "
224+
"Use sklearnex.neighbors.KNeighborsClassifier instead."
225+
)
340226

341227
@supports_queue
342228
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
@@ -427,38 +313,25 @@ def _predict_gpu(self, X):
427313
return result
428314

429315
def _predict_skl(self, X):
430-
neigh_dist, neigh_ind = self.kneighbors(X)
431-
432-
# Use uniform weights for now - weights calculation should be done at sklearnex level
433-
weights = None
434-
435-
_y = self._y
436-
if _y.ndim == 1:
437-
_y = _y.reshape((-1, 1))
438-
439-
if weights is None:
440-
y_pred = np.mean(_y[neigh_ind], axis=1)
441-
else:
442-
y_pred = np.empty((X.shape[0], _y.shape[1]), dtype=np.float64)
443-
denom = np.sum(weights, axis=1)
444-
445-
for j in range(_y.shape[1]):
446-
num = np.sum(_y[neigh_ind, j] * weights, axis=1)
447-
y_pred[:, j] = num / denom
448-
449-
if self._y.ndim == 1:
450-
y_pred = y_pred.ravel()
451-
452-
return y_pred
316+
# This method should not be called directly - weights processing moved to sklearnex level
317+
raise NotImplementedError(
318+
"Regression weights processing moved to sklearnex level. "
319+
"Use sklearnex.neighbors.KNeighborsRegressor instead."
320+
)
453321

454322
@supports_queue
455323
def predict(self, X, queue=None):
324+
# For GPU with uniform weights, use direct oneDAL prediction
456325
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
457326
is_uniform_weights = getattr(self, "weights", "uniform") == "uniform"
458327
if gpu_device and is_uniform_weights:
459328
return self._predict_gpu(X)
460329
else:
461-
return self._predict_skl(X)
330+
# Weights processing should be handled at sklearnex level
331+
raise NotImplementedError(
332+
"Regression weights processing moved to sklearnex level. "
333+
"Use sklearnex.neighbors.KNeighborsRegressor instead."
334+
)
462335

463336

464337
class NearestNeighbors(NeighborsBase):

sklearnex/neighbors/common.py

Lines changed: 110 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,116 @@ def _validate_n_neighbors(self, n_neighbors):
140140
"enter integer value" % type(n_neighbors)
141141
)
142142

143+
def _validate_weights(self, weights):
144+
"""Validate weights parameter at sklearnex level."""
145+
if weights not in [None, "uniform", "distance"] and not callable(weights):
146+
raise ValueError(
147+
"weights not recognized: should be 'uniform', "
148+
"'distance', or a callable function"
149+
)
150+
151+
def _validate_fit_method(self, fit_method):
152+
"""Validate that fit_method is properly set before calling oneDAL."""
153+
if not hasattr(self, "_fit_method") or self._fit_method is None:
154+
raise ValueError(
155+
"_fit_method must be set by sklearnex level before calling oneDAL. "
156+
"This indicates improper usage - oneDAL neighbors should not be called directly."
157+
)
158+
159+
def _validate_kneighbors_params(self, n_neighbors, X=None):
160+
"""Validate parameters for kneighbors method."""
161+
if n_neighbors is not None:
162+
self._validate_n_neighbors(n_neighbors)
163+
164+
# Check bounds if we have fit data
165+
if hasattr(self, 'n_samples_fit_'):
166+
effective_n_neighbors = n_neighbors if n_neighbors is not None else self.n_neighbors
167+
if effective_n_neighbors > self.n_samples_fit_:
168+
raise ValueError(
169+
"Expected n_neighbors <= n_samples_fit, but n_samples_fit = %d, "
170+
"n_neighbors = %d" % (self.n_samples_fit_, effective_n_neighbors)
171+
)
172+
173+
def _process_kneighbors_results(self, results, X, n_neighbors, return_distance, query_is_train=None):
174+
"""Process kneighbors results at sklearnex level - handles chunking, self-neighbor removal, etc."""
175+
176+
# Determine if query is training data
177+
if query_is_train is None:
178+
query_is_train = X is None
179+
180+
# Handle chunked results (if any)
181+
chunked_results = None # This would be set by chunking logic if implemented
182+
if chunked_results is not None:
183+
if return_distance:
184+
neigh_dist, neigh_ind = zip(*chunked_results)
185+
results = np.vstack(neigh_dist), np.vstack(neigh_ind)
186+
else:
187+
results = np.vstack(chunked_results)
188+
189+
if not query_is_train:
190+
return results
191+
192+
# If the query data is the same as the indexed data, we need to
193+
# ignore the first nearest neighbor of every sample (the sample itself)
194+
if return_distance:
195+
neigh_dist, neigh_ind = results
196+
else:
197+
neigh_ind = results
198+
199+
n_queries = X.shape[0] if X is not None else self._fit_X.shape[0]
200+
sample_range = np.arange(n_queries)[:, None]
201+
sample_mask = neigh_ind != sample_range
202+
203+
# Corner case: When the number of duplicates are more
204+
# than the number of neighbors, the first NN will not
205+
# be the sample, but a duplicate.
206+
# In that case mask the first duplicate.
207+
dup_gr_nbrs = np.all(sample_mask, axis=1)
208+
sample_mask[:, 0][dup_gr_nbrs] = False
209+
210+
neigh_ind = np.reshape(neigh_ind[sample_mask], (n_queries, n_neighbors - 1))
211+
212+
if return_distance:
213+
neigh_dist = np.reshape(neigh_dist[sample_mask], (n_queries, n_neighbors - 1))
214+
return neigh_dist, neigh_ind
215+
return neigh_ind
216+
217+
def _compute_weights(self, distances, weights_param):
218+
"""Compute weights based on distances and weights parameter."""
219+
if weights_param in (None, "uniform"):
220+
return None
221+
elif weights_param == "distance":
222+
# if user attempts to classify a point that was zero distance from one
223+
# or more training points, those training points are weighted as 1.0
224+
# and the other points as 0.0
225+
if distances.dtype is np.dtype(object):
226+
for i, dist_row in enumerate(distances):
227+
zero_mask = dist_row == 0.0
228+
if np.any(zero_mask):
229+
distances[i] = zero_mask.astype(np.float64)
230+
else:
231+
distances[i] = 1.0 / dist_row
232+
else:
233+
zero_mask = distances == 0.0
234+
if np.any(zero_mask):
235+
# Handle the case where some distances are zero
236+
weights = np.where(zero_mask, 1.0, 1.0 / np.where(distances == 0.0, 1.0, distances))
237+
# Normalize so that zero distance points get all the weight
238+
for i in range(weights.shape[0]):
239+
if np.any(zero_mask[i]):
240+
weights[i] = zero_mask[i].astype(np.float64)
241+
return weights
242+
else:
243+
return 1.0 / distances
244+
return distances
245+
elif callable(weights_param):
246+
return weights_param(distances)
247+
else:
248+
raise ValueError(
249+
"weights not recognized: should be 'uniform', "
250+
"'distance', or a callable function"
251+
)
252+
143253
def _validate_feature_count(self, X, method_name=""):
144254
n_features = getattr(self, "n_features_in_", None)
145255
shape = getattr(X, "shape", None)
@@ -168,8 +278,6 @@ def _validate_kneighbors_bounds(self, n_neighbors, query_is_train, X):
168278

169279
def _process_classification_targets(self, y):
170280
"""Process classification targets and set class-related attributes."""
171-
import numpy as np
172-
173281
# Handle shape processing
174282
shape = getattr(y, "shape", None)
175283
self._shape = shape if shape is not None else y.shape

0 commit comments

Comments
 (0)