Skip to content

Commit e799ec4

Browse files
authored
[fix] Add forgotten _check_array to IncrementalBasicStatistics.partial_fit (#2022)
1 parent 4fd4568 commit e799ec4

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

onedal/basic_statistics/incremental_basic_statistics.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from daal4py.sklearn._utils import get_dtype
2020

2121
from ..datatypes import _convert_to_supported, from_table, to_table
22+
from ..utils import _check_array
2223
from .basic_statistics import BaseBasicStatistics
2324

2425

@@ -96,6 +97,17 @@ def partial_fit(self, X, weights=None, queue=None):
9697
policy = self._get_policy(queue, X)
9798
X, weights = _convert_to_supported(policy, X, weights)
9899

100+
X = _check_array(
101+
X, dtype=[np.float64, np.float32], ensure_2d=False, force_all_finite=False
102+
)
103+
if weights is not None:
104+
weights = _check_array(
105+
weights,
106+
dtype=[np.float64, np.float32],
107+
ensure_2d=False,
108+
force_all_finite=False,
109+
)
110+
99111
if not hasattr(self, "_onedal_params"):
100112
dtype = get_dtype(X)
101113
self._onedal_params = self._get_onedal_params(False, dtype=dtype)

0 commit comments

Comments
 (0)