@@ -46,12 +46,13 @@ def test_sklearnex_fit_on_gold_data(dataframe, queue, fit_intercept, macro_block
4646 inclin .fit (X_df , y_df )
4747
4848 y_pred = inclin .predict (X_df )
49+ np_y_pred = _as_numpy (y_pred )
4950
50- tol = 2e-6 if y_pred . dtype == np .float32 else 1e-7
51+ tol = 2e-6 if dtype == np .float32 else 1e-7
5152 assert_allclose (inclin .coef_ , [1 ], atol = tol )
5253 if fit_intercept :
5354 assert_allclose (inclin .intercept_ , [0 ], atol = tol )
54- assert_allclose (_as_numpy ( y_pred ) , y , atol = tol )
55+ assert_allclose (np_y_pred , y , atol = tol )
5556
5657
5758@pytest .mark .parametrize ("dataframe,queue" , get_dataframes_and_queues ())
@@ -84,14 +85,15 @@ def test_sklearnex_partial_fit_on_gold_data(
8485
8586 X_df = _convert_to_dataframe (X , sycl_queue = queue , target_df = dataframe )
8687 y_pred = inclin .predict (X_df )
88+ np_y_pred = _as_numpy (y_pred )
8789
8890 assert inclin .n_features_in_ == 1
89- tol = 2e-6 if y_pred . dtype == np .float32 else 1e-7
91+ tol = 2e-6 if dtype == np .float32 else 1e-7
9092 assert_allclose (inclin .coef_ , [[1 ]], atol = tol )
9193 if fit_intercept :
9294 assert_allclose (inclin .intercept_ , 3 , atol = tol )
9395
94- assert_allclose (_as_numpy ( y_pred ) , y , atol = tol )
96+ assert_allclose (np_y_pred , y , atol = tol )
9597
9698
9799@pytest .mark .parametrize ("dataframe,queue" , get_dataframes_and_queues ())
@@ -124,14 +126,15 @@ def test_sklearnex_partial_fit_multitarget_on_gold_data(
124126
125127 X_df = _convert_to_dataframe (X , sycl_queue = queue , target_df = dataframe )
126128 y_pred = inclin .predict (X_df )
129+ np_y_pred = _as_numpy (y_pred )
127130
128131 assert inclin .n_features_in_ == 2
129- tol = 7e-6 if y_pred . dtype == np .float32 else 1e-7
132+ tol = 7e-6 if dtype == np .float32 else 1e-7
130133 assert_allclose (inclin .coef_ , [1.0 , 2.0 ], atol = tol )
131134 if fit_intercept :
132135 assert_allclose (inclin .intercept_ , 3.0 , atol = tol )
133136
134- assert_allclose (_as_numpy ( y_pred ) , y , atol = tol )
137+ assert_allclose (np_y_pred , y , atol = tol )
135138
136139
137140@pytest .mark .parametrize ("dataframe,queue" , get_dataframes_and_queues ())
0 commit comments