Skip to content

Commit 133c6a2

Browse files
fix compared function values (#2732)
1 parent 8353d67 commit 133c6a2

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

daal4py/sklearn/linear_model/tests/test_enet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def fn_lasso(model, X, y, lambda_):
2929
resid = y - model.predict(X)
3030
fn_ssq = resid.reshape(-1) @ resid.reshape(-1)
3131
fn_l1 = np.abs(model.coef_).sum()
32-
return fn_ssq + lambda_ * fn_l1
32+
return (1 / (2 * X.shape[0])) * fn_ssq + lambda_ * fn_l1
3333

3434

3535
@pytest.mark.parametrize("nrows", [10, 20])
@@ -62,7 +62,7 @@ def test_enet_is_correct(nrows, ncols, n_targets, fit_intercept, positive, l1_ra
6262
# Note: lasso is not guaranteed to have a unique global optimum.
6363
# If the coefficients do not match, this makes another check on
6464
# the optimality of the function values instead. It checks that
65-
# the result from daal4py is no worse than 2% off scikit-learn's.
65+
# the result from daal4py is no worse than scikit-learn's.
6666

6767
tol = 1e-6 if n_targets == 1 else 1e-5
6868
try:
@@ -72,7 +72,7 @@ def test_enet_is_correct(nrows, ncols, n_targets, fit_intercept, positive, l1_ra
7272
raise e
7373
fn_d4p = fn_lasso(model_d4p, X, y, model_d4p.alpha)
7474
fn_skl = fn_lasso(model_skl, X, y, model_skl.alpha)
75-
assert fn_d4p <= fn_skl * 1.02
75+
assert fn_d4p <= fn_skl
7676

7777
if fit_intercept:
7878
np.testing.assert_allclose(
@@ -120,7 +120,7 @@ def test_lasso_is_correct(nrows, ncols, n_targets, fit_intercept, positive, alph
120120
except AssertionError as e:
121121
fn_d4p = fn_lasso(model_d4p, X, y, model_d4p.alpha)
122122
fn_skl = fn_lasso(model_skl, X, y, model_skl.alpha)
123-
assert fn_d4p <= fn_skl * 1.02
123+
assert fn_d4p <= fn_skl
124124

125125
if positive:
126126
assert np.all(model_d4p.coef_ >= 0)

0 commit comments

Comments
 (0)