Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions skglm/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from skglm.penalties import (L1, WeightedL1, L1_plus_L2, L2, WeightedGroupL2,
MCPenalty, WeightedMCPenalty, IndicatorBox, L2_1)
from skglm.utils.data import grp_converter
from sklearn.utils.validation import validate_data


def _glm_fit(X, y, model, datafit, penalty, solver):
Expand All @@ -51,8 +52,8 @@ def _glm_fit(X, y, model, datafit, penalty, solver):
accept_sparse='csc', copy=fit_intercept)
check_y_params = dict(ensure_2d=False, order='F')

X, y = model._validate_data(
X, y, validate_separately=(check_X_params, check_y_params))
X, y = validate_data(
model, X, y, validate_separately=(check_X_params, check_y_params))
X = check_array(X, 'csc', dtype=[np.float64, np.float32],
order='F', copy=False, accept_large_sparse=False)
y = check_array(y, 'csc', dtype=X.dtype.type, order='F', copy=False,
Expand Down Expand Up @@ -1498,7 +1499,7 @@ def fit(self, X, Y):
accept_sparse='csc',
copy=self.copy_X and self.fit_intercept)
check_Y_params = dict(ensure_2d=False, order='F')
X, Y = self._validate_data(X, Y, validate_separately=(check_X_params,
X, Y = validate_data(self, X, Y, validate_separately=(check_X_params,
check_Y_params))
Y = Y.astype(X.dtype)

Expand Down