Skip to content

Commit 9d0a05b

Browse files
Alexsandrussnapetrov
authored andcommitted
Fix for n_iter_ SVM attribute and oneDAL kNN classifier result option (#1071)
* Add workaround for n_iter_ SVM attribute * Pytest onedal dir * Exclude test_policy * fix result_option fail for kneighbors method in knn classifier
1 parent 81559dd commit 9d0a05b

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

.ci/pipeline/ci.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,15 @@ jobs:
141141
pip install -r requirements-test.txt
142142
conda install -y /usr/share/miniconda/envs/CB/conda-bld/linux-64/daal4py*.tar.bz2
143143
python setup_sklearnex.py install --single-version-externally-managed --record=record1.txt
144+
displayName: install sklearnex
145+
- script: |
146+
. /usr/share/miniconda/etc/profile.d/conda.sh
147+
conda activate CB
144148
cd ..
145149
pytest --pyargs s/sklearnex/tests/
146-
displayName: install sklearnex
150+
# TODO: remove ignore of test_policy when device names are fixed
151+
pytest --pyargs s/onedal/ --ignore=s/onedal/common/tests/test_policy.py
152+
displayName: test sklearnex
147153
- script: |
148154
. /usr/share/miniconda/etc/profile.d/conda.sh
149155
conda activate CB

onedal/neighbors/neighbors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,8 @@ def _onedal_predict(self, model, X, params, queue):
425425
model = self._onedal_model
426426
else:
427427
model = self._create_model(_backend.neighbors.classification)
428+
if 'responses' not in params['result_option']:
429+
params['result_option'] += '|responses'
428430
result = _backend.neighbors.classification.infer(
429431
policy, params, model, to_table(X))
430432

onedal/svm/svm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,9 @@ def _get_sample_weight(self, X, y, sample_weight):
165165

166166
def _get_onedal_params(self, data):
167167
max_iter = 10000 if self.max_iter == -1 else self.max_iter
168+
# TODO: remove this workaround
169+
# when oneDAL SVM starts support of 'n_iterations' result
170+
self.n_iter_ = 1 if max_iter < 1 else max_iter
168171
class_count = 0 if self.classes_ is None else len(self.classes_)
169172
return {
170173
'fptype': 'float' if data.dtype is np.dtype('float32') else 'double',

0 commit comments

Comments
 (0)