-
Notifications
You must be signed in to change notification settings - Fork 183
[Enhancement] enable array API support in EmpiricialCovariance
and IncrementalEmpiricalCovariance
#2207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
EmpericialCovariance
EmpiricialCovariance
EmpiricialCovariance
EmpiricialCovariance
and IncrementalEmpiricalCovariance
/intelci: run |
/intelci: run |
This PR is now dependent on the developments in #2096 (SPMD testing requires array_api bypassing on oneDAL offloading) |
/intelci: run |
/intelci: run |
/intelci: run |
/intelci: run |
/intelci: run |
/intelci: run |
with config_context(array_api_dispatch=True): | ||
est.fit(X_df) | ||
|
||
with pytest.raises(TypeError, match="Multiple namespaces for array inputs: .*"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this work correctly if put under a config context with array_api_dispatch=True
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used sklearn's Ridge + numpy and torch as an example of what to expect: (https://scikit-learn.org/stable/modules/array_api.html#input-and-output-array-type-handling)
When attempting to use any non-numpy input after fitting with array_api_dispatch=True
will lead to some sort of error associated with the fitted framework, as get_namespace
and validate_data
will default force data to numpy (https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/_array_api.py#L394) and then comparing numpy to the array api framework will fail.
If we were to use array_api_dispatch=True
throughout, it will error at this point in an external package if get_namespace is used: https://github.com/data-apis/array-api-compat/blob/main/array_api_compat/common/_helpers.py#L665
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't the idea with array API support in stock sklearn to make it work under such types of situations? (e.g. fitting on a torch array, then predicting on a different kind of array).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually no, I see that sklearn throws the same error.
/intelci: run |
Private CI failure comes from an infrastructure timeout. |
Description
Enables array API zero copy dispatching for
EmpiricalCovariance
andIncrementalEmpiricalCovariance
, this required the following changes:log_likelihood
andpinvh
functions (as they are unavailable in sklearn) tosklearnex.utils._array_api
sklearnex.preview.covariance.EmpiricalCovariance.mahalanobis.__doc__
which was a bugmahalanobis
,score
, anderror_norm
methods. Likely to change due to the nature of how we support dpnp and dpctl.check_is_fitted
(which is even missing from sklearn)get_namespace
and swapped namespace support away from numpyget_precision
function to use internalpinvh
. This is important for array API conformance, where attributes will no longer only be numpy arraysvalidate_params
before fit's dispatch (will be set as a design rule in a follow-up PRpairwise_distances
kwargs and support_input_format, which do not interact well. A follow up development ticket for fixing this issue will be madereturn_type_constructor
in IncrementalEmpiricalCovariancePR should start as a draft, then move to ready for review state after CI is passed and all applicable checkboxes are closed.
This approach ensures that reviewers don't spend extra time asking for regular requirements.
You can remove a checkbox as not applicable only if it doesn't relate to this PR in any way.
For example, PR with docs update doesn't require checkboxes for performance while PR with any change in actual code should have checkboxes and justify how this code change is expected to affect performance (or justification should be self-evident).
Checklist to comply with before moving PR from draft:
PR completeness and readability
Testing
Performance