Skip to content

Commit b81f810

Browse files
authored
OneDAL python ifaces for kNN (#743)
* c++ code built * c++ and python built and method called * save work; python works, problem with indices * fix problem with indices * save work, feel free to comment * fix patching * light scope of tests were added * add gpu support; dispatching for sklearnex * apply part of comments * save work, comments applied, patching for classification and unsupervised has been added * fix for patching condition * fix multioutput, y for unsupervised fit is none * apply comment int32->64 * fix unsupervised patching, save work * apply comment * revert regression changes, codefactor fixes * save work * clear comments in test, need to add fresh tests * fix unsupervised tests, warnings import * apply @vlad-nazarov's comment * remove redundant astype * tests * fixes for pep8 * fix codefactor * fix pep8 hits concerning conditions * applied comments * fix condition * update copyrights * applied comment * fix private ci
1 parent 77f08a5 commit b81f810

File tree

17 files changed

+1709
-31
lines changed

17 files changed

+1709
-31
lines changed

onedal/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@
3232
import onedal._onedal_py_host as _backend
3333
_is_dpc_backend = False
3434

35-
__all__ = ['primitives', 'svm']
35+
__all__ = ['neighbors', 'primitives', 'svm']

onedal/common/_estimator_checks.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#===============================================================================
2+
# Copyright 2022 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#===============================================================================
16+
17+
18+
def _check_is_fitted(estimator, attributes=None, *, msg=None):
19+
if msg is None:
20+
msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
21+
"appropriate arguments before using this estimator.")
22+
23+
if not hasattr(estimator, 'fit'):
24+
raise TypeError("%s is not an estimator instance." % (estimator))
25+
26+
if attributes is not None:
27+
if not isinstance(attributes, (list, tuple)):
28+
attributes = [attributes]
29+
attrs = all([hasattr(estimator, attr) for attr in attributes])
30+
else:
31+
attrs = [v for v in vars(estimator)
32+
if v.endswith("_") and not v.startswith("__")]
33+
34+
if not attrs:
35+
raise AttributeError(msg % {'name': type(estimator).__name__})
36+
37+
38+
def _is_classifier(estimator):
39+
return getattr(estimator, "_estimator_type", None) == "classifier"
40+
41+
42+
def _is_regressor(estimator):
43+
return getattr(estimator, "_estimator_type", None) == "regressor"

onedal/dal.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ ONEDAL_PY_INIT_MODULE(sigmoid_kernel);
3434

3535
/* algorithms */
3636
ONEDAL_PY_INIT_MODULE(svm);
37+
ONEDAL_PY_INIT_MODULE(neighbors);
3738

3839
#ifdef ONEDAL_DATA_PARALLEL
3940
PYBIND11_MODULE(_onedal_py_dpc, m) {
@@ -49,6 +50,7 @@ PYBIND11_MODULE(_onedal_py_host, m) {
4950
init_sigmoid_kernel(m);
5051

5152
init_svm(m);
53+
init_neighbors(m);
5254
}
5355

5456
} // namespace oneapi::dal::python

onedal/datatypes/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@
1919
_validate_targets,
2020
_check_X_y,
2121
_check_array,
22-
_check_is_fitted,
2322
_check_classification_targets,
2423
_type_of_target,
2524
_is_integral_float,
2625
_is_multilabel,
2726
_check_n_features,
28-
_num_features
27+
_num_features,
28+
_num_samples
2929
)
3030

3131
__all__ = ['_column_or_1d', '_validate_targets', '_check_X_y',
32-
'_check_array', '_check_is_fitted',
33-
'_check_classification_targets', '_type_of_target', '_is_integral_float',
34-
'_is_multilabel', '_check_n_features', '_num_features']
32+
'_check_array', '_check_classification_targets',
33+
'_type_of_target', '_is_integral_float',
34+
'_is_multilabel', '_check_n_features', '_num_features',
35+
'_num_samples']

onedal/datatypes/numpy_helpers.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
_FUNCT(NPY_INT32); \
3434
break; \
3535
} \
36+
case dal::data_type::int64: { \
37+
_FUNCT(NPY_INT64); \
38+
break; \
39+
} \
3640
default: _EXCEPTION; \
3741
};
3842

@@ -50,6 +54,10 @@
5054
_FUNCT(NPY_INT32, std::int32_t); \
5155
break; \
5256
} \
57+
case dal::data_type::int64: { \
58+
_FUNCT(NPY_INT64, std::int64_t); \
59+
break; \
60+
} \
5361
default: _EXCEPTION; \
5462
};
5563

onedal/datatypes/validation.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from scipy.sparse import issparse, dok_matrix, lil_matrix
2121
from scipy.sparse.base import spmatrix
2222
from collections.abc import Sequence
23+
from numbers import Integral
2324

2425

2526
class DataConversionWarning(UserWarning):
@@ -141,26 +142,6 @@ def _check_X_y(X, y, dtype="numeric", accept_sparse=False, order=None, copy=Fals
141142
return X, y
142143

143144

144-
def _check_is_fitted(estimator, attributes=None, *, msg=None):
145-
if msg is None:
146-
msg = ("This %(name)s instance is not fitted yet. Call 'fit' with "
147-
"appropriate arguments before using this estimator.")
148-
149-
if not hasattr(estimator, 'fit'):
150-
raise TypeError("%s is not an estimator instance." % (estimator))
151-
152-
if attributes is not None:
153-
if not isinstance(attributes, (list, tuple)):
154-
attributes = [attributes]
155-
attrs = all([hasattr(estimator, attr) for attr in attributes])
156-
else:
157-
attrs = [v for v in vars(estimator)
158-
if v.endswith("_") and not v.startswith("__")]
159-
160-
if not attrs:
161-
raise AttributeError(msg % {'name': type(estimator).__name__})
162-
163-
164145
def _check_classification_targets(y):
165146
y_type = _type_of_target(y)
166147
if y_type not in ['binary', 'multiclass', 'multiclass-multioutput',
@@ -329,3 +310,31 @@ def _num_features(X):
329310
return len(first_sample)
330311
except Exception as err:
331312
raise TypeError(message) from err
313+
314+
315+
def _num_samples(x):
316+
message = "Expected sequence or array-like, got %s" % type(x)
317+
if hasattr(x, "fit") and callable(x.fit):
318+
# Don't get num_samples from an ensembles length!
319+
raise TypeError(message)
320+
321+
if not hasattr(x, "__len__") and not hasattr(x, "shape"):
322+
if hasattr(x, "__array__"):
323+
x = np.asarray(x)
324+
else:
325+
raise TypeError(message)
326+
327+
if hasattr(x, "shape") and x.shape is not None:
328+
if len(x.shape) == 0:
329+
raise TypeError(
330+
"Singleton array %r cannot be considered a valid collection." % x
331+
)
332+
# Check that shape is returning an integer or default to len
333+
# Dask dataframes may not return numeric shape[0] value
334+
if isinstance(x.shape[0], Integral):
335+
return x.shape[0]
336+
337+
try:
338+
return len(x)
339+
except TypeError as type_error:
340+
raise TypeError(message) from type_error

onedal/neighbors/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#===============================================================================
2+
# Copyright 2022 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#===============================================================================
16+
17+
from .neighbors import KNeighborsClassifier, NearestNeighbors
18+
19+
__all__ = ['KNeighborsClassifier', 'NearestNeighbors']

0 commit comments

Comments
 (0)