Skip to content

Commit 056fdea

Browse files
FIX: Return correct shape of probabilities from GPU logreg (#2645)
* fix shapes of prediction outputs * more comprehensive testing for GPU cases
1 parent d9e5197 commit 056fdea

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

onedal/linear_model/logistic_regression.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ def _predict_proba(self, X):
217217
result = result = self._infer(X)
218218
_, xp, _ = _get_sycl_namespace(X)
219219
y = from_table(result.probabilities, like=X)
220+
y = xp.reshape(y, -1)
220221
return xp.stack([1 - y, y], axis=1)
221222

222223
def _predict_log_proba(self, X):

sklearnex/linear_model/tests/test_logreg.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,25 @@ def logistic_model_function(predicted_probabilities, coefs):
196196
model_sklearnex.predict_proba(X), model_sklearnex.coef_
197197
)
198198
assert fn_sklearnex <= fn_sklearn
199+
200+
201+
@pytest.mark.parametrize(
202+
"dataframe,queue", get_dataframes_and_queues(device_filter_="gpu")
203+
)
204+
def test_gpu_logreg_prediction_shapes(dataframe, queue):
205+
if not queue or not queue.sycl_device.is_gpu:
206+
pytest.skip("Test for GPU-only code branch")
207+
from sklearnex.linear_model import LogisticRegression
208+
209+
X, y = make_classification(random_state=123)
210+
X = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
211+
y = _convert_to_dataframe(y, sycl_queue=queue, target_df=dataframe)
212+
213+
model = LogisticRegression(solver="newton-cg").fit(X, y)
214+
pred = model.predict(X)
215+
pred_proba = model.predict_proba(X)
216+
pred_log_proba = model.predict_log_proba(X)
217+
218+
np.testing.assert_array_equal(pred.shape, (X.shape[0],))
219+
np.testing.assert_array_equal(pred_proba.shape, (X.shape[0], 2))
220+
np.testing.assert_array_equal(pred_log_proba.shape, (X.shape[0], 2))

0 commit comments

Comments
 (0)