Skip to content

Commit e4a3750

Browse files
authored
[bug] fix to_table for a non-array input when a low-precision queue is used (#2271)
* Update data_conversion.cpp * clang-formatting * add tests * fix errors * fix tests * change the logic * oops * try again * darn it * fix regex * swap due to sparse failure * wrong reference * avoid Nones earlier * Update test_data.py * Update basic_statistics.py
1 parent 83617c6 commit e4a3750

File tree

3 files changed

+51
-12
lines changed

3 files changed

+51
-12
lines changed

onedal/basic_statistics/basic_statistics.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,7 @@ def fit(self, data, sample_weight=None, queue=None):
8383

8484
is_single_dim = data.ndim == 1
8585

86-
data_table = to_table(data, queue=queue)
87-
weights_table = (
88-
to_table(sample_weight, queue=queue)
89-
if sample_weight is not None
90-
else to_table(None)
91-
)
86+
data_table, weights_table = to_table(data, sample_weight, queue=queue)
9287

9388
dtype = data_table.dtype
9489
raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr)

onedal/datatypes/data_conversion.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,16 @@ inline csr_table_t convert_to_csr_impl(PyObject *py_data,
154154

155155
dal::table convert_to_table(py::object inp_obj, py::object queue) {
156156
dal::table res;
157+
158+
PyObject *obj = inp_obj.ptr();
159+
160+
if (obj == nullptr || obj == Py_None) {
161+
return res;
162+
}
163+
157164
#ifdef ONEDAL_DATA_PARALLEL
158-
if (!queue.is(py::none()) && !queue.attr("sycl_device").attr("has_aspect_fp64").cast<bool>()) {
165+
if (!queue.is(py::none()) && !queue.attr("sycl_device").attr("has_aspect_fp64").cast<bool>() &&
166+
hasattr(inp_obj, "dtype")) {
159167
// If the queue exists, doesn't have the fp64 aspect, and the data is float64
160168
// then cast it to float32
161169
int type = reinterpret_cast<PyArray_Descr *>(inp_obj.attr("dtype").ptr())->type_num;
@@ -173,11 +181,6 @@ dal::table convert_to_table(py::object inp_obj, py::object queue) {
173181
}
174182
#endif // ONEDAL_DATA_PARALLEL
175183

176-
PyObject *obj = inp_obj.ptr();
177-
178-
if (obj == nullptr || obj == Py_None) {
179-
return res;
180-
}
181184
if (is_array(obj)) {
182185
PyArrayObject *ary = reinterpret_cast<PyArrayObject *>(obj);
183186

onedal/datatypes/tests/test_data.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,44 @@ class DummySyclDevice:
442442
assert X_table.dtype == np.float32
443443
if dtype == np.float32 and not sparse:
444444
assert_allclose(X, from_table(X_table))
445+
446+
447+
@pytest.mark.parametrize("X", [None, 5, "test", True, [], np.pi, lambda: None])
448+
@pytest.mark.parametrize("queue", get_queues())
449+
def test_non_array(X, queue):
450+
# Verify that to and from table doesn't raise errors
451+
# no guarantee is made about type or content
452+
err_str = ""
453+
454+
if np.isscalar(X):
455+
if np.atleast_2d(X).dtype not in [np.float64, np.float32, np.int64, np.int32]:
456+
err_str = "Found unsupported array type"
457+
elif not (X is None or isinstance(X, np.ndarray)):
458+
err_str = r"\[convert_to_table\] Not available input format for convert Python object to onedal table."
459+
460+
if err_str:
461+
with pytest.raises(ValueError, match=err_str):
462+
to_table(X)
463+
else:
464+
X_table = to_table(X, queue=queue)
465+
from_table(X_table)
466+
467+
468+
@pytest.mark.skipif(
469+
not _is_dpc_backend, reason="Requires DPC backend for dtype conversion"
470+
)
471+
@pytest.mark.parametrize("X", [None, 5, "test", True, [], np.pi, lambda: None])
472+
def test_low_precision_non_array(X):
473+
# Use a dummy queue as fp32 hardware is not in public testing
474+
475+
class DummySyclQueue:
476+
"""This class is designed to act like dpctl.SyclQueue
477+
to force dtype conversion"""
478+
479+
class DummySyclDevice:
480+
has_aspect_fp64 = False
481+
482+
sycl_device = DummySyclDevice()
483+
484+
queue = DummySyclQueue()
485+
test_non_array(X, queue)

0 commit comments

Comments
 (0)