Skip to content

Commit d769d14

Browse files
icfaustdavid-cortes-intelethanglaser
authored
[enhancement] Enable Array API in ensemble algos (#2201)
* add finiteness_checker pybind11 bindings * added finiteness checker * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Update finiteness_checker.cpp * Rename finiteness_checker.cpp to finiteness_checker.cpp * Update finiteness_checker.cpp * add next step * follow conventions * make xtable explicit * remove comment * Update validation.py * Update __init__.py * Update validation.py * Update __init__.py * Update __init__.py * Update validation.py * Update _data_conversion.py * Update _data_conversion.py * Update policy_common.cpp * Update policy_common.cpp * Update _policy.py * Update policy_common.cpp * Rename finiteness_checker.cpp to finiteness_checker.cpp * Create finiteness_checker.py * Update validation.py * Update __init__.py * attempt at fixing circular imports again * fix isort * remove __init__ changes * last move * Update policy_common.cpp * Update policy_common.cpp * Update policy_common.cpp * Update policy_common.cpp * Update validation.py * add testing * isort * attempt to fix module error * add fptype * fix typo * Update validation.py * remove sua_ifcae from to_table * isort and black * Update test_memory_usage.py * format * Update _data_conversion.py * Update _data_conversion.py * Update test_validation.py * remove unnecessary code * make reviewer changes * make dtype check change * add sparse testing * try again * try again * try again * temporary commit * first attempt * missing change? * modify DummyEstimator for testing * generalize DummyEstimator * switch test * further testing changes * add initial validate_data test, will be refactored * fixes for CI * Update validation.py * Update validation.py * Update test_memory_usage.py * Update base.py * Update base.py * improve tests * fix logic * fix logic * fix logic again * rename file * Revert "rename file" This reverts commit 8d47744. * remove duplication * fix imports * Rename test_finite.py to test_validation.py * Revert "Rename test_finite.py to test_validation.py" This reverts commit ee799f6. * updates * Update validation.py * fixes for some test failures * fix text * fixes for some failures * make consistent * fix bad logic * fix in string * attempt tp see if dataframe conversion is causing the issue * fix iter problem * fix testing issues * formatting * revert change * fixes for pandas * there is a slowdown with pandas that needs to be solved * swap to transpose for speed * more clarity * add _check_sample_weight * add more testing' * rename * remove unnecessary imports * fix test slowness * focus get_dataframes_and_queues * put config_context around * Update test_validation.py * Update base.py * Update test_validation.py * generalize regex * add fixes for sklearn 1.0 and input_name * fixes for test failures * Update validation.py * Update test_validation.py * Update validation.py * formattintg * make suggested changes * follow changes made in #2126 * fix future device problem * Update validation.py * finished movement * fix first error * next mistake * remove bad dtypes check * updates * remove array * solve onedal issues * solve onedal issues * updates * updates * further fixes * further fixes * fix issues to see how it goes * oops * updates * add finite checks for predict and predict_proba * updates * centralize * further reduce code * updates * remove sklearn conformance from onedal estimator init signature * remove more * fixes * change away from sklearn `max_samples` in onedal estimators * fix error * move things * Update forest.py * Update forest.py * Update _forest.py * further fixes to onedal side * further fixes to onedal side * simplifications * attempt at classifiers support * further changes * fix error on onedal side * fix error on onedal side * fixes * fix pandas related error * remove unnecessary code: * try to fix issues related to regressor data * fixes necessary for CI * fixes for formatting * updates * push * push * fixes * remove upon request * remove upon request * further fixes * try to fix classifiers for array API inputs * try again * Update array_api.rst * Update sklearnex/ensemble/_forest.py Co-authored-by: david-cortes-intel <[email protected]> * Update _forest.py * Update _forest.py * Update _forest.py * Update _forest.py * Update _forest.py * Update _forest.py * Update _forest.py * Update _forest.py * Update _forest.py * Update _forest.py * Update forest.py * Update _forest.py * Update sklearnex/ensemble/_forest.py Co-authored-by: ethanglaser <[email protected]> * Update _forest.py * Update _forest.py * Update array_api.rst * Update array_api.rst * remove sparse checks for sample_weight * Update deselected_tests.yaml * Update deselected_tests.yaml --------- Co-authored-by: david-cortes-intel <[email protected]> Co-authored-by: ethanglaser <[email protected]>
1 parent 803c7ad commit d769d14

File tree

6 files changed

+623
-1145
lines changed

6 files changed

+623
-1145
lines changed

deselected_tests.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -528,12 +528,8 @@ gpu:
528528
- tests/test_common.py::test_estimators[ExtraTreesClassifier()-check_class_weight_classifiers]
529529
- tests/test_common.py::test_estimators[ExtraTreesRegressor()-check_sample_weights_invariance(kind=zeros)]
530530
- tests/test_common.py::test_estimators[RandomForestRegressor()-check_regressor_data_not_an_array]
531-
532-
# GPU implementation of Extra Trees doesn't support sample_weights
533-
# comparisons to GPU with sample weights will use different algorithms
534-
- tests/test_common.py::test_estimators[ExtraTreesClassifier()-check_sample_weights_invariance(kind=ones)]
535531
- tests/test_common.py::test_estimators[ExtraTreesClassifier()-check_sample_weights_invariance(kind=zeros)]
536-
- tests/test_common.py::test_estimators[ExtraTreesRegressor()-check_sample_weights_invariance(kind=ones)]
532+
- ensemble/tests/test_forest.py::test_min_weight_fraction_leaf
537533

538534
# RuntimeError: Device support is not implemented, failing as result of fallback to cpu false
539535
- svm/tests/test_svm.py::test_unfitted

doc/sources/array_api.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ The following patched classes have support for array API inputs:
9797
- :obj:`sklearn.covariance.EmpiricalCovariance`
9898
- :obj:`sklearnex.covariance.IncrementalEmpiricalCovariance`
9999
- :obj:`sklearn.decomposition.PCA`
100+
- :obj:`sklearn.ensemble.ExtraTreesClassifier`
101+
- :obj:`sklearn.ensemble.ExtraTreesRegressor`
102+
- :obj:`sklearn.ensemble.RandomForestClassifier`
103+
- :obj:`sklearn.ensemble.RandomForestRegressor`
100104
- :obj:`sklearn.linear_model.LinearRegression`
101105
- :obj:`sklearn.linear_model.Ridge`
102106
- :obj:`sklearnex.linear_model.IncrementalLinearRegression`
@@ -108,6 +112,11 @@ The following patched classes have support for array API inputs:
108112
enabled in |sklearn|, when passing these classes as inputs, data will be transferred to host and then back to
109113
device instead of being used directly.
110114

115+
Result attributes of |sklearnex| classes which contain |sklearn| or |sklearnex| classes may not themselves be
116+
array API compliant. For example, ensemble algorithms contain decision tree estimators result objects which
117+
do not comply with the array API standard.
118+
119+
111120

112121
Example usage
113122
=============

0 commit comments

Comments
 (0)